diff --git a/.gitattributes b/.gitattributes index 67b3a75fef9ccbc3c54a6b447a6a0894169c1b70..2e68d1b7d07edb82cfcd070d9e5a2dc3178ef9a7 100644 --- a/.gitattributes +++ b/.gitattributes @@ -54,3 +54,10 @@ phivenv/Lib/site-packages/pip/_vendor/distlib/t64.exe filter=lfs diff=lfs merge= phivenv/Lib/site-packages/pip/_vendor/distlib/w64-arm.exe filter=lfs diff=lfs merge=lfs -text phivenv/Lib/site-packages/pip/_vendor/idna/__pycache__/uts46data.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text phivenv/Lib/site-packages/pip/_vendor/pyparsing/__pycache__/core.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/pip/_vendor/pkg_resources/__pycache__/__init__.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/pip/_vendor/rich/__pycache__/_emoji_codes.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/pkg_resources/_vendor/__pycache__/pyparsing.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/pkg_resources/__pycache__/__init__.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/regex/_regex.cp39-win_amd64.pyd filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/regex/__pycache__/test_regex.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/regex/__pycache__/_regex_core.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text diff --git a/phivenv/Lib/site-packages/pip/_vendor/pkg_resources/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/pip/_vendor/pkg_resources/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..504cfda7fc8344e91695b376265c42140219835b --- /dev/null +++ b/phivenv/Lib/site-packages/pip/_vendor/pkg_resources/__pycache__/__init__.cpython-39.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a9ae9e3e39533b703fa0fb49576d02a073be55a6fbe3f9d9a38cbeb9ed03e116 +size 100308 diff --git a/phivenv/Lib/site-packages/pip/_vendor/rich/__pycache__/_emoji_codes.cpython-39.pyc b/phivenv/Lib/site-packages/pip/_vendor/rich/__pycache__/_emoji_codes.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9081812ac0f41891e67fa7ce5c686b4050fa4871 --- /dev/null +++ b/phivenv/Lib/site-packages/pip/_vendor/rich/__pycache__/_emoji_codes.cpython-39.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:88b150085f0eb6dcd1c70d632b16ccf923b66e1700800a4756d06b3726b91fcf +size 132673 diff --git a/phivenv/Lib/site-packages/pkg_resources/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/pkg_resources/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3bc3f4df2d424574a75f4578ab31f675b68119f --- /dev/null +++ b/phivenv/Lib/site-packages/pkg_resources/__pycache__/__init__.cpython-39.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:92b5449a62f76826fcde2e62b85c16f953a7ccfff8847bc4854a098b5a954dae +size 100411 diff --git a/phivenv/Lib/site-packages/pkg_resources/_vendor/__pycache__/pyparsing.cpython-39.pyc b/phivenv/Lib/site-packages/pkg_resources/_vendor/__pycache__/pyparsing.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82e0eaac81e7cdcc596fc50d0059836b7a59d645 --- /dev/null +++ b/phivenv/Lib/site-packages/pkg_resources/_vendor/__pycache__/pyparsing.cpython-39.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3f26434c485b7881d3ef563e57c88a171319a39cfcc3bf348cbe5bfd0d2a9887 +size 201319 diff --git a/phivenv/Lib/site-packages/regex/__pycache__/_regex_core.cpython-39.pyc b/phivenv/Lib/site-packages/regex/__pycache__/_regex_core.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b64dd06ccd2205f42c6837abded8a2c0028cbdcf --- /dev/null +++ b/phivenv/Lib/site-packages/regex/__pycache__/_regex_core.cpython-39.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5727abd2cd4972398036f183a2e811e78ffa31946bf89c453917a171a61c12aa +size 114484 diff --git a/phivenv/Lib/site-packages/regex/__pycache__/test_regex.cpython-39.pyc b/phivenv/Lib/site-packages/regex/__pycache__/test_regex.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1874a5b4e528fda84c4106056d3b67856d58525d --- /dev/null +++ b/phivenv/Lib/site-packages/regex/__pycache__/test_regex.cpython-39.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aa833453940a5409176fe65a5ba338e66d9d875a4a905f92c064b1ade0faba66 +size 140105 diff --git a/phivenv/Lib/site-packages/regex/_regex.cp39-win_amd64.pyd b/phivenv/Lib/site-packages/regex/_regex.cp39-win_amd64.pyd new file mode 100644 index 0000000000000000000000000000000000000000..fb254b7091fec5b2a3e07d8b3d9e036fb1345668 --- /dev/null +++ b/phivenv/Lib/site-packages/regex/_regex.cp39-win_amd64.pyd @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:72ee579e80fb57b5b52f1a5a44b4dcbf85567e43442ad80f9da51f21e2f9977f +size 723968 diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/Formatting.h b/phivenv/Lib/site-packages/torch/include/ATen/core/Formatting.h new file mode 100644 index 0000000000000000000000000000000000000000..0f0476a497db0e0329d2308ef2d219a65e57656f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/Formatting.h @@ -0,0 +1,25 @@ +#pragma once + +#include +#include + +#include +#include + +namespace c10 { +TORCH_API std::ostream& operator<<(std::ostream& out, Backend b); +TORCH_API std::ostream& operator<<(std::ostream & out, const Scalar& s); +TORCH_API std::string toString(const Scalar& s); +} +namespace at { + +TORCH_API std::ostream& operator<<(std::ostream& out, const DeprecatedTypeProperties& t); +TORCH_API std::ostream& print( + std::ostream& stream, + const Tensor& tensor, + int64_t linesize); +inline std::ostream& operator<<(std::ostream & out, const Tensor & t) { + return print(out,t,80); +} +TORCH_API void print(const Tensor & t, int64_t linesize=80); +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/Generator.h b/phivenv/Lib/site-packages/torch/include/ATen/core/Generator.h new file mode 100644 index 0000000000000000000000000000000000000000..5a54146c2e37fc70377456d3701db7578a2b7e66 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/Generator.h @@ -0,0 +1,191 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include + +// For the record I don't think this is a correct pimpl idiom. +// Including Impl header in interface header defeats the purpose +// because you can't change Impl private members without forcing +// everything that included the interface to rebuild. +// Impl should be forward-declared in the interface header instead. +#include + +/** + * Note [Generator] + * ~~~~~~~~~~~~~~~~ + * A Pseudo Random Number Generator (PRNG) is an engine that uses an algorithm to + * generate a seemingly random sequence of numbers, that may be later be used in creating + * a random distribution. Such an engine almost always maintains a state and requires a + * seed to start off the creation of random numbers. Often times, users have + * found it beneficial to be able to explicitly create, retain, and destroy + * PRNG states and also be able to have control over the seed value. + * + * A Generator in ATen gives users the ability to read, write and modify a PRNG engine. + * For instance, it does so by letting users seed a PRNG engine, fork the state of the + * engine, etc. + * + * By default, there is one generator per device, and a device's generator is + * lazily created. A user can use the torch.Generator() api to create their own generator. + */ + +/** + * Note [Acquire lock when using random generators] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * Generator and its derived classes are NOT thread-safe. Please note that most of the + * places where we have inserted locking for generators are historically based, and we + * haven't actually checked that everything is truly thread safe (and it probably isn't). + * Please use the public mutex_ when using any methods from these classes, except for the + * read-only methods. You can learn about the usage by looking into the unittests + * (aten/src/ATen/cpu_generator_test.cpp) and other places where we have used lock_guard. + * + * TODO: Look into changing the threading semantics of Generators in ATen (e.g., making + * them non-thread safe and instead making the generator state splittable, to accommodate + * forks into other threads). + */ + +namespace at { + +class Tensor; + +struct TORCH_API Generator { + Generator() = default; + + explicit Generator(c10::intrusive_ptr gen_impl) + : impl_(std::move(gen_impl)) { + if (impl_.get() == nullptr) { + throw std::runtime_error("GeneratorImpl with nullptr is not supported"); + } + } + + bool operator==(const Generator& rhs) const { + return this->impl_ == rhs.impl_; + } + + bool operator!=(const Generator& rhs) const { + return !((*this) == rhs); + } + + bool defined() const { + return static_cast(impl_); + } + + c10::GeneratorImpl* unsafeGetGeneratorImpl() const { + return impl_.get(); + } + + c10::GeneratorImpl* unsafeReleaseGeneratorImpl() { + return impl_.release(); + } + + const c10::intrusive_ptr& getIntrusivePtr() const { + return impl_; + } + + void set_current_seed(uint64_t seed) { impl_->set_current_seed(seed); } + // Sets the offset of Generator state to the desired offset. This is currently + // supported for only Philox based Generators, i.e., CUDA and MPS. + void set_offset(uint64_t offset) { impl_->set_offset(offset); } + + // Returns the offset of Generator state. This is currently supported for only + // Philox based Generators, i.e., CUDA and MPS. + uint64_t get_offset() const { return impl_->get_offset(); } + + uint64_t current_seed() const { return impl_->current_seed(); } + + uint64_t seed() { return impl_->seed(); } + + // Implementation not inlined to prevent cycle reference between + // `ATen/core/Generator.h` and `ATen/core/Tensor.h` + void set_state(const at::Tensor& new_state); + + at::Tensor get_state() const; + + void graphsafe_set_state(const Generator& new_state); + + Generator graphsafe_get_state() const; + + std::mutex& mutex() { + return impl_->mutex_; + } + + DispatchKeySet key_set() const { + return impl_->key_set(); + } + + Device device() const { return impl_->device(); } + + inline void set_pyobj(PyObject* pyobj) const noexcept { + impl_->set_pyobj(pyobj); + } + + inline PyObject* pyobj() const noexcept { + return impl_->pyobj(); + } + + template + T* get() const { return static_cast(impl_.get()); } + + Generator clone() const { + return Generator(impl_->clone()); + } + + private: + c10::intrusive_ptr impl_; +}; + +template +Generator make_generator(Args&&... args) { + return Generator(c10::make_intrusive(std::forward(args)...)); +} + +/** + * Utility function to static cast input Generator* to + * the backend generator type (CPU/CUDAGeneratorImpl etc.) + */ +template +inline T * check_generator(std::optional gen) { + TORCH_CHECK(gen.has_value(), "Expected Generator but received nullopt"); + TORCH_CHECK(gen->defined(), "Generator with undefined implementation is not allowed"); + TORCH_CHECK(T::device_type() == gen->device().type(), "Expected a '", T::device_type(), "' device type for generator but found '", gen->device().type(), "'"); + return gen->get(); +} + +/** + * Utility function used in tensor implementations, which + * supplies the default generator to tensors, if an input generator + * is not supplied. The input Generator* is also static casted to + * the backend generator type (CPU/CUDAGeneratorImpl etc.) + */ +template +inline T* get_generator_or_default(const std::optional& gen, const Generator& default_gen) { + return gen.has_value() && gen->defined() ? check_generator(gen) : check_generator(default_gen); +} + +namespace detail { + +/** + * Helper function for checking the validity of new random generator + * state. Right now following conditions are checked: + * + * - The new state tensor must be a torch.ByteTensor + * - Data of the new state tensor must be contiguous + */ +inline void check_rng_state(const c10::TensorImpl& new_state) { + TORCH_CHECK_TYPE( + new_state.layout() == kStrided && new_state.device().type() == kCPU && new_state.dtype() == kByte, + "RNG state must be a torch.ByteTensor" + ); + + TORCH_CHECK(new_state.is_contiguous(), "RNG state must be contiguous"); +} + +} // namespace detail + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/GeneratorForPrivateuseone.h b/phivenv/Lib/site-packages/torch/include/ATen/core/GeneratorForPrivateuseone.h new file mode 100644 index 0000000000000000000000000000000000000000..f7e8e25ce8610b39dc189146bcd46d73ddfe9a2d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/GeneratorForPrivateuseone.h @@ -0,0 +1,39 @@ +#pragma once + +#include +#include + +namespace at { + +using GeneratorFuncType = std::function; + +TORCH_API std::optional& GetGeneratorPrivate(); + +class TORCH_API _GeneratorRegister { + public: + explicit _GeneratorRegister(const GeneratorFuncType& func); +}; + +TORCH_API at::Generator GetGeneratorForPrivateuse1( + c10::DeviceIndex device_index); + +/** + * This is used to register Generator to PyTorch for `privateuse1` key. + * + * Usage: REGISTER_GENERATOR_PRIVATEUSE1(MakeGeneratorForPrivateuse1) + * + * class CustomGeneratorImpl : public c10::GeneratorImpl { + * CustomGeneratorImpl(DeviceIndex device_index = -1); + * explicit ~CustomGeneratorImpl() override = default; + * ... + * }; + * + * at::Generator MakeGeneratorForPrivateuse1(c10::DeviceIndex id) { + * return at::make_generator(id); + * } + */ + +#define REGISTER_GENERATOR_PRIVATEUSE1(GeneratorPrivate) \ + static auto temp##GeneratorPrivate = at::_GeneratorRegister(GeneratorPrivate); + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/IListRef.h b/phivenv/Lib/site-packages/torch/include/ATen/core/IListRef.h new file mode 100644 index 0000000000000000000000000000000000000000..d3d15178a2bb8d70839fcc69bdcdf01ff6e68c88 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/IListRef.h @@ -0,0 +1,631 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include + +/* + * [Note: IListRef] + * Wrapper around different API containers (e.g. boxed and unboxed). + * + * What is it? + * =========== + * It is a tagged union of both boxed and unboxed API containers. + * Working implementations: + * + * - `IListRef` + * - `IListRef` + * + * Note that `IListRef` is a view type. Meaning that it won't own the + * tensors it holds. It's intended to be used only as argument parameters. + * Specifically, where these 2 worlds overlap. + * + * What is this for? + * ================= + * Historically, PyTorch has maintained 2 different APIs: the unboxed + * (called from C++ API and Python eager mode) and boxed APIs (called + * from the TorchScript JIT, mobile interpreter, and boxed fallbacks). + * + * Calling unboxed kernels from the boxed "world" and vice-versa may + * result in non-negligible overhead. Lists are one of those types: + * + * - Boxed world: `c10::List` + * - Unboxed world: `c10::ArrayRef` + * + * In this context, `c10::IListRef` solves this problem by wrapping those + * 2 container types, so that we don't need to convert from one to + * the other. + * + * (see https://github.com/pytorch/pytorch/issues/66328) + * + * What does it do? + * ================ + * This container wraps around the different tagged containers + * (currently, only boxed and unboxed), without incurring in extra + * overhead for converting from one to another. It does so while + * exposing usual container methods, which dispatch to corresponding + * implementations. + * + * While it works with different container types, it introduces + * overhead for repeatedly calling member functions (since those will + * get dispatched, again). Therefore, you should only use it to iterate + * through the list up to one time. If you need to do more complex things, + * call `materialize()` first. + * + * Adding support for a new Tag + * ============================ + * Suppose we want to add a new tag: `Chest`. Here are the steps + * we would have to go through: + * + * 1. Add a line for it in the macro `TORCH_ILISTREF_FORALL_TAGS`. + * + * #define TORCH_ILISTREF_FORALL_TAGS(_, ...) \ + * ... + * _(Chest, ##__VA_ARGS__) + * + * 2. Add type aliases, union members, and constructors. + * + * template + * class IListRef { + * ... + * using chest_type = + * typename detail::IListRefTagImpl::list_type; + * ... + * IListRef(...) : tag_(IListRefTag::Chest) { + * ... + * } + * ... + * union Payload { + * ... + * chest_type chest; + * ... + * }; + * ... + * }; + * + * 3. Add a default implementation for it (in 'IListRef_inl.h'). It's + * preferable to make the default implementation work for `T = Tensor` + * (both `Unboxed` and `Boxed` do it). + * + * template + * class IListRefTagImplBase { + * public: + * using elem_type = ListElemT; + * using list_type = ChestContainer; + * + * static const list_type& unwrap(const IListRef& ilist) { ... } + * + * static typename list_type::const_iterator& unwrap( + * IListRefIterator& it) { ... } + * + * static const typename list_type::const_iterator& unwrap( + * const IListRefIterator& it) { ... } + * + * static IListRefConstRef iterator_get( + * const typename list_type::const_iterator& it) { ... } + * } + * + * 4. Add an specialization for each of the already supported types. + * Finally, for consistency, add them to the tracking list. + * (see [Note: IListRefTagImpl Specializations]) + * + * template <> + * class IListRefTagImpl + * : public IListRefTagImplBase {}; + * + * Adding support for a new Type + * ============================= + * Suppose we want to add support for a new type: `Matrix`. + * Here are the steps we would have to go through: + * + * 1. Add an specialization for each of the existing tags. + * For consistency, add them to the tracking list. + * (see [Note: IListRefTagImpl Specializations]) + * + * template <> + * class IListRefTagImpl + * : public IListRefTagImplBase {}; + * + * template <> + * class IListRefTagImpl + * : public IListRefTagImplBase {}; + * + * Common Problems + * =============== + * 1. One of `IListRef(Iterator)` methods are failing to compile. + * + * That may be happening because the container type you added + * is not compatible with the code written for that method. If + * that's true, then you might have to transform that code into + * a static method call (see `List::operator[]` method). + * + * 2. Can't make `IListRefIterator::operator*` return a const-reference. + * + * First, keep in mind that we assume that boxed containers will + * have to deal with `IValue` (e.g. `c10::List`). In this context, + * what may be happening is that `IValue` doesn't store internally + * your type `T`. Instead, it constructs a type new `T` everytime + * you try to get `T` for it (see `IListRef`). + */ + +namespace c10 { +template +class IListRef; + +/* + * Applies arbitrary macros to each `IListRefTag`. + */ +#define TORCH_ILISTREF_FORALL_TAGS(_, ...) \ + _(Unboxed, ##__VA_ARGS__) \ + _(Boxed, ##__VA_ARGS__) \ + _(Materialized, ##__VA_ARGS__) + +/* + * Defines a "switch-case" for `TAG`. Inside, it executes `BODY`, + * while bringing to scope: + * + * - `ImplT`: the implementation class for `TAG` + * - `this_`: the result of unwrapping `this` + */ +#define TORCH_ILISTREF_UNWRAP_CASE(TAG, BODY) \ + case c10::IListRefTag::TAG: { \ + using ImplT = c10::detail::IListRefTagImpl; \ + auto& this_ = ImplT::unwrap(*this); \ + BODY \ + } break; + +/* + * Dispatches the unwrap call, depending on `TAG`, followed by + * the execution of `BODY`. It aborts if `TAG` is not a `IListRefTag`. + * + * This macro is useful because it allows us to handle different + * types (that correspond to different tags) to be implemented + * only once. We can do it even when the implementation of the + * different tags aren't syntatically the same, by dispatching + * it to a function (e.g. `ImplT::(this_)`). + */ +#define TORCH_ILISTREF_UNWRAP(TAG, BODY) \ + switch (TAG) { \ + TORCH_ILISTREF_FORALL_TAGS(TORCH_ILISTREF_UNWRAP_CASE, BODY) \ + break; \ + default: \ + TORCH_INTERNAL_ASSERT(false, "invalid IListRef tag."); \ + } + +enum class IListRefTag { +#define DEFINE_TAG(tag, ...) tag, + TORCH_ILISTREF_FORALL_TAGS(DEFINE_TAG) +#undef DEFINE_TAG + None +}; + +namespace detail { +/* + * Type alias that specifies whether we return a reference or a copy of `T`. + * + * What is this for? + * ================= + * Since values in the boxed world are represented by an `IValue`, we also + * depend on whether it can be converted to a const-reference (`Tensor`) or + * has to create a new copy of `T` (`OptionalTensorRef`). + */ +template +using IListRefConstRef = typename ivalue_to_const_ref_overload_return::type; + +/* + * Interface that implements key functions for each `IListRefTag` type. + * + * What is this for? + * ================= + * Given an `IListRef(Iterator)`, some methods have to be implemented + * differently for each `TAG`. Therefore, the methods inside this class + * are used as dispatch targets for the different `IListRefTag` values. + * + * You should create an specialization of this class for each possible + * combination of `IListRefTag` type (except `None`) and element types + * (e.g. `Tensor`). + * + * What does it do? + * ================ + * 1. defines static methods to be used as dispatch targets by both + * `IListRef` and `IListRefIterator` (see the implementation of + * `IListRefTagImplBase`). + * + * 2. defines the `elem_type` and `list_type` aliases that will be + * used in the definition of `IListRef`. In general, we should do + * so by inheriting from `IListRefTagImplBase`. + * + * [Note: IListRefTagImpl Specialization] + * ====================================== + * For `IListRef(Iterator)`: + * - + * - + * - + * + * For `IListRef(Iterator)`: + * - + * - + * - + */ +template +class IListRefTagImpl {}; + +/* + * Base implementation of `IListRefTagImpl` methods. + * + * What is this for? + * ================= + * This should make adding specializations for new types easier. For + * example, one should be able to add a new type just by making its + * `IListRefTagImpl` specialization inherit from `IListRefTagImplBase`. + * + * You should create a partial specialization for this class only if + * you introduce a new `IListRefTag`. The idea being that there is one + * default implementation for each possible value of `IListRefTag`. + * + * What does it do? + * ================ + * 1. defines `elem_type` as an alias to `ListElemT`. + * + * 1. defines `list_type` as an alias to the default container type + * that will hold a collection of `elem_type`. The idea being that + * all types tagged as `TAG` will have `list_type` as its container, + * with different `elem_type`. + * + * 3. defines the default implementation for each of the methods that + * are supposed to be defined on `IListRefTagImpl` specializations. + * + * 4. inheriting from `IListRefTagImplBase` also means + * that the payload of the type `IListRef` will be of type `list_type` + * when it is tagged as `TAG`. + */ +template +class IListRefTagImplBase {}; + +/* + * Materialized container for `IListRef`. + * + * What is this for? + * ================= + * Container that groups `T` references together. This exchanges the + * overhead of every method call from `IListRef` for a dynamic allocation. + * + * You should use this container instead of `IListRef` if: + * + * - You are going to iterate the list more than once + * - You need to repeatedly access arbitrary elements (using `operator[]`) + * What does it do? + + * ================ + * Removes the reference (&) from the type, and wraps it into a + * `std::reference_wrapper`. If `IListRefConstRef` is not a + * reference type, then it's left unchanged. + */ +template +using _MaterializedIListRefElem = std::conditional_t< + std::is_reference_v, + typename std::reference_wrapper>, + T>; + +template +using MaterializedIListRefElem = _MaterializedIListRefElem>; + +template +using MaterializedIListRef = std::vector>; + +} // namespace detail + +/* + * Iterator for `IListRef`. + * + * What is it? + * =========== + * Currently, a `std::bidirectional_iterator` that wraps the iterator + * types defined for each of the `IListRefTag`. + * + * One should be able to use it, as if it were the unwrapped + * iterators themselves. + + * What does it do? + * ================ + * Similarly to `IListRef`, this is a wrapper class. Specifically, it + * wraps each container's `const_iterator` type alias. So, for example, + * given that the container for `IListRefTag::Boxed` is `c10::List`, this + * iterator will wrap a `c10::List::const_iterator`. + * + * [Note: MSVC Iterator Debug] + * =========================== + * MSVC `vector::iterator` implementation (used in the boxed variant) + * makes it so this union's destructor, copy-constructor (assignment), and + * move-constructor (assignment) are implicitly deleted. + * + * Therefore, we need to explicitly define them as needed. Follows a list + * of places where these are needed and their reason: + * + * - `Payload` destructor: + * it is deleted only if the macro `_ITERATOR_DEBUG_LEVEL` is set to 2. + * + * - `IListRefIterator` destructor: + * same as above. However, we need to explicitly call the variant + * destructor explicitly. + * + * - `IListRefIterator` copy-constructor: + * it is deleted only if the macro `_ITERATOR_DEBUG_LEVEL` is different + * than 0. + */ +template +class IListRefIterator { + private: +#define DEFINE_FRIEND_CLASS(TAG, ...) \ + friend class detail::IListRefTagImpl; \ + friend class detail::IListRefTagImplBase< \ + IListRefTag::TAG, \ + T, \ + typename detail::IListRefTagImpl::elem_type>; + TORCH_ILISTREF_FORALL_TAGS(DEFINE_FRIEND_CLASS) +#undef DEFINE_FRIEND_CLASS + + public: + // C++17 friendly std::iterator implementation + using iterator_category = std::bidirectional_iterator_tag; + using value_type = T; + using difference_type = std::ptrdiff_t; + using pointer = T*; + using reference = T&; + + using unboxed_iterator_type = typename detail:: + IListRefTagImpl::list_type::const_iterator; + using boxed_iterator_type = typename detail:: + IListRefTagImpl::list_type::const_iterator; + using materialized_iterator_type = + typename detail::MaterializedIListRef::const_iterator; + + IListRefIterator() : tag_(IListRefTag::None) {} + +#if defined(_MSC_VER) && _ITERATOR_DEBUG_LEVEL != 0 + // See [Note: MSVC Iterator Debug] + IListRefIterator(const IListRefIterator& iterator) + : tag_(iterator.tag_) { + switch (tag_) { + case IListRefTag::Boxed: + payload_.boxed_iterator = iterator.payload_.boxed_iterator; + break; + case IListRefTag::Unboxed: + payload_.unboxed_iterator = iterator.payload_.unboxed_iterator; + break; + case IListRefTag::Materialized: + payload_.materialized_iterator = iterator.payload_.materialized_iterator; + break; + default: + TORCH_INTERNAL_ASSERT(false, "invalid IListRef tag."); + } + } +#endif + +#if defined(_MSC_VER) && _ITERATOR_DEBUG_LEVEL == 2 + // See [Note: MSVC Iterator Debug] + ~IListRefIterator() noexcept(false) { + switch (tag_) { + case IListRefTag::Boxed: + payload_.boxed_iterator.~boxed_iterator_type(); + break; + case IListRefTag::Unboxed: + payload_.unboxed_iterator.~unboxed_iterator_type(); + break; + case IListRefTag::Materialized: + payload_.materialized_iterator.~materialized_iterator_type(); + break; + default: + TORCH_INTERNAL_ASSERT(false, "invalid IListRef tag."); + } + } +#endif + + IListRefIterator(boxed_iterator_type boxed) : tag_(IListRefTag::Boxed) { + payload_.boxed_iterator = boxed; + } + + IListRefIterator(unboxed_iterator_type unboxed) : tag_(IListRefTag::Unboxed) { + payload_.unboxed_iterator = unboxed; + } + + IListRefIterator(materialized_iterator_type materialized) : tag_(IListRefTag::Materialized) { + payload_.materialized_iterator = materialized; + } + + detail::IListRefConstRef operator*() const { + TORCH_ILISTREF_UNWRAP(tag_, { return ImplT::iterator_get(this_); }); + } + + IListRefIterator& operator++() { + TORCH_ILISTREF_UNWRAP(tag_, { ++this_; }); + return *this; + } + + IListRefIterator operator++(int) { + auto old = *this; + TORCH_ILISTREF_UNWRAP(tag_, { ++this_; }); + return old; + } + + IListRefIterator& operator--() { + TORCH_ILISTREF_UNWRAP(tag_, { --this_; }); + return *this; + } + + IListRefIterator operator--(int) { + auto old = *this; + TORCH_ILISTREF_UNWRAP(tag_, { --this_; }); + return old; + } + + bool operator==(const IListRefIterator& rhs) const { + if (tag_ != rhs.tag_) { + return false; + } + TORCH_ILISTREF_UNWRAP(tag_, { + auto& rhs_it = ImplT::unwrap(rhs); + return this_ == rhs_it; + }); + } + + bool operator!=(const IListRefIterator& rhs) const { + return !(*this == rhs); + } + + private: + union Payload { + boxed_iterator_type boxed_iterator; + unboxed_iterator_type unboxed_iterator; + materialized_iterator_type materialized_iterator; + void* _init_ptr; + Payload() : _init_ptr(nullptr) {} +#if defined(_MSC_VER) + // See [Note: MSVC Iterator Debug] + ~Payload() {} +#endif + }; + + Payload payload_; + IListRefTag tag_; +}; + +/* + * See [Note: IListRef] + */ +template +class IListRef { + private: +#define DEFINE_FRIEND_CLASS(TAG, ...) \ + friend class detail::IListRefTagImpl; \ + friend class detail::IListRefTagImplBase< \ + IListRefTag::TAG, \ + T, \ + typename detail::IListRefTagImpl::elem_type>; + TORCH_ILISTREF_FORALL_TAGS(DEFINE_FRIEND_CLASS) +#undef DEFINE_FRIEND_CLASS + + public: + using unboxed_type = + typename detail::IListRefTagImpl::list_type; + using boxed_type = + typename detail::IListRefTagImpl::list_type; + using materialized_type = + typename detail::MaterializedIListRef; + + using iterator = IListRefIterator; + using const_iterator = IListRefIterator; + using reverse_iterator = std::reverse_iterator; + using value_type = typename iterator::value_type; + + IListRef() : tag_(IListRefTag::None) {} + + IListRef(const boxed_type& boxed) : tag_(IListRefTag::Boxed) { + payload_.boxed = &boxed; + } + + IListRef(const unboxed_type& unboxed) : tag_(IListRefTag::Unboxed) { + payload_.unboxed = unboxed; + } + + IListRef(const std::initializer_list& list) : tag_(IListRefTag::Unboxed) { + payload_.unboxed = at::ArrayRef(list); + } + + template < + typename... UnboxedConstructorArgs, + typename = std::enable_if_t< + std::is_constructible_v>> + IListRef(UnboxedConstructorArgs&&... args) : tag_(IListRefTag::Unboxed) { + payload_.unboxed = unboxed_type(std::forward(args)...); + } + + IListRef(const materialized_type& materialized) : tag_(IListRefTag::Materialized) { + payload_.materialized = &materialized; + } + + size_t size() const { + TORCH_ILISTREF_UNWRAP(tag_, { return this_.size(); }); + } + + bool empty() const { + return size() == 0; + } + + iterator begin() const { + TORCH_ILISTREF_UNWRAP(tag_, { return this_.begin(); }); + } + + iterator end() const { + TORCH_ILISTREF_UNWRAP(tag_, { return this_.end(); }); + } + + detail::IListRefConstRef front() const { + TORCH_ILISTREF_UNWRAP(tag_, { return ImplT::front(this_); }); + } + + /* + * Materializes the `IListRef` into a `std::vector`. + * + * This should be used when one wishes to either: + * + * - iterate over the list more than once: each `IListRefIterator` + * member function call has to go through a switch, introducing + * non-negligible overhead + * + * - randomly access an arbitrary element using `operator[]`: + * same reason as above + */ + detail::MaterializedIListRef materialize() const { + if (isMaterialized()) { + return toMaterialized(); + } + + detail::MaterializedIListRef materialized; + materialized.reserve(size()); + for (const auto& t : *this) { + materialized.emplace_back(t); + } + return materialized; + } + +#define DEFINE_CHECK(TAG, ...) \ + bool is##TAG() const { \ + return tag_ == IListRefTag::TAG; \ + } + TORCH_ILISTREF_FORALL_TAGS(DEFINE_CHECK) +#undef DEFINE_CHECK + + bool isNone() const { + return tag_ == IListRefTag::None; + } + +#define DEFINE_CASTING(TAG, ...) \ + const typename detail::IListRefTagImpl::list_type& \ + to##TAG() const { \ + TORCH_INTERNAL_ASSERT(is##TAG()); \ + return detail::IListRefTagImpl::unwrap(*this); \ + } + TORCH_ILISTREF_FORALL_TAGS(DEFINE_CASTING) +#undef DEFINE_CASTING + + private: + union Payload { + const boxed_type* boxed; + unboxed_type unboxed; + const materialized_type* materialized; + Payload() : boxed(nullptr) {} + }; + + Payload payload_; + IListRefTag tag_; +}; + +} // namespace c10 + +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/IListRef_inl.h b/phivenv/Lib/site-packages/torch/include/ATen/core/IListRef_inl.h new file mode 100644 index 0000000000000000000000000000000000000000..7faec3c11669e61ca45a62d063941f939dd5a12b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/IListRef_inl.h @@ -0,0 +1,203 @@ +#pragma once + +#include +#include + +namespace at { +class Tensor; +class OptionalTensorRef; +} + + +namespace c10::detail { + +/* + * Specializations of `IListRefTagImplBase` that implement the default + * implementation for `IListRefTag::Unboxed`. + */ +template +class IListRefTagImplBase { + public: + using elem_type = ListElemT; + using list_type = ArrayRef; + + /* + * These `unwrap` static methods unwraps the inner containers out + * of `IListRef` (and `IListRefIterator`). They are required when + * the macro `TORCH_ILISTREF_UNWRAP` is called. + */ + static const list_type& unwrap(const IListRef& ilist) { + return ilist.payload_.unboxed; + } + + static typename list_type::const_iterator& unwrap(IListRefIterator& it) { + return it.payload_.unboxed_iterator; + } + + static const typename list_type::const_iterator& unwrap( + const IListRefIterator& it) { + return it.payload_.unboxed_iterator; + } + + /* + * We have these function (besides the `unwrap`s above) because the + * implementation for both `IListRef::operator[]` and `IListRefIterator::operator*` + * weren't syntatically equal for the existing tags at the time + * (`Unboxed` and `Boxed`). + */ + static IListRefConstRef front(const list_type& lst) { + return lst.front(); + } + + static IListRefConstRef iterator_get( + const typename list_type::const_iterator& it) { + return *it; + } +}; + +/* + * Specializations of `IListRefTagImplBase` that implement the default + * implementation for `IListRefTag::Boxed`. + */ +template +class IListRefTagImplBase { + public: + using elem_type = ListElemT; + using list_type = List; + + static const list_type& unwrap(const IListRef& ilist) { + return *ilist.payload_.boxed; + } + + static typename list_type::const_iterator& unwrap(IListRefIterator& it) { + return it.payload_.boxed_iterator; + } + + static const typename list_type::const_iterator& unwrap( + const IListRefIterator& it) { + return it.payload_.boxed_iterator; + } + + static IListRefConstRef front(const list_type& lst) { + return lst[0]; + } + + static IListRefConstRef iterator_get( + const typename list_type::const_iterator& it) { + return (*it).get().toTensor(); + } +}; + +/* + * Specializations of `IListRefTagImplBase` that implement the default + * implementation for `IListRefTag::Materialized`. + */ +template +class IListRefTagImplBase> { + public: + using elem_type = MaterializedIListRefElem; + using list_type = MaterializedIListRef; + + static const list_type& unwrap(const IListRef& ilist) { + return *ilist.payload_.materialized; + } + + static typename list_type::const_iterator& unwrap(IListRefIterator& it) { + return it.payload_.materialized_iterator; + } + + static const typename list_type::const_iterator& unwrap( + const IListRefIterator& it) { + return it.payload_.materialized_iterator; + } + + static IListRefConstRef front(const list_type& lst) { + return lst[0]; + } + + static IListRefConstRef iterator_get( + const typename list_type::const_iterator& it) { + return *it; + } +}; + +/* + * [Note: ITensorListRef] + * Specializations necessary for `IListRef` type. + * + * Since the default implementations are usually done with supporting + * `Tensor` in mind, we only have to inherit from the base implementations. + */ +template <> +class IListRefTagImpl + : public IListRefTagImplBase {}; + +template <> +class IListRefTagImpl + : public IListRefTagImplBase {}; + +template <> +class IListRefTagImpl + : public IListRefTagImplBase< + IListRefTag::Materialized, + at::Tensor, + MaterializedIListRefElem> {}; + +/* + * [Note: IOptTensorListRef] + * Specializations necessary for `IListRef` type. + * + * We can't get an `at::OptionalTensorRef` directly from an instance of + * `List>` (the type that corresponds to the boxed world). + * + * So, the default implementation won't help us. Thus, we have to implement + * this method ourselves. + */ +template <> +class IListRefTagImpl + : public IListRefTagImplBase {}; + +template <> +class IListRefTagImpl + : public IListRefTagImplBase> { + + public: + /* + * Given an instance of the types corresponding to the `Boxed` tag, we override + * the default implementation, so that we can return a `at::OptionalTensorRef`. + */ + static IListRefConstRef iterator_get( + const typename list_type::const_iterator& it) { + C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdangling-reference") + const auto& ivalue = (*it).get(); + C10_DIAGNOSTIC_POP() + if (!ivalue.isNone()) { + const auto& tensor = ivalue.toTensor(); + return (tensor.defined()) ? tensor : at::OptionalTensorRef{}; + } + return {}; + } +}; + +template <> +class IListRefTagImpl + : public IListRefTagImplBase< + IListRefTag::Materialized, + at::OptionalTensorRef, + MaterializedIListRefElem> {}; + +} // namespace c10::detail + + +namespace at { + +// [Note: ITensorListRef] +using ITensorListRef = c10::IListRef; +using ITensorListRefIterator = c10::IListRefIterator; +using MaterializedITensorListRef = c10::detail::MaterializedIListRef; +// [Note: IOptTensorListRef] +using IOptTensorListRef = c10::IListRef; +using IOptTensorListRefIterator = c10::IListRefIterator; +using MaterializedIOptTensorListRef = c10::detail::MaterializedIListRef; + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/LegacyTypeDispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/core/LegacyTypeDispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..1244d0fda87ffd3846c6c6352ed46a7de45b5678 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/LegacyTypeDispatch.h @@ -0,0 +1,111 @@ +#pragma once + +// The legacy mechanism for dispatching operators in ATen is a Type +// object, which is essentially a giant virtual dispatch table +// for every operation we support dynamically dispatching over. +// +// This has been deprecated in favor of ATenDispatch, and in the future, +// c10 dispatcher. +// TODO: Clean up what remains here + +#include + +namespace at { + +// A RAII, thread local (!) guard that will disable dispatch to variable +// handler. +// +// NOTE [ Treating Variables as non-Variables in type dispatch ] +// +// What exactly does AutoDispatchBelowAutograd do? The short answer is, it causes +// dispatches on ATen functions to go to the non-variable implementation, +// bypassing autograd handling (and also profiling and tracing). +// +// To understand why this guard exists, it's helpful to understand the history +// behind how Variable was implemented. Previously, Variables were implemented +// as a wrapper on Tensors; so the act of processing a Variable involved +// unwrapping the underlying Tensor, and then calling the underlying base +// operation on /that/ operation +// +// However, after the Variable/Tensor merge, there is no concept of unwrapping +// a tensor anymore. If you just call the operation on the same variable +// again inside your VariableType handler, you'll dispatch back to +// VariableType, which is not what we want. +// +// The solution to the above problem is to add `at::AutoDispatchBelowAutograd`, which +// when enabled will cause `legacyTensorType()` and `getType()` to always return +// non-Variable type, even if the tensor being called on is a variable. + +/* Note [AutoDispatchBelowAutograd] + * AutoDispatchBelowAutograd is **INTERNAL ONLY** that it should be used + * for kernel implementations and customized C++ kernels. + * If you are looking for a guard to run workload in inference mode, please use + * c10::InferenceMode RAII which is user facing API. + * In the past AutoDispatchBelowAutograd(or its old version AutoNonVariableTypeMode) + * was used in the user code for inference-only workload, this was under risk of + * producing wrong results silently in some edge cases. For example: + * ``` + * torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(true); + * torch::Tensor out = s * s; + * { + * at::AutoDispatchBelowAutograd guard; + * s.add_(1); // Skips version bump on `s`. + * } + * // WRONG GRADIENT! s.grad() are now computed using `s` value after the + * // inplace update. + * out.backward(torch::ones_like(out)); + * ``` + * Users should use `c10::InferenceMode` here so that it'll properly throw an + * error saying "one of the variables needed for gradient computation has be modified." + */ +struct TORCH_API AutoDispatchBelowAutograd { + AutoDispatchBelowAutograd() : + autograd_guard_(c10::autograd_dispatch_keyset) { + } + + // disable all autograd dispatch keys + c10::impl::ExcludeDispatchKeyGuard autograd_guard_; +}; + +// TODO: AutoNonVariableTypeMode should be removed in release 1.10. +struct TORCH_API AutoNonVariableTypeMode { + AutoNonVariableTypeMode(bool enabled = true) : + autograd_guard_(c10::autograd_dispatch_keyset) { + TORCH_WARN_ONCE("AutoNonVariableTypeMode is deprecated and will be removed in 1.10 release. " + "For kernel implementations please use AutoDispatchBelowADInplaceOrView instead, " + "If you are looking for a user facing API to enable running your inference-only " + "workload, please use c10::InferenceMode. Using AutoDispatchBelowADInplaceOrView in user code " + "is under risk of producing silent wrong result in some edge cases. " + "See Note [AutoDispatchBelowAutograd] for more details."); + TORCH_INTERNAL_ASSERT(enabled); + } + + // disable all autograd dispatch keys + c10::impl::ExcludeDispatchKeyGuard autograd_guard_; +}; + +struct TORCH_API AutoDispatchSkipFunctionalize { + AutoDispatchSkipFunctionalize() : + dispatch_key_guard_(c10::DispatchKeySet(c10::DispatchKey::Functionalize)) { + } + c10::impl::ExcludeDispatchKeyGuard dispatch_key_guard_; +}; + +/* Note [AutoDispatchBelowADInplaceOrView] + * AutoDispatchBelowADInplaceOrView is equivalent to AutoNonVariableTypeMode + * before we split inplace & view ops out of VariableType kernel. + * Note this guard is used in VariableType kernels for functional ops + * as well as ADInplaceOrView kernels for inplace/view ops to enforce the + * Invariant: + * Once you are in VariableType/ADInplaceOrView kernel for an op, + * you never go back to a kernel on same dispatch key until + * you finish the current op. + */ +struct TORCH_API AutoDispatchBelowADInplaceOrView { + AutoDispatchBelowADInplaceOrView() : + dispatch_key_guard_(c10::autograd_dispatch_keyset_with_ADInplaceOrView) { + } + // disable Autograd & ADInplaceOrView dispatch keys + c10::impl::ExcludeDispatchKeyGuard dispatch_key_guard_; +}; +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/List.h b/phivenv/Lib/site-packages/torch/include/ATen/core/List.h new file mode 100644 index 0000000000000000000000000000000000000000..3635219c35baff5cf486e38d19443c9d6a402466 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/List.h @@ -0,0 +1,491 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +class Tensor; +} +namespace c10 { +struct IValue; +template class List; +struct Type; + +namespace detail { + +struct ListImpl final : public c10::intrusive_ptr_target { + using list_type = std::vector; + + explicit TORCH_API ListImpl(list_type list_, TypePtr elementType_); + + list_type list; + + TypePtr elementType; + + intrusive_ptr copy() const { + return make_intrusive(list, elementType); + } + friend TORCH_API bool operator==(const ListImpl& lhs, const ListImpl& rhs); +}; +} + +namespace impl { + +template class ListIterator; + +template class ListElementReference; + +template +void swap(ListElementReference&& lhs, ListElementReference&& rhs) noexcept; + +template +bool operator==(const ListElementReference& lhs, const T& rhs); + +template +bool operator==(const T& lhs, const ListElementReference& rhs); + +template +struct ListElementConstReferenceTraits { + // In the general case, we use IValue::to(). + using const_reference = typename c10::detail::ivalue_to_const_ref_overload_return::type; +}; + +// There is no to() overload for std::optional. +template<> +struct ListElementConstReferenceTraits> { + using const_reference = std::optional>; +}; + +template +class ListElementReference final { +public: + operator std::conditional_t< + std::is_reference_v::type>, + const T&, + T>() const; + + ListElementReference& operator=(T&& new_value) &&; + + ListElementReference& operator=(const T& new_value) &&; + + // assigning another ref to this assigns the underlying value + ListElementReference& operator=(ListElementReference&& rhs) && noexcept; + + const IValue& get() const& { + return *iterator_; + } + + friend void swap(ListElementReference&& lhs, ListElementReference&& rhs) noexcept; + + ListElementReference(const ListElementReference&) = delete; + ListElementReference& operator=(const ListElementReference&) = delete; + ~ListElementReference() = default; + +private: + ListElementReference(Iterator iter) + : iterator_(iter) {} + + // allow moving, but only our friends (i.e. the List class) can move us + ListElementReference(ListElementReference&&) noexcept = default; + ListElementReference& operator=(ListElementReference&& rhs) & noexcept { + iterator_ = std::move(rhs.iterator_); + return *this; + } + + friend class List; + friend class ListIterator; + + Iterator iterator_; +}; + +// this wraps vector::iterator to make sure user code can't rely +// on it being the type of the underlying vector. +template +class ListIterator final { + public: + // C++17 friendly std::iterator implementation + using iterator_category = std::random_access_iterator_tag; + using value_type = T; + using difference_type = std::ptrdiff_t; + using pointer = T*; + using reference = ListElementReference; + + explicit ListIterator() = default; + ~ListIterator() = default; + + ListIterator(const ListIterator&) = default; + ListIterator(ListIterator&&) noexcept = default; + ListIterator& operator=(const ListIterator&) = default; + ListIterator& operator=(ListIterator&&) noexcept = default; + + ListIterator& operator++() { + ++iterator_; + return *this; + } + + ListIterator operator++(int) { + ListIterator copy(*this); + ++*this; + return copy; + } + + ListIterator& operator--() { + --iterator_; + return *this; + } + + ListIterator operator--(int) { + ListIterator copy(*this); + --*this; + return copy; + } + + ListIterator& operator+=(typename List::size_type offset) { + iterator_ += offset; + return *this; + } + + ListIterator& operator-=(typename List::size_type offset) { + iterator_ -= offset; + return *this; + } + + ListIterator operator+(typename List::size_type offset) const { + return ListIterator{iterator_ + offset}; + } + + ListIterator operator-(typename List::size_type offset) const { + return ListIterator{iterator_ - offset}; + } + + friend difference_type operator-(const ListIterator& lhs, const ListIterator& rhs) { + return lhs.iterator_ - rhs.iterator_; + } + + ListElementReference operator*() const { + return {iterator_}; + } + + ListElementReference operator[](typename List::size_type offset) const { + return {iterator_ + offset}; + } + +private: + explicit ListIterator(Iterator iterator): iterator_(std::move(iterator)) {} + + Iterator iterator_; + + friend bool operator==(const ListIterator& lhs, const ListIterator& rhs) { + return lhs.iterator_ == rhs.iterator_; + } + + friend bool operator!=(const ListIterator& lhs, const ListIterator& rhs) { + return !(lhs == rhs); + } + + friend bool operator<(const ListIterator& lhs, const ListIterator& rhs) { + return lhs.iterator_ < rhs.iterator_; + } + + friend bool operator<=(const ListIterator& lhs, const ListIterator& rhs) { + return lhs.iterator_ <= rhs.iterator_; + } + + friend bool operator>(const ListIterator& lhs, const ListIterator& rhs) { + return lhs.iterator_ > rhs.iterator_; + } + + friend bool operator>=(const ListIterator& lhs, const ListIterator& rhs) { + return lhs.iterator_ >= rhs.iterator_; + } + + friend class ListIterator; + friend class List; +}; + +template List toTypedList(List list); +template List toList(List&& list); +template List toList(const List& list); +const IValue* ptr_to_first_element(const List& list); +} + +/** + * An object of this class stores a list of values of type T. + * + * This is a pointer type. After a copy, both Lists + * will share the same storage: + * + * > List a; + * > List b = a; + * > b.push_back("three"); + * > ASSERT("three" == a.get(0)); + * + * We use this class in the PyTorch kernel API instead of + * std::vector, because that allows us to do optimizations + * and switch out the underlying list implementation without + * breaking backwards compatibility for the kernel API. + */ +template +// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions) +class List final { +private: + // This is an intrusive_ptr because List is a pointer type. + // Invariant: This will never be a nullptr, there will always be a valid + // ListImpl. + c10::intrusive_ptr impl_; + + using internal_reference_type = impl::ListElementReference; + using internal_const_reference_type = typename impl::ListElementConstReferenceTraits::const_reference; + +public: + using value_type = T; + using size_type = typename c10::detail::ListImpl::list_type::size_type; + using iterator = impl::ListIterator; + using const_iterator = impl::ListIterator; + using reverse_iterator = impl::ListIterator; + + /** + * Constructs an empty list. + */ + explicit List(); + + /** + * Constructs a list with some initial values. + * Example: + * List a({2, 3, 4}); + */ + List(std::initializer_list initial_values); + explicit List(ArrayRef initial_values); + + /** + * Create a generic list with runtime type information. + * This only works for c10::impl::GenericList and is not part of the public API + * but only supposed to be used internally by PyTorch. + */ + explicit List(TypePtr elementType); + + List(const List&) = default; + List& operator=(const List&) = default; + ~List() = default; + + /** + * Create a new List pointing to a deep copy of the same data. + * The List returned is a new list with separate storage. + * Changes in it are not reflected in the original list or vice versa. + */ + List copy() const; + + /** + * Returns the element at specified location pos, with bounds checking. + * If pos is not within the range of the container, an exception of type std::out_of_range is thrown. + */ + internal_const_reference_type get(size_type pos) const; + + /** + * Moves out the element at the specified location pos and returns it, with bounds checking. + * If pos is not within the range of the container, an exception of type std::out_of_range is thrown. + * The list contains an invalid element at position pos afterwards. Any operations + * on it before re-setting it are invalid. + */ + value_type extract(size_type pos) const; + + /** + * Returns a reference to the element at specified location pos, with bounds checking. + * If pos is not within the range of the container, an exception of type std::out_of_range is thrown. + * + * You cannot store the reference, but you can read it and assign new values to it: + * + * List list = ...; + * list[2] = 5; + * int64_t v = list[1]; + */ + internal_const_reference_type operator[](size_type pos) const; + + internal_reference_type operator[](size_type pos); + + /** + * Assigns a new value to the element at location pos. + */ + void set(size_type pos, const value_type& value) const; + + /** + * Assigns a new value to the element at location pos. + */ + void set(size_type pos, value_type&& value) const; + + /** + * Returns an iterator to the first element of the container. + * If the container is empty, the returned iterator will be equal to end(). + */ + iterator begin() const; + + /** + * Returns an iterator to the element following the last element of the container. + * This element acts as a placeholder; attempting to access it results in undefined behavior. + */ + iterator end() const; + + /** + * Checks if the container has no elements. + */ + bool empty() const; + + /** + * Returns the number of elements in the container + */ + size_type size() const; + + /** + * Increase the capacity of the vector to a value that's greater or equal to new_cap. + */ + void reserve(size_type new_cap) const; + + /** + * Erases all elements from the container. After this call, size() returns zero. + * Invalidates any references, pointers, or iterators referring to contained elements. Any past-the-end iterators are also invalidated. + */ + void clear() const; + + /** + * Inserts value before pos. + * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. + */ + iterator insert(iterator pos, const T& value) const; + + /** + * Inserts value before pos. + * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. + */ + iterator insert(iterator pos, T&& value) const; + + /** + * Inserts a new element into the container directly before pos. + * The new element is constructed with the given arguments. + * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. + */ + template + iterator emplace(iterator pos, Args&&... value) const; + + /** + * Appends the given element value to the end of the container. + * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. + */ + void push_back(const T& value) const; + + /** + * Appends the given element value to the end of the container. + * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. + */ + void push_back(T&& value) const; + + /** + * Appends the given list to the end of the container. Uses at most one memory allocation. + * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. + */ + void append(List lst) const; + + /** + * Appends the given element value to the end of the container. + * The new element is constructed with the given arguments. + * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. + */ + template + void emplace_back(Args&&... args) const; + + /** + * Removes the element at pos. + * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. + */ + iterator erase(iterator pos) const; + + /** + * Removes the elements in the range [first, last). + * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. + */ + iterator erase(iterator first, iterator last) const; + + /** + * Removes the last element of the container. + * Calling pop_back on an empty container is undefined. + * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. + */ + void pop_back() const; + + /** + * Resizes the container to contain count elements. + * If the current size is less than count, additional default-inserted elements are appended. + * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. + */ + void resize(size_type count) const; + + /** + * Resizes the container to contain count elements. + * If the current size is less than count, additional copies of value are appended. + * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. + */ + void resize(size_type count, const T& value) const; + + /** + * Value equality comparison. This function implements Python-like semantics for + * equality: two lists with the same identity (e.g. same pointer) trivially + * compare equal, otherwise each element is compared for equality. + */ + template + friend bool operator==(const List& lhs, const List& rhs); + + template + friend bool operator!=(const List& lhs, const List& rhs); + + /** + * Identity comparison. Returns true if and only if `rhs` represents the same + * List object as `this`. + */ + bool is(const List& rhs) const; + + std::vector vec() const; + + /** + * Returns the number of Lists currently pointing to this same list. + * If this is the only instance pointing to this list, returns 1. + */ + // TODO Test use_count + size_t use_count() const; + + TypePtr elementType() const; + + // See [unsafe set type] for why this exists. + void unsafeSetElementType(TypePtr t); + +private: + explicit List(c10::intrusive_ptr&& elements); + explicit List(const c10::intrusive_ptr& elements); + friend struct IValue; + template friend List impl::toTypedList(List); + template friend List impl::toList(List&&); + template friend List impl::toList(const List&); + friend const IValue* impl::ptr_to_first_element(const List& list); +}; + +namespace impl { +// GenericList is how IValue stores lists. It is, however, not part of the +// public API. Kernels should use Lists with concrete types instead +// (maybe except for some internal prim ops). +using GenericList = List; + +} +} + +namespace torch { + template using List = c10::List; +} + +#include // IWYU pragma: keep diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/List_inl.h b/phivenv/Lib/site-packages/torch/include/ATen/core/List_inl.h new file mode 100644 index 0000000000000000000000000000000000000000..c60ca363ee58fc331e66d7394d0d49b795c43dae --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/List_inl.h @@ -0,0 +1,353 @@ +#pragma once + +#include +#include + +namespace c10 { + +template decltype(auto) getTypePtr(); +std::string toString(const Type& type); + +template +List::List(c10::intrusive_ptr&& elements) +: impl_(std::move(elements)) {} + +template +List::List(const c10::intrusive_ptr& elements) +: impl_(elements) {} + +template +List::List() +: List(make_intrusive( + typename c10::detail::ListImpl::list_type(), + getTypePtr())) { + static_assert(!std::is_same_v, "This constructor is not valid for List. Please use c10::impl::GenericList(elementType) instead."); +} + +template +List::List(ArrayRef values) +: List(make_intrusive( + typename c10::detail::ListImpl::list_type(), + getTypePtr())) { + static_assert(!std::is_same_v, "This constructor is not valid for List. Please use c10::impl::GenericList(elementType)."); + impl_->list.reserve(values.size()); + for (const T& element : values) { + impl_->list.push_back(element); + } +} + +template +List::List(std::initializer_list initial_values) +: List(ArrayRef(initial_values)) { + static_assert(!std::is_same_v, "This constructor is not valid for List. Please use c10::impl::GenericList(elementType)."); +} + +template +List::List(TypePtr elementType) +: List(make_intrusive( + typename c10::detail::ListImpl::list_type(), + std::move(elementType))) { + static_assert(std::is_same_v || std::is_same_v>, + "This constructor is only valid for c10::impl::GenericList or List."); +} + +namespace impl { +template +List toTypedList(impl::GenericList list) { + // If there's other instances of the list (i.e. list.use_count() > 1), then we have to be invariant + // because upcasting would allow people to add types into the new list that would break the old list. + // However, if there aren't any other instances of this list (i.e. list.use_count() == 1), then we can + // allow upcasting. This can be a perf improvement since we can cast List to List> + // without having to copy it. This is also used to provide backwards compatibility with some old models + // that serialized the index arguments to aten::index, aten::index_put, aten::index_put_ and aten::index_put_impl_ + // as List before we changed that argument to be List>. When deserializing, we + // have list.use_count() == 1 and can deserialize the List directly as List>. + TORCH_CHECK(*list.impl_->elementType == *getTypePtr() + || (list.use_count() == 1 && list.impl_->elementType->isSubtypeOf(*getTypePtr())) + , "Tried to cast a List<", toString(*list.impl_->elementType), "> to a List<", toString(*getTypePtr()), ">. Types mismatch."); + return List(std::move(list.impl_)); +} + +template +impl::GenericList toList(List&& list) { + return GenericList(std::move(list.impl_)); +} +template +impl::GenericList toList(const List& list) { + return GenericList(list.impl_); +} +} + +template +List List::copy() const { + return List(impl_->copy()); +} + +namespace detail { + template + T list_element_to(T element) { + return element; + } + template + T list_element_to(const IValue& element) { + return element.template to(); + } + template + T list_element_to(IValue&& element) { + return std::move(element).template to(); + } + template + struct ListElementFrom { + static IValue from(const T& element) { + return element; + } + static IValue from(T&& element) { + return std::move(element); + } + }; + template<> + struct ListElementFrom { + static const IValue& from(const IValue& element) { + return element; + } + static IValue&& from(IValue&& element) { + return std::move(element); + } + }; +} + +namespace impl { + +template +ListElementReference::operator std::conditional_t< + std::is_reference_v::type>, + const T&, + T>() const { + return iterator_->template to(); +} + +template +ListElementReference& ListElementReference::operator=(T&& new_value) && { + *iterator_ = c10::detail::ListElementFrom::from(std::move(new_value)); + return *this; +} + +template +ListElementReference& ListElementReference::operator=(const T& new_value) && { + *iterator_ = c10::detail::ListElementFrom::from(new_value); + return *this; +} + +template +ListElementReference& ListElementReference::operator=(ListElementReference&& rhs) && noexcept { + *iterator_ = *rhs.iterator_; + return *this; +} + +template +void swap(ListElementReference&& lhs, ListElementReference&& rhs) noexcept { + std::swap(*lhs.iterator_, *rhs.iterator_); +} + +template +bool operator==(const ListElementReference& lhs, const T& rhs) { + const T& lhs_tmp = lhs; + return lhs_tmp == rhs; +} + +template +inline bool operator==(const T& lhs, const ListElementReference& rhs) { + return rhs == lhs; +} + +template +inline typename ListElementConstReferenceTraits::const_reference +list_element_to_const_ref(const IValue& element) { + return element.template to(); +} + +template<> +inline typename ListElementConstReferenceTraits>::const_reference +list_element_to_const_ref>(const IValue& element) { + return element.toOptionalStringRef(); +} + +} // namespace impl + +template +void List::set(size_type pos, const value_type& value) const { + impl_->list.at(pos) = c10::detail::ListElementFrom::from(value); +} + +template +void List::set(size_type pos, value_type&& value) const { + impl_->list.at(pos) = c10::detail::ListElementFrom::from(std::move(value)); +} + +template +typename List::internal_const_reference_type List::get(size_type pos) const { + return operator[](pos); +} + +template +typename List::internal_const_reference_type List::operator[](size_type pos) const { + return c10::impl::list_element_to_const_ref(impl_->list.at(pos)); +} + +template +typename List::internal_reference_type List::operator[](size_type pos) { + static_cast(impl_->list.at(pos)); // Throw the exception if it is out of range. + return {impl_->list.begin() + static_castlist)::difference_type>(pos)}; +} + +template +typename List::value_type List::extract(size_type pos) const { + auto& elem = impl_->list.at(pos); + auto result = c10::detail::list_element_to(std::move(elem)); + // Reset the list element to a T() instead of None to keep it correctly typed + elem = c10::detail::ListElementFrom::from(T{}); + return result; +} + +template +typename List::iterator List::begin() const { + return iterator(impl_->list.begin()); +} + +template +typename List::iterator List::end() const { + return iterator(impl_->list.end()); +} + +template +bool List::empty() const { + return impl_->list.empty(); +} + +template +typename List::size_type List::size() const { + return impl_->list.size(); +} + +template +void List::reserve(size_type new_cap) const { + impl_->list.reserve(new_cap); +} + +template +void List::clear() const { + impl_->list.clear(); +} + +template +typename List::iterator List::insert(iterator pos, const T& value) const { + return iterator { impl_->list.insert(pos.iterator_, c10::detail::ListElementFrom::from(value)) }; +} + +template +typename List::iterator List::insert(iterator pos, T&& value) const { + return iterator { impl_->list.insert(pos.iterator_, c10::detail::ListElementFrom::from(std::move(value))) }; +} + +template +template +typename List::iterator List::emplace(iterator pos, Args&&... value) const { + // TODO Use list_element_from? + return iterator { impl_->list.emplace(pos.iterator_, std::forward(value)...) }; +} + +template +void List::push_back(const T& value) const { + impl_->list.push_back(c10::detail::ListElementFrom::from(value)); +} + +template +void List::push_back(T&& value) const { + impl_->list.push_back(c10::detail::ListElementFrom::from(std::move(value))); +} + +template +void List::append(List b) const { + if (b.use_count() == 1) { + impl_->list.insert(impl_->list.end(), make_move_iterator(b.impl_->list.begin()), make_move_iterator(b.impl_->list.end())); + } else { + impl_->list.insert(impl_->list.end(), b.impl_->list.begin(), b.impl_->list.end()); + } +} + +template +template +void List::emplace_back(Args&&... args) const { + // TODO Use list_element_from? + impl_->list.push_back(T(std::forward(args)...)); +} + +template +typename List::iterator List::erase(iterator pos) const { + return iterator { impl_->list.erase(pos.iterator_) }; +} + +template +typename List::iterator List::erase(iterator first, iterator last) const { + return iterator { impl_->list.erase(first.iterator_, last.iterator_) }; +} + +template +void List::pop_back() const { + impl_->list.pop_back(); +} + +template +void List::resize(size_type count) const { + impl_->list.resize(count, T{}); +} + +template +void List::resize(size_type count, const T& value) const { + impl_->list.resize(count, value); +} + +template +bool operator==(const List& lhs, const List& rhs) { + // Lists with the same identity trivially compare equal. + if (lhs.impl_ == rhs.impl_) { + return true; + } + + // Otherwise, just compare values directly. + return *lhs.impl_ == *rhs.impl_; +} + +template +bool operator!=(const List& lhs, const List& rhs) { + return !(lhs == rhs); +} + +template +bool List::is(const List& rhs) const { + return this->impl_ == rhs.impl_; +} + +template +std::vector List::vec() const { + std::vector result(begin(), end()); + return result; +} + +template +size_t List::use_count() const { + return impl_.use_count(); +} + +template +TypePtr List::elementType() const { + return impl_->elementType; +} + +template +void List::unsafeSetElementType(TypePtr t) { + impl_->elementType = std::move(t); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/MT19937RNGEngine.h b/phivenv/Lib/site-packages/torch/include/ATen/core/MT19937RNGEngine.h new file mode 100644 index 0000000000000000000000000000000000000000..7aaebf8289e5c3ce80411846ea8f24ca29c3f620 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/MT19937RNGEngine.h @@ -0,0 +1,194 @@ +#pragma once + +#include + +// define constants like M_PI and C keywords for MSVC +#ifdef _MSC_VER +#ifndef _USE_MATH_DEFINES +#define _USE_MATH_DEFINES +#endif +#include +#endif + +#include +#include +#include + +namespace at { + +constexpr int MERSENNE_STATE_N = 624; +constexpr int MERSENNE_STATE_M = 397; +constexpr uint32_t MATRIX_A = 0x9908b0df; +constexpr uint32_t UMASK = 0x80000000; +constexpr uint32_t LMASK = 0x7fffffff; + +/** + * Note [Mt19937 Engine implementation] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * Originally implemented in: + * http://www.math.sci.hiroshima-u.ac.jp/~m-mat/MT/MT2002/CODES/MTARCOK/mt19937ar-cok.c + * and modified with C++ constructs. Moreover the state array of the engine + * has been modified to hold 32 bit uints instead of 64 bits. + * + * Note that we reimplemented mt19937 instead of using std::mt19937 because, + * at::mt19937 turns out to be faster in the pytorch codebase. PyTorch builds with -O2 + * by default and following are the benchmark numbers (benchmark code can be found at + * https://github.com/syed-ahmed/benchmark-rngs): + * + * with -O2 + * Time to get 100000000 philox randoms with at::uniform_real_distribution = 0.462759s + * Time to get 100000000 at::mt19937 randoms with at::uniform_real_distribution = 0.39628s + * Time to get 100000000 std::mt19937 randoms with std::uniform_real_distribution = 0.352087s + * Time to get 100000000 std::mt19937 randoms with at::uniform_real_distribution = 0.419454s + * + * std::mt19937 is faster when used in conjunction with std::uniform_real_distribution, + * however we can't use std::uniform_real_distribution because of this bug: + * http://open-std.org/JTC1/SC22/WG21/docs/lwg-active.html#2524. Plus, even if we used + * std::uniform_real_distribution and filtered out the 1's, it is a different algorithm + * than what's in pytorch currently and that messes up the tests in tests_distributions.py. + * The other option, using std::mt19937 with at::uniform_real_distribution is a tad bit slower + * than at::mt19937 with at::uniform_real_distribution and hence, we went with the latter. + * + * Copyright notice: + * A C-program for MT19937, with initialization improved 2002/2/10. + * Coded by Takuji Nishimura and Makoto Matsumoto. + * This is a faster version by taking Shawn Cokus's optimization, + * Matthe Bellew's simplification, Isaku Wada's real version. + * + * Before using, initialize the state by using init_genrand(seed) + * or init_by_array(init_key, key_length). + * + * Copyright (C) 1997 - 2002, Makoto Matsumoto and Takuji Nishimura, + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. The names of its contributors may not be used to endorse or promote + * products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * + * Any feedback is very welcome. + * http://www.math.sci.hiroshima-u.ac.jp/~m-mat/MT/emt.html + * email: m-mat @ math.sci.hiroshima-u.ac.jp (remove space) + */ + +/** + * mt19937_data_pod is used to get POD data in and out + * of mt19937_engine. Used in torch.get_rng_state and + * torch.set_rng_state functions. + */ +struct mt19937_data_pod { + uint64_t seed_; + int left_; + bool seeded_; + uint32_t next_; + std::array state_; +}; + +class mt19937_engine { +public: + + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) + inline explicit mt19937_engine(uint64_t seed = 5489) { + init_with_uint32(seed); + } + + inline mt19937_data_pod data() const { + return data_; + } + + inline void set_data(const mt19937_data_pod& data) { + data_ = data; + } + + inline uint64_t seed() const { + return data_.seed_; + } + + inline bool is_valid() { + if ((data_.seeded_ == true) + && (data_.left_ > 0 && data_.left_ <= MERSENNE_STATE_N) + && (data_.next_ <= MERSENNE_STATE_N)) { + return true; + } + return false; + } + + inline uint32_t operator()() { + if (--(data_.left_) == 0) { + next_state(); + } + uint32_t y = *(data_.state_.data() + data_.next_++); + y ^= (y >> 11); + y ^= (y << 7) & 0x9d2c5680; + y ^= (y << 15) & 0xefc60000; + y ^= (y >> 18); + + return y; + } + +private: + mt19937_data_pod data_; + + inline void init_with_uint32(uint64_t seed) { + data_.seed_ = seed; + data_.seeded_ = true; + data_.state_[0] = seed & 0xffffffff; + for (const auto j : c10::irange(1, MERSENNE_STATE_N)) { + data_.state_[j] = (1812433253 * (data_.state_[j-1] ^ (data_.state_[j-1] >> 30)) + j); + } + data_.left_ = 1; + data_.next_ = 0; + } + + inline uint32_t mix_bits(uint32_t u, uint32_t v) { + return (u & UMASK) | (v & LMASK); + } + + inline uint32_t twist(uint32_t u, uint32_t v) { + return (mix_bits(u,v) >> 1) ^ (v & 1 ? MATRIX_A : 0); + } + + inline void next_state() { + uint32_t* p = data_.state_.data(); + data_.left_ = MERSENNE_STATE_N; + data_.next_ = 0; + + for(int j = MERSENNE_STATE_N - MERSENNE_STATE_M + 1; --j; p++) { + *p = p[MERSENNE_STATE_M] ^ twist(p[0], p[1]); + } + + for(int j = MERSENNE_STATE_M; --j; p++) { + *p = p[MERSENNE_STATE_M - MERSENNE_STATE_N] ^ twist(p[0], p[1]); + } + + *p = p[MERSENNE_STATE_M - MERSENNE_STATE_N] ^ twist(p[0], data_.state_[0]); + } + +}; + +typedef mt19937_engine mt19937; + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/NamedTensor.h b/phivenv/Lib/site-packages/torch/include/ATen/core/NamedTensor.h new file mode 100644 index 0000000000000000000000000000000000000000..fd63651d48770b290ca6ae05fcdb45f8cb13bf91 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/NamedTensor.h @@ -0,0 +1,143 @@ +#pragma once + +#include +#include + +namespace at { + +class TensorBase; + +// XXX: This file exists because TensorImpl is in c10, but Dimname is in ATen. +// Due to the c10/ATen library split, TensorImpl cannot depend on Dimname, +// so we have a couple of workarounds. +// +// In the long term, we'll move Dimname to c10 and everything in this file +// can be refactored out. The main blocker for that is that "c10::Symbol" +// actually exists outside of c10 and needs to be moved in. + +// TensorImpl has a unique_ptr field. +// XXX: Ideally we would just put std::optional> into TensorImpl. +// +// This class has an important invariant: there must be at least ONE +// non-wildcard +struct TORCH_API NamedTensorMeta final : public c10::NamedTensorMetaInterface { + // This enum is to remind people that the invariant on constructors is that + // the list of dimnames must have at least one non-wildcard + enum HAS_NON_WILDCARD { + HasNonWildcard + }; + + explicit NamedTensorMeta(HAS_NON_WILDCARD, DimnameList names) + : names_(names.vec()) { + check_invariants(); + } + explicit NamedTensorMeta(HAS_NON_WILDCARD, std::vector&& names) + : names_(std::move(names)) { + check_invariants(); + } + + std::unique_ptr clone() const override { + return std::make_unique(HasNonWildcard, names_); + } + + DimnameList names() const { return names_; } + + // Used for an assertion in TensorImpl.h + int64_t slow_dim() const override { + return static_cast(names_.size()); + } + + void check_invariants() const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + std::any_of(names_.begin(), names_.end(), [](const Dimname& n) { return !n.isWildcard(); })); + } + + void set_names(HAS_NON_WILDCARD, DimnameList new_names) { + TORCH_INTERNAL_ASSERT(new_names.size() == names_.size()); + std::copy(new_names.begin(), new_names.end(), names_.begin()); + check_invariants(); + } + + void set_names(HAS_NON_WILDCARD, std::vector&& new_names) { + TORCH_INTERNAL_ASSERT(new_names.size() == names_.size()); + names_ = std::move(new_names); + check_invariants(); + } + + // INVARIANT: at least one Dimname is non-WILDCARD + std::vector names_; +}; + +// When NamesMode is disabled, then all operations ignore tensors' names fields. +// Concretely speaking, all tensors are treated as having nullopt names. +struct TORCH_API NamesMode { + static bool is_enabled(); + static void set_enabled(bool enabled); +}; + + +// A RAII, thread local (!) guard that enables or disables names upon +// construction, and sets it back to the original value upon destruction. +struct TORCH_API NoNamesGuard { + NoNamesGuard() : prev_mode(NamesMode::is_enabled()) { + NamesMode::set_enabled(false); + } + NoNamesGuard(const NoNamesGuard&) = delete; + NoNamesGuard(NoNamesGuard&&) = delete; + NoNamesGuard& operator=(const NoNamesGuard&) = delete; + NoNamesGuard& operator=(NoNamesGuard&&) = delete; + ~NoNamesGuard() { + if (initialized) { + reset(); + } + } + void reset() { + TORCH_INTERNAL_ASSERT(initialized); + NamesMode::set_enabled(prev_mode); + } + private: + bool prev_mode; + bool initialized{true}; +}; + +void check_names_valid_for(const TensorBase& tensor, DimnameList names); +void check_names_valid_for(size_t tensor_dim, DimnameList names); + +// Sets the names of `tensor` to be `names`. +TORCH_API const TensorBase& internal_set_names_inplace(const TensorBase& tensor, std::optional names); +TORCH_API const TensorBase& internal_set_names_inplace(const TensorBase& tensor, std::vector&& names, bool validate_names); + +constexpr size_t kMaxNamedTensorDim = 64; + +DimnameList default_names(size_t len); + +namespace impl { + +// Some helper functions on TensorImpl. Useful for working with names in TH. +// XXX: Ideally these would exist as methods on TensorImpl +TORCH_API void internal_set_names_inplace(TensorImpl* impl, std::optional names, bool validate_names); +TORCH_API void internal_set_names_inplace(TensorImpl* impl, std::vector&& names, bool validate_names); + +void check_names_valid_for(TensorImpl* impl, DimnameList names); + +// Returns true if the tensor's names exist and are not all 'None'. +// Returns false if the tensor's names don't exist (were not allocated), +// or if all names are 'None'. +// We treat not-allocated-names the same as allocated names that are all 'None'. +TORCH_API bool has_names(const TensorImpl* impl); + +// Returns the names of the tensor's dimensions. +// Unnamed tensors are treated as having 'None' in all dimension; this method +// would return a DimnameList of all 'None's for an unnamed tensor. +TORCH_API DimnameList get_names(const TensorImpl* impl); + +// This is more of an implementation detail; one should use impl::get_names / +// Tensor::names() whenever possible because it provides a cleaner API. +// Returns the names of the tensor if they have been allocated; returns nullopt +// instead if the haven't been. The names of a tensor are not allocated if a +// tensor is constructed with names=None. +TORCH_API std::optional get_opt_names(const TensorImpl* impl); + +} // namespace impl + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/NestedIntSymNodeImpl.h b/phivenv/Lib/site-packages/torch/include/ATen/core/NestedIntSymNodeImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..cc67bec1512041bd3a5117d63b8f8470dd298be1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/NestedIntSymNodeImpl.h @@ -0,0 +1,187 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +// The motivating usecase for this is to represent the ragged size structure +// of a jagged tensor [B, [s_0, s_1, s_2], D] as a single integer j0. This +// allows us to simply return [B, j0, D] if someone queries for the size of our +// tensor. +// +// Morally we define comparison between two nested ints to return true if +// that comparison holds for all corresponding elements of the arrays they +// represent. Comparison between a nested int and a plain int is defined +// similarly. +// +// To simulate this desired behavior but also avoid the O(N) cost of checking, +// we associate each raggedness pattern with an integer "id" that can be used as +// a proxy to evaluate equality. We also constrain the range of values for this +// as to enable inequality checks. +// +// We also support a positive integer scalar "coeff" that is used for computing +// strides. For example given, a [B, j0, D] tensor, it can be strided in two +// different ways: [D * j0, D, 1] and [j0, 1, sum(j0)]. The coeff is used to +// differentiate the two cases. +// +// During tracing the strides of the outputs need to be a function of the size +// and strides of the inputs so it is important that NestedIntSymNode itself is +// able to express this. +class TORCH_API NestedIntSymNodeImpl : public SymNodeImpl { + public: + // CAUTION: you should probably not be constructing these directly; please + // the higher-level API in python instead (TODO: actually introduce that). + explicit NestedIntSymNodeImpl(int64_t val, int64_t coeff) + : val_(val), coeff_(coeff) {} + + bool bool_() override { + return false; + } + + bool is_int() override { + return true; + } + + bool is_float() override { + return false; + } + + bool is_bool() override { + return false; + } + + bool is_nested_int() const override { + return true; + } + + bool has_hint() override { + return true; + } + + c10::SymNode wrap_int(int64_t num) override { + return SymNode(c10::make_intrusive>(num)); + } + + int64_t guard_int(const char* file, int64_t line) override { + TORCH_CHECK(false); + } + + double guard_float(const char* file, int64_t line) override { + TORCH_CHECK(false, "not a float"); + } + + bool guard_bool(const char* file, int64_t line) override { + TORCH_CHECK(false, "not a bool"); + } + + int64_t int_() override { + TORCH_CHECK(false); + } + + std::string str() override { + if (coeff_ == 1) { + return "j" + std::to_string(val_); + } + return std::to_string(coeff_) + "*j" + std::to_string(val_); + } + + // NOTE [ Inequalities with nested int ] + // + // The semantics of nested int when it comes to relations is that it is + // treated as integer known to be within a certain range, + // + // j0 \in [2, int64_t::max] + // + // allowing us to answer queries like j0 >= 1 (True), and j0 == 0 (False). + // This is a useful default range for the raggedness pattern of a jagged + // tensor (1) since sizes are non-negative, and (2) we need to get past 0/1 + // specialization checks. + // + // [ Indeterminate inequalities error out ] + // + // Given the semantic defined above, certain relations like j0 < 3 are thus + // indeterminable. In our impl today, evaluating such relations error + // + // It may seem convenient to just define indeterminate relations to return + // False, but the implementation we maintain in parallel using sympy does not + // allow this. + // + // Sympy only allows overriding of Ge. The other relations (Lt, Gt, Le) are, + // by consequence, all derived from Ge e.g., Lt(a, b) := !Ge(a, b). This + // would mean that means that if we define the indeterminate j0 >= 3 to be + // False, the also indeterminate j0 < 3 will be evaluated to be True! + // + // [ Coefficient are assumed positive ] + // + // For the purpose of computing inequalities, we consider the coefficient of + // the nested int to be a positive integer. + // + // Thus, no modifications are needed to the logic since + // j0 >= k implies coeff * j0 >= k + // + c10::SymNode eq(const c10::SymNode& other) override; + c10::SymNode ne(const c10::SymNode& other) override; + c10::SymNode ge(const c10::SymNode& other) override; + c10::SymNode gt(const c10::SymNode& other) override; + c10::SymNode lt(const c10::SymNode& other) override; + c10::SymNode le(const c10::SymNode& other) override; + c10::SymNode mul(const c10::SymNode& other) override; + + std::optional nested_int() override { + return val_; + } + + std::optional nested_int_coeff() override { + return coeff_; + } + + bool is_symbolic() override { + return false; + } + + c10::SymNode clone() override; + +#define DEFINE_BINARY_NOT_SUPPORTED(name) \ + c10::SymNode name(const c10::SymNode& other) override { \ + TORCH_CHECK(false, #name " not supported by NestedIntSymNode"); \ + } + + DEFINE_BINARY_NOT_SUPPORTED(add) + DEFINE_BINARY_NOT_SUPPORTED(sub) + DEFINE_BINARY_NOT_SUPPORTED(truediv) + DEFINE_BINARY_NOT_SUPPORTED(pow) + DEFINE_BINARY_NOT_SUPPORTED(floordiv) + DEFINE_BINARY_NOT_SUPPORTED(mod) + DEFINE_BINARY_NOT_SUPPORTED(sym_min) + DEFINE_BINARY_NOT_SUPPORTED(sym_max) + DEFINE_BINARY_NOT_SUPPORTED(sym_and) + DEFINE_BINARY_NOT_SUPPORTED(sym_or) + +#undef DEFINE_BINARY_NOT_SUPPORTED + +#define DEFINE_NOT_SUPPORTED(name) \ + c10::SymNode name() override { \ + TORCH_CHECK(false, #name " is not supported by NestedIntSymNode"); \ + } + + DEFINE_NOT_SUPPORTED(sym_not) + DEFINE_NOT_SUPPORTED(ceil) + DEFINE_NOT_SUPPORTED(floor) + DEFINE_NOT_SUPPORTED(neg) + DEFINE_NOT_SUPPORTED(sym_float) + +#undef DEFINE_NOT_SUPPORTED + + private: + int64_t val_; + int64_t coeff_; +}; + +} // namespace c10 diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/PhiloxRNGEngine.h b/phivenv/Lib/site-packages/torch/include/ATen/core/PhiloxRNGEngine.h new file mode 100644 index 0000000000000000000000000000000000000000..6ce1fb0508b5664d8d5cfc669bb21ba948868451 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/PhiloxRNGEngine.h @@ -0,0 +1,240 @@ +#pragma once + +// define constants like M_PI and C keywords for MSVC +#ifdef _MSC_VER +#define _USE_MATH_DEFINES +#include +#endif + + +#ifdef __CUDACC__ +#include +#endif + +#include +#include +#include +#include + +namespace at { + +// typedefs for holding vector data +namespace detail { + +typedef std::array UINT4; +typedef std::array UINT2; +typedef std::array DOUBLE2; +typedef std::array FLOAT2; + +} // namespace detail + +/** + * Note [Philox Engine implementation] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * Originally implemented in PyTorch's fusion compiler + * Refer to: http://www.thesalmons.org/john/random123/papers/random123sc11.pdf + * for details regarding the engine. + * + * Note that currently this implementation of the philox engine is not used + * anywhere except for tests in cpu_generator_test.cpp. However, this engine + * will replace curandStatePhilox4_32_10_t in the future. + * + * The philox engine takes a seed value, a subsequeunce + * for starting the generation and an offset for the subsequence. + * Think of this engine as an algorithm producing a huge array. We are + * parallelizing this array by partitioning the huge array and assigning + * a thread index to each partition. In other words, each seed value + * (there are 2^64 possible seed values) gives a sub array of size + * 2^128 (each element in that array is a 128 bit number). Reasoning + * behind the array being of size 2^128 is, there are 2^64 possible + * thread index value and there is an array of size 2^64 for each of + * those thread index. Hence 2^64 * 2^64 = 2^128 for each seed value. + * + * In short, this generator can produce 2^64 (seed values) * 2^128 (number + * of elements in an array given by a seed value) = 2^192 values. + * + * Arguments: + * seed: Seed values could be any number from 0 to 2^64-1. + * subsequence: Subsequence is just the cuda thread indexing with: + * - blockIdx.x * blockDim.x + threadIdx.x + * offset: The offset variable in PhiloxEngine decides how many 128-bit + * random numbers to skip (i.e. how many groups of 4, 32-bit numbers to skip) + * and hence really decides the total number of randoms that can be achieved + * for the given subsequence. + */ + +class philox_engine { +public: + + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) + C10_HOST_DEVICE inline explicit philox_engine(uint64_t seed = 67280421310721, + uint64_t subsequence = 0, + uint64_t offset = 0) { + + reset_state(seed, subsequence); + incr_n(offset); + } + + C10_HOST_DEVICE inline void reset_state(uint64_t seed = 67280421310721, + uint64_t subsequence = 0) { + key_[0] = static_cast(seed); + key_[1] = static_cast(seed >> 32); + counter_ = detail::UINT4{}; + counter_[2] = static_cast(subsequence); + counter_[3] = static_cast(subsequence >> 32); + STATE = 0; + } + + /** + * Set the offset field of Philox Generator to the desired offset. + */ + C10_HOST_DEVICE inline void set_offset(uint64_t offset) { + counter_[0] = static_cast(offset); + counter_[1] = static_cast(offset >> 32); + } + + /** + * Gets the current offset of the Philox Generator. + */ + C10_HOST_DEVICE uint64_t get_offset() const { + uint64_t lo = static_cast(counter_[0]); + uint64_t hi = static_cast(counter_[1]) << 32; + return lo | hi; + } + + /** + * Produces a unique 32-bit pseudo random number on every invocation. Bookeeps state to avoid waste. + */ + C10_HOST_DEVICE inline uint32_t operator()(int32_t n_rounds = 10) { // 10 here to preserve back-compat behavior + if(STATE == 0) { + detail::UINT4 counter = counter_; + detail::UINT2 key = key_; + output_ = rand(counter, key, n_rounds); + incr(); + } + uint32_t ret = output_[static_cast(STATE)]; + STATE = (STATE + 1) & 3; + return ret; + } + + inline float randn(uint32_t n_rounds) { + #ifdef __CUDA_ARCH__ + AT_ASSERT(false, "Unsupported invocation of randn on CUDA"); + #endif + if(STATE == 0) { + detail::UINT4 counter = counter_; + detail::UINT2 key = key_; + output_ = rand(counter, key, n_rounds); + incr(); + } + // TODO(min-jean-cho) change to Polar method, a more efficient version of Box-Muller method + // TODO(voz) We use std:: below, and thus need a separate impl for CUDA. + float u1 = 1 - uint32_to_uniform_float(output_[0]); // uint32_to_uniform_float returns [0,1), we need (0,1] to avoid passing 0 to log. + float u2 = 1 - uint32_to_uniform_float(output_[1]); + return static_cast(std::sqrt(-2.0 * std::log(u1)) * std::cos(2.0 * M_PI * u2)); + } + + /** + * Function that Skips N 128 bit numbers in a subsequence + */ + C10_HOST_DEVICE inline void incr_n(uint64_t n) { + uint32_t nlo = static_cast(n); + uint32_t nhi = static_cast(n >> 32); + counter_[0] += nlo; + // if overflow in x has occurred, carry over to nhi + if (counter_[0] < nlo) { + nhi++; + // if overflow in nhi has occurred during carry over, + // propagate that overflow to y and exit to increment z + // otherwise return + counter_[1] += nhi; + if(nhi != 0) { + if (nhi <= counter_[1]) { + return; + } + } + } else { + // if overflow in y has occurred during addition, + // exit to increment z + // otherwise return + counter_[1] += nhi; + if (nhi <= counter_[1]) { + return; + } + } + if (++counter_[2]) + return; + ++counter_[3]; + } + + /** + * Function that Skips one 128 bit number in a subsequence + */ + C10_HOST_DEVICE inline void incr() { + if (++counter_[0]) + return; + if (++counter_[1]) + return; + if (++counter_[2]) { + return; + } + ++counter_[3]; + } + +private: + detail::UINT4 counter_; + detail::UINT4 output_; + detail::UINT2 key_; + uint32_t STATE; + + C10_HOST_DEVICE inline uint32_t mulhilo32(uint32_t a, uint32_t b, + uint32_t *result_high) { + #ifdef __CUDA_ARCH__ + *result_high = __umulhi(a, b); + return a*b; + #else + const uint64_t product = static_cast(a) * b; + *result_high = static_cast(product >> 32); + return static_cast(product); + #endif + } + + C10_HOST_DEVICE inline detail::UINT4 single_round(detail::UINT4 ctr, detail::UINT2 in_key) { + uint32_t hi0 = 0; + uint32_t hi1 = 0; + uint32_t lo0 = mulhilo32(kPhiloxSA, ctr[0], &hi0); + uint32_t lo1 = mulhilo32(kPhiloxSB, ctr[2], &hi1); + detail::UINT4 ret; + ret[0] = hi1 ^ ctr[1] ^ in_key[0]; + ret[1] = lo1; + ret[2] = hi0 ^ ctr[3] ^ in_key[1]; + ret[3] = lo0; + return ret; + } + + C10_HOST_DEVICE constexpr float uint32_to_uniform_float(uint32_t value) { + // maximum value such that `MAX_INT * scale < 1.0` (with float rounding) + constexpr float scale = 4.6566127342e-10; + return static_cast(value & 0x7FFFFFFF) * scale; + } + + + + C10_HOST_DEVICE inline detail::UINT4 rand(detail::UINT4& counter, detail::UINT2& key, uint32_t n_rounds) { + for (uint32_t round = 0; round < (n_rounds - 1); round++) { + counter = single_round(counter, key); + key[0] += (kPhilox10A); key[1] += (kPhilox10B); + } + return single_round(counter, key); + } + + + static const uint32_t kPhilox10A = 0x9E3779B9; + static const uint32_t kPhilox10B = 0xBB67AE85; + static const uint32_t kPhiloxSA = 0xD2511F53; + static const uint32_t kPhiloxSB = 0xCD9E8D57; +}; + +typedef philox_engine Philox4_32; + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/PythonFallbackKernel.h b/phivenv/Lib/site-packages/torch/include/ATen/core/PythonFallbackKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..1639c422829dab38c4333a4236ada449c274045c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/PythonFallbackKernel.h @@ -0,0 +1,35 @@ +#pragma once +#include + + +namespace at::impl { + +struct TORCH_API RestorePythonTLSSnapshot { + RestorePythonTLSSnapshot(); + RestorePythonTLSSnapshot(RestorePythonTLSSnapshot&& other) = delete; + RestorePythonTLSSnapshot(const RestorePythonTLSSnapshot&) = delete; + RestorePythonTLSSnapshot& operator=(const RestorePythonTLSSnapshot&) = delete; + RestorePythonTLSSnapshot& operator=(RestorePythonTLSSnapshot&&) = delete; + ~RestorePythonTLSSnapshot(); + +private: + c10::impl::LocalDispatchKeySet saved_; + c10::impl::ForceDispatchKeyGuard guard_; +}; + + +// RAII guard to make working with the above TLS safer. +struct TORCH_API MaybeSetTLSOnEntryGuard { +public: + MaybeSetTLSOnEntryGuard(); + MaybeSetTLSOnEntryGuard(MaybeSetTLSOnEntryGuard&& other) = delete; + MaybeSetTLSOnEntryGuard(const MaybeSetTLSOnEntryGuard&) = delete; + MaybeSetTLSOnEntryGuard& operator=(const MaybeSetTLSOnEntryGuard&) = delete; + MaybeSetTLSOnEntryGuard& operator=(MaybeSetTLSOnEntryGuard&&) = delete; + ~MaybeSetTLSOnEntryGuard(); + +private: + bool value_set_; +}; + +} // namespace at::impl diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/PythonOpRegistrationTrampoline.h b/phivenv/Lib/site-packages/torch/include/ATen/core/PythonOpRegistrationTrampoline.h new file mode 100644 index 0000000000000000000000000000000000000000..2a18d9d23a509763ca4befda69881c37c4513cc0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/PythonOpRegistrationTrampoline.h @@ -0,0 +1,22 @@ +#pragma once + +#include + +// TODO: this can probably live in c10 + + +namespace at::impl { + +class TORCH_API PythonOpRegistrationTrampoline final { + static std::atomic interpreter_; + +public: + // Returns true if you successfully registered yourself (that means + // you are in the hot seat for doing the operator registrations!) + static bool registerInterpreter(c10::impl::PyInterpreter*); + + // Returns nullptr if no interpreter has been registered yet. + static c10::impl::PyInterpreter* getInterpreter(); +}; + +} // namespace at::impl diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/QuantizerBase.h b/phivenv/Lib/site-packages/torch/include/ATen/core/QuantizerBase.h new file mode 100644 index 0000000000000000000000000000000000000000..165baf43421be5c04d9e4869f9f8d2b4de797e3d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/QuantizerBase.h @@ -0,0 +1,84 @@ +#pragma once + +#include +#include +#include + +namespace at { + +class Tensor; +struct QTensorImpl; +struct Quantizer; +using ConstQuantizerPtr = const c10::intrusive_ptr&; +using QuantizerPtr = c10::intrusive_ptr; + +/** + * Quantizer is the class for storing all the information + * that's necessary to perform quantize and dequantize + * operation. + * + * We might have different types of quantization schemes and this is + * the base class for all quantizers. + * + * QTensorImpl will hold a pointer to Quantizer so that we can support + * different quantization schemes on Tensor. + * + * For example, the most common quantization scheme, Affine Quantization, + * requires scale and zero_point as parameters, we'll store scale and zero_point + * inside the instance and we can use it to quantize a float Tensor or + * dequantize a quantized Tensor. + * + * When you add new types of leaf Quantizer class, please also + * make sure to add a corresponding QScheme enum since + * they should have one to one mapping. + * + * Note about intrusive_ptr: + * Quantized Tensor holds an intrusive_ptr to Quantizer, and multiple Tensor can + * share the same Quantizer. Quantizer should be immutable. + */ +struct TORCH_API Quantizer : public c10::intrusive_ptr_target { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const ScalarType scalar_type_; + explicit Quantizer(ScalarType scalar_type) : scalar_type_(scalar_type) {} + ~Quantizer() override = default; + + // Copied from torch/csrc/jit/ir/scope.h + QuantizerPtr intrusive_from_this() { + c10::raw::intrusive_ptr::incref(this); // we are creating a new pointer + // from a raw `this` pointer + // so we need to bump the refcount + // to account for this ownership + return c10::intrusive_ptr::reclaim(this); + } + + /** + * Each concrete Quantizer type should have a unique QScheme type. + */ + virtual QScheme qscheme() const = 0; + + ScalarType scalar_type() const { + return scalar_type_; + } + + /** + * quantize a float Tensor into a quantized Tensor. + */ + virtual Tensor quantize(const Tensor& t) = 0; + + /** + * dequantize a quantized Tensor into a float Tensor. + */ + virtual Tensor dequantize(const Tensor& t) = 0; + + /** + * dequantize a quantized Tensor into a float Tensor, out= variant + */ + virtual Tensor& dequantize_out(Tensor& out, const Tensor& t) = 0; + + /** + * Compare against `other` for equality. + */ + virtual bool equalTo(QuantizerPtr other) const = 0; +}; + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/Range.h b/phivenv/Lib/site-packages/torch/include/ATen/core/Range.h new file mode 100644 index 0000000000000000000000000000000000000000..eb79331a2fa8e6520929badeeab10868d9f6f23e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/Range.h @@ -0,0 +1,25 @@ +#pragma once + +#include +#include + +namespace at { + +struct Range { + Range(int64_t begin, int64_t end) + : begin(begin) + , end(end) {} + + int64_t size() const { return end - begin; } + + Range operator/(int64_t divisor) { + return Range(begin / divisor, end / divisor); + } + + int64_t begin; + int64_t end; +}; + +std::ostream& operator<<(std::ostream& out, const Range& range); + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/Reduction.h b/phivenv/Lib/site-packages/torch/include/ATen/core/Reduction.h new file mode 100644 index 0000000000000000000000000000000000000000..a540db6fb9334f964273236f9096467ac472ae49 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/Reduction.h @@ -0,0 +1,14 @@ +#pragma once + +namespace at::Reduction { + +// NB: Keep this in sync with Reduction class in torch/nn/_reduction.py +// These constants control the reduction behavior of loss functions. +// Ideally, this would be a scoped enum, but jit doesn't support that +enum Reduction { + None, // Do not reduce + Mean, // (Possibly weighted) mean of losses + Sum, // Sum losses + END +}; +} // namespace at::Reduction diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/Scalar.h b/phivenv/Lib/site-packages/torch/include/ATen/core/Scalar.h new file mode 100644 index 0000000000000000000000000000000000000000..7c1649491a2b69236e633783b794d97faf3546b1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/Scalar.h @@ -0,0 +1 @@ +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/ScalarType.h b/phivenv/Lib/site-packages/torch/include/ATen/core/ScalarType.h new file mode 100644 index 0000000000000000000000000000000000000000..b83740b82dc25709e2aa8d2252c7fad88b4638dd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/ScalarType.h @@ -0,0 +1 @@ +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/Tensor.h b/phivenv/Lib/site-packages/torch/include/ATen/core/Tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..96fe83b9954dafcefc900ac6c95f6c3b715ff99f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/Tensor.h @@ -0,0 +1,98 @@ +#pragma once + +#include +#include + +namespace at { +// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions) +class TORCH_API OptionalTensorRef { + public: + OptionalTensorRef() = default; + + ~OptionalTensorRef() { + ref_.unsafeReleaseTensorImpl(); + } + + OptionalTensorRef(const TensorBase& src) + : ref_(Tensor::unsafe_borrow_t{}, src) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src.defined()); + } + + OptionalTensorRef(const OptionalTensorRef& rhs) + : ref_(Tensor::unsafe_borrow_t{}, rhs.ref_) {} + + OptionalTensorRef(OptionalTensorRef&& rhs) = default; + OptionalTensorRef& operator=(OptionalTensorRef rhs) { + std::swap(ref_, rhs.ref_); + return *this; + } + + bool has_value() const { + return ref_.defined(); + } + + const Tensor& getTensorRef() const & { + return ref_; + } + + const Tensor& operator*() const & { + return ref_; + } + + const Tensor* operator->() const & { + return &ref_; + } + + operator bool() const { + return ref_.defined(); + } + + private: + Tensor ref_; +}; + +// Use to convert a TensorBase (that may be undefined) to an at::Tensor +// without bumping refcount. +class TORCH_API TensorRef { + public: + ~TensorRef() { + ref_.unsafeReleaseTensorImpl(); + } + + TensorRef(const TensorBase& src) + : ref_(Tensor::unsafe_borrow_t{}, src) {} + TensorRef(TensorRef&& other) = default; + TensorRef(const TensorRef&) = default; + TensorRef& operator=(const TensorRef&) = default; + TensorRef& operator=(TensorRef&&) = default; + + const Tensor& operator*() const & { + return ref_; + } + private: + Tensor ref_; +}; + +template +auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_void_t { + // Return the grad argument in case of a hook with void return type to have an + // std::function with Tensor return type + static_assert(std::is_same_v, + "Expected hook to return void"); + return _register_hook([fn=std::forward(hook)](const TensorBase& grad_base) { + TensorRef grad(grad_base); + fn(*grad); + return Tensor(); + }); +} + +template +auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_var_t { + return _register_hook([fn=std::forward(hook)](const TensorBase& grad_base) { + TensorRef grad(grad_base); + Tensor ret = fn(*grad); + return TensorBase(std::move(ret)); + }); +} + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/TensorAccessor.h b/phivenv/Lib/site-packages/torch/include/ATen/core/TensorAccessor.h new file mode 100644 index 0000000000000000000000000000000000000000..9de6b1c88e333fe4187de1ae3a277449632263c4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/TensorAccessor.h @@ -0,0 +1,275 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { + +// The PtrTraits argument to the TensorAccessor/GenericPackedTensorAccessor +// is used to enable the __restrict__ keyword/modifier for the data +// passed to cuda. +template +struct DefaultPtrTraits { + typedef T* PtrType; +}; + +#if defined(__CUDACC__) || defined(__HIPCC__) +template +struct RestrictPtrTraits { + typedef T* __restrict__ PtrType; +}; +#endif + +// TensorAccessorBase and TensorAccessor are used for both CPU and CUDA tensors. +// For CUDA tensors it is used in device code (only). This means that we restrict ourselves +// to functions and types available there (e.g. IntArrayRef isn't). + +// The PtrTraits argument is only relevant to cuda to support `__restrict__` pointers. +template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> +class TensorAccessorBase { +public: + typedef typename PtrTraits::PtrType PtrType; + + C10_HOST_DEVICE TensorAccessorBase( + PtrType data_, + const index_t* sizes_, + const index_t* strides_) + : data_(data_), sizes_(sizes_), strides_(strides_) {} + C10_HOST IntArrayRef sizes() const { + return IntArrayRef(sizes_,N); + } + C10_HOST IntArrayRef strides() const { + return IntArrayRef(strides_,N); + } + C10_HOST_DEVICE index_t stride(index_t i) const { + return strides_[i]; + } + C10_HOST_DEVICE index_t size(index_t i) const { + return sizes_[i]; + } + C10_HOST_DEVICE PtrType data() { + return data_; + } + C10_HOST_DEVICE const PtrType data() const { + return data_; + } +protected: + PtrType data_; + const index_t* sizes_; + const index_t* strides_; +}; + +// The `TensorAccessor` is typically instantiated for CPU `Tensor`s using +// `Tensor.accessor()`. +// For CUDA `Tensor`s, `GenericPackedTensorAccessor` is used on the host and only +// indexing on the device uses `TensorAccessor`s. +template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> +class TensorAccessor : public TensorAccessorBase { +public: + typedef typename PtrTraits::PtrType PtrType; + + C10_HOST_DEVICE TensorAccessor( + PtrType data_, + const index_t* sizes_, + const index_t* strides_) + : TensorAccessorBase(data_,sizes_,strides_) {} + + C10_HOST_DEVICE TensorAccessor operator[](index_t i) { + return TensorAccessor(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1); + } + + C10_HOST_DEVICE const TensorAccessor operator[](index_t i) const { + return TensorAccessor(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1); + } +}; + +template class PtrTraits, typename index_t> +class TensorAccessor : public TensorAccessorBase { +public: + typedef typename PtrTraits::PtrType PtrType; + + C10_HOST_DEVICE TensorAccessor( + PtrType data_, + const index_t* sizes_, + const index_t* strides_) + : TensorAccessorBase(data_,sizes_,strides_) {} + C10_HOST_DEVICE T & operator[](index_t i) { + // NOLINTNEXTLINE(clang-analyzer-core.NullDereference) + return this->data_[this->strides_[0]*i]; + } + C10_HOST_DEVICE const T & operator[](index_t i) const { + return this->data_[this->strides_[0]*i]; + } +}; + + +// GenericPackedTensorAccessorBase and GenericPackedTensorAccessor are used on for CUDA `Tensor`s on the host +// and as +// In contrast to `TensorAccessor`s, they copy the strides and sizes on instantiation (on the host) +// in order to transfer them on the device when calling kernels. +// On the device, indexing of multidimensional tensors gives to `TensorAccessor`s. +// Use RestrictPtrTraits as PtrTraits if you want the tensor's data pointer to be marked as __restrict__. +// Instantiation from data, sizes, strides is only needed on the host and std::copy isn't available +// on the device, so those functions are host only. +template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> +class GenericPackedTensorAccessorBase { +public: + typedef typename PtrTraits::PtrType PtrType; + C10_HOST GenericPackedTensorAccessorBase( + PtrType data_, + const index_t* sizes_, + const index_t* strides_) + : data_(data_) { + std::copy(sizes_, sizes_ + N, std::begin(this->sizes_)); + std::copy(strides_, strides_ + N, std::begin(this->strides_)); + } + + // if index_t is not int64_t, we want to have an int64_t constructor + template >> + C10_HOST GenericPackedTensorAccessorBase( + PtrType data_, + const source_index_t* sizes_, + const source_index_t* strides_) + : data_(data_) { + for (const auto i : c10::irange(N)) { + this->sizes_[i] = sizes_[i]; + this->strides_[i] = strides_[i]; + } + } + + C10_HOST_DEVICE index_t stride(index_t i) const { + return strides_[i]; + } + C10_HOST_DEVICE index_t size(index_t i) const { + return sizes_[i]; + } + C10_HOST_DEVICE PtrType data() { + return data_; + } + C10_HOST_DEVICE const PtrType data() const { + return data_; + } +protected: + PtrType data_; + // NOLINTNEXTLINE(*c-arrays*) + index_t sizes_[N]; + // NOLINTNEXTLINE(*c-arrays*) + index_t strides_[N]; + C10_HOST void bounds_check_(index_t i) const { + TORCH_CHECK_INDEX( + 0 <= i && i < index_t{N}, + "Index ", + i, + " is not within bounds of a tensor of dimension ", + N); + } +}; + +template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> +class GenericPackedTensorAccessor : public GenericPackedTensorAccessorBase { +public: + typedef typename PtrTraits::PtrType PtrType; + + C10_HOST GenericPackedTensorAccessor( + PtrType data_, + const index_t* sizes_, + const index_t* strides_) + : GenericPackedTensorAccessorBase(data_, sizes_, strides_) {} + + // if index_t is not int64_t, we want to have an int64_t constructor + template >> + C10_HOST GenericPackedTensorAccessor( + PtrType data_, + const source_index_t* sizes_, + const source_index_t* strides_) + : GenericPackedTensorAccessorBase(data_, sizes_, strides_) {} + + C10_DEVICE TensorAccessor operator[](index_t i) { + index_t* new_sizes = this->sizes_ + 1; + index_t* new_strides = this->strides_ + 1; + return TensorAccessor(this->data_ + this->strides_[0]*i, new_sizes, new_strides); + } + + C10_DEVICE const TensorAccessor operator[](index_t i) const { + const index_t* new_sizes = this->sizes_ + 1; + const index_t* new_strides = this->strides_ + 1; + return TensorAccessor(this->data_ + this->strides_[0]*i, new_sizes, new_strides); + } + + /// Returns a PackedTensorAccessor of the same dimension after transposing the + /// two dimensions given. Does not actually move elements; transposition is + /// made by permuting the size/stride arrays. If the dimensions are not valid, + /// asserts. + C10_HOST GenericPackedTensorAccessor transpose( + index_t dim1, + index_t dim2) const { + this->bounds_check_(dim1); + this->bounds_check_(dim2); + GenericPackedTensorAccessor result( + this->data_, this->sizes_, this->strides_); + std::swap(result.strides_[dim1], result.strides_[dim2]); + std::swap(result.sizes_[dim1], result.sizes_[dim2]); + return result; + } +}; + +template class PtrTraits, typename index_t> +class GenericPackedTensorAccessor : public GenericPackedTensorAccessorBase { +public: + typedef typename PtrTraits::PtrType PtrType; + C10_HOST GenericPackedTensorAccessor( + PtrType data_, + const index_t* sizes_, + const index_t* strides_) + : GenericPackedTensorAccessorBase(data_, sizes_, strides_) {} + + // if index_t is not int64_t, we want to have an int64_t constructor + template >> + C10_HOST GenericPackedTensorAccessor( + PtrType data_, + const source_index_t* sizes_, + const source_index_t* strides_) + : GenericPackedTensorAccessorBase(data_, sizes_, strides_) {} + + C10_DEVICE T & operator[](index_t i) { + return this->data_[this->strides_[0] * i]; + } + C10_DEVICE const T& operator[](index_t i) const { + return this->data_[this->strides_[0]*i]; + } + + // Same as in the general N-dimensional case, but note that in the + // 1-dimensional case the returned PackedTensorAccessor will always be an + // identical copy of the original + C10_HOST GenericPackedTensorAccessor transpose( + index_t dim1, + index_t dim2) const { + this->bounds_check_(dim1); + this->bounds_check_(dim2); + return GenericPackedTensorAccessor( + this->data_, this->sizes_, this->strides_); + } +}; + + +// Can't put this directly into the macro function args because of commas +#define AT_X GenericPackedTensorAccessor + +// Old name for `GenericPackedTensorAccessor` +template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> +C10_DEFINE_DEPRECATED_USING(PackedTensorAccessor, AT_X) + +#undef AT_X + +template class PtrTraits = DefaultPtrTraits> +using PackedTensorAccessor32 = GenericPackedTensorAccessor; + +template class PtrTraits = DefaultPtrTraits> +using PackedTensorAccessor64 = GenericPackedTensorAccessor; +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/TensorBase.h b/phivenv/Lib/site-packages/torch/include/ATen/core/TensorBase.h new file mode 100644 index 0000000000000000000000000000000000000000..0402aada2f0719bbdc847bd52e32e0e24009ca6c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/TensorBase.h @@ -0,0 +1,1056 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace c10 { +class Scalar; +} + +namespace torch::autograd { + +struct Node; + +} // namespace torch::autograd + +namespace at { + +class Tensor; +class TensorBase; + +// Convert Tensor to TensorBase without any need to include Tensor.h +TORCH_API const TensorBase& get_tensor_base(const Tensor& t); + +namespace impl { +inline bool variable_excluded_from_dispatch() { +#ifdef C10_MOBILE + // Please read the comment in `VariableFallbackKernel.cpp` about the background of this change. + return true; +#else + return c10::impl::tls_local_dispatch_key_set().excluded_.isSupersetOf(c10::autograd_dispatch_keyset); +#endif +} + +} + +// NOTE: [Tensor vs. TensorBase] +// +// Tensor, being the central data structure in PyTorch, gets used and +// its header included almost everywhere. Unfortunately this means +// every time an operator signature is updated or changed in +// native_functions.yaml, you (and every other PyTorch developer) need +// to recompile all of ATen and its dependencies. +// +// TensorBase aims to break up these header dependencies, and improve +// incremental build times for all PyTorch developers. TensorBase +// represents a reference counted handle to TensorImpl, exactly the +// same as Tensor. However, TensorBase doesn't have code generated +// methods in its API and thus no dependence on native_functions.yaml. +// +// Usage tips +// ---------- +// - You can `#define TORCH_ASSERT_NO_OPERATORS` at the top of a .cpp +// or .cu file to ensure it has no header dependencies on +// native_functions.yaml (direct or indirect). +// - Tensor inherits from TensorBase, so functions taking +// `const TensorBase &` are callable with Tensor as well. +// - TensorBase can be converted to Tensor with `Tensor(tensor_base)`, +// but this requires a reference-count bump. OptionalTensorRef, on +// the other hand, can materialize a `const Tensor &` without +// touching the reference-count. +class TORCH_API TensorBase { + public: + struct unsafe_borrow_t { explicit unsafe_borrow_t() = default; }; + + protected: + // Create a Tensor with a +0 reference count. Special care must be + // taken to avoid decrementing this reference count at destruction + // time. Intended to support MaybeOwnedTraits. + explicit TensorBase(unsafe_borrow_t, const TensorBase& rhs) + : impl_(c10::intrusive_ptr(rhs.impl_.get(), c10::raw::DontIncreaseRefcount{})) {} + friend MaybeOwnedTraits; + + public: + TensorBase() = default; + // This constructor should not be used by end users and is an implementation + // detail invoked by autogenerated code. + explicit TensorBase( + c10::intrusive_ptr tensor_impl) + : impl_(std::move(tensor_impl)) { + if (impl_.get() == nullptr) { + throw std::runtime_error("TensorImpl with nullptr is not supported"); + } + } + TensorBase(const TensorBase&) = default; + TensorBase(TensorBase&&) noexcept = default; + ~TensorBase() noexcept = default; + + public: + // Creates a new wrapper from TensorImpl. Intentionally a free method because + // it should be used with care. Checks necessary invariants + static TensorBase wrap_tensor_impl( + c10::intrusive_ptr tensor_impl) { + TensorBase r(std::move(tensor_impl)); + r.enforce_invariants(); + return r; + } + + int64_t dim() const { + return impl_->dim(); + } + int64_t storage_offset() const { + return impl_->storage_offset(); + } + + TensorBase contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const { + if (is_contiguous(memory_format)) { + return *this; + } else { + return __dispatch_contiguous(memory_format); + } + } + + /// Should be used if *this can reasonably be expected to be contiguous and + /// performance is important. + /// Compared to contiguous, it saves a reference count + /// increment/decrement if *this is already contiguous, at the cost + /// in all cases of an extra pointer of stack usage, an extra branch + /// to access, and an extra branch at destruction time. + c10::MaybeOwned expect_contiguous( + MemoryFormat memory_format=MemoryFormat::Contiguous) const &; + + // Use .contiguous() instead. Trying to borrow from a prvalue + // will only lead to trouble and dangling references. + c10::MaybeOwned expect_contiguous( + MemoryFormat memory_format=MemoryFormat::Contiguous) && = delete; + + const TensorBase& fill_(const c10::Scalar& scalar) const; + const TensorBase& zero_() const; + + TensorBase to(at::TensorOptions options={}, bool non_blocking=false, bool copy=false, std::optional memory_format=std::nullopt) const; + + bool is_complex() const { + return at::isComplexType(this->scalar_type()); + } + + bool is_floating_point() const { + return at::isFloatingType(this->scalar_type()); + } + + bool is_signed() const { + return at::isSignedType(this->scalar_type()); + } + + c10::SymInt sym_size(int64_t dim) const { + return impl_->sym_size(dim); + } + + c10::SymInt sym_stride(int64_t dim) const { + const auto sizes = this->sym_strides(); + const auto ndim = static_cast(sizes.size()); + // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping) + return sizes[c10::maybe_wrap_dim(dim, ndim, /*wrap_scalar=*/false)]; + + } + + int64_t size(int64_t dim) const { + return impl_->size(dim); + } + + int64_t stride(int64_t dim) const { + const auto strides = this->strides(); + const auto ndim = static_cast(strides.size()); + // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping) + return strides[c10::maybe_wrap_dim(dim, ndim, /*wrap_scalar=*/false)]; + } + + TensorImpl * unsafeGetTensorImpl() const { + return impl_.get(); + } + TensorImpl * unsafeReleaseTensorImpl() { + return impl_.release(); + } + const c10::intrusive_ptr& getIntrusivePtr() const { + return impl_; + } + + c10::intrusive_ptr unsafeReleaseIntrusivePtr() { + return std::move(impl_); + } + + bool defined() const { + return impl_; + } + + void reset() { + impl_.reset(); + } + +#if defined (_MSC_VER) + TensorBase& operator=(const TensorBase& x) & { + impl_ = x.impl_; + return *this; + }; + TensorBase& operator=(TensorBase&& x) & noexcept { + impl_ = std::move(x.impl_); + return *this; + } +#else + TensorBase& operator=(const TensorBase& x) & = default; + TensorBase& operator=(TensorBase&& x) & noexcept = default; +#endif + + // Ban assignment to rvalues, since at::Tensor (weirdly) performs a deep copy here + TensorBase& operator=(const TensorBase&) && = delete; + TensorBase& operator=(TensorBase&&) && noexcept = delete; + + bool is_same(const TensorBase& other) const noexcept { + return impl_ == other.impl_; + } + size_t use_count() const noexcept { + return impl_.use_count(); + } + size_t weak_use_count() const noexcept { + return impl_.weak_use_count(); + } + + std::string toString() const; + + IntArrayRef sizes() const { + return impl_->sizes(); + } + c10::SymIntArrayRef sym_sizes() const { + return impl_->sym_sizes(); + } + c10::SymIntArrayRef sym_strides() const { + return impl_->sym_strides(); + } + IntArrayRef strides() const { + return impl_->strides(); + } + // See impl::get_opt_names in ATen/NamedTensor.h for docs. + std::optional opt_names() const { + return impl::get_opt_names(unsafeGetTensorImpl()); + } + // See impl::get_names in ATen/NamedTensor.h for docs. + DimnameList names() const { + return impl::get_names(unsafeGetTensorImpl()); + } + int64_t ndimension() const { + return dim(); + } + + bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const { + return impl_->is_contiguous(memory_format); + } + + bool is_non_overlapping_and_dense() const { + return impl_->is_non_overlapping_and_dense(); + } + + at::MemoryFormat suggest_memory_format( + bool channels_last_strides_exact_match = false) const { + // Setting channels_last_strides_exact_match to true forces function to + // check 0,1 - sized dimension strides. + if (layout() == at::kStrided) { + if (impl_->is_strides_like_channels_last()) { + if (!channels_last_strides_exact_match || + get_channels_last_strides_2d(sizes()) == strides()) { + return at::MemoryFormat::ChannelsLast; + } + } + else if (impl_->is_strides_like_channels_last_3d()) { + if (!channels_last_strides_exact_match || + get_channels_last_strides_3d(sizes()) == strides()) { + return at::MemoryFormat::ChannelsLast3d; + } + } + } + return at::MemoryFormat::Contiguous; + } + + // Total bytes consumed by the "view" of elements of the array. Does not + // include size of metadata. The number reported here does not necessarily + // correspond to the true physical memory consumed by a tensor; instead, + // it reports the memory the tensor would take *if* it were contiguous. + // Defined to be numel() * itemsize() + size_t nbytes() const { + TORCH_CHECK(layout () != at::kSparse, + "nbytes is not defined for sparse tensors. If you want the size of the constituent " \ + "tensors, add the nbytes of the indices and values. If you want the size of the " \ + "equivalent dense tensor, multiply numel() by element_size()"); + return impl_->numel() * impl_->itemsize(); + } + + c10::SymInt sym_nbytes() const { + TORCH_CHECK(layout () != at::kSparse, + "nbytes is not defined for sparse tensors. If you want the size of the constituent " \ + "tensors, add the nbytes of the indices and values. If you want the size of the " \ + "equivalent dense tensor, multiply numel() by element_size()"); + return impl_->sym_numel() * impl_->itemsize(); + } + + int64_t numel() const { + return impl_->numel(); + } + + c10::SymInt sym_numel() const { + return impl_->sym_numel(); + } + + c10::SymInt sym_storage_offset() const { + return impl_->sym_storage_offset(); + } + + // Length of one array element in bytes. This is the traditional + // Numpy naming. + size_t itemsize() const { + return impl_->itemsize(); + } + + // Same as itemsize(). This is the PyTorch naming. + int64_t element_size() const { + return static_cast(impl_->itemsize()); + } + + DispatchKeySet key_set() const { + return impl_->key_set(); + } + ScalarType scalar_type() const { + return typeMetaToScalarType(impl_->dtype()); + } + bool has_storage() const { + return defined() && impl_->has_storage(); + } + const Storage& storage() const { + return impl_->storage(); + } + bool is_alias_of(const at::TensorBase& other) const{ + return impl_->storage().is_alias_of(other.storage()); + } + + // Move the storage backend to shm based + // to enable memory sharing across processes. + // + // NB1: the ideal behavior of this API still requires further discussion + // but for now we are inclined to keep it consistent with existing THP behavior + // https://github.com/pytorch/pytorch/blob/4dca9bde0552afc67b5b74f4a0696fe6055709c4/torch/storage.py#L196-L212 + // so we don't assert on anything here and rely on caller knowing + // what it's doing. + // + // NB2: this currently provides Linux fd based shm support only + // to simplify the storage lifetime management logic in ATen + // and similarly for now we are not adding support for file system based + // shm support like in THP due to additional GC manager support needed + // to prevent leaks. + // As such, calling this from non supported systems (e.g. Windows) would fail. + void share_memory_() { + at::share_memory_(*this); + } + + inline bool _is_zerotensor() const { + return impl_->_is_zerotensor(); + } + + inline void _set_zero(bool zero) const { + impl_->_set_zero(zero); + } + + inline bool is_conj() const { + return impl_->is_conj(); + } + + // sets the conjugate bit of a tensor. + // NOTE: Conjugate bit is supposed to be a read-only field. Only change this, if you are sure + // that's what you want. Changing this might lead to incorrect behavior since conjugation is + // a lazy operation and we rely on this bit to determine if a conjugation needs to be materialized. + inline void _set_conj(bool conjugate) const { + impl_->_set_conj(conjugate); + } + + inline bool is_neg() const { + return impl_->is_neg(); + } + + // sets the negative bit of a tensor. + // NOTE: Negative bit is supposed to be a read-only field. Only change this, if you are sure + // that's what you want. Changing this might lead to incorrect behavior since we rely on this + // bit to determine if a negation needs to be materialized. + inline void _set_neg(bool negative) const { + impl_->_set_neg(negative); + } + + /// Returns a `Tensor`'s layout. + Layout layout() const { + return impl_->layout(); + } + + /// Returns a `Tensor`'s dtype (`TypeMeta`). + caffe2::TypeMeta dtype() const { + return impl_->dtype(); + } + + /// Returns a `Tensor`'s device. + inline Device device() const { + return impl_->device(); + } + + /// Returns a `Tensor`'s device index. + DeviceIndex get_device() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->get_device(); + } + + /// Returns if a `Tensor` has CPU backend. + bool is_cpu() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_cpu(); + } + + /// Returns if a `Tensor` has CUDA backend. + bool is_cuda() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_cuda(); + } + + /// Returns if a `Tensor` has IPU backend. + bool is_ipu() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_ipu(); + } + + /// Returns if a `Tensor` has XPU backend. + bool is_xpu() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_xpu(); + } + + /// Returns if a `Tensor` has XLA backend. + bool is_xla() const { + return impl_->is_xla(); + } + + /// Returns if a `Tensor` has MTIA backend. + bool is_mtia() const { + return impl_->is_mtia(); + } + + /// Returns if a `Tensor` has HPU backend. + bool is_hpu() const { + return impl_->is_hpu(); + } + + /// Returns if a `Tensor` has Lazy backend. + bool is_lazy() const { + return impl_->is_lazy(); + } + + /// Returns if a `Tensor` has HIP backend. + bool is_hip() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_hip(); + } + + /// Returns if a `Tensor` has VE backend. + bool is_ve() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_ve(); + } + + /// Returns if a `Tensor` has PrivateUse1 backend. + bool is_privateuseone() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_privateuseone(); + } + + /// Returns if a `Tensor` has sparse backend. + bool is_sparse() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_sparse(); + } + + /// Returns is a `Tensor` has a sparse CSR backend. + bool is_sparse_csr() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_sparse_csr(); + } + + /// Returns if a `Tensor` is mkldnn tensor. + bool is_mkldnn() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_mkldnn(); + } + + /// Returns if a `Tensor` is mps tensor. + bool is_mps() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_mps(); + } + + /// Returns if a `Tensor` is maia tensor. + bool is_maia() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_maia(); + } + + /// Returns if a `Tensor` is vulkan tensor. + bool is_vulkan() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_vulkan(); + } + + /// Returns if a `Tensor` is metal tensor. + bool is_metal() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_metal(); + } + + /// Returns if a `Tensor` has quantized backend. + bool is_quantized() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_quantized(); + } + + /// Returns if a `Tensor` is a meta tensor. Meta tensors can + /// also have other designations. + bool is_meta() const { + return impl_->is_meta(); + } + + /// Returns if a `Tensor` is an inference tensor. + bool is_inference() const { + return impl_->is_inference(); + } + + // Returns if a `Tensor` is a NestedTensor. + bool is_nested() const { + return impl_->is_nested(); + } + + /// If a tensor is a quantized tensor, returns its quantizer + /// TODO: it's not in native_functions.yaml yet as it's not exposed to python + QuantizerPtr quantizer() const; + + /// Returns if a `Tensor` has any dimension names + bool has_names() const { + // If a user is using unnamed tensors, then we can short-circuit right here. + // Otherwise, impl::has_names attempts to retrieve names. + if (!impl_->has_named_tensor_meta()) { + return false; + } + return impl::has_names(unsafeGetTensorImpl()); + } + + /// Returns a `Tensor`'s dimension names data structure + const NamedTensorMeta* get_named_tensor_meta() const { + return static_cast(impl_->named_tensor_meta()); + } + + NamedTensorMeta* get_named_tensor_meta() { + return static_cast(impl_->named_tensor_meta()); + } + + /// Returns the `TensorOptions` corresponding to this `Tensor`. Defined in + /// TensorOptions.h. + TensorOptions options() const { + return TensorOptions().dtype(dtype()) + .device(device()) + .layout(layout()); + } + + const void* const_data_ptr() const { + return this->unsafeGetTensorImpl()->data(); + } + + void* mutable_data_ptr() const { + return this->unsafeGetTensorImpl()->mutable_data(); + } + + // TODO(#97856) Make this return a const pointer. This currently + // returns a non-const pointer because of the large + // number of clients that we still want to audit before + // migrating to mutable_data_ptr(). + void* data_ptr() const { + return mutable_data_ptr(); + } + + template , int> = 0> + const T* const_data_ptr() const; + + template , int> = 0> + const std::remove_const_t* const_data_ptr() const; + + template + T* mutable_data_ptr() const; + + // Legacy interface during the migration to indicate that a callsite + // has not been audited for mutability. + // + // Do not add new uses of this, use const_data_ptr() if possible, + // mutable_data_ptr() otherwise. + // + // TODO(#97856) Make this return a const pointer. This is currently + // const because of the vast number of clients that + // rely on this. + template + T* data_ptr() const; + + // Purposely not defined here to avoid inlining + void print() const; + + // Return a `TensorAccessor` for CPU `Tensor`s. You have to specify scalar type and + // dimension. + template + TensorAccessor accessor() const& { + static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr()"); + TORCH_CHECK(dim() == N, "TensorAccessor expected ", N, " dims but tensor has ", dim()); + T* ptr = nullptr; + if constexpr (std::is_const_v) { + ptr = const_data_ptr(); + } else { + ptr = mutable_data_ptr(); + } + return TensorAccessor(ptr,sizes().data(),strides().data()); + } + template + TensorAccessor accessor() && = delete; + + // Return a `GenericPackedTensorAccessor` for CUDA `Tensor`s. You have to specify scalar type and + // dimension. You can optionally specify RestrictPtrTraits as a template parameter to + // cast the data pointer to a __restrict__ pointer. + // In order to use this, your CUDA kernel has to take a corresponding GenericPackedTensorAccessor + // as an argument. + template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> + GenericPackedTensorAccessor generic_packed_accessor() const& { + static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr()"); + TORCH_CHECK(dim() == N, "TensorAccessor expected ", N, " dims but tensor has ", dim()); + T* ptr = nullptr; + if constexpr (std::is_const_v) { + ptr = const_data_ptr(); + } else { + ptr = mutable_data_ptr(); + } + return GenericPackedTensorAccessor(static_cast::PtrType>(ptr),sizes().data(),strides().data()); + } + template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> + GenericPackedTensorAccessor generic_packed_accessor() && = delete; + + template class PtrTraits = DefaultPtrTraits> + PackedTensorAccessor32 packed_accessor32() const& { + TORCH_CHECK( + impl_->numel() <= + static_cast(std::numeric_limits::max()), + "numel needs to be smaller than int32_t max; otherwise, please use packed_accessor64"); + return generic_packed_accessor(); + } + template class PtrTraits = DefaultPtrTraits> + PackedTensorAccessor32 packed_accessor32() && = delete; + + template class PtrTraits = DefaultPtrTraits> + PackedTensorAccessor64 packed_accessor64() const& { + return generic_packed_accessor(); + } + template class PtrTraits = DefaultPtrTraits> + PackedTensorAccessor64 packed_accessor64() && = delete; + + // ~~~~~ Autograd API ~~~~~ + + /// \fn bool is_leaf() const; + /// + /// All Tensors that have `requires_grad()` which is ``false`` will be leaf Tensors by convention. + /// + /// For Tensors that have `requires_grad()` which is ``true``, they will be leaf Tensors if they were + /// created by the user. This means that they are not the result of an operation and so + /// `grad_fn()` is `nullptr`. + /// + /// Only leaf Tensors will have their `grad()` populated during a call to `backward()`. + /// To get `grad()` populated for non-leaf Tensors, you can use `retain_grad()`. + /// + /// Example: + /// @code + /// auto a = torch::rand(10, torch::requires_grad()); + /// std::cout << a.is_leaf() << std::endl; // prints `true` + /// + /// auto b = torch::rand(10, torch::requires_grad()).to(torch::kCUDA); + /// std::cout << b.is_leaf() << std::endl; // prints `false` + /// // b was created by the operation that cast a cpu Tensor into a cuda Tensor + /// + /// auto c = torch::rand(10, torch::requires_grad()) + 2; + /// std::cout << c.is_leaf() << std::endl; // prints `false` + /// // c was created by the addition operation + /// + /// auto d = torch::rand(10).cuda(); + /// std::cout << d.is_leaf() << std::endl; // prints `true` + /// // d does not require gradients and so has no operation creating it (that is tracked by the autograd engine) + /// + /// auto e = torch::rand(10).cuda().requires_grad_(); + /// std::cout << e.is_leaf() << std::endl; // prints `true` + /// // e requires gradients and has no operations creating it + /// + /// auto f = torch::rand(10, torch::device(torch::kCUDA).requires_grad(true)); + /// std::cout << f.is_leaf() << std::endl; // prints `true` + /// // f requires grad, has no operation creating it + /// @endcode + + /// \fn void backward(const Tensor & gradient={}, std::optional retain_graph=std::nullopt, bool create_graph=false, std::optional inputs=std::nullopt) const; + /// + /// Computes the gradient of current tensor with respect to graph leaves. + /// + /// The graph is differentiated using the chain rule. If the tensor is + /// non-scalar (i.e. its data has more than one element) and requires + /// gradient, the function additionally requires specifying ``gradient``. + /// It should be a tensor of matching type and location, that contains + /// the gradient of the differentiated function w.r.t. this Tensor. + /// + /// This function accumulates gradients in the leaves - you might need to + /// zero them before calling it. + /// + /// \param gradient Gradient w.r.t. the + /// tensor. If it is a tensor, it will be automatically converted + /// to a Tensor that does not require grad unless ``create_graph`` is True. + /// None values can be specified for scalar Tensors or ones that + /// don't require grad. If a None value would be acceptable then + /// this argument is optional. + /// \param retain_graph If ``false``, the graph used to compute + /// the grads will be freed. Note that in nearly all cases setting + /// this option to True is not needed and often can be worked around + /// in a much more efficient way. Defaults to the value of + /// ``create_graph``. + /// \param create_graph If ``true``, graph of the derivative will + /// be constructed, allowing to compute higher order derivative + /// products. Defaults to ``false``. + /// \param inputs Inputs w.r.t. which the gradient will be accumulated into + /// ``at::Tensor::grad``. All other Tensors will be ignored. If not + /// provided, the gradient is accumulated into all the leaf Tensors + /// that were used to compute the current tensor. + /// When inputs are provided and a given input is not a leaf, + /// the current implementation will call its grad_fn (even though it is not strictly needed to get this gradients). + /// It is an implementation detail on which the user should not rely. + /// See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details. + + /// \fn Tensor detach() const; + /// + /// Returns a new Tensor, detached from the current graph. + /// The result will never require gradient. + + /// \fn Tensor & detach_() const; + /// + /// Detaches the Tensor from the graph that created it, making it a leaf. + /// Views cannot be detached in-place. + + /// \fn void retain_grad() const; + /// + /// Enables this Tensor to have their :attr:`grad` populated during + /// :func:`backward`. This is a no-op for leaf tensors. + + /// \fn bool retains_grad() const; + /// + /// Is ``true`` if this Tensor is non-leaf and its :attr:`grad` is enabled to be + /// populated during :func:`backward`, ``false`` otherwise. + + const TensorBase& set_requires_grad(bool requires_grad) const { + impl_->set_requires_grad(requires_grad); + return *this; + } + bool requires_grad() const { + return impl_->requires_grad(); + } + + // The Forward AD API functions below are low level and are not to be used by end + // users who should use the API provided in torch/csrc/autograd.h + + /// This function returns the forward gradient for this Tensor at the given level. + const Tensor& _fw_grad(uint64_t level) const { + return impl_->_fw_grad(level, *this); + } + + /// This function can be used to set the value of the forward grad. + /// Note that the given new_grad might not be used directly if it has different + /// metadata (size/stride/storage offset) compared to this Tensor. In that case, + /// new_grad content will be copied into a new Tensor + void _set_fw_grad(const TensorBase& new_grad, uint64_t level, bool is_inplace_op) const { + impl_->_set_fw_grad(new_grad, *this, level, is_inplace_op); + } + + /// NOTE: This is similar to the legacy `.data()` function on `Variable`, and is intended + /// to be used from functions that need to access the `Variable`'s equivalent `Tensor` + /// (i.e. `Tensor` that shares the same storage and tensor metadata with the `Variable`). + /// + /// One notable difference with the legacy `.data()` function is that changes to the + /// returned `Tensor`'s tensor metadata (e.g. sizes / strides / storage / storage_offset) + /// will not update the original `Variable`, due to the fact that this function + /// shallow-copies the `Variable`'s underlying TensorImpl. + at::TensorBase tensor_data() const; + + /// NOTE: `var.variable_data()` in C++ has the same semantics as `tensor.data` + /// in Python, which create a new `Variable` that shares the same storage and + /// tensor metadata with the original `Variable`, but with a completely new + /// autograd history. + /// + /// NOTE: If we change the tensor metadata (e.g. sizes / strides / + /// storage / storage_offset) of a variable created from `var.variable_data()`, those + /// changes will not update the original variable `var`. In `.variable_data()`, we set + /// `allow_tensor_metadata_change_` to false to make such changes explicitly illegal, + /// in order to prevent users from changing metadata of `var.variable_data()` + /// and expecting the original variable `var` to also be updated. + at::TensorBase variable_data() const; + + // Gradient Node and Edges + //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + /// Gets the gradient function of the `Variable`. If this is a leaf variable, + /// the pointer returned will be null. + /// + /// For View Variables: + /// Gets the up-to-date grad_fn. If the shared data or base was modified, we + /// re-create the grad_fn to express the up-to-date view relationship between + /// this and the base Variable. + const std::shared_ptr& grad_fn() const; + + // Hooks + //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + template + using hook_return_void_t = std::enable_if_t>, unsigned>; + template + using hook_return_var_t = std::enable_if_t, TensorBase>, unsigned>; + + /// Registers a backward hook. + /// + /// The hook will be called every time a gradient with respect to the Tensor is computed. + /// The hook should have one of the following signature: + /// ``` + /// hook(TensorBase grad) -> TensorBase + /// ``` + /// ``` + /// hook(TensorBase grad) -> void + /// ``` + /// The hook should not modify its argument, but it can optionally return a new gradient + /// which will be used in place of `grad`. + /// + /// This function returns the index of the hook in the list which can be used to remove hook. + /// + /// Example: + /// @code + /// auto v = torch::tensor({0., 0., 0.}, torch::requires_grad()); + /// auto h = v.register_hook([](torch::Tensor grad){ return grad * 2; }); // double the gradient + /// v.backward(torch::tensor({1., 2., 3.})); + /// // This prints: + /// // ``` + /// // 2 + /// // 4 + /// // 6 + /// // [ CPUFloatType{3} ] + /// // ``` + /// std::cout << v.grad() << std::endl; + /// v.remove_hook(h); // removes the hook + /// @endcode + template + hook_return_void_t register_hook(T&& hook) const; + template + hook_return_var_t register_hook(T&& hook) const; + +protected: + unsigned _register_hook(std::function hook) const; + +public: + + /// Remove hook at given position + void remove_hook(unsigned pos) const; + + // Variable methods + //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + bool is_leaf() const; + + int64_t output_nr() const; + + void set_data(const TensorBase & new_data) const; + + TensorBase data() const; + + int64_t _version() const; + + void retain_grad() const; + + bool retains_grad() const; + + const TensorBase& requires_grad_(bool _requires_grad=true) const; + + // View Variables + //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + /// Returns true if this `Variable` is a view of another `Variable`. + bool is_view() const; + + /// Returns the `Variable` that this `Variable` is a view of. If this + /// `Variable` is not a view, throw a `std::runtime_error`. + const TensorBase& _base() const; + + // Miscellaneous + //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + const std::string& name() const; + +protected: + void enforce_invariants(); + c10::intrusive_ptr impl_; + +private: + TensorBase __dispatch_contiguous(c10::MemoryFormat) const; +}; + +inline DeviceIndex get_device(const TensorBase& self) { + return self.get_device(); +} + +template +auto TensorBase::register_hook(T&& hook) const -> TensorBase::hook_return_void_t { + // Return the grad argument in case of a hook with void return type to have an + // std::function with Tensor return type + static_assert(std::is_same_v, + "Expected hook to return void"); + return _register_hook([fn=std::forward(hook)](const TensorBase& grad) { + fn(grad); + return TensorBase(); + }); +} + +template +auto TensorBase::register_hook(T&& hook) const -> TensorBase::hook_return_var_t { + return _register_hook(std::forward(hook)); +} + +namespace detail { +// Helper creator for Tensor class which doesn't requires the users to pass +// in an intrusive_ptr instead it just converts the argument passed to +// requested intrusive_ptr type. +template +TensorBase make_tensor_base(Args&&... args) { + return TensorBase(c10::make_intrusive(std::forward(args)...)); +} + +} // namespace detail + +inline DispatchKey legacyExtractDispatchKey(const TensorBase& t) { + return legacyExtractDispatchKey(t.key_set()); +} + +} // namespace at + +namespace c10 { +template <> +struct MaybeOwnedTraits { + using owned_type = at::TensorBase; + using borrow_type = at::TensorBase; + + static borrow_type createBorrow(const owned_type& from) { + // NOTE: this can be implemented without the special + // unsafe_borrow_t Tensor constructor as + // + // return borrow_type(c10::intrusive_ptr::reclaim(from.unsafeGetTensorImpl())); + // + // but that hurts inlining due to the nullptr check in the + // Tensor(c10::intrusive_ptr<...>) constructor. We already know + // that from.impl_ isn't null because from is a valid Tensor, so + // we needn't do the check again. (using __builtin_assume can + // avoid this, but wouldn't be portable to MSVC.) + return borrow_type(borrow_type::unsafe_borrow_t{}, from); + } + + static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) { + lhs.unsafeReleaseTensorImpl(); + // See above note: this can be implemented with public API + // similarly to createBorrow(), but that would hurt inlining. + lhs = borrow_type(borrow_type::unsafe_borrow_t{}, rhs); + } + + static void destroyBorrow(borrow_type& toDestroy) { + toDestroy.unsafeReleaseTensorImpl(); // "leak" it, but it was already +0. + } + + static const owned_type& referenceFromBorrow(const borrow_type& borrow) { + return borrow; + } + + static const owned_type* pointerFromBorrow(const borrow_type& borrow) { + return &borrow; + } + + static bool debugBorrowIsValid(const borrow_type& /*borrow*/) { + return true; + } +}; + +template <> +struct ExclusivelyOwnedTraits : public c10::ExclusivelyOwnedTensorTraits {}; +} // namespace c10 + +namespace at { + +inline c10::MaybeOwned borrow_from_optional_tensor( + const std::optional& opt) { + return opt.has_value() + ? c10::MaybeOwned::borrowed(*opt) + : c10::MaybeOwned::owned(std::in_place); +} + +inline c10::MaybeOwned TensorBase::expect_contiguous(MemoryFormat memory_format) const & { + if (is_contiguous(memory_format)) { + return c10::MaybeOwned::borrowed(*this); + } else { + return c10::MaybeOwned::owned(__dispatch_contiguous(memory_format)); + } +} + +namespace symint { + +template +using enable_if_symint = std::enable_if_t>; +template +using enable_if_int = std::enable_if_t>; + +template > +c10::SymIntArrayRef sizes(const TensorBase& t) { return t.sym_sizes(); } +template > +IntArrayRef sizes(const TensorBase& t) { return t.sizes(); } + +template > +c10::SymInt size(const TensorBase& t, int64_t dim) { return t.sym_size(dim); } +template > +int64_t size(const TensorBase& t, int64_t dim) { return t.size(dim); } + +template > +c10::SymIntArrayRef strides(const TensorBase& t) { return t.sym_strides(); } +template > +IntArrayRef strides(const TensorBase& t) { return t.strides(); } + +template > +c10::SymInt numel(const TensorBase& t) { return t.sym_numel(); } +template > +int64_t numel(const TensorBase& t) { return t.numel(); } + +} // namespace symint + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/TensorBody.h b/phivenv/Lib/site-packages/torch/include/ATen/core/TensorBody.h new file mode 100644 index 0000000000000000000000000000000000000000..335db1a28bdfb29d8788fa1d940d01f77789be0c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/TensorBody.h @@ -0,0 +1,5761 @@ +#pragma once + +#ifdef TORCH_ASSERT_NO_OPERATORS +#error This change adds a dependency on native_functions.yaml, \ + meaning the file will need to be re-compiled every time an operator \ + is changed or added. Consider if your change would be better placed in \ + another file, or if a more specific header might achieve the same goal. \ + See NOTE: [Tensor vs. TensorBase] +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +#include + +namespace c10{ +template class List; +template class IListRef; +} +namespace at { +struct Generator; +struct Type; +class DeprecatedTypeProperties; +class Tensor; +} // namespace at +namespace at { +namespace indexing { +struct TensorIndex; +} // namespace indexing +} // namespace at + +namespace torch { namespace autograd { + +struct Node; + +}} // namespace torch::autograd + +namespace at { + +class OptionalTensorRef; +class TensorRef; +class Tensor; +using TensorList = ArrayRef; +using ITensorList = c10::IListRef; + +using Stream = c10::Stream; + +// Tensor is a "generic" object holding a pointer to the underlying TensorImpl object, which +// has an embedded reference count. In this way, Tensor is similar to boost::intrusive_ptr. +// +// For example: +// +// void func(Tensor a) { +// Tensor b = a; +// ... +// } +// +// In this example, when we say Tensor b = a, we are creating a new object that points to the +// same underlying TensorImpl, and bumps its reference count. When b goes out of scope, the +// destructor decrements the reference count by calling release() on the TensorImpl it points to. +// The existing constructors, operator overloads, etc. take care to implement the correct semantics. +// +// Note that Tensor can also be NULL, i.e. it is not associated with any underlying TensorImpl, and +// special care must be taken to handle this. +class TORCH_API Tensor: public TensorBase { + protected: + // Create a Tensor with a +0 reference count. Special care must be + // taken to avoid decrementing this reference count at destruction + // time. Intended to support MaybeOwnedTraits. + explicit Tensor(unsafe_borrow_t, const TensorBase& rhs): TensorBase(unsafe_borrow_t{}, rhs) {} + friend MaybeOwnedTraits; + friend OptionalTensorRef; + friend TensorRef; + + public: + Tensor() = default; + // This constructor should not be used by end users and is an implementation + // detail invoked by autogenerated code. + explicit Tensor( + c10::intrusive_ptr tensor_impl) + : TensorBase(std::move(tensor_impl)) {} + Tensor(const Tensor &tensor) = default; + Tensor(Tensor &&tensor) = default; + + // Implicitly move-constructible from TensorBase, but must be explicit to increase refcount + explicit Tensor(const TensorBase &base): TensorBase(base) {} + /*implicit*/ Tensor(TensorBase &&base): TensorBase(std::move(base)) {} + + // Creates a new wrapper from TensorImpl. Intentionally a free method because + // it should be used with care. Checks necessary invariants + static Tensor wrap_tensor_impl( + c10::intrusive_ptr tensor_impl) { + return TensorBase::wrap_tensor_impl(std::move(tensor_impl)); + } + + Tensor contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const { + return TensorBase::contiguous(memory_format); + } + + Tensor conj() const { + if (!this->is_complex()) { + return *this; + } + + switch (this->layout()) { + case at::kSparse: + case at::kSparseCsr: + case at::kSparseCsc: + case at::kSparseBsr: + case at::kSparseBsc: + return this->conj_physical(); + default: + return this->_conj(); + } + } + + // Aliased by Dimname overloads, so need explicit using + using TensorBase::size; + using TensorBase::sym_size; + using TensorBase::stride; + + /// Should be used if *this can reasonably be expected to be contiguous and + /// performance is important. + /// Compared to contiguous, it saves a reference count + /// increment/decrement if *this is already contiguous, at the cost + /// in all cases of an extra pointer of stack usage, an extra branch + /// to access, and an extra branch at destruction time. + c10::MaybeOwned expect_contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const &; + + // Use .contiguous() instead. Trying to borrow from a prvalue Tensor + // will only lead to trouble and dangling references. + c10::MaybeOwned expect_contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) && = delete; + + // The following overloads are very intruiging. Consider the following + // program: + // + // x[1] = 3; + // + // We would expect that the first entry of x is written to 3. But how can we + // actually achieve this? x[1] evaluates to a tensor... + // + // The answer is, using a ref-qualifier. x[1] is an rvalue, which cannot be + // (profitably) assigned to in the traditional sense, so we overload + // assignment to mean, "Actually, copy 3 into the tensor data." This is done + // with an rvalue-reference ref-qualified overload (the methods with && at the + // end of their type.) + // + // There's one more fly in the ointment: We also want + // + // Tensor x = y; + // + // to work, and we want it NOT to copy. So we need a traditional operator= + // overload. But we MUST specify a mutable lvalue ref-qualifier, to + // disambiguate the traditional overload from the rvalue-reference + // ref-qualified overload. Otherwise, it will be ambiguous, because + // a non ref-qualified method is eligible for all situations. + + // Unfortunately, we have to write these constructors out manually + // to work around an MSVC bug: + // error C2580: 'at::Tensor &at::Tensor::operator =(const at::Tensor &) &': + // multiple versions of a defaulted special member functions are not allowed + // Tensor& operator=(const Tensor&) & = default; + // Tensor& operator=(Tensor&&) & = default; + + // Also MSVC will wrongly issue the following warning with the aforementioned fix + // warning C4522: 'at::Tensor': multiple assignment operators specified + // Let's just skip the warning. + // + // TODO: temporarily disabled + + Tensor& operator=(const TensorBase& x) & noexcept { + impl_ = x.getIntrusivePtr(); + return *this; + } + Tensor& operator=(TensorBase&& x) & noexcept { + impl_ = x.unsafeReleaseIntrusivePtr(); + return *this; + } + + Tensor& operator=(const Tensor &x) & noexcept { + return operator=(static_cast(x)); + } + Tensor& operator=(Tensor &&x) & noexcept { + return operator=(static_cast(x)); + } + + Tensor& operator=(const Scalar &v) && { + return fill_(v); + } + Tensor& operator=(const Tensor &rhs) && { + return copy_(rhs); + } + Tensor& operator=(Tensor&& rhs) && { + return copy_(rhs); + } + + C10_DEPRECATED_MESSAGE("Tensor.type() is deprecated. Instead use Tensor.options(), which in many cases (e.g. in a constructor) is a drop-in replacement. If you were using data from type(), that is now available from Tensor itself, so instead of tensor.type().scalar_type(), use tensor.scalar_type() instead and instead of tensor.type().backend() use tensor.device().") + DeprecatedTypeProperties & type() const { + return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( + dispatchKeyToBackend(legacyExtractDispatchKey(key_set())), + scalar_type()); + } + + Tensor toType(ScalarType t) const { + return to(options().dtype(t), /*non_blocking*/ false, /*copy*/ false); + } + + // TODO: Deprecate me + Tensor toBackend(Backend b) const { + return to(options().device(backendToDeviceType(b)).layout(layout_from_backend(b)), /*non_blocking*/ false, /*copy*/ false); + } + + C10_DEPRECATED_MESSAGE("Tensor.is_variable() is deprecated; everything is a variable now. (If you want to assert that variable has been appropriately handled already, use at::impl::variable_excluded_from_dispatch())") + bool is_variable() const noexcept { + return !at::impl::variable_excluded_from_dispatch(); + } + + template + C10_DEPRECATED_MESSAGE("Tensor.data() is deprecated. Please use Tensor.data_ptr() instead.") + T * data() const { + return data_ptr(); + } + + template + T item() const; + + template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> + C10_DEPRECATED_MESSAGE("packed_accessor is deprecated, use packed_accessor32 or packed_accessor64 instead") + GenericPackedTensorAccessor packed_accessor() const & { + return generic_packed_accessor(); + } + template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> + C10_DEPRECATED_MESSAGE("packed_accessor is deprecated, use packed_accessor32 or packed_accessor64 instead") + GenericPackedTensorAccessor packed_accessor() && = delete; + + Tensor operator~() const { + return bitwise_not(); + } + Tensor operator-() const { + return neg(); + } + Tensor& operator+=(const Tensor & other) { + return add_(other); + } + Tensor& operator+=(const Scalar & other) { + return add_(other); + } + Tensor& operator-=(const Tensor & other) { + return sub_(other); + } + Tensor& operator-=(const Scalar & other) { + return sub_(other); + } + Tensor& operator*=(const Tensor & other) { + return mul_(other); + } + Tensor& operator*=(const Scalar & other) { + return mul_(other); + } + Tensor& operator/=(const Tensor & other) { + return div_(other); + } + Tensor& operator/=(const Scalar & other) { + return div_(other); + } + Tensor& operator&=(const Tensor & other) { + return bitwise_and_(other); + } + Tensor& operator|=(const Tensor & other) { + return bitwise_or_(other); + } + Tensor& operator^=(const Tensor & other) { + return bitwise_xor_(other); + } + Tensor operator[](const Scalar & index) const { + if (!index.isIntegral(false)) { + TORCH_CHECK_INDEX(false, "Can only index tensors with integral scalars"); + } + return this->operator[](index.toLong()); + } + Tensor operator[](const Tensor & index) const { + // These properties are checked in the Scalar constructor, but we already + // check them here to provide more useful diagnostics for the user. + if (!index.defined()) { + TORCH_CHECK_INDEX(false, "Can only index with tensors that are defined"); + } + if (index.dim() != 0) { + TORCH_CHECK_INDEX(false, + "Can only index with tensors that are scalars (zero-dim)"); + } + // The Scalar(Tensor) constructor is explicit, so we need to call it. + return this->operator[](index.item()); + } + Tensor operator[](int64_t index) const { + return select(0, index); + } + + Tensor index(ArrayRef indices) const; + Tensor index(std::initializer_list indices) const; + + Tensor & index_put_(ArrayRef indices, Tensor const & rhs); + Tensor & index_put_(ArrayRef indices, const Scalar& v); + Tensor & index_put_(std::initializer_list indices, Tensor const & rhs); + Tensor & index_put_(std::initializer_list indices, const Scalar& v); + + Tensor cpu() const { + return to(options().device(c10::DeviceType::CPU), /*non_blocking*/ false, /*copy*/ false); + } + + // TODO: The Python version also accepts arguments + Tensor cuda() const { + return to(options().device(c10::DeviceType::CUDA), /*non_blocking*/ false, /*copy*/ false); + } + + Tensor hip() const { + return to(options().device(c10::DeviceType::HIP), /*non_blocking*/ false, /*copy*/ false); + } + + Tensor ve() const { + return to(options().device(c10::DeviceType::VE), /*non_blocking*/ false, /*copy*/ false); + } + + Tensor vulkan() const { + return to(options().device(c10::DeviceType::Vulkan), /*non_blocking*/ false, /*copy*/ false); + } + + Tensor metal() const { + return to(options().device(c10::DeviceType::Metal), /*non_blocking*/ false, /*copy*/ false); + } + + Tensor meta() const { + return to(options().device(c10::DeviceType::Meta), /*non_blocking*/ false, /*copy*/ false); + } + + // ~~~~~ Autograd API ~~~~~ + + /// \fn bool is_leaf() const; + /// + /// All Tensors that have `requires_grad()` which is ``false`` will be leaf Tensors by convention. + /// + /// For Tensors that have `requires_grad()` which is ``true``, they will be leaf Tensors if they were + /// created by the user. This means that they are not the result of an operation and so + /// `grad_fn()` is `nullptr`. + /// + /// Only leaf Tensors will have their `grad()` populated during a call to `backward()`. + /// To get `grad()` populated for non-leaf Tensors, you can use `retain_grad()`. + /// + /// Example: + /// @code + /// auto a = torch::rand(10, torch::requires_grad()); + /// std::cout << a.is_leaf() << std::endl; // prints `true` + /// + /// auto b = torch::rand(10, torch::requires_grad()).to(torch::kCUDA); + /// std::cout << b.is_leaf() << std::endl; // prints `false` + /// // b was created by the operation that cast a cpu Tensor into a cuda Tensor + /// + /// auto c = torch::rand(10, torch::requires_grad()) + 2; + /// std::cout << c.is_leaf() << std::endl; // prints `false` + /// // c was created by the addition operation + /// + /// auto d = torch::rand(10).cuda(); + /// std::cout << d.is_leaf() << std::endl; // prints `true` + /// // d does not require gradients and so has no operation creating it (that is tracked by the autograd engine) + /// + /// auto e = torch::rand(10).cuda().requires_grad_(); + /// std::cout << e.is_leaf() << std::endl; // prints `true` + /// // e requires gradients and has no operations creating it + /// + /// auto f = torch::rand(10, torch::device(torch::kCUDA).requires_grad(true)); + /// std::cout << f.is_leaf() << std::endl; // prints `true` + /// // f requires grad, has no operation creating it + /// @endcode + + /// \fn void backward(const Tensor & gradient={}, std::optional retain_graph=std::nullopt, bool create_graph=false, std::optional inputs=std::nullopt) const; + /// + /// Computes the gradient of current tensor with respect to graph leaves. + /// + /// The graph is differentiated using the chain rule. If the tensor is + /// non-scalar (i.e. its data has more than one element) and requires + /// gradient, the function additionally requires specifying ``gradient``. + /// It should be a tensor of matching type and location, that contains + /// the gradient of the differentiated function w.r.t. this Tensor. + /// + /// This function accumulates gradients in the leaves - you might need to + /// zero them before calling it. + /// + /// \param gradient Gradient w.r.t. the + /// tensor. If it is a tensor, it will be automatically converted + /// to a Tensor that does not require grad unless ``create_graph`` is True. + /// None values can be specified for scalar Tensors or ones that + /// don't require grad. If a None value would be acceptable then + /// this argument is optional. + /// \param retain_graph If ``false``, the graph used to compute + /// the grads will be freed. Note that in nearly all cases setting + /// this option to True is not needed and often can be worked around + /// in a much more efficient way. Defaults to the value of + /// ``create_graph``. + /// \param create_graph If ``true``, graph of the derivative will + /// be constructed, allowing to compute higher order derivative + /// products. Defaults to ``false``. + /// \param inputs Inputs w.r.t. which the gradient will be accumulated into + /// ``at::Tensor::grad``. All other Tensors will be ignored. If not + /// provided, the gradient is accumulated into all the leaf Tensors + /// that were used to compute the current tensor. + /// When inputs are provided and a given input is not a leaf, + /// the current implementation will call its grad_fn (even though it is not strictly needed to get this gradients). + /// It is an implementation detail on which the user should not rely. + /// See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details. + void backward(const Tensor & gradient={}, std::optional retain_graph=std::nullopt, bool create_graph=false, std::optional inputs=std::nullopt) const { + // NB: Adding this wrapper to _backward here because we'd like our + // 'backwards' api to accept the 'inputs' argument optionally. Since code gen + // currently does not support optional of TensorList our approach is to replace + // backward in native_functions.yaml with _backward and call it here instead. + if (inputs.has_value()) { + TORCH_CHECK(inputs.value().size() > 0, "'inputs' argument to backward cannot be empty") + this->_backward(inputs.value(), gradient, retain_graph, create_graph); + } else { + this->_backward({}, gradient, retain_graph, create_graph); + } + } + + /// \fn Tensor detach() const; + /// + /// Returns a new Tensor, detached from the current graph. + /// The result will never require gradient. + + /// \fn Tensor & detach_() const; + /// + /// Detaches the Tensor from the graph that created it, making it a leaf. + /// Views cannot be detached in-place. + + /// \fn void retain_grad() const; + /// + /// Enables this Tensor to have their :attr:`grad` populated during + /// :func:`backward`. This is a no-op for leaf tensors. + + /// \fn bool retains_grad() const; + /// + /// Is ``true`` if this Tensor is non-leaf and its :attr:`grad` is enabled to be + /// populated during :func:`backward`, ``false`` otherwise. + + const Tensor& set_requires_grad(bool requires_grad) const { + TensorBase::set_requires_grad(requires_grad); + return *this; + } + + /// Return a mutable reference to the gradient. This is conventionally + /// used as `t.grad() = x` to set a gradient to a completely new tensor. + /// Note that this function work with a non-const Tensor and is not + /// thread safe. + Tensor& mutable_grad() const { + return impl_->mutable_grad(); + } + + /// This function returns an undefined tensor by default and returns a defined tensor + /// the first time a call to `backward()` computes gradients for this Tensor. + /// The attribute will then contain the gradients computed and future calls + /// to `backward()` will accumulate (add) gradients into it. + const Tensor& grad() const { + const Tensor& maybe_grad = impl_->grad(); + if (!is_leaf() && !retains_grad() && !maybe_grad.defined()) { + TORCH_WARN( + "The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad " + "attribute won't be populated during autograd.backward(). If you indeed want the .grad " + "field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. " + "If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor " + "instead. See github.com/pytorch/pytorch/pull/30531 for more informations."); + } + return maybe_grad; + } + + // The Forward AD API functions below are low level and are not to be used by end + // users who should use the API provided in torch/csrc/autograd.h + + /// This function returns the forward gradient for this Tensor at the given level. + const Tensor& _fw_grad(uint64_t level) const { + return impl_->_fw_grad(level, *this); + } + + /// This function can be used to set the value of the forward grad. + /// Note that the given new_grad might not be used directly if it has different + /// metadata (size/stride/storage offset) compared to this Tensor. In that case, + /// new_grad content will be copied into a new Tensor + void _set_fw_grad(const TensorBase& new_grad, uint64_t level, bool is_inplace_op) const { + impl_->_set_fw_grad(new_grad, *this, level, is_inplace_op); + } + + + // STOP. Thinking of adding a method here, which only makes use + // of other ATen methods? Define it in native_functions.yaml. + + //example + //Tensor * add(Tensor & b); + void __dispatch__backward(at::TensorList inputs, const ::std::optional & gradient={}, ::std::optional retain_graph=::std::nullopt, bool create_graph=false) const; + void __dispatch_set_data(const at::Tensor & new_data) const; + at::Tensor __dispatch_data() const; + bool __dispatch_is_leaf() const; + int64_t __dispatch_output_nr() const; + int64_t __dispatch__version() const; + at::Tensor & __dispatch_requires_grad_(bool requires_grad=true) const; + void __dispatch_retain_grad() const; + bool __dispatch_retains_grad() const; + at::Tensor _fw_primal(int64_t level) const; + at::Tensor & rename_(::std::optional names) const; + at::Tensor rename(::std::optional names) const; + at::Tensor align_to(at::DimnameList names) const; + at::Tensor align_to(at::DimnameList order, int64_t ellipsis_idx) const; + at::Tensor align_as(const at::Tensor & other) const; + at::Tensor refine_names(at::DimnameList names) const; + at::Tensor abs() const; + at::Tensor & abs_() const; + at::Tensor absolute() const; + at::Tensor & absolute_() const; + at::Tensor angle() const; + at::Tensor sgn() const; + at::Tensor & sgn_() const; + at::Tensor chalf(::std::optional memory_format=::std::nullopt) const; + at::Tensor _conj() const; + at::Tensor __dispatch_conj() const; + at::Tensor _conj_physical() const; + at::Tensor conj_physical() const; + at::Tensor & conj_physical_() const; + at::Tensor resolve_conj() const; + at::Tensor resolve_neg() const; + at::Tensor _neg_view() const; + at::Tensor acos() const; + at::Tensor & acos_() const; + at::Tensor arccos() const; + at::Tensor & arccos_() const; + at::Tensor add(const at::Tensor & other, const at::Scalar & alpha=1) const; + at::Tensor & add_(const at::Tensor & other, const at::Scalar & alpha=1) const; + at::Tensor add(const at::Scalar & other, const at::Scalar & alpha=1) const; + at::Tensor & add_(const at::Scalar & other, const at::Scalar & alpha=1) const; + at::Tensor addmv(const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta=1, const at::Scalar & alpha=1) const; + at::Tensor & addmv_(const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta=1, const at::Scalar & alpha=1) const; + at::Tensor addr(const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta=1, const at::Scalar & alpha=1) const; + at::Tensor & addr_(const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta=1, const at::Scalar & alpha=1) const; + at::Tensor _is_all_true() const; + at::Tensor _is_any_true() const; + at::Tensor all(int64_t dim, bool keepdim=false) const; + at::Tensor all(at::OptionalIntArrayRef dim, bool keepdim=false) const; + at::Tensor all(at::Dimname dim, bool keepdim=false) const; + bool allclose(const at::Tensor & other, double rtol=1e-05, double atol=1e-08, bool equal_nan=false) const; + at::Tensor any(int64_t dim, bool keepdim=false) const; + at::Tensor any(at::OptionalIntArrayRef dim, bool keepdim=false) const; + at::Tensor any(at::Dimname dim, bool keepdim=false) const; + at::Tensor argmax(::std::optional dim=::std::nullopt, bool keepdim=false) const; + at::Tensor argmin(::std::optional dim=::std::nullopt, bool keepdim=false) const; + at::Tensor acosh() const; + at::Tensor & acosh_() const; + at::Tensor arccosh() const; + at::Tensor & arccosh_() const; + at::Tensor asinh() const; + at::Tensor & asinh_() const; + at::Tensor arcsinh() const; + at::Tensor & arcsinh_() const; + at::Tensor atanh() const; + at::Tensor & atanh_() const; + at::Tensor arctanh() const; + at::Tensor & arctanh_() const; + at::Tensor as_strided(at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt) const; + at::Tensor as_strided_symint(c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt) const; + const at::Tensor & as_strided_(at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt) const; + const at::Tensor & as_strided__symint(c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt) const; + at::Tensor asin() const; + at::Tensor & asin_() const; + at::Tensor arcsin() const; + at::Tensor & arcsin_() const; + at::Tensor atan() const; + at::Tensor & atan_() const; + at::Tensor arctan() const; + at::Tensor & arctan_() const; + at::Tensor baddbmm(const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1) const; + at::Tensor & baddbmm_(const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1) const; + at::Tensor bernoulli(::std::optional generator=::std::nullopt) const; + at::Tensor & bernoulli_(const at::Tensor & p, ::std::optional generator=::std::nullopt) const; + at::Tensor & bernoulli_(double p=0.5, ::std::optional generator=::std::nullopt) const; + at::Tensor bernoulli(double p, ::std::optional generator=::std::nullopt) const; + at::Tensor bincount(const ::std::optional & weights={}, int64_t minlength=0) const; + at::Tensor bincount_symint(const ::std::optional & weights={}, c10::SymInt minlength=0) const; + at::Tensor bitwise_not() const; + at::Tensor & bitwise_not_() const; + at::Tensor copysign(const at::Tensor & other) const; + at::Tensor & copysign_(const at::Tensor & other) const; + at::Tensor copysign(const at::Scalar & other) const; + at::Tensor & copysign_(const at::Scalar & other) const; + at::Tensor _lazy_clone() const; + at::Tensor logical_not() const; + at::Tensor & logical_not_() const; + at::Tensor logical_xor(const at::Tensor & other) const; + at::Tensor & logical_xor_(const at::Tensor & other) const; + at::Tensor logical_and(const at::Tensor & other) const; + at::Tensor & logical_and_(const at::Tensor & other) const; + at::Tensor logical_or(const at::Tensor & other) const; + at::Tensor & logical_or_(const at::Tensor & other) const; + at::Tensor bmm(const at::Tensor & mat2) const; + at::Tensor broadcast_to(at::IntArrayRef size) const; + at::Tensor broadcast_to_symint(c10::SymIntArrayRef size) const; + at::Tensor ceil() const; + at::Tensor & ceil_() const; + ::std::vector unsafe_chunk(int64_t chunks, int64_t dim=0) const; + ::std::vector chunk(int64_t chunks, int64_t dim=0) const; + ::std::vector tensor_split(int64_t sections, int64_t dim=0) const; + ::std::vector tensor_split_symint(c10::SymInt sections, int64_t dim=0) const; + ::std::vector tensor_split(at::IntArrayRef indices, int64_t dim=0) const; + ::std::vector tensor_split_symint(c10::SymIntArrayRef indices, int64_t dim=0) const; + ::std::vector tensor_split(const at::Tensor & tensor_indices_or_sections, int64_t dim=0) const; + at::Tensor clamp(const ::std::optional & min, const ::std::optional & max=::std::nullopt) const; + at::Tensor clamp(const ::std::optional & min={}, const ::std::optional & max={}) const; + at::Tensor & clamp_(const ::std::optional & min, const ::std::optional & max=::std::nullopt) const; + at::Tensor & clamp_(const ::std::optional & min={}, const ::std::optional & max={}) const; + at::Tensor clamp_max(const at::Scalar & max) const; + at::Tensor clamp_max(const at::Tensor & max) const; + at::Tensor & clamp_max_(const at::Scalar & max) const; + at::Tensor & clamp_max_(const at::Tensor & max) const; + at::Tensor clamp_min(const at::Scalar & min) const; + at::Tensor clamp_min(const at::Tensor & min) const; + at::Tensor & clamp_min_(const at::Scalar & min) const; + at::Tensor & clamp_min_(const at::Tensor & min) const; + at::Tensor clip(const ::std::optional & min, const ::std::optional & max=::std::nullopt) const; + at::Tensor clip(const ::std::optional & min={}, const ::std::optional & max={}) const; + at::Tensor & clip_(const ::std::optional & min, const ::std::optional & max=::std::nullopt) const; + at::Tensor & clip_(const ::std::optional & min={}, const ::std::optional & max={}) const; + at::Tensor __dispatch_contiguous(at::MemoryFormat memory_format=c10::MemoryFormat::Contiguous) const; + at::Tensor & copy_(const at::Tensor & src, bool non_blocking=false) const; + at::Tensor cos() const; + at::Tensor & cos_() const; + at::Tensor cosh() const; + at::Tensor & cosh_() const; + at::Tensor count_nonzero(at::IntArrayRef dim) const; + at::Tensor count_nonzero(::std::optional dim=::std::nullopt) const; + at::Tensor cov(int64_t correction=1, const ::std::optional & fweights={}, const ::std::optional & aweights={}) const; + at::Tensor corrcoef() const; + ::std::tuple cummax(int64_t dim) const; + ::std::tuple cummax(at::Dimname dim) const; + ::std::tuple cummin(int64_t dim) const; + ::std::tuple cummin(at::Dimname dim) const; + at::Tensor cumprod(int64_t dim, ::std::optional dtype=::std::nullopt) const; + at::Tensor & cumprod_(int64_t dim, ::std::optional dtype=::std::nullopt) const; + at::Tensor cumprod(at::Dimname dim, ::std::optional dtype=::std::nullopt) const; + at::Tensor & cumprod_(at::Dimname dim, ::std::optional dtype=::std::nullopt) const; + at::Tensor cumsum(int64_t dim, ::std::optional dtype=::std::nullopt) const; + at::Tensor & cumsum_(int64_t dim, ::std::optional dtype=::std::nullopt) const; + at::Tensor cumsum(at::Dimname dim, ::std::optional dtype=::std::nullopt) const; + at::Tensor & cumsum_(at::Dimname dim, ::std::optional dtype=::std::nullopt) const; + at::Tensor diag_embed(int64_t offset=0, int64_t dim1=-2, int64_t dim2=-1) const; + at::Tensor diagflat(int64_t offset=0) const; + at::Tensor diagonal(int64_t offset=0, int64_t dim1=0, int64_t dim2=1) const; + at::Tensor diagonal(at::Dimname outdim, at::Dimname dim1, at::Dimname dim2, int64_t offset=0) const; + at::Tensor & fill_diagonal_(const at::Scalar & fill_value, bool wrap=false) const; + at::Tensor diff(int64_t n=1, int64_t dim=-1, const ::std::optional & prepend={}, const ::std::optional & append={}) const; + at::Tensor div(const at::Tensor & other) const; + at::Tensor & div_(const at::Tensor & other) const; + at::Tensor div(const at::Tensor & other, ::std::optional rounding_mode) const; + at::Tensor & div_(const at::Tensor & other, ::std::optional rounding_mode) const; + at::Tensor div(const at::Scalar & other) const; + at::Tensor & div_(const at::Scalar & other) const; + at::Tensor div(const at::Scalar & other, ::std::optional rounding_mode) const; + at::Tensor & div_(const at::Scalar & other, ::std::optional rounding_mode) const; + at::Tensor divide(const at::Tensor & other) const; + at::Tensor & divide_(const at::Tensor & other) const; + at::Tensor divide(const at::Scalar & other) const; + at::Tensor & divide_(const at::Scalar & other) const; + at::Tensor divide(const at::Tensor & other, ::std::optional rounding_mode) const; + at::Tensor & divide_(const at::Tensor & other, ::std::optional rounding_mode) const; + at::Tensor divide(const at::Scalar & other, ::std::optional rounding_mode) const; + at::Tensor & divide_(const at::Scalar & other, ::std::optional rounding_mode) const; + at::Tensor true_divide(const at::Tensor & other) const; + at::Tensor & true_divide_(const at::Tensor & other) const; + at::Tensor true_divide(const at::Scalar & other) const; + at::Tensor & true_divide_(const at::Scalar & other) const; + at::Tensor dot(const at::Tensor & tensor) const; + at::Tensor vdot(const at::Tensor & other) const; + at::Tensor new_empty(at::IntArrayRef size, at::TensorOptions options={}) const; + at::Tensor new_empty(at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const; + at::Tensor new_empty_symint(c10::SymIntArrayRef size, at::TensorOptions options={}) const; + at::Tensor new_empty_symint(c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const; + at::Tensor new_empty_strided(at::IntArrayRef size, at::IntArrayRef stride, at::TensorOptions options={}) const; + at::Tensor new_empty_strided(at::IntArrayRef size, at::IntArrayRef stride, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const; + at::Tensor new_empty_strided_symint(c10::SymIntArrayRef size, c10::SymIntArrayRef stride, at::TensorOptions options={}) const; + at::Tensor new_empty_strided_symint(c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const; + at::Tensor new_full(at::IntArrayRef size, const at::Scalar & fill_value, at::TensorOptions options={}) const; + at::Tensor new_full(at::IntArrayRef size, const at::Scalar & fill_value, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const; + at::Tensor new_full_symint(c10::SymIntArrayRef size, const at::Scalar & fill_value, at::TensorOptions options={}) const; + at::Tensor new_full_symint(c10::SymIntArrayRef size, const at::Scalar & fill_value, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const; + at::Tensor new_zeros(at::IntArrayRef size, at::TensorOptions options={}) const; + at::Tensor new_zeros(at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const; + at::Tensor new_zeros_symint(c10::SymIntArrayRef size, at::TensorOptions options={}) const; + at::Tensor new_zeros_symint(c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const; + at::Tensor new_ones(at::IntArrayRef size, at::TensorOptions options={}) const; + at::Tensor new_ones(at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const; + at::Tensor new_ones_symint(c10::SymIntArrayRef size, at::TensorOptions options={}) const; + at::Tensor new_ones_symint(c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const; + const at::Tensor & resize_(at::IntArrayRef size, ::std::optional memory_format=::std::nullopt) const; + const at::Tensor & resize__symint(c10::SymIntArrayRef size, ::std::optional memory_format=::std::nullopt) const; + at::Tensor erf() const; + at::Tensor & erf_() const; + at::Tensor erfc() const; + at::Tensor & erfc_() const; + at::Tensor exp() const; + at::Tensor & exp_() const; + at::Tensor exp2() const; + at::Tensor & exp2_() const; + at::Tensor expm1() const; + at::Tensor & expm1_() const; + at::Tensor expand(at::IntArrayRef size, bool implicit=false) const; + at::Tensor expand_symint(c10::SymIntArrayRef size, bool implicit=false) const; + at::Tensor expand_as(const at::Tensor & other) const; + at::Tensor flatten(int64_t start_dim=0, int64_t end_dim=-1) const; + at::Tensor flatten(int64_t start_dim, int64_t end_dim, at::Dimname out_dim) const; + at::Tensor flatten(at::Dimname start_dim, at::Dimname end_dim, at::Dimname out_dim) const; + at::Tensor flatten(at::DimnameList dims, at::Dimname out_dim) const; + at::Tensor unflatten(int64_t dim, at::IntArrayRef sizes) const; + at::Tensor unflatten_symint(int64_t dim, c10::SymIntArrayRef sizes) const; + at::Tensor unflatten(at::Dimname dim, at::IntArrayRef sizes, at::DimnameList names) const; + at::Tensor unflatten_symint(at::Dimname dim, c10::SymIntArrayRef sizes, at::DimnameList names) const; + at::Tensor & fill_(const at::Scalar & value) const; + at::Tensor & fill_(const at::Tensor & value) const; + at::Tensor floor() const; + at::Tensor & floor_() const; + at::Tensor floor_divide(const at::Tensor & other) const; + at::Tensor & floor_divide_(const at::Tensor & other) const; + at::Tensor floor_divide(const at::Scalar & other) const; + at::Tensor & floor_divide_(const at::Scalar & other) const; + at::Tensor frac() const; + at::Tensor & frac_() const; + at::Tensor gcd(const at::Tensor & other) const; + at::Tensor & gcd_(const at::Tensor & other) const; + at::Tensor lcm(const at::Tensor & other) const; + at::Tensor & lcm_(const at::Tensor & other) const; + at::Tensor index(const c10::List<::std::optional> & indices) const; + at::Tensor & index_copy_(int64_t dim, const at::Tensor & index, const at::Tensor & source) const; + at::Tensor index_copy(int64_t dim, const at::Tensor & index, const at::Tensor & source) const; + at::Tensor & index_copy_(at::Dimname dim, const at::Tensor & index, const at::Tensor & source) const; + at::Tensor index_copy(at::Dimname dim, const at::Tensor & index, const at::Tensor & source) const; + at::Tensor & index_put_(const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate=false) const; + at::Tensor index_put(const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate=false) const; + at::Tensor isclose(const at::Tensor & other, double rtol=1e-05, double atol=1e-08, bool equal_nan=false) const; + at::Tensor isnan() const; + bool is_distributed() const; + bool __dispatch_is_floating_point() const; + bool __dispatch_is_complex() const; + bool __dispatch_is_conj() const; + bool __dispatch__is_zerotensor() const; + bool __dispatch_is_neg() const; + at::Tensor isreal() const; + bool is_nonzero() const; + bool is_same_size(const at::Tensor & other) const; + bool __dispatch_is_signed() const; + bool __dispatch_is_inference() const; + at::Tensor kron(const at::Tensor & other) const; + ::std::tuple kthvalue(int64_t k, int64_t dim=-1, bool keepdim=false) const; + ::std::tuple kthvalue_symint(c10::SymInt k, int64_t dim=-1, bool keepdim=false) const; + ::std::tuple kthvalue(int64_t k, at::Dimname dim, bool keepdim=false) const; + ::std::tuple kthvalue_symint(c10::SymInt k, at::Dimname dim, bool keepdim=false) const; + at::Tensor nan_to_num(::std::optional nan=::std::nullopt, ::std::optional posinf=::std::nullopt, ::std::optional neginf=::std::nullopt) const; + at::Tensor & nan_to_num_(::std::optional nan=::std::nullopt, ::std::optional posinf=::std::nullopt, ::std::optional neginf=::std::nullopt) const; + at::Tensor ldexp(const at::Tensor & other) const; + at::Tensor & ldexp_(const at::Tensor & other) const; + at::Tensor log() const; + at::Tensor & log_() const; + at::Tensor log10() const; + at::Tensor & log10_() const; + at::Tensor log1p() const; + at::Tensor & log1p_() const; + at::Tensor log2() const; + at::Tensor & log2_() const; + at::Tensor logaddexp(const at::Tensor & other) const; + at::Tensor logaddexp2(const at::Tensor & other) const; + at::Tensor xlogy(const at::Tensor & other) const; + at::Tensor xlogy(const at::Scalar & other) const; + at::Tensor & xlogy_(const at::Tensor & other) const; + at::Tensor & xlogy_(const at::Scalar & other) const; + at::Tensor log_softmax(int64_t dim, ::std::optional dtype=::std::nullopt) const; + at::Tensor log_softmax(at::Dimname dim, ::std::optional dtype=::std::nullopt) const; + at::Tensor logcumsumexp(int64_t dim) const; + at::Tensor logcumsumexp(at::Dimname dim) const; + at::Tensor logsumexp(at::IntArrayRef dim, bool keepdim=false) const; + at::Tensor logsumexp(at::DimnameList dim, bool keepdim=false) const; + at::Tensor matmul(const at::Tensor & other) const; + at::Tensor matrix_power(int64_t n) const; + at::Tensor matrix_exp() const; + ::std::tuple aminmax(::std::optional dim=::std::nullopt, bool keepdim=false) const; + ::std::tuple max(int64_t dim, bool keepdim=false) const; + ::std::tuple max(at::Dimname dim, bool keepdim=false) const; + at::Tensor amax(at::IntArrayRef dim={}, bool keepdim=false) const; + at::Tensor mean(::std::optional dtype=::std::nullopt) const; + at::Tensor mean(at::OptionalIntArrayRef dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) const; + at::Tensor mean(at::DimnameList dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) const; + at::Tensor nanmean(at::OptionalIntArrayRef dim=::std::nullopt, bool keepdim=false, ::std::optional dtype=::std::nullopt) const; + at::Tensor median() const; + ::std::tuple median(int64_t dim, bool keepdim=false) const; + ::std::tuple median(at::Dimname dim, bool keepdim=false) const; + at::Tensor nanmedian() const; + ::std::tuple nanmedian(int64_t dim, bool keepdim=false) const; + ::std::tuple nanmedian(at::Dimname dim, bool keepdim=false) const; + ::std::tuple min(int64_t dim, bool keepdim=false) const; + ::std::tuple min(at::Dimname dim, bool keepdim=false) const; + at::Tensor amin(at::IntArrayRef dim={}, bool keepdim=false) const; + at::Tensor mm(const at::Tensor & mat2) const; + ::std::tuple mode(int64_t dim=-1, bool keepdim=false) const; + ::std::tuple mode(at::Dimname dim, bool keepdim=false) const; + at::Tensor mul(const at::Tensor & other) const; + at::Tensor & mul_(const at::Tensor & other) const; + at::Tensor mul(const at::Scalar & other) const; + at::Tensor & mul_(const at::Scalar & other) const; + at::Tensor multiply(const at::Tensor & other) const; + at::Tensor & multiply_(const at::Tensor & other) const; + at::Tensor multiply(const at::Scalar & other) const; + at::Tensor & multiply_(const at::Scalar & other) const; + at::Tensor mv(const at::Tensor & vec) const; + at::Tensor mvlgamma(int64_t p) const; + at::Tensor & mvlgamma_(int64_t p) const; + at::Tensor narrow_copy(int64_t dim, int64_t start, int64_t length) const; + at::Tensor narrow_copy_symint(int64_t dim, c10::SymInt start, c10::SymInt length) const; + at::Tensor narrow(int64_t dim, int64_t start, int64_t length) const; + at::Tensor narrow_symint(int64_t dim, c10::SymInt start, c10::SymInt length) const; + at::Tensor narrow(int64_t dim, const at::Tensor & start, int64_t length) const; + at::Tensor narrow_symint(int64_t dim, const at::Tensor & start, c10::SymInt length) const; + at::Tensor permute(at::IntArrayRef dims) const; + at::Tensor movedim(at::IntArrayRef source, at::IntArrayRef destination) const; + at::Tensor movedim(int64_t source, int64_t destination) const; + at::Tensor moveaxis(at::IntArrayRef source, at::IntArrayRef destination) const; + at::Tensor moveaxis(int64_t source, int64_t destination) const; + at::Tensor numpy_T() const; + at::Tensor matrix_H() const; + at::Tensor mT() const; + at::Tensor mH() const; + at::Tensor adjoint() const; + bool is_pinned(::std::optional device=::std::nullopt) const; + at::Tensor pin_memory(::std::optional device=::std::nullopt) const; + at::Tensor pinverse(double rcond=1e-15) const; + at::Tensor rad2deg() const; + at::Tensor & rad2deg_() const; + at::Tensor deg2rad() const; + at::Tensor & deg2rad_() const; + at::Tensor ravel() const; + at::Tensor reciprocal() const; + at::Tensor & reciprocal_() const; + at::Tensor neg() const; + at::Tensor & neg_() const; + at::Tensor negative() const; + at::Tensor & negative_() const; + at::Tensor repeat(at::IntArrayRef repeats) const; + at::Tensor repeat_symint(c10::SymIntArrayRef repeats) const; + at::Tensor repeat_interleave(const at::Tensor & repeats, ::std::optional dim=::std::nullopt, ::std::optional output_size=::std::nullopt) const; + at::Tensor repeat_interleave_symint(const at::Tensor & repeats, ::std::optional dim=::std::nullopt, ::std::optional output_size=::std::nullopt) const; + at::Tensor repeat_interleave(int64_t repeats, ::std::optional dim=::std::nullopt, ::std::optional output_size=::std::nullopt) const; + at::Tensor repeat_interleave_symint(c10::SymInt repeats, ::std::optional dim=::std::nullopt, ::std::optional output_size=::std::nullopt) const; + at::Tensor reshape(at::IntArrayRef shape) const; + at::Tensor reshape_symint(c10::SymIntArrayRef shape) const; + at::Tensor _reshape_alias(at::IntArrayRef size, at::IntArrayRef stride) const; + at::Tensor _reshape_alias_symint(c10::SymIntArrayRef size, c10::SymIntArrayRef stride) const; + at::Tensor reshape_as(const at::Tensor & other) const; + at::Tensor round() const; + at::Tensor & round_() const; + at::Tensor round(int64_t decimals) const; + at::Tensor & round_(int64_t decimals) const; + at::Tensor relu() const; + at::Tensor & relu_() const; + at::Tensor prelu(const at::Tensor & weight) const; + at::Tensor hardshrink(const at::Scalar & lambd=0.5) const; + at::Tensor hardshrink_backward(const at::Tensor & grad_out, const at::Scalar & lambd) const; + at::Tensor rsqrt() const; + at::Tensor & rsqrt_() const; + at::Tensor select(at::Dimname dim, int64_t index) const; + at::Tensor select(int64_t dim, int64_t index) const; + at::Tensor select_symint(int64_t dim, c10::SymInt index) const; + at::Tensor sigmoid() const; + at::Tensor & sigmoid_() const; + at::Tensor logit(::std::optional eps=::std::nullopt) const; + at::Tensor & logit_(::std::optional eps=::std::nullopt) const; + at::Tensor sin() const; + at::Tensor & sin_() const; + at::Tensor sinc() const; + at::Tensor & sinc_() const; + at::Tensor sinh() const; + at::Tensor & sinh_() const; + at::Tensor detach() const; + at::Tensor & detach_() const; + int64_t size(at::Dimname dim) const; + at::Tensor slice(int64_t dim=0, ::std::optional start=::std::nullopt, ::std::optional end=::std::nullopt, int64_t step=1) const; + at::Tensor slice_symint(int64_t dim=0, ::std::optional start=::std::nullopt, ::std::optional end=::std::nullopt, c10::SymInt step=1) const; + at::Tensor slice_inverse(const at::Tensor & src, int64_t dim=0, ::std::optional start=::std::nullopt, ::std::optional end=::std::nullopt, int64_t step=1) const; + at::Tensor slice_inverse_symint(const at::Tensor & src, int64_t dim=0, ::std::optional start=::std::nullopt, ::std::optional end=::std::nullopt, c10::SymInt step=1) const; + at::Tensor slice_scatter(const at::Tensor & src, int64_t dim=0, ::std::optional start=::std::nullopt, ::std::optional end=::std::nullopt, int64_t step=1) const; + at::Tensor slice_scatter_symint(const at::Tensor & src, int64_t dim=0, ::std::optional start=::std::nullopt, ::std::optional end=::std::nullopt, c10::SymInt step=1) const; + at::Tensor select_scatter(const at::Tensor & src, int64_t dim, int64_t index) const; + at::Tensor select_scatter_symint(const at::Tensor & src, int64_t dim, c10::SymInt index) const; + at::Tensor diagonal_scatter(const at::Tensor & src, int64_t offset=0, int64_t dim1=0, int64_t dim2=1) const; + at::Tensor as_strided_scatter(const at::Tensor & src, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt) const; + at::Tensor as_strided_scatter_symint(const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt) const; + at::Tensor smm(const at::Tensor & mat2) const; + at::Tensor softmax(int64_t dim, ::std::optional dtype=::std::nullopt) const; + at::Tensor softmax(at::Dimname dim, ::std::optional dtype=::std::nullopt) const; + ::std::vector unsafe_split(int64_t split_size, int64_t dim=0) const; + ::std::vector unsafe_split_symint(c10::SymInt split_size, int64_t dim=0) const; + ::std::vector split(int64_t split_size, int64_t dim=0) const; + ::std::vector split_symint(c10::SymInt split_size, int64_t dim=0) const; + ::std::vector split(at::IntArrayRef split_size, int64_t dim=0) const; + ::std::vector split_symint(c10::SymIntArrayRef split_size, int64_t dim=0) const; + ::std::vector unsafe_split_with_sizes(at::IntArrayRef split_sizes, int64_t dim=0) const; + ::std::vector unsafe_split_with_sizes_symint(c10::SymIntArrayRef split_sizes, int64_t dim=0) const; + ::std::vector split_with_sizes(at::IntArrayRef split_sizes, int64_t dim=0) const; + ::std::vector split_with_sizes_symint(c10::SymIntArrayRef split_sizes, int64_t dim=0) const; + ::std::vector hsplit(int64_t sections) const; + ::std::vector hsplit(at::IntArrayRef indices) const; + ::std::vector vsplit(int64_t sections) const; + ::std::vector vsplit(at::IntArrayRef indices) const; + ::std::vector dsplit(int64_t sections) const; + ::std::vector dsplit(at::IntArrayRef indices) const; + at::Tensor squeeze() const; + at::Tensor squeeze(int64_t dim) const; + at::Tensor squeeze(at::Dimname dim) const; + at::Tensor squeeze(at::IntArrayRef dim) const; + at::Tensor & squeeze_() const; + at::Tensor & squeeze_(int64_t dim) const; + at::Tensor & squeeze_(at::IntArrayRef dim) const; + at::Tensor & squeeze_(at::Dimname dim) const; + at::Tensor sspaddmm(const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) const; + at::Tensor stft(int64_t n_fft, ::std::optional hop_length, ::std::optional win_length, const ::std::optional & window, bool normalized, ::std::optional onesided=::std::nullopt, ::std::optional return_complex=::std::nullopt, ::std::optional align_to_window=::std::nullopt) const; + at::Tensor stft(int64_t n_fft, ::std::optional hop_length=::std::nullopt, ::std::optional win_length=::std::nullopt, const ::std::optional & window={}, bool center=true, c10::string_view pad_mode="reflect", bool normalized=false, ::std::optional onesided=::std::nullopt, ::std::optional return_complex=::std::nullopt, ::std::optional align_to_window=::std::nullopt) const; + at::Tensor istft(int64_t n_fft, ::std::optional hop_length=::std::nullopt, ::std::optional win_length=::std::nullopt, const ::std::optional & window={}, bool center=true, bool normalized=false, ::std::optional onesided=::std::nullopt, ::std::optional length=::std::nullopt, bool return_complex=false) const; + int64_t stride(at::Dimname dim) const; + at::Tensor sum(::std::optional dtype=::std::nullopt) const; + at::Tensor sum(at::OptionalIntArrayRef dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) const; + at::Tensor sum(at::DimnameList dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) const; + at::Tensor nansum(at::OptionalIntArrayRef dim=::std::nullopt, bool keepdim=false, ::std::optional dtype=::std::nullopt) const; + at::Tensor sum_to_size(at::IntArrayRef size) const; + at::Tensor sum_to_size_symint(c10::SymIntArrayRef size) const; + at::Tensor sqrt() const; + at::Tensor & sqrt_() const; + at::Tensor square() const; + at::Tensor & square_() const; + at::Tensor std(bool unbiased) const; + at::Tensor std(at::OptionalIntArrayRef dim, bool unbiased, bool keepdim=false) const; + at::Tensor std(at::OptionalIntArrayRef dim=::std::nullopt, const ::std::optional & correction=::std::nullopt, bool keepdim=false) const; + at::Tensor std(at::DimnameList dim, bool unbiased, bool keepdim=false) const; + at::Tensor std(at::DimnameList dim, const ::std::optional & correction=::std::nullopt, bool keepdim=false) const; + at::Tensor prod(::std::optional dtype=::std::nullopt) const; + at::Tensor prod(int64_t dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) const; + at::Tensor prod(at::Dimname dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) const; + at::Tensor t() const; + at::Tensor & t_() const; + at::Tensor tan() const; + at::Tensor & tan_() const; + at::Tensor tanh() const; + at::Tensor & tanh_() const; + at::Tensor tile(at::IntArrayRef dims) const; + at::Tensor tile_symint(c10::SymIntArrayRef dims) const; + at::Tensor transpose(int64_t dim0, int64_t dim1) const; + at::Tensor transpose(at::Dimname dim0, at::Dimname dim1) const; + at::Tensor & transpose_(int64_t dim0, int64_t dim1) const; + at::Tensor flip(at::IntArrayRef dims) const; + at::Tensor fliplr() const; + at::Tensor flipud() const; + at::Tensor roll(at::IntArrayRef shifts, at::IntArrayRef dims={}) const; + at::Tensor roll_symint(c10::SymIntArrayRef shifts, at::IntArrayRef dims={}) const; + at::Tensor rot90(int64_t k=1, at::IntArrayRef dims={0,1}) const; + at::Tensor _nested_tensor_size() const; + at::Tensor _nested_tensor_strides() const; + at::Tensor _nested_tensor_storage_offsets() const; + at::Tensor trunc() const; + at::Tensor & trunc_() const; + at::Tensor fix() const; + at::Tensor & fix_() const; + at::Tensor type_as(const at::Tensor & other) const; + at::Tensor unsqueeze(int64_t dim) const; + at::Tensor & unsqueeze_(int64_t dim) const; + at::Tensor var(bool unbiased) const; + at::Tensor var(at::OptionalIntArrayRef dim, bool unbiased, bool keepdim=false) const; + at::Tensor var(at::OptionalIntArrayRef dim=::std::nullopt, const ::std::optional & correction=::std::nullopt, bool keepdim=false) const; + at::Tensor var(at::DimnameList dim, bool unbiased, bool keepdim=false) const; + at::Tensor var(at::DimnameList dim, const ::std::optional & correction=::std::nullopt, bool keepdim=false) const; + at::Tensor view_as(const at::Tensor & other) const; + at::Tensor where(const at::Tensor & condition, const at::Tensor & other) const; + at::Tensor where(const at::Tensor & condition, const at::Scalar & other) const; + at::Tensor norm(const ::std::optional & p, at::ScalarType dtype) const; + at::Tensor norm(const at::Scalar & p=2) const; + at::Tensor norm(const ::std::optional & p, at::IntArrayRef dim, bool keepdim, at::ScalarType dtype) const; + at::Tensor norm(const ::std::optional & p, at::IntArrayRef dim, bool keepdim=false) const; + at::Tensor norm(const ::std::optional & p, at::DimnameList dim, bool keepdim, at::ScalarType dtype) const; + at::Tensor norm(const ::std::optional & p, at::DimnameList dim, bool keepdim=false) const; + ::std::tuple frexp() const; + at::Tensor clone(::std::optional memory_format=::std::nullopt) const; + at::Tensor positive() const; + const at::Tensor & resize_as_(const at::Tensor & the_template, ::std::optional memory_format=::std::nullopt) const; + const at::Tensor & resize_as_sparse_(const at::Tensor & the_template) const; + at::Tensor & zero_() const; + at::Tensor sub(const at::Tensor & other, const at::Scalar & alpha=1) const; + at::Tensor & sub_(const at::Tensor & other, const at::Scalar & alpha=1) const; + at::Tensor sub(const at::Scalar & other, const at::Scalar & alpha=1) const; + at::Tensor & sub_(const at::Scalar & other, const at::Scalar & alpha=1) const; + at::Tensor subtract(const at::Tensor & other, const at::Scalar & alpha=1) const; + at::Tensor & subtract_(const at::Tensor & other, const at::Scalar & alpha=1) const; + at::Tensor subtract(const at::Scalar & other, const at::Scalar & alpha=1) const; + at::Tensor & subtract_(const at::Scalar & other, const at::Scalar & alpha=1) const; + at::Tensor heaviside(const at::Tensor & values) const; + at::Tensor & heaviside_(const at::Tensor & values) const; + at::Tensor addmm(const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) const; + at::Tensor & addmm_(const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) const; + at::Tensor _addmm_activation(const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1, bool use_gelu=false) const; + const at::Tensor & sparse_resize_(at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) const; + const at::Tensor & sparse_resize_and_clear_(at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) const; + at::Tensor sparse_mask(const at::Tensor & mask) const; + at::Tensor _sparse_mask_projection(const at::Tensor & mask, bool accumulate_matches=false) const; + at::Tensor to_dense(::std::optional dtype=::std::nullopt, ::std::optional masked_grad=::std::nullopt) const; + at::Tensor _to_dense(::std::optional dtype=::std::nullopt, ::std::optional masked_grad=::std::nullopt) const; + int64_t sparse_dim() const; + int64_t _dimI() const; + int64_t dense_dim() const; + int64_t _dimV() const; + int64_t _nnz() const; + at::Tensor coalesce() const; + bool is_coalesced() const; + at::Tensor _indices() const; + at::Tensor _values() const; + at::Tensor & _coalesced_(bool coalesced) const; + at::Tensor indices() const; + at::Tensor values() const; + at::Tensor crow_indices() const; + at::Tensor col_indices() const; + at::Tensor ccol_indices() const; + at::Tensor row_indices() const; + ::std::vector unbind(int64_t dim=0) const; + ::std::vector unbind(at::Dimname dim) const; + at::Tensor to_sparse(int64_t sparse_dim) const; + at::Tensor _to_sparse(int64_t sparse_dim) const; + at::Tensor to_sparse(::std::optional layout=::std::nullopt, at::OptionalIntArrayRef blocksize=::std::nullopt, ::std::optional dense_dim=::std::nullopt) const; + at::Tensor _to_sparse(::std::optional layout=::std::nullopt, at::OptionalIntArrayRef blocksize=::std::nullopt, ::std::optional dense_dim=::std::nullopt) const; + at::Tensor to_sparse_csr(::std::optional dense_dim=::std::nullopt) const; + at::Tensor _to_sparse_csr(::std::optional dense_dim=::std::nullopt) const; + at::Tensor to_sparse_csc(::std::optional dense_dim=::std::nullopt) const; + at::Tensor _to_sparse_csc(::std::optional dense_dim=::std::nullopt) const; + at::Tensor to_sparse_bsr(at::IntArrayRef blocksize, ::std::optional dense_dim=::std::nullopt) const; + at::Tensor _to_sparse_bsr(at::IntArrayRef blocksize, ::std::optional dense_dim=::std::nullopt) const; + at::Tensor to_sparse_bsc(at::IntArrayRef blocksize, ::std::optional dense_dim=::std::nullopt) const; + at::Tensor _to_sparse_bsc(at::IntArrayRef blocksize, ::std::optional dense_dim=::std::nullopt) const; + at::Tensor to_mkldnn(::std::optional dtype=::std::nullopt) const; + at::Tensor dequantize() const; + double q_scale() const; + int64_t q_zero_point() const; + at::Tensor q_per_channel_scales() const; + at::Tensor q_per_channel_zero_points() const; + int64_t q_per_channel_axis() const; + at::Tensor int_repr() const; + at::QScheme qscheme() const; + at::Tensor _autocast_to_reduced_precision(bool cuda_enabled, bool cpu_enabled, at::ScalarType cuda_dtype, at::ScalarType cpu_dtype) const; + at::Tensor _autocast_to_full_precision(bool cuda_enabled, bool cpu_enabled) const; + at::Tensor to(at::TensorOptions options={}, bool non_blocking=false, bool copy=false, ::std::optional memory_format=::std::nullopt) const; + at::Tensor to(::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, bool non_blocking, bool copy, ::std::optional memory_format) const; + at::Tensor to(at::Device device, at::ScalarType dtype, bool non_blocking=false, bool copy=false, ::std::optional memory_format=::std::nullopt) const; + at::Tensor to(at::ScalarType dtype, bool non_blocking=false, bool copy=false, ::std::optional memory_format=::std::nullopt) const; + at::Tensor to(const at::Tensor & other, bool non_blocking=false, bool copy=false, ::std::optional memory_format=::std::nullopt) const; + at::Scalar item() const; + at::Tensor & set_(at::Storage source) const; + at::Tensor & set_(at::Storage source, int64_t storage_offset, at::IntArrayRef size, at::IntArrayRef stride={}) const; + at::Tensor & set__symint(at::Storage source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride={}) const; + at::Tensor & set_(const at::Tensor & source, int64_t storage_offset, at::IntArrayRef size, at::IntArrayRef stride={}) const; + at::Tensor & set__symint(const at::Tensor & source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride={}) const; + at::Tensor & set_(const at::Tensor & source) const; + at::Tensor & set_() const; + bool is_set_to(const at::Tensor & tensor) const; + at::Tensor & masked_fill_(const at::Tensor & mask, const at::Scalar & value) const; + at::Tensor masked_fill(const at::Tensor & mask, const at::Scalar & value) const; + at::Tensor & masked_fill_(const at::Tensor & mask, const at::Tensor & value) const; + at::Tensor masked_fill(const at::Tensor & mask, const at::Tensor & value) const; + at::Tensor & masked_scatter_(const at::Tensor & mask, const at::Tensor & source) const; + at::Tensor masked_scatter(const at::Tensor & mask, const at::Tensor & source) const; + at::Tensor view(at::IntArrayRef size) const; + at::Tensor view_symint(c10::SymIntArrayRef size) const; + at::Tensor view(at::ScalarType dtype) const; + at::Tensor & put_(const at::Tensor & index, const at::Tensor & source, bool accumulate=false) const; + at::Tensor put(const at::Tensor & index, const at::Tensor & source, bool accumulate=false) const; + at::Tensor & index_add_(int64_t dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha=1) const; + at::Tensor index_add(int64_t dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha=1) const; + at::Tensor index_add(at::Dimname dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha=1) const; + at::Tensor & index_reduce_(int64_t dim, const at::Tensor & index, const at::Tensor & source, c10::string_view reduce, bool include_self=true) const; + at::Tensor index_reduce(int64_t dim, const at::Tensor & index, const at::Tensor & source, c10::string_view reduce, bool include_self=true) const; + at::Tensor & index_fill_(int64_t dim, const at::Tensor & index, const at::Scalar & value) const; + at::Tensor index_fill(int64_t dim, const at::Tensor & index, const at::Scalar & value) const; + at::Tensor & index_fill_(int64_t dim, const at::Tensor & index, const at::Tensor & value) const; + at::Tensor index_fill(int64_t dim, const at::Tensor & index, const at::Tensor & value) const; + at::Tensor & index_fill_(at::Dimname dim, const at::Tensor & index, const at::Scalar & value) const; + at::Tensor & index_fill_(at::Dimname dim, const at::Tensor & index, const at::Tensor & value) const; + at::Tensor index_fill(at::Dimname dim, const at::Tensor & index, const at::Scalar & value) const; + at::Tensor index_fill(at::Dimname dim, const at::Tensor & index, const at::Tensor & value) const; + at::Tensor scatter(int64_t dim, const at::Tensor & index, const at::Tensor & src) const; + at::Tensor & scatter_(int64_t dim, const at::Tensor & index, const at::Tensor & src) const; + at::Tensor scatter(int64_t dim, const at::Tensor & index, const at::Scalar & value) const; + at::Tensor & scatter_(int64_t dim, const at::Tensor & index, const at::Scalar & value) const; + at::Tensor scatter(int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce) const; + at::Tensor & scatter_(int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce) const; + at::Tensor scatter(int64_t dim, const at::Tensor & index, const at::Scalar & value, c10::string_view reduce) const; + at::Tensor & scatter_(int64_t dim, const at::Tensor & index, const at::Scalar & value, c10::string_view reduce) const; + at::Tensor scatter(at::Dimname dim, const at::Tensor & index, const at::Tensor & src) const; + at::Tensor scatter(at::Dimname dim, const at::Tensor & index, const at::Scalar & value) const; + at::Tensor scatter_add(int64_t dim, const at::Tensor & index, const at::Tensor & src) const; + at::Tensor & scatter_add_(int64_t dim, const at::Tensor & index, const at::Tensor & src) const; + at::Tensor scatter_add(at::Dimname dim, const at::Tensor & index, const at::Tensor & src) const; + at::Tensor scatter_reduce(int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self=true) const; + at::Tensor & scatter_reduce_(int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self=true) const; + at::Tensor & eq_(const at::Scalar & other) const; + at::Tensor & eq_(const at::Tensor & other) const; + at::Tensor bitwise_and(const at::Scalar & other) const; + at::Tensor bitwise_and(const at::Tensor & other) const; + at::Tensor & bitwise_and_(const at::Scalar & other) const; + at::Tensor & bitwise_and_(const at::Tensor & other) const; + at::Tensor __and__(const at::Scalar & other) const; + at::Tensor __and__(const at::Tensor & other) const; + at::Tensor & __iand__(const at::Scalar & other) const; + at::Tensor & __iand__(const at::Tensor & other) const; + at::Tensor bitwise_or(const at::Scalar & other) const; + at::Tensor bitwise_or(const at::Tensor & other) const; + at::Tensor & bitwise_or_(const at::Scalar & other) const; + at::Tensor & bitwise_or_(const at::Tensor & other) const; + at::Tensor __or__(const at::Scalar & other) const; + at::Tensor __or__(const at::Tensor & other) const; + at::Tensor & __ior__(const at::Scalar & other) const; + at::Tensor & __ior__(const at::Tensor & other) const; + at::Tensor bitwise_xor(const at::Scalar & other) const; + at::Tensor bitwise_xor(const at::Tensor & other) const; + at::Tensor & bitwise_xor_(const at::Scalar & other) const; + at::Tensor & bitwise_xor_(const at::Tensor & other) const; + at::Tensor __xor__(const at::Scalar & other) const; + at::Tensor __xor__(const at::Tensor & other) const; + at::Tensor & __ixor__(const at::Scalar & other) const; + at::Tensor & __ixor__(const at::Tensor & other) const; + at::Tensor __lshift__(const at::Scalar & other) const; + at::Tensor __lshift__(const at::Tensor & other) const; + at::Tensor & __ilshift__(const at::Scalar & other) const; + at::Tensor & __ilshift__(const at::Tensor & other) const; + at::Tensor bitwise_left_shift(const at::Tensor & other) const; + at::Tensor & bitwise_left_shift_(const at::Tensor & other) const; + at::Tensor bitwise_left_shift(const at::Scalar & other) const; + at::Tensor & bitwise_left_shift_(const at::Scalar & other) const; + at::Tensor __rshift__(const at::Scalar & other) const; + at::Tensor __rshift__(const at::Tensor & other) const; + at::Tensor & __irshift__(const at::Scalar & other) const; + at::Tensor & __irshift__(const at::Tensor & other) const; + at::Tensor bitwise_right_shift(const at::Tensor & other) const; + at::Tensor & bitwise_right_shift_(const at::Tensor & other) const; + at::Tensor bitwise_right_shift(const at::Scalar & other) const; + at::Tensor & bitwise_right_shift_(const at::Scalar & other) const; + at::Tensor & tril_(int64_t diagonal=0) const; + at::Tensor & triu_(int64_t diagonal=0) const; + at::Tensor & digamma_() const; + at::Tensor & lerp_(const at::Tensor & end, const at::Scalar & weight) const; + at::Tensor & lerp_(const at::Tensor & end, const at::Tensor & weight) const; + at::Tensor & addbmm_(const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1) const; + at::Tensor addbmm(const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1) const; + at::Tensor & random_(int64_t from, ::std::optional to, ::std::optional generator=::std::nullopt) const; + at::Tensor & random_(int64_t to, ::std::optional generator=::std::nullopt) const; + at::Tensor & random_(::std::optional generator=::std::nullopt) const; + at::Tensor & uniform_(double from=0, double to=1, ::std::optional generator=::std::nullopt) const; + at::Tensor & cauchy_(double median=0, double sigma=1, ::std::optional generator=::std::nullopt) const; + at::Tensor & log_normal_(double mean=1, double std=2, ::std::optional generator=::std::nullopt) const; + at::Tensor & exponential_(double lambd=1, ::std::optional generator=::std::nullopt) const; + at::Tensor & geometric_(double p, ::std::optional generator=::std::nullopt) const; + at::Tensor diag(int64_t diagonal=0) const; + at::Tensor cross(const at::Tensor & other, ::std::optional dim=::std::nullopt) const; + at::Tensor triu(int64_t diagonal=0) const; + at::Tensor tril(int64_t diagonal=0) const; + at::Tensor trace() const; + at::Tensor ne(const at::Scalar & other) const; + at::Tensor ne(const at::Tensor & other) const; + at::Tensor & ne_(const at::Scalar & other) const; + at::Tensor & ne_(const at::Tensor & other) const; + at::Tensor not_equal(const at::Scalar & other) const; + at::Tensor not_equal(const at::Tensor & other) const; + at::Tensor & not_equal_(const at::Scalar & other) const; + at::Tensor & not_equal_(const at::Tensor & other) const; + at::Tensor eq(const at::Scalar & other) const; + at::Tensor eq(const at::Tensor & other) const; + at::Tensor ge(const at::Scalar & other) const; + at::Tensor ge(const at::Tensor & other) const; + at::Tensor & ge_(const at::Scalar & other) const; + at::Tensor & ge_(const at::Tensor & other) const; + at::Tensor greater_equal(const at::Scalar & other) const; + at::Tensor greater_equal(const at::Tensor & other) const; + at::Tensor & greater_equal_(const at::Scalar & other) const; + at::Tensor & greater_equal_(const at::Tensor & other) const; + at::Tensor le(const at::Scalar & other) const; + at::Tensor le(const at::Tensor & other) const; + at::Tensor & le_(const at::Scalar & other) const; + at::Tensor & le_(const at::Tensor & other) const; + at::Tensor less_equal(const at::Scalar & other) const; + at::Tensor less_equal(const at::Tensor & other) const; + at::Tensor & less_equal_(const at::Scalar & other) const; + at::Tensor & less_equal_(const at::Tensor & other) const; + at::Tensor gt(const at::Scalar & other) const; + at::Tensor gt(const at::Tensor & other) const; + at::Tensor & gt_(const at::Scalar & other) const; + at::Tensor & gt_(const at::Tensor & other) const; + at::Tensor greater(const at::Scalar & other) const; + at::Tensor greater(const at::Tensor & other) const; + at::Tensor & greater_(const at::Scalar & other) const; + at::Tensor & greater_(const at::Tensor & other) const; + at::Tensor lt(const at::Scalar & other) const; + at::Tensor lt(const at::Tensor & other) const; + at::Tensor & lt_(const at::Scalar & other) const; + at::Tensor & lt_(const at::Tensor & other) const; + at::Tensor less(const at::Scalar & other) const; + at::Tensor less(const at::Tensor & other) const; + at::Tensor & less_(const at::Scalar & other) const; + at::Tensor & less_(const at::Tensor & other) const; + at::Tensor take(const at::Tensor & index) const; + at::Tensor take_along_dim(const at::Tensor & indices, ::std::optional dim=::std::nullopt) const; + at::Tensor index_select(int64_t dim, const at::Tensor & index) const; + at::Tensor index_select(at::Dimname dim, const at::Tensor & index) const; + at::Tensor masked_select(const at::Tensor & mask) const; + at::Tensor nonzero() const; + at::Tensor nonzero_static(int64_t size, int64_t fill_value=-1) const; + at::Tensor nonzero_static_symint(c10::SymInt size, int64_t fill_value=-1) const; + ::std::vector nonzero_numpy() const; + at::Tensor argwhere() const; + at::Tensor gather(int64_t dim, const at::Tensor & index, bool sparse_grad=false) const; + at::Tensor gather(at::Dimname dim, const at::Tensor & index, bool sparse_grad=false) const; + at::Tensor addcmul(const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) const; + at::Tensor & addcmul_(const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) const; + at::Tensor addcdiv(const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) const; + at::Tensor & addcdiv_(const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) const; + ::std::tuple triangular_solve(const at::Tensor & A, bool upper=true, bool transpose=false, bool unitriangular=false) const; + ::std::tuple svd(bool some=true, bool compute_uv=true) const; + at::Tensor swapaxes(int64_t axis0, int64_t axis1) const; + at::Tensor & swapaxes_(int64_t axis0, int64_t axis1) const; + at::Tensor swapdims(int64_t dim0, int64_t dim1) const; + at::Tensor & swapdims_(int64_t dim0, int64_t dim1) const; + at::Tensor cholesky(bool upper=false) const; + at::Tensor cholesky_solve(const at::Tensor & input2, bool upper=false) const; + at::Tensor cholesky_inverse(bool upper=false) const; + ::std::tuple qr(bool some=true) const; + ::std::tuple geqrf() const; + at::Tensor orgqr(const at::Tensor & input2) const; + at::Tensor ormqr(const at::Tensor & input2, const at::Tensor & input3, bool left=true, bool transpose=false) const; + at::Tensor lu_solve(const at::Tensor & LU_data, const at::Tensor & LU_pivots) const; + at::Tensor multinomial(int64_t num_samples, bool replacement=false, ::std::optional generator=::std::nullopt) const; + at::Tensor multinomial_symint(c10::SymInt num_samples, bool replacement=false, ::std::optional generator=::std::nullopt) const; + at::Tensor & lgamma_() const; + at::Tensor lgamma() const; + at::Tensor digamma() const; + at::Tensor polygamma(int64_t n) const; + at::Tensor & polygamma_(int64_t n) const; + at::Tensor erfinv() const; + at::Tensor & erfinv_() const; + at::Tensor i0() const; + at::Tensor & i0_() const; + at::Tensor sign() const; + at::Tensor & sign_() const; + at::Tensor signbit() const; + at::Tensor dist(const at::Tensor & other, const at::Scalar & p=2) const; + at::Tensor & atan2_(const at::Tensor & other) const; + at::Tensor atan2(const at::Tensor & other) const; + at::Tensor arctan2(const at::Tensor & other) const; + at::Tensor & arctan2_(const at::Tensor & other) const; + at::Tensor lerp(const at::Tensor & end, const at::Scalar & weight) const; + at::Tensor lerp(const at::Tensor & end, const at::Tensor & weight) const; + at::Tensor histc(int64_t bins=100, const at::Scalar & min=0, const at::Scalar & max=0) const; + ::std::tuple histogram(const at::Tensor & bins, const ::std::optional & weight={}, bool density=false) const; + ::std::tuple histogram(int64_t bins=100, ::std::optional> range=::std::nullopt, const ::std::optional & weight={}, bool density=false) const; + at::Tensor fmod(const at::Scalar & other) const; + at::Tensor & fmod_(const at::Scalar & other) const; + at::Tensor fmod(const at::Tensor & other) const; + at::Tensor & fmod_(const at::Tensor & other) const; + at::Tensor hypot(const at::Tensor & other) const; + at::Tensor & hypot_(const at::Tensor & other) const; + at::Tensor igamma(const at::Tensor & other) const; + at::Tensor & igamma_(const at::Tensor & other) const; + at::Tensor igammac(const at::Tensor & other) const; + at::Tensor & igammac_(const at::Tensor & other) const; + at::Tensor nextafter(const at::Tensor & other) const; + at::Tensor & nextafter_(const at::Tensor & other) const; + at::Tensor remainder(const at::Scalar & other) const; + at::Tensor & remainder_(const at::Scalar & other) const; + at::Tensor remainder(const at::Tensor & other) const; + at::Tensor & remainder_(const at::Tensor & other) const; + at::Tensor min() const; + at::Tensor fmin(const at::Tensor & other) const; + at::Tensor max() const; + at::Tensor fmax(const at::Tensor & other) const; + at::Tensor maximum(const at::Tensor & other) const; + at::Tensor max(const at::Tensor & other) const; + at::Tensor minimum(const at::Tensor & other) const; + at::Tensor min(const at::Tensor & other) const; + at::Tensor quantile(const at::Tensor & q, ::std::optional dim=::std::nullopt, bool keepdim=false, c10::string_view interpolation="linear") const; + at::Tensor quantile(double q, ::std::optional dim=::std::nullopt, bool keepdim=false, c10::string_view interpolation="linear") const; + at::Tensor nanquantile(const at::Tensor & q, ::std::optional dim=::std::nullopt, bool keepdim=false, c10::string_view interpolation="linear") const; + at::Tensor nanquantile(double q, ::std::optional dim=::std::nullopt, bool keepdim=false, c10::string_view interpolation="linear") const; + ::std::tuple sort(int64_t dim=-1, bool descending=false) const; + ::std::tuple sort(::std::optional stable, int64_t dim=-1, bool descending=false) const; + ::std::tuple sort(at::Dimname dim, bool descending=false) const; + ::std::tuple sort(::std::optional stable, at::Dimname dim, bool descending=false) const; + at::Tensor msort() const; + at::Tensor argsort(int64_t dim=-1, bool descending=false) const; + at::Tensor argsort(bool stable, int64_t dim=-1, bool descending=false) const; + at::Tensor argsort(at::Dimname dim, bool descending=false) const; + ::std::tuple topk(int64_t k, int64_t dim=-1, bool largest=true, bool sorted=true) const; + ::std::tuple topk_symint(c10::SymInt k, int64_t dim=-1, bool largest=true, bool sorted=true) const; + at::Tensor all() const; + at::Tensor any() const; + at::Tensor renorm(const at::Scalar & p, int64_t dim, const at::Scalar & maxnorm) const; + at::Tensor & renorm_(const at::Scalar & p, int64_t dim, const at::Scalar & maxnorm) const; + at::Tensor unfold(int64_t dimension, int64_t size, int64_t step) const; + bool equal(const at::Tensor & other) const; + at::Tensor pow(const at::Tensor & exponent) const; + at::Tensor pow(const at::Scalar & exponent) const; + at::Tensor & pow_(const at::Scalar & exponent) const; + at::Tensor & pow_(const at::Tensor & exponent) const; + at::Tensor float_power(const at::Tensor & exponent) const; + at::Tensor float_power(const at::Scalar & exponent) const; + at::Tensor & float_power_(const at::Scalar & exponent) const; + at::Tensor & float_power_(const at::Tensor & exponent) const; + at::Tensor & normal_(double mean=0, double std=1, ::std::optional generator=::std::nullopt) const; + at::Tensor alias() const; + at::Tensor isfinite() const; + at::Tensor isinf() const; + void record_stream(at::Stream s) const; + at::Tensor isposinf() const; + at::Tensor isneginf() const; + at::Tensor det() const; + ::std::tuple slogdet() const; + at::Tensor logdet() const; + at::Tensor inverse() const; + at::Tensor inner(const at::Tensor & other) const; + at::Tensor outer(const at::Tensor & vec2) const; + at::Tensor ger(const at::Tensor & vec2) const; + at::Tensor to_padded_tensor(double padding, at::OptionalIntArrayRef output_size=::std::nullopt) const; + at::Tensor to_padded_tensor_symint(double padding, at::OptionalSymIntArrayRef output_size=::std::nullopt) const; + + // Special C++ only overloads for std()-like functions (See gh-40287) + // These are needed because int -> bool conversion takes precedence over int -> IntArrayRef + // So, for example std(0) would select the std(unbiased=False) overload + + Tensor var(int dim) const { + return var(IntArrayRef{dim}); + } + + Tensor std(int dim) const { + return std(IntArrayRef{dim}); + } + + // We changed .dtype() to return a TypeMeta in #12766. Ideally, we want the + // at::kDouble and its friends to be TypeMeta's, but that hasn't happened yet. + // Before that change, we make this method to maintain BC for C++ usage like + // `x.to(y.dtype)`. + // TODO: remove following two after at::kDouble and its friends are TypeMeta's. + inline Tensor to(caffe2::TypeMeta type_meta, bool non_blocking=false, bool copy=false) const { + return this->to(/*scalar_type=*/typeMetaToScalarType(type_meta), non_blocking, copy); + } + inline Tensor to(Device device, caffe2::TypeMeta type_meta, bool non_blocking=false, bool copy=false) const { + return this->to(device, /*scalar_type=*/typeMetaToScalarType(type_meta), non_blocking, copy); + } + + template + decltype(auto) m(F func, Args&&... params) const { + return func(*this, std::forward(params)...); + } + + /// NOTE: This is similar to the legacy `.data()` function on `Variable`, and is intended + /// to be used from functions that need to access the `Variable`'s equivalent `Tensor` + /// (i.e. `Tensor` that shares the same storage and tensor metadata with the `Variable`). + /// + /// One notable difference with the legacy `.data()` function is that changes to the + /// returned `Tensor`'s tensor metadata (e.g. sizes / strides / storage / storage_offset) + /// will not update the original `Variable`, due to the fact that this function + /// shallow-copies the `Variable`'s underlying TensorImpl. + at::Tensor tensor_data() const { + return TensorBase::tensor_data(); + } + + /// NOTE: `var.variable_data()` in C++ has the same semantics as `tensor.data` + /// in Python, which create a new `Variable` that shares the same storage and + /// tensor metadata with the original `Variable`, but with a completely new + /// autograd history. + /// + /// NOTE: If we change the tensor metadata (e.g. sizes / strides / + /// storage / storage_offset) of a variable created from `var.variable_data()`, those + /// changes will not update the original variable `var`. In `.variable_data()`, we set + /// `allow_tensor_metadata_change_` to false to make such changes explicitly illegal, + /// in order to prevent users from changing metadata of `var.variable_data()` + /// and expecting the original variable `var` to also be updated. + at::Tensor variable_data() const { + return TensorBase::variable_data(); + } + + // Hooks + //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + template + using hook_return_void_t = std::enable_if_t>::value, unsigned>; + template + using hook_return_var_t = std::enable_if_t, Tensor>, unsigned>; + + /// Registers a backward hook. + /// + /// The hook will be called every time a gradient with respect to the Tensor is computed. + /// The hook should have one of the following signature: + /// ``` + /// hook(Tensor grad) -> Tensor + /// ``` + /// ``` + /// hook(Tensor grad) -> void + /// ``` + /// The hook should not modify its argument, but it can optionally return a new gradient + /// which will be used in place of `grad`. + /// + /// This function returns the index of the hook in the list which can be used to remove hook. + /// + /// Example: + /// @code + /// auto v = torch::tensor({0., 0., 0.}, torch::requires_grad()); + /// auto h = v.register_hook([](torch::Tensor grad){ return grad * 2; }); // double the gradient + /// v.backward(torch::tensor({1., 2., 3.})); + /// // This prints: + /// // ``` + /// // 2 + /// // 4 + /// // 6 + /// // [ CPUFloatType{3} ] + /// // ``` + /// std::cout << v.grad() << std::endl; + /// v.remove_hook(h); // removes the hook + /// @endcode + template + hook_return_void_t register_hook(T&& hook) const; + template + hook_return_var_t register_hook(T&& hook) const; + + // Variable methods + //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + Tensor data() const { + return TensorBase::data(); + } + + void _backward(TensorList inputs, const std::optional& gradient, std::optional keep_graph, bool create_graph) const; + + const Tensor& requires_grad_(bool _requires_grad=true) const { + TensorBase::requires_grad_(_requires_grad); + return *this; + } +}; + +namespace detail { +// Helper creator for Tensor class which doesn't requires the users to pass +// in an intrusive_ptr instead it just converts the argument passed to +// requested intrusive_ptr type. +template +Tensor make_tensor(Args&&... args) { + return Tensor(c10::make_intrusive(std::forward(args)...)); +} + +} // namespace detail + +} // namespace at + + +namespace at { + +// aten::_backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> () +inline void Tensor::__dispatch__backward(at::TensorList inputs, const ::std::optional & gradient, ::std::optional retain_graph, bool create_graph) const { + return at::_ops::_backward::call(const_cast(*this), inputs, gradient, retain_graph, create_graph); +} + +// aten::set_data(Tensor(a!) self, Tensor new_data) -> () +inline void Tensor::__dispatch_set_data(const at::Tensor & new_data) const { + return at::_ops::set_data::call(const_cast(*this), new_data); +} + +// aten::data(Tensor self) -> Tensor +inline at::Tensor Tensor::__dispatch_data() const { + return at::_ops::data::call(const_cast(*this)); +} + +// aten::is_leaf(Tensor self) -> bool +inline bool Tensor::__dispatch_is_leaf() const { + return at::_ops::is_leaf::call(const_cast(*this)); +} + +// aten::output_nr(Tensor self) -> int +inline int64_t Tensor::__dispatch_output_nr() const { + return at::_ops::output_nr::call(const_cast(*this)); +} + +// aten::_version(Tensor self) -> int +inline int64_t Tensor::__dispatch__version() const { + return at::_ops::_version::call(const_cast(*this)); +} + +// aten::requires_grad_(Tensor(a!) self, bool requires_grad=True) -> Tensor(a!) +inline at::Tensor & Tensor::__dispatch_requires_grad_(bool requires_grad) const { + return at::_ops::requires_grad_::call(const_cast(*this), requires_grad); +} + +// aten::retain_grad(Tensor(a!) self) -> () +inline void Tensor::__dispatch_retain_grad() const { + return at::_ops::retain_grad::call(const_cast(*this)); +} + +// aten::retains_grad(Tensor self) -> bool +inline bool Tensor::__dispatch_retains_grad() const { + return at::_ops::retains_grad::call(const_cast(*this)); +} + +// aten::_fw_primal(Tensor(a) self, int level) -> Tensor(a) +inline at::Tensor Tensor::_fw_primal(int64_t level) const { + return at::_ops::_fw_primal::call(const_cast(*this), level); +} + +// aten::rename_(Tensor(a!) self, Dimname[]? names) -> Tensor(a!) +inline at::Tensor & Tensor::rename_(::std::optional names) const { + return at::_ops::rename_::call(const_cast(*this), names); +} + +// aten::rename(Tensor(a) self, Dimname[]? names) -> Tensor(a) +inline at::Tensor Tensor::rename(::std::optional names) const { + return at::_ops::rename::call(const_cast(*this), names); +} + +// aten::align_to(Tensor(a) self, Dimname[] names) -> Tensor(a) +inline at::Tensor Tensor::align_to(at::DimnameList names) const { + return at::_ops::align_to::call(const_cast(*this), names); +} + +// aten::align_to.ellipsis_idx(Tensor(a) self, Dimname[] order, int ellipsis_idx) -> Tensor(a) +inline at::Tensor Tensor::align_to(at::DimnameList order, int64_t ellipsis_idx) const { + return at::_ops::align_to_ellipsis_idx::call(const_cast(*this), order, ellipsis_idx); +} + +// aten::align_as(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::align_as(const at::Tensor & other) const { + return at::_ops::align_as::call(const_cast(*this), other); +} + +// aten::refine_names(Tensor(a) self, Dimname[] names) -> Tensor(a) +inline at::Tensor Tensor::refine_names(at::DimnameList names) const { + return at::_ops::refine_names::call(const_cast(*this), names); +} + +// aten::abs(Tensor self) -> Tensor +inline at::Tensor Tensor::abs() const { + return at::_ops::abs::call(const_cast(*this)); +} + +// aten::abs_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::abs_() const { + return at::_ops::abs_::call(const_cast(*this)); +} + +// aten::absolute(Tensor self) -> Tensor +inline at::Tensor Tensor::absolute() const { + return at::_ops::absolute::call(const_cast(*this)); +} + +// aten::absolute_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::absolute_() const { + return at::_ops::absolute_::call(const_cast(*this)); +} + +// aten::angle(Tensor self) -> Tensor +inline at::Tensor Tensor::angle() const { + return at::_ops::angle::call(const_cast(*this)); +} + +// aten::sgn(Tensor self) -> Tensor +inline at::Tensor Tensor::sgn() const { + return at::_ops::sgn::call(const_cast(*this)); +} + +// aten::sgn_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::sgn_() const { + return at::_ops::sgn_::call(const_cast(*this)); +} + +// aten::chalf(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor +inline at::Tensor Tensor::chalf(::std::optional memory_format) const { + return at::_ops::chalf::call(const_cast(*this), memory_format); +} + +// aten::_conj(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::_conj() const { + return at::_ops::_conj::call(const_cast(*this)); +} + +// aten::conj(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::__dispatch_conj() const { + return at::_ops::conj::call(const_cast(*this)); +} + +// aten::_conj_physical(Tensor self) -> Tensor +inline at::Tensor Tensor::_conj_physical() const { + return at::_ops::_conj_physical::call(const_cast(*this)); +} + +// aten::conj_physical(Tensor self) -> Tensor +inline at::Tensor Tensor::conj_physical() const { + return at::_ops::conj_physical::call(const_cast(*this)); +} + +// aten::conj_physical_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::conj_physical_() const { + return at::_ops::conj_physical_::call(const_cast(*this)); +} + +// aten::resolve_conj(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::resolve_conj() const { + return at::_ops::resolve_conj::call(const_cast(*this)); +} + +// aten::resolve_neg(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::resolve_neg() const { + return at::_ops::resolve_neg::call(const_cast(*this)); +} + +// aten::_neg_view(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::_neg_view() const { + return at::_ops::_neg_view::call(const_cast(*this)); +} + +// aten::acos(Tensor self) -> Tensor +inline at::Tensor Tensor::acos() const { + return at::_ops::acos::call(const_cast(*this)); +} + +// aten::acos_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::acos_() const { + return at::_ops::acos_::call(const_cast(*this)); +} + +// aten::arccos(Tensor self) -> Tensor +inline at::Tensor Tensor::arccos() const { + return at::_ops::arccos::call(const_cast(*this)); +} + +// aten::arccos_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::arccos_() const { + return at::_ops::arccos_::call(const_cast(*this)); +} + +// aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor +inline at::Tensor Tensor::add(const at::Tensor & other, const at::Scalar & alpha) const { + return at::_ops::add_Tensor::call(const_cast(*this), other, alpha); +} + +// aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) +inline at::Tensor & Tensor::add_(const at::Tensor & other, const at::Scalar & alpha) const { + return at::_ops::add__Tensor::call(const_cast(*this), other, alpha); +} + +// aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor +inline at::Tensor Tensor::add(const at::Scalar & other, const at::Scalar & alpha) const { + return at::_ops::add_Scalar::call(const_cast(*this), other, alpha); +} + +// aten::add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!) +inline at::Tensor & Tensor::add_(const at::Scalar & other, const at::Scalar & alpha) const { + return at::_ops::add__Scalar::call(const_cast(*this), other, alpha); +} + +// aten::addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor +inline at::Tensor Tensor::addmv(const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta, const at::Scalar & alpha) const { + return at::_ops::addmv::call(const_cast(*this), mat, vec, beta, alpha); +} + +// aten::addmv_(Tensor(a!) self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) +inline at::Tensor & Tensor::addmv_(const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta, const at::Scalar & alpha) const { + return at::_ops::addmv_::call(const_cast(*this), mat, vec, beta, alpha); +} + +// aten::addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor +inline at::Tensor Tensor::addr(const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta, const at::Scalar & alpha) const { + return at::_ops::addr::call(const_cast(*this), vec1, vec2, beta, alpha); +} + +// aten::addr_(Tensor(a!) self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) +inline at::Tensor & Tensor::addr_(const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta, const at::Scalar & alpha) const { + return at::_ops::addr_::call(const_cast(*this), vec1, vec2, beta, alpha); +} + +// aten::_is_all_true(Tensor self) -> Tensor +inline at::Tensor Tensor::_is_all_true() const { + return at::_ops::_is_all_true::call(const_cast(*this)); +} + +// aten::_is_any_true(Tensor self) -> Tensor +inline at::Tensor Tensor::_is_any_true() const { + return at::_ops::_is_any_true::call(const_cast(*this)); +} + +// aten::all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::all(int64_t dim, bool keepdim) const { + return at::_ops::all_dim::call(const_cast(*this), dim, keepdim); +} + +// aten::all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::all(at::OptionalIntArrayRef dim, bool keepdim) const { + return at::_ops::all_dims::call(const_cast(*this), dim, keepdim); +} + +// aten::all.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::all(at::Dimname dim, bool keepdim) const { + return at::_ops::all_dimname::call(const_cast(*this), dim, keepdim); +} + +// aten::allclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> bool +inline bool Tensor::allclose(const at::Tensor & other, double rtol, double atol, bool equal_nan) const { + return at::_ops::allclose::call(const_cast(*this), other, rtol, atol, equal_nan); +} + +// aten::any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::any(int64_t dim, bool keepdim) const { + return at::_ops::any_dim::call(const_cast(*this), dim, keepdim); +} + +// aten::any.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::any(at::OptionalIntArrayRef dim, bool keepdim) const { + return at::_ops::any_dims::call(const_cast(*this), dim, keepdim); +} + +// aten::any.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::any(at::Dimname dim, bool keepdim) const { + return at::_ops::any_dimname::call(const_cast(*this), dim, keepdim); +} + +// aten::argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::argmax(::std::optional dim, bool keepdim) const { + return at::_ops::argmax::call(const_cast(*this), dim, keepdim); +} + +// aten::argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::argmin(::std::optional dim, bool keepdim) const { + return at::_ops::argmin::call(const_cast(*this), dim, keepdim); +} + +// aten::acosh(Tensor self) -> Tensor +inline at::Tensor Tensor::acosh() const { + return at::_ops::acosh::call(const_cast(*this)); +} + +// aten::acosh_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::acosh_() const { + return at::_ops::acosh_::call(const_cast(*this)); +} + +// aten::arccosh(Tensor self) -> Tensor +inline at::Tensor Tensor::arccosh() const { + return at::_ops::arccosh::call(const_cast(*this)); +} + +// aten::arccosh_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::arccosh_() const { + return at::_ops::arccosh_::call(const_cast(*this)); +} + +// aten::asinh(Tensor self) -> Tensor +inline at::Tensor Tensor::asinh() const { + return at::_ops::asinh::call(const_cast(*this)); +} + +// aten::asinh_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::asinh_() const { + return at::_ops::asinh_::call(const_cast(*this)); +} + +// aten::arcsinh(Tensor self) -> Tensor +inline at::Tensor Tensor::arcsinh() const { + return at::_ops::arcsinh::call(const_cast(*this)); +} + +// aten::arcsinh_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::arcsinh_() const { + return at::_ops::arcsinh_::call(const_cast(*this)); +} + +// aten::atanh(Tensor self) -> Tensor +inline at::Tensor Tensor::atanh() const { + return at::_ops::atanh::call(const_cast(*this)); +} + +// aten::atanh_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::atanh_() const { + return at::_ops::atanh_::call(const_cast(*this)); +} + +// aten::arctanh(Tensor self) -> Tensor +inline at::Tensor Tensor::arctanh() const { + return at::_ops::arctanh::call(const_cast(*this)); +} + +// aten::arctanh_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::arctanh_() const { + return at::_ops::arctanh_::call(const_cast(*this)); +} + +// aten::as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a) +inline at::Tensor Tensor::as_strided(at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset) const { + return at::_ops::as_strided::call(const_cast(*this), c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt); +} + +// aten::as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a) +inline at::Tensor Tensor::as_strided_symint(c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset) const { + return at::_ops::as_strided::call(const_cast(*this), size, stride, storage_offset); +} + +// aten::as_strided_(Tensor(a!) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a!) +inline const at::Tensor & Tensor::as_strided_(at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset) const { + return at::_ops::as_strided_::call(const_cast(*this), c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt); +} + +// aten::as_strided_(Tensor(a!) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a!) +inline const at::Tensor & Tensor::as_strided__symint(c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset) const { + return at::_ops::as_strided_::call(const_cast(*this), size, stride, storage_offset); +} + +// aten::asin(Tensor self) -> Tensor +inline at::Tensor Tensor::asin() const { + return at::_ops::asin::call(const_cast(*this)); +} + +// aten::asin_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::asin_() const { + return at::_ops::asin_::call(const_cast(*this)); +} + +// aten::arcsin(Tensor self) -> Tensor +inline at::Tensor Tensor::arcsin() const { + return at::_ops::arcsin::call(const_cast(*this)); +} + +// aten::arcsin_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::arcsin_() const { + return at::_ops::arcsin_::call(const_cast(*this)); +} + +// aten::atan(Tensor self) -> Tensor +inline at::Tensor Tensor::atan() const { + return at::_ops::atan::call(const_cast(*this)); +} + +// aten::atan_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::atan_() const { + return at::_ops::atan_::call(const_cast(*this)); +} + +// aten::arctan(Tensor self) -> Tensor +inline at::Tensor Tensor::arctan() const { + return at::_ops::arctan::call(const_cast(*this)); +} + +// aten::arctan_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::arctan_() const { + return at::_ops::arctan_::call(const_cast(*this)); +} + +// aten::baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor +inline at::Tensor Tensor::baddbmm(const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha) const { + return at::_ops::baddbmm::call(const_cast(*this), batch1, batch2, beta, alpha); +} + +// aten::baddbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) +inline at::Tensor & Tensor::baddbmm_(const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha) const { + return at::_ops::baddbmm_::call(const_cast(*this), batch1, batch2, beta, alpha); +} + +// aten::bernoulli(Tensor self, *, Generator? generator=None) -> Tensor +inline at::Tensor Tensor::bernoulli(::std::optional generator) const { + return at::_ops::bernoulli::call(const_cast(*this), generator); +} + +// aten::bernoulli_.Tensor(Tensor(a!) self, Tensor p, *, Generator? generator=None) -> Tensor(a!) +inline at::Tensor & Tensor::bernoulli_(const at::Tensor & p, ::std::optional generator) const { + return at::_ops::bernoulli__Tensor::call(const_cast(*this), p, generator); +} + +// aten::bernoulli_.float(Tensor(a!) self, float p=0.5, *, Generator? generator=None) -> Tensor(a!) +inline at::Tensor & Tensor::bernoulli_(double p, ::std::optional generator) const { + return at::_ops::bernoulli__float::call(const_cast(*this), p, generator); +} + +// aten::bernoulli.p(Tensor self, float p, *, Generator? generator=None) -> Tensor +inline at::Tensor Tensor::bernoulli(double p, ::std::optional generator) const { + return at::_ops::bernoulli_p::call(const_cast(*this), p, generator); +} + +// aten::bincount(Tensor self, Tensor? weights=None, SymInt minlength=0) -> Tensor +inline at::Tensor Tensor::bincount(const ::std::optional & weights, int64_t minlength) const { + return at::_ops::bincount::call(const_cast(*this), weights, minlength); +} + +// aten::bincount(Tensor self, Tensor? weights=None, SymInt minlength=0) -> Tensor +inline at::Tensor Tensor::bincount_symint(const ::std::optional & weights, c10::SymInt minlength) const { + return at::_ops::bincount::call(const_cast(*this), weights, minlength); +} + +// aten::bitwise_not(Tensor self) -> Tensor +inline at::Tensor Tensor::bitwise_not() const { + return at::_ops::bitwise_not::call(const_cast(*this)); +} + +// aten::bitwise_not_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::bitwise_not_() const { + return at::_ops::bitwise_not_::call(const_cast(*this)); +} + +// aten::copysign.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::copysign(const at::Tensor & other) const { + return at::_ops::copysign_Tensor::call(const_cast(*this), other); +} + +// aten::copysign_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::copysign_(const at::Tensor & other) const { + return at::_ops::copysign__Tensor::call(const_cast(*this), other); +} + +// aten::copysign.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::copysign(const at::Scalar & other) const { + return at::_ops::copysign_Scalar::call(const_cast(*this), other); +} + +// aten::copysign_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::copysign_(const at::Scalar & other) const { + return at::_ops::copysign__Scalar::call(const_cast(*this), other); +} + +// aten::_lazy_clone(Tensor self) -> Tensor +inline at::Tensor Tensor::_lazy_clone() const { + return at::_ops::_lazy_clone::call(const_cast(*this)); +} + +// aten::logical_not(Tensor self) -> Tensor +inline at::Tensor Tensor::logical_not() const { + return at::_ops::logical_not::call(const_cast(*this)); +} + +// aten::logical_not_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::logical_not_() const { + return at::_ops::logical_not_::call(const_cast(*this)); +} + +// aten::logical_xor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::logical_xor(const at::Tensor & other) const { + return at::_ops::logical_xor::call(const_cast(*this), other); +} + +// aten::logical_xor_(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::logical_xor_(const at::Tensor & other) const { + return at::_ops::logical_xor_::call(const_cast(*this), other); +} + +// aten::logical_and(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::logical_and(const at::Tensor & other) const { + return at::_ops::logical_and::call(const_cast(*this), other); +} + +// aten::logical_and_(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::logical_and_(const at::Tensor & other) const { + return at::_ops::logical_and_::call(const_cast(*this), other); +} + +// aten::logical_or(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::logical_or(const at::Tensor & other) const { + return at::_ops::logical_or::call(const_cast(*this), other); +} + +// aten::logical_or_(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::logical_or_(const at::Tensor & other) const { + return at::_ops::logical_or_::call(const_cast(*this), other); +} + +// aten::bmm(Tensor self, Tensor mat2) -> Tensor +inline at::Tensor Tensor::bmm(const at::Tensor & mat2) const { + return at::_ops::bmm::call(const_cast(*this), mat2); +} + +// aten::broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a) +inline at::Tensor Tensor::broadcast_to(at::IntArrayRef size) const { + return at::_ops::broadcast_to::call(const_cast(*this), c10::fromIntArrayRefSlow(size)); +} + +// aten::broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a) +inline at::Tensor Tensor::broadcast_to_symint(c10::SymIntArrayRef size) const { + return at::_ops::broadcast_to::call(const_cast(*this), size); +} + +// aten::ceil(Tensor self) -> Tensor +inline at::Tensor Tensor::ceil() const { + return at::_ops::ceil::call(const_cast(*this)); +} + +// aten::ceil_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::ceil_() const { + return at::_ops::ceil_::call(const_cast(*this)); +} + +// aten::unsafe_chunk(Tensor self, int chunks, int dim=0) -> Tensor[] +inline ::std::vector Tensor::unsafe_chunk(int64_t chunks, int64_t dim) const { + return at::_ops::unsafe_chunk::call(const_cast(*this), chunks, dim); +} + +// aten::chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[] +inline ::std::vector Tensor::chunk(int64_t chunks, int64_t dim) const { + return at::_ops::chunk::call(const_cast(*this), chunks, dim); +} + +// aten::tensor_split.sections(Tensor(a -> *) self, SymInt sections, int dim=0) -> Tensor(a)[] +inline ::std::vector Tensor::tensor_split(int64_t sections, int64_t dim) const { + return at::_ops::tensor_split_sections::call(const_cast(*this), sections, dim); +} + +// aten::tensor_split.sections(Tensor(a -> *) self, SymInt sections, int dim=0) -> Tensor(a)[] +inline ::std::vector Tensor::tensor_split_symint(c10::SymInt sections, int64_t dim) const { + return at::_ops::tensor_split_sections::call(const_cast(*this), sections, dim); +} + +// aten::tensor_split.indices(Tensor(a -> *) self, SymInt[] indices, int dim=0) -> Tensor(a)[] +inline ::std::vector Tensor::tensor_split(at::IntArrayRef indices, int64_t dim) const { + return at::_ops::tensor_split_indices::call(const_cast(*this), c10::fromIntArrayRefSlow(indices), dim); +} + +// aten::tensor_split.indices(Tensor(a -> *) self, SymInt[] indices, int dim=0) -> Tensor(a)[] +inline ::std::vector Tensor::tensor_split_symint(c10::SymIntArrayRef indices, int64_t dim) const { + return at::_ops::tensor_split_indices::call(const_cast(*this), indices, dim); +} + +// aten::tensor_split.tensor_indices_or_sections(Tensor(a -> *) self, Tensor tensor_indices_or_sections, int dim=0) -> Tensor(a)[] +inline ::std::vector Tensor::tensor_split(const at::Tensor & tensor_indices_or_sections, int64_t dim) const { + return at::_ops::tensor_split_tensor_indices_or_sections::call(const_cast(*this), tensor_indices_or_sections, dim); +} + +// aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor +inline at::Tensor Tensor::clamp(const ::std::optional & min, const ::std::optional & max) const { + return at::_ops::clamp::call(const_cast(*this), min, max); +} + +// aten::clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor +inline at::Tensor Tensor::clamp(const ::std::optional & min, const ::std::optional & max) const { + return at::_ops::clamp_Tensor::call(const_cast(*this), min, max); +} + +// aten::clamp_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!) +inline at::Tensor & Tensor::clamp_(const ::std::optional & min, const ::std::optional & max) const { + return at::_ops::clamp_::call(const_cast(*this), min, max); +} + +// aten::clamp_.Tensor(Tensor(a!) self, Tensor? min=None, Tensor? max=None) -> Tensor(a!) +inline at::Tensor & Tensor::clamp_(const ::std::optional & min, const ::std::optional & max) const { + return at::_ops::clamp__Tensor::call(const_cast(*this), min, max); +} + +// aten::clamp_max(Tensor self, Scalar max) -> Tensor +inline at::Tensor Tensor::clamp_max(const at::Scalar & max) const { + return at::_ops::clamp_max::call(const_cast(*this), max); +} + +// aten::clamp_max.Tensor(Tensor self, Tensor max) -> Tensor +inline at::Tensor Tensor::clamp_max(const at::Tensor & max) const { + return at::_ops::clamp_max_Tensor::call(const_cast(*this), max); +} + +// aten::clamp_max_(Tensor(a!) self, Scalar max) -> Tensor(a!) +inline at::Tensor & Tensor::clamp_max_(const at::Scalar & max) const { + return at::_ops::clamp_max_::call(const_cast(*this), max); +} + +// aten::clamp_max_.Tensor(Tensor(a!) self, Tensor max) -> Tensor(a!) +inline at::Tensor & Tensor::clamp_max_(const at::Tensor & max) const { + return at::_ops::clamp_max__Tensor::call(const_cast(*this), max); +} + +// aten::clamp_min(Tensor self, Scalar min) -> Tensor +inline at::Tensor Tensor::clamp_min(const at::Scalar & min) const { + return at::_ops::clamp_min::call(const_cast(*this), min); +} + +// aten::clamp_min.Tensor(Tensor self, Tensor min) -> Tensor +inline at::Tensor Tensor::clamp_min(const at::Tensor & min) const { + return at::_ops::clamp_min_Tensor::call(const_cast(*this), min); +} + +// aten::clamp_min_(Tensor(a!) self, Scalar min) -> Tensor(a!) +inline at::Tensor & Tensor::clamp_min_(const at::Scalar & min) const { + return at::_ops::clamp_min_::call(const_cast(*this), min); +} + +// aten::clamp_min_.Tensor(Tensor(a!) self, Tensor min) -> Tensor(a!) +inline at::Tensor & Tensor::clamp_min_(const at::Tensor & min) const { + return at::_ops::clamp_min__Tensor::call(const_cast(*this), min); +} + +// aten::clip(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor +inline at::Tensor Tensor::clip(const ::std::optional & min, const ::std::optional & max) const { + return at::_ops::clip::call(const_cast(*this), min, max); +} + +// aten::clip.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor +inline at::Tensor Tensor::clip(const ::std::optional & min, const ::std::optional & max) const { + return at::_ops::clip_Tensor::call(const_cast(*this), min, max); +} + +// aten::clip_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!) +inline at::Tensor & Tensor::clip_(const ::std::optional & min, const ::std::optional & max) const { + return at::_ops::clip_::call(const_cast(*this), min, max); +} + +// aten::clip_.Tensor(Tensor(a!) self, Tensor? min=None, Tensor? max=None) -> Tensor(a!) +inline at::Tensor & Tensor::clip_(const ::std::optional & min, const ::std::optional & max) const { + return at::_ops::clip__Tensor::call(const_cast(*this), min, max); +} + +// aten::contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a) +inline at::Tensor Tensor::__dispatch_contiguous(at::MemoryFormat memory_format) const { + return at::_ops::contiguous::call(const_cast(*this), memory_format); +} + +// aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!) +inline at::Tensor & Tensor::copy_(const at::Tensor & src, bool non_blocking) const { + return at::_ops::copy_::call(const_cast(*this), src, non_blocking); +} + +// aten::cos(Tensor self) -> Tensor +inline at::Tensor Tensor::cos() const { + return at::_ops::cos::call(const_cast(*this)); +} + +// aten::cos_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::cos_() const { + return at::_ops::cos_::call(const_cast(*this)); +} + +// aten::cosh(Tensor self) -> Tensor +inline at::Tensor Tensor::cosh() const { + return at::_ops::cosh::call(const_cast(*this)); +} + +// aten::cosh_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::cosh_() const { + return at::_ops::cosh_::call(const_cast(*this)); +} + +// aten::count_nonzero.dim_IntList(Tensor self, int[] dim) -> Tensor +inline at::Tensor Tensor::count_nonzero(at::IntArrayRef dim) const { + return at::_ops::count_nonzero_dim_IntList::call(const_cast(*this), dim); +} + +// aten::count_nonzero(Tensor self, int? dim=None) -> Tensor +inline at::Tensor Tensor::count_nonzero(::std::optional dim) const { + return at::_ops::count_nonzero::call(const_cast(*this), dim); +} + +// aten::cov(Tensor self, *, int correction=1, Tensor? fweights=None, Tensor? aweights=None) -> Tensor +inline at::Tensor Tensor::cov(int64_t correction, const ::std::optional & fweights, const ::std::optional & aweights) const { + return at::_ops::cov::call(const_cast(*this), correction, fweights, aweights); +} + +// aten::corrcoef(Tensor self) -> Tensor +inline at::Tensor Tensor::corrcoef() const { + return at::_ops::corrcoef::call(const_cast(*this)); +} + +// aten::cummax(Tensor self, int dim) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::cummax(int64_t dim) const { + return at::_ops::cummax::call(const_cast(*this), dim); +} + +// aten::cummax.dimname(Tensor self, Dimname dim) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::cummax(at::Dimname dim) const { + return at::_ops::cummax_dimname::call(const_cast(*this), dim); +} + +// aten::cummin(Tensor self, int dim) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::cummin(int64_t dim) const { + return at::_ops::cummin::call(const_cast(*this), dim); +} + +// aten::cummin.dimname(Tensor self, Dimname dim) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::cummin(at::Dimname dim) const { + return at::_ops::cummin_dimname::call(const_cast(*this), dim); +} + +// aten::cumprod(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::cumprod(int64_t dim, ::std::optional dtype) const { + return at::_ops::cumprod::call(const_cast(*this), dim, dtype); +} + +// aten::cumprod_(Tensor(a!) self, int dim, *, ScalarType? dtype=None) -> Tensor(a!) +inline at::Tensor & Tensor::cumprod_(int64_t dim, ::std::optional dtype) const { + return at::_ops::cumprod_::call(const_cast(*this), dim, dtype); +} + +// aten::cumprod.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::cumprod(at::Dimname dim, ::std::optional dtype) const { + return at::_ops::cumprod_dimname::call(const_cast(*this), dim, dtype); +} + +// aten::cumprod_.dimname(Tensor(a!) self, Dimname dim, *, ScalarType? dtype=None) -> Tensor(a!) +inline at::Tensor & Tensor::cumprod_(at::Dimname dim, ::std::optional dtype) const { + return at::_ops::cumprod__dimname::call(const_cast(*this), dim, dtype); +} + +// aten::cumsum(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::cumsum(int64_t dim, ::std::optional dtype) const { + return at::_ops::cumsum::call(const_cast(*this), dim, dtype); +} + +// aten::cumsum_(Tensor(a!) self, int dim, *, ScalarType? dtype=None) -> Tensor(a!) +inline at::Tensor & Tensor::cumsum_(int64_t dim, ::std::optional dtype) const { + return at::_ops::cumsum_::call(const_cast(*this), dim, dtype); +} + +// aten::cumsum.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::cumsum(at::Dimname dim, ::std::optional dtype) const { + return at::_ops::cumsum_dimname::call(const_cast(*this), dim, dtype); +} + +// aten::cumsum_.dimname(Tensor(a!) self, Dimname dim, *, ScalarType? dtype=None) -> Tensor(a!) +inline at::Tensor & Tensor::cumsum_(at::Dimname dim, ::std::optional dtype) const { + return at::_ops::cumsum__dimname::call(const_cast(*this), dim, dtype); +} + +// aten::diag_embed(Tensor self, int offset=0, int dim1=-2, int dim2=-1) -> Tensor +inline at::Tensor Tensor::diag_embed(int64_t offset, int64_t dim1, int64_t dim2) const { + return at::_ops::diag_embed::call(const_cast(*this), offset, dim1, dim2); +} + +// aten::diagflat(Tensor self, int offset=0) -> Tensor +inline at::Tensor Tensor::diagflat(int64_t offset) const { + return at::_ops::diagflat::call(const_cast(*this), offset); +} + +// aten::diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a) +inline at::Tensor Tensor::diagonal(int64_t offset, int64_t dim1, int64_t dim2) const { + return at::_ops::diagonal::call(const_cast(*this), offset, dim1, dim2); +} + +// aten::diagonal.Dimname(Tensor(a) self, *, Dimname outdim, Dimname dim1, Dimname dim2, int offset=0) -> Tensor(a) +inline at::Tensor Tensor::diagonal(at::Dimname outdim, at::Dimname dim1, at::Dimname dim2, int64_t offset) const { + return at::_ops::diagonal_Dimname::call(const_cast(*this), outdim, dim1, dim2, offset); +} + +// aten::fill_diagonal_(Tensor(a!) self, Scalar fill_value, bool wrap=False) -> Tensor(a!) +inline at::Tensor & Tensor::fill_diagonal_(const at::Scalar & fill_value, bool wrap) const { + return at::_ops::fill_diagonal_::call(const_cast(*this), fill_value, wrap); +} + +// aten::diff(Tensor self, int n=1, int dim=-1, Tensor? prepend=None, Tensor? append=None) -> Tensor +inline at::Tensor Tensor::diff(int64_t n, int64_t dim, const ::std::optional & prepend, const ::std::optional & append) const { + return at::_ops::diff::call(const_cast(*this), n, dim, prepend, append); +} + +// aten::div.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::div(const at::Tensor & other) const { + return at::_ops::div_Tensor::call(const_cast(*this), other); +} + +// aten::div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::div_(const at::Tensor & other) const { + return at::_ops::div__Tensor::call(const_cast(*this), other); +} + +// aten::div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor +inline at::Tensor Tensor::div(const at::Tensor & other, ::std::optional rounding_mode) const { + return at::_ops::div_Tensor_mode::call(const_cast(*this), other, rounding_mode); +} + +// aten::div_.Tensor_mode(Tensor(a!) self, Tensor other, *, str? rounding_mode) -> Tensor(a!) +inline at::Tensor & Tensor::div_(const at::Tensor & other, ::std::optional rounding_mode) const { + return at::_ops::div__Tensor_mode::call(const_cast(*this), other, rounding_mode); +} + +// aten::div.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::div(const at::Scalar & other) const { + return at::_ops::div_Scalar::call(const_cast(*this), other); +} + +// aten::div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::div_(const at::Scalar & other) const { + return at::_ops::div__Scalar::call(const_cast(*this), other); +} + +// aten::div.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor +inline at::Tensor Tensor::div(const at::Scalar & other, ::std::optional rounding_mode) const { + return at::_ops::div_Scalar_mode::call(const_cast(*this), other, rounding_mode); +} + +// aten::div_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!) +inline at::Tensor & Tensor::div_(const at::Scalar & other, ::std::optional rounding_mode) const { + return at::_ops::div__Scalar_mode::call(const_cast(*this), other, rounding_mode); +} + +// aten::divide.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::divide(const at::Tensor & other) const { + return at::_ops::divide_Tensor::call(const_cast(*this), other); +} + +// aten::divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::divide_(const at::Tensor & other) const { + return at::_ops::divide__Tensor::call(const_cast(*this), other); +} + +// aten::divide.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::divide(const at::Scalar & other) const { + return at::_ops::divide_Scalar::call(const_cast(*this), other); +} + +// aten::divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::divide_(const at::Scalar & other) const { + return at::_ops::divide__Scalar::call(const_cast(*this), other); +} + +// aten::divide.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor +inline at::Tensor Tensor::divide(const at::Tensor & other, ::std::optional rounding_mode) const { + return at::_ops::divide_Tensor_mode::call(const_cast(*this), other, rounding_mode); +} + +// aten::divide_.Tensor_mode(Tensor(a!) self, Tensor other, *, str? rounding_mode) -> Tensor(a!) +inline at::Tensor & Tensor::divide_(const at::Tensor & other, ::std::optional rounding_mode) const { + return at::_ops::divide__Tensor_mode::call(const_cast(*this), other, rounding_mode); +} + +// aten::divide.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor +inline at::Tensor Tensor::divide(const at::Scalar & other, ::std::optional rounding_mode) const { + return at::_ops::divide_Scalar_mode::call(const_cast(*this), other, rounding_mode); +} + +// aten::divide_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!) +inline at::Tensor & Tensor::divide_(const at::Scalar & other, ::std::optional rounding_mode) const { + return at::_ops::divide__Scalar_mode::call(const_cast(*this), other, rounding_mode); +} + +// aten::true_divide.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::true_divide(const at::Tensor & other) const { + return at::_ops::true_divide_Tensor::call(const_cast(*this), other); +} + +// aten::true_divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::true_divide_(const at::Tensor & other) const { + return at::_ops::true_divide__Tensor::call(const_cast(*this), other); +} + +// aten::true_divide.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::true_divide(const at::Scalar & other) const { + return at::_ops::true_divide_Scalar::call(const_cast(*this), other); +} + +// aten::true_divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::true_divide_(const at::Scalar & other) const { + return at::_ops::true_divide__Scalar::call(const_cast(*this), other); +} + +// aten::dot(Tensor self, Tensor tensor) -> Tensor +inline at::Tensor Tensor::dot(const at::Tensor & tensor) const { + return at::_ops::dot::call(const_cast(*this), tensor); +} + +// aten::vdot(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::vdot(const at::Tensor & other) const { + return at::_ops::vdot::call(const_cast(*this), other); +} + +// aten::new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_empty(at::IntArrayRef size, at::TensorOptions options) const { + return at::_ops::new_empty::call(const_cast(*this), c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); +} + +// aten::new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_empty(at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const { + return at::_ops::new_empty::call(const_cast(*this), c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory); +} + +// aten::new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_empty_symint(c10::SymIntArrayRef size, at::TensorOptions options) const { + return at::_ops::new_empty::call(const_cast(*this), size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); +} + +// aten::new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_empty_symint(c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const { + return at::_ops::new_empty::call(const_cast(*this), size, dtype, layout, device, pin_memory); +} + +// aten::new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_empty_strided(at::IntArrayRef size, at::IntArrayRef stride, at::TensorOptions options) const { + return at::_ops::new_empty_strided::call(const_cast(*this), c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); +} + +// aten::new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_empty_strided(at::IntArrayRef size, at::IntArrayRef stride, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const { + return at::_ops::new_empty_strided::call(const_cast(*this), c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), dtype, layout, device, pin_memory); +} + +// aten::new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_empty_strided_symint(c10::SymIntArrayRef size, c10::SymIntArrayRef stride, at::TensorOptions options) const { + return at::_ops::new_empty_strided::call(const_cast(*this), size, stride, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); +} + +// aten::new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_empty_strided_symint(c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const { + return at::_ops::new_empty_strided::call(const_cast(*this), size, stride, dtype, layout, device, pin_memory); +} + +// aten::new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_full(at::IntArrayRef size, const at::Scalar & fill_value, at::TensorOptions options) const { + return at::_ops::new_full::call(const_cast(*this), c10::fromIntArrayRefSlow(size), fill_value, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); +} + +// aten::new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_full(at::IntArrayRef size, const at::Scalar & fill_value, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const { + return at::_ops::new_full::call(const_cast(*this), c10::fromIntArrayRefSlow(size), fill_value, dtype, layout, device, pin_memory); +} + +// aten::new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_full_symint(c10::SymIntArrayRef size, const at::Scalar & fill_value, at::TensorOptions options) const { + return at::_ops::new_full::call(const_cast(*this), size, fill_value, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); +} + +// aten::new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_full_symint(c10::SymIntArrayRef size, const at::Scalar & fill_value, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const { + return at::_ops::new_full::call(const_cast(*this), size, fill_value, dtype, layout, device, pin_memory); +} + +// aten::new_zeros(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_zeros(at::IntArrayRef size, at::TensorOptions options) const { + return at::_ops::new_zeros::call(const_cast(*this), c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); +} + +// aten::new_zeros(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_zeros(at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const { + return at::_ops::new_zeros::call(const_cast(*this), c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory); +} + +// aten::new_zeros(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_zeros_symint(c10::SymIntArrayRef size, at::TensorOptions options) const { + return at::_ops::new_zeros::call(const_cast(*this), size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); +} + +// aten::new_zeros(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_zeros_symint(c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const { + return at::_ops::new_zeros::call(const_cast(*this), size, dtype, layout, device, pin_memory); +} + +// aten::new_ones(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_ones(at::IntArrayRef size, at::TensorOptions options) const { + return at::_ops::new_ones::call(const_cast(*this), c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); +} + +// aten::new_ones(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_ones(at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const { + return at::_ops::new_ones::call(const_cast(*this), c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory); +} + +// aten::new_ones(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_ones_symint(c10::SymIntArrayRef size, at::TensorOptions options) const { + return at::_ops::new_ones::call(const_cast(*this), size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); +} + +// aten::new_ones(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_ones_symint(c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const { + return at::_ops::new_ones::call(const_cast(*this), size, dtype, layout, device, pin_memory); +} + +// aten::resize_(Tensor(a!) self, SymInt[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!) +inline const at::Tensor & Tensor::resize_(at::IntArrayRef size, ::std::optional memory_format) const { + return at::_ops::resize_::call(const_cast(*this), c10::fromIntArrayRefSlow(size), memory_format); +} + +// aten::resize_(Tensor(a!) self, SymInt[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!) +inline const at::Tensor & Tensor::resize__symint(c10::SymIntArrayRef size, ::std::optional memory_format) const { + return at::_ops::resize_::call(const_cast(*this), size, memory_format); +} + +// aten::erf(Tensor self) -> Tensor +inline at::Tensor Tensor::erf() const { + return at::_ops::erf::call(const_cast(*this)); +} + +// aten::erf_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::erf_() const { + return at::_ops::erf_::call(const_cast(*this)); +} + +// aten::erfc(Tensor self) -> Tensor +inline at::Tensor Tensor::erfc() const { + return at::_ops::erfc::call(const_cast(*this)); +} + +// aten::erfc_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::erfc_() const { + return at::_ops::erfc_::call(const_cast(*this)); +} + +// aten::exp(Tensor self) -> Tensor +inline at::Tensor Tensor::exp() const { + return at::_ops::exp::call(const_cast(*this)); +} + +// aten::exp_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::exp_() const { + return at::_ops::exp_::call(const_cast(*this)); +} + +// aten::exp2(Tensor self) -> Tensor +inline at::Tensor Tensor::exp2() const { + return at::_ops::exp2::call(const_cast(*this)); +} + +// aten::exp2_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::exp2_() const { + return at::_ops::exp2_::call(const_cast(*this)); +} + +// aten::expm1(Tensor self) -> Tensor +inline at::Tensor Tensor::expm1() const { + return at::_ops::expm1::call(const_cast(*this)); +} + +// aten::expm1_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::expm1_() const { + return at::_ops::expm1_::call(const_cast(*this)); +} + +// aten::expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a) +inline at::Tensor Tensor::expand(at::IntArrayRef size, bool implicit) const { + return at::_ops::expand::call(const_cast(*this), c10::fromIntArrayRefSlow(size), implicit); +} + +// aten::expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a) +inline at::Tensor Tensor::expand_symint(c10::SymIntArrayRef size, bool implicit) const { + return at::_ops::expand::call(const_cast(*this), size, implicit); +} + +// aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a) +inline at::Tensor Tensor::expand_as(const at::Tensor & other) const { + return at::_ops::expand_as::call(const_cast(*this), other); +} + +// aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a) +inline at::Tensor Tensor::flatten(int64_t start_dim, int64_t end_dim) const { + return at::_ops::flatten_using_ints::call(const_cast(*this), start_dim, end_dim); +} + +// aten::flatten.named_out_dim(Tensor(a) self, int start_dim, int end_dim, Dimname out_dim) -> Tensor(a) +inline at::Tensor Tensor::flatten(int64_t start_dim, int64_t end_dim, at::Dimname out_dim) const { + return at::_ops::flatten_named_out_dim::call(const_cast(*this), start_dim, end_dim, out_dim); +} + +// aten::flatten.using_names(Tensor(a) self, Dimname start_dim, Dimname end_dim, Dimname out_dim) -> Tensor(a) +inline at::Tensor Tensor::flatten(at::Dimname start_dim, at::Dimname end_dim, at::Dimname out_dim) const { + return at::_ops::flatten_using_names::call(const_cast(*this), start_dim, end_dim, out_dim); +} + +// aten::flatten.DimnameList(Tensor(a) self, Dimname[] dims, Dimname out_dim) -> Tensor(a) +inline at::Tensor Tensor::flatten(at::DimnameList dims, at::Dimname out_dim) const { + return at::_ops::flatten_DimnameList::call(const_cast(*this), dims, out_dim); +} + +// aten::unflatten.int(Tensor(a) self, int dim, SymInt[] sizes) -> Tensor(a) +inline at::Tensor Tensor::unflatten(int64_t dim, at::IntArrayRef sizes) const { + return at::_ops::unflatten_int::call(const_cast(*this), dim, c10::fromIntArrayRefSlow(sizes)); +} + +// aten::unflatten.int(Tensor(a) self, int dim, SymInt[] sizes) -> Tensor(a) +inline at::Tensor Tensor::unflatten_symint(int64_t dim, c10::SymIntArrayRef sizes) const { + return at::_ops::unflatten_int::call(const_cast(*this), dim, sizes); +} + +// aten::unflatten.Dimname(Tensor(a) self, Dimname dim, SymInt[] sizes, Dimname[] names) -> Tensor(a) +inline at::Tensor Tensor::unflatten(at::Dimname dim, at::IntArrayRef sizes, at::DimnameList names) const { + return at::_ops::unflatten_Dimname::call(const_cast(*this), dim, c10::fromIntArrayRefSlow(sizes), names); +} + +// aten::unflatten.Dimname(Tensor(a) self, Dimname dim, SymInt[] sizes, Dimname[] names) -> Tensor(a) +inline at::Tensor Tensor::unflatten_symint(at::Dimname dim, c10::SymIntArrayRef sizes, at::DimnameList names) const { + return at::_ops::unflatten_Dimname::call(const_cast(*this), dim, sizes, names); +} + +// aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!) +inline at::Tensor & Tensor::fill_(const at::Scalar & value) const { + return at::_ops::fill__Scalar::call(const_cast(*this), value); +} + +// aten::fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!) +inline at::Tensor & Tensor::fill_(const at::Tensor & value) const { + return at::_ops::fill__Tensor::call(const_cast(*this), value); +} + +// aten::floor(Tensor self) -> Tensor +inline at::Tensor Tensor::floor() const { + return at::_ops::floor::call(const_cast(*this)); +} + +// aten::floor_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::floor_() const { + return at::_ops::floor_::call(const_cast(*this)); +} + +// aten::floor_divide(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::floor_divide(const at::Tensor & other) const { + return at::_ops::floor_divide::call(const_cast(*this), other); +} + +// aten::floor_divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::floor_divide_(const at::Tensor & other) const { + return at::_ops::floor_divide__Tensor::call(const_cast(*this), other); +} + +// aten::floor_divide.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::floor_divide(const at::Scalar & other) const { + return at::_ops::floor_divide_Scalar::call(const_cast(*this), other); +} + +// aten::floor_divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::floor_divide_(const at::Scalar & other) const { + return at::_ops::floor_divide__Scalar::call(const_cast(*this), other); +} + +// aten::frac(Tensor self) -> Tensor +inline at::Tensor Tensor::frac() const { + return at::_ops::frac::call(const_cast(*this)); +} + +// aten::frac_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::frac_() const { + return at::_ops::frac_::call(const_cast(*this)); +} + +// aten::gcd(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::gcd(const at::Tensor & other) const { + return at::_ops::gcd::call(const_cast(*this), other); +} + +// aten::gcd_(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::gcd_(const at::Tensor & other) const { + return at::_ops::gcd_::call(const_cast(*this), other); +} + +// aten::lcm(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::lcm(const at::Tensor & other) const { + return at::_ops::lcm::call(const_cast(*this), other); +} + +// aten::lcm_(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::lcm_(const at::Tensor & other) const { + return at::_ops::lcm_::call(const_cast(*this), other); +} + +// aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor +inline at::Tensor Tensor::index(const c10::List<::std::optional> & indices) const { + return at::_ops::index_Tensor::call(const_cast(*this), indices); +} + +// aten::index_copy_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!) +inline at::Tensor & Tensor::index_copy_(int64_t dim, const at::Tensor & index, const at::Tensor & source) const { + return at::_ops::index_copy_::call(const_cast(*this), dim, index, source); +} + +// aten::index_copy(Tensor self, int dim, Tensor index, Tensor source) -> Tensor +inline at::Tensor Tensor::index_copy(int64_t dim, const at::Tensor & index, const at::Tensor & source) const { + return at::_ops::index_copy::call(const_cast(*this), dim, index, source); +} + +// aten::index_copy_.dimname(Tensor(a!) self, Dimname dim, Tensor index, Tensor source) -> Tensor(a!) +inline at::Tensor & Tensor::index_copy_(at::Dimname dim, const at::Tensor & index, const at::Tensor & source) const { + return at::_ops::index_copy__dimname::call(const_cast(*this), dim, index, source); +} + +// aten::index_copy.dimname(Tensor self, Dimname dim, Tensor index, Tensor source) -> Tensor +inline at::Tensor Tensor::index_copy(at::Dimname dim, const at::Tensor & index, const at::Tensor & source) const { + return at::_ops::index_copy_dimname::call(const_cast(*this), dim, index, source); +} + +// aten::index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!) +inline at::Tensor & Tensor::index_put_(const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate) const { + return at::_ops::index_put_::call(const_cast(*this), indices, values, accumulate); +} + +// aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor +inline at::Tensor Tensor::index_put(const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate) const { + return at::_ops::index_put::call(const_cast(*this), indices, values, accumulate); +} + +// aten::isclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> Tensor +inline at::Tensor Tensor::isclose(const at::Tensor & other, double rtol, double atol, bool equal_nan) const { + return at::_ops::isclose::call(const_cast(*this), other, rtol, atol, equal_nan); +} + +// aten::isnan(Tensor self) -> Tensor +inline at::Tensor Tensor::isnan() const { + return at::_ops::isnan::call(const_cast(*this)); +} + +// aten::is_distributed(Tensor self) -> bool +inline bool Tensor::is_distributed() const { + return at::_ops::is_distributed::call(const_cast(*this)); +} + +// aten::is_floating_point(Tensor self) -> bool +inline bool Tensor::__dispatch_is_floating_point() const { + return at::_ops::is_floating_point::call(const_cast(*this)); +} + +// aten::is_complex(Tensor self) -> bool +inline bool Tensor::__dispatch_is_complex() const { + return at::_ops::is_complex::call(const_cast(*this)); +} + +// aten::is_conj(Tensor self) -> bool +inline bool Tensor::__dispatch_is_conj() const { + return at::_ops::is_conj::call(const_cast(*this)); +} + +// aten::_is_zerotensor(Tensor self) -> bool +inline bool Tensor::__dispatch__is_zerotensor() const { + return at::_ops::_is_zerotensor::call(const_cast(*this)); +} + +// aten::is_neg(Tensor self) -> bool +inline bool Tensor::__dispatch_is_neg() const { + return at::_ops::is_neg::call(const_cast(*this)); +} + +// aten::isreal(Tensor self) -> Tensor +inline at::Tensor Tensor::isreal() const { + return at::_ops::isreal::call(const_cast(*this)); +} + +// aten::is_nonzero(Tensor self) -> bool +inline bool Tensor::is_nonzero() const { + return at::_ops::is_nonzero::call(const_cast(*this)); +} + +// aten::is_same_size(Tensor self, Tensor other) -> bool +inline bool Tensor::is_same_size(const at::Tensor & other) const { + return at::_ops::is_same_size::call(const_cast(*this), other); +} + +// aten::is_signed(Tensor self) -> bool +inline bool Tensor::__dispatch_is_signed() const { + return at::_ops::is_signed::call(const_cast(*this)); +} + +// aten::is_inference(Tensor self) -> bool +inline bool Tensor::__dispatch_is_inference() const { + return at::_ops::is_inference::call(const_cast(*this)); +} + +// aten::kron(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::kron(const at::Tensor & other) const { + return at::_ops::kron::call(const_cast(*this), other); +} + +// aten::kthvalue(Tensor self, SymInt k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::kthvalue(int64_t k, int64_t dim, bool keepdim) const { + return at::_ops::kthvalue::call(const_cast(*this), k, dim, keepdim); +} + +// aten::kthvalue(Tensor self, SymInt k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::kthvalue_symint(c10::SymInt k, int64_t dim, bool keepdim) const { + return at::_ops::kthvalue::call(const_cast(*this), k, dim, keepdim); +} + +// aten::kthvalue.dimname(Tensor self, SymInt k, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::kthvalue(int64_t k, at::Dimname dim, bool keepdim) const { + return at::_ops::kthvalue_dimname::call(const_cast(*this), k, dim, keepdim); +} + +// aten::kthvalue.dimname(Tensor self, SymInt k, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::kthvalue_symint(c10::SymInt k, at::Dimname dim, bool keepdim) const { + return at::_ops::kthvalue_dimname::call(const_cast(*this), k, dim, keepdim); +} + +// aten::nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor +inline at::Tensor Tensor::nan_to_num(::std::optional nan, ::std::optional posinf, ::std::optional neginf) const { + return at::_ops::nan_to_num::call(const_cast(*this), nan, posinf, neginf); +} + +// aten::nan_to_num_(Tensor(a!) self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor(a!) +inline at::Tensor & Tensor::nan_to_num_(::std::optional nan, ::std::optional posinf, ::std::optional neginf) const { + return at::_ops::nan_to_num_::call(const_cast(*this), nan, posinf, neginf); +} + +// aten::ldexp.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::ldexp(const at::Tensor & other) const { + return at::_ops::ldexp_Tensor::call(const_cast(*this), other); +} + +// aten::ldexp_(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::ldexp_(const at::Tensor & other) const { + return at::_ops::ldexp_::call(const_cast(*this), other); +} + +// aten::log(Tensor self) -> Tensor +inline at::Tensor Tensor::log() const { + return at::_ops::log::call(const_cast(*this)); +} + +// aten::log_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::log_() const { + return at::_ops::log_::call(const_cast(*this)); +} + +// aten::log10(Tensor self) -> Tensor +inline at::Tensor Tensor::log10() const { + return at::_ops::log10::call(const_cast(*this)); +} + +// aten::log10_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::log10_() const { + return at::_ops::log10_::call(const_cast(*this)); +} + +// aten::log1p(Tensor self) -> Tensor +inline at::Tensor Tensor::log1p() const { + return at::_ops::log1p::call(const_cast(*this)); +} + +// aten::log1p_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::log1p_() const { + return at::_ops::log1p_::call(const_cast(*this)); +} + +// aten::log2(Tensor self) -> Tensor +inline at::Tensor Tensor::log2() const { + return at::_ops::log2::call(const_cast(*this)); +} + +// aten::log2_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::log2_() const { + return at::_ops::log2_::call(const_cast(*this)); +} + +// aten::logaddexp(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::logaddexp(const at::Tensor & other) const { + return at::_ops::logaddexp::call(const_cast(*this), other); +} + +// aten::logaddexp2(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::logaddexp2(const at::Tensor & other) const { + return at::_ops::logaddexp2::call(const_cast(*this), other); +} + +// aten::xlogy.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::xlogy(const at::Tensor & other) const { + return at::_ops::xlogy_Tensor::call(const_cast(*this), other); +} + +// aten::xlogy.Scalar_Other(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::xlogy(const at::Scalar & other) const { + return at::_ops::xlogy_Scalar_Other::call(const_cast(*this), other); +} + +// aten::xlogy_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::xlogy_(const at::Tensor & other) const { + return at::_ops::xlogy__Tensor::call(const_cast(*this), other); +} + +// aten::xlogy_.Scalar_Other(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::xlogy_(const at::Scalar & other) const { + return at::_ops::xlogy__Scalar_Other::call(const_cast(*this), other); +} + +// aten::log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::log_softmax(int64_t dim, ::std::optional dtype) const { + return at::_ops::log_softmax_int::call(const_cast(*this), dim, dtype); +} + +// aten::log_softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::log_softmax(at::Dimname dim, ::std::optional dtype) const { + return at::_ops::log_softmax_Dimname::call(const_cast(*this), dim, dtype); +} + +// aten::logcumsumexp(Tensor self, int dim) -> Tensor +inline at::Tensor Tensor::logcumsumexp(int64_t dim) const { + return at::_ops::logcumsumexp::call(const_cast(*this), dim); +} + +// aten::logcumsumexp.dimname(Tensor self, Dimname dim) -> Tensor +inline at::Tensor Tensor::logcumsumexp(at::Dimname dim) const { + return at::_ops::logcumsumexp_dimname::call(const_cast(*this), dim); +} + +// aten::logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::logsumexp(at::IntArrayRef dim, bool keepdim) const { + return at::_ops::logsumexp::call(const_cast(*this), dim, keepdim); +} + +// aten::logsumexp.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::logsumexp(at::DimnameList dim, bool keepdim) const { + return at::_ops::logsumexp_names::call(const_cast(*this), dim, keepdim); +} + +// aten::matmul(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::matmul(const at::Tensor & other) const { + return at::_ops::matmul::call(const_cast(*this), other); +} + +// aten::matrix_power(Tensor self, int n) -> Tensor +inline at::Tensor Tensor::matrix_power(int64_t n) const { + return at::_ops::matrix_power::call(const_cast(*this), n); +} + +// aten::matrix_exp(Tensor self) -> Tensor +inline at::Tensor Tensor::matrix_exp() const { + return at::_ops::matrix_exp::call(const_cast(*this)); +} + +// aten::aminmax(Tensor self, *, int? dim=None, bool keepdim=False) -> (Tensor min, Tensor max) +inline ::std::tuple Tensor::aminmax(::std::optional dim, bool keepdim) const { + return at::_ops::aminmax::call(const_cast(*this), dim, keepdim); +} + +// aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::max(int64_t dim, bool keepdim) const { + return at::_ops::max_dim::call(const_cast(*this), dim, keepdim); +} + +// aten::max.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::max(at::Dimname dim, bool keepdim) const { + return at::_ops::max_names_dim::call(const_cast(*this), dim, keepdim); +} + +// aten::amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor +inline at::Tensor Tensor::amax(at::IntArrayRef dim, bool keepdim) const { + return at::_ops::amax::call(const_cast(*this), dim, keepdim); +} + +// aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::mean(::std::optional dtype) const { + return at::_ops::mean::call(const_cast(*this), dtype); +} + +// aten::mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::mean(at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype) const { + return at::_ops::mean_dim::call(const_cast(*this), dim, keepdim, dtype); +} + +// aten::mean.names_dim(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::mean(at::DimnameList dim, bool keepdim, ::std::optional dtype) const { + return at::_ops::mean_names_dim::call(const_cast(*this), dim, keepdim, dtype); +} + +// aten::nanmean(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::nanmean(at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype) const { + return at::_ops::nanmean::call(const_cast(*this), dim, keepdim, dtype); +} + +// aten::median(Tensor self) -> Tensor +inline at::Tensor Tensor::median() const { + return at::_ops::median::call(const_cast(*this)); +} + +// aten::median.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::median(int64_t dim, bool keepdim) const { + return at::_ops::median_dim::call(const_cast(*this), dim, keepdim); +} + +// aten::median.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::median(at::Dimname dim, bool keepdim) const { + return at::_ops::median_names_dim::call(const_cast(*this), dim, keepdim); +} + +// aten::nanmedian(Tensor self) -> Tensor +inline at::Tensor Tensor::nanmedian() const { + return at::_ops::nanmedian::call(const_cast(*this)); +} + +// aten::nanmedian.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::nanmedian(int64_t dim, bool keepdim) const { + return at::_ops::nanmedian_dim::call(const_cast(*this), dim, keepdim); +} + +// aten::nanmedian.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::nanmedian(at::Dimname dim, bool keepdim) const { + return at::_ops::nanmedian_names_dim::call(const_cast(*this), dim, keepdim); +} + +// aten::min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::min(int64_t dim, bool keepdim) const { + return at::_ops::min_dim::call(const_cast(*this), dim, keepdim); +} + +// aten::min.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::min(at::Dimname dim, bool keepdim) const { + return at::_ops::min_names_dim::call(const_cast(*this), dim, keepdim); +} + +// aten::amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor +inline at::Tensor Tensor::amin(at::IntArrayRef dim, bool keepdim) const { + return at::_ops::amin::call(const_cast(*this), dim, keepdim); +} + +// aten::mm(Tensor self, Tensor mat2) -> Tensor +inline at::Tensor Tensor::mm(const at::Tensor & mat2) const { + return at::_ops::mm::call(const_cast(*this), mat2); +} + +// aten::mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::mode(int64_t dim, bool keepdim) const { + return at::_ops::mode::call(const_cast(*this), dim, keepdim); +} + +// aten::mode.dimname(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::mode(at::Dimname dim, bool keepdim) const { + return at::_ops::mode_dimname::call(const_cast(*this), dim, keepdim); +} + +// aten::mul.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::mul(const at::Tensor & other) const { + return at::_ops::mul_Tensor::call(const_cast(*this), other); +} + +// aten::mul_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::mul_(const at::Tensor & other) const { + return at::_ops::mul__Tensor::call(const_cast(*this), other); +} + +// aten::mul.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::mul(const at::Scalar & other) const { + return at::_ops::mul_Scalar::call(const_cast(*this), other); +} + +// aten::mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::mul_(const at::Scalar & other) const { + return at::_ops::mul__Scalar::call(const_cast(*this), other); +} + +// aten::multiply.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::multiply(const at::Tensor & other) const { + return at::_ops::multiply_Tensor::call(const_cast(*this), other); +} + +// aten::multiply_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::multiply_(const at::Tensor & other) const { + return at::_ops::multiply__Tensor::call(const_cast(*this), other); +} + +// aten::multiply.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::multiply(const at::Scalar & other) const { + return at::_ops::multiply_Scalar::call(const_cast(*this), other); +} + +// aten::multiply_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::multiply_(const at::Scalar & other) const { + return at::_ops::multiply__Scalar::call(const_cast(*this), other); +} + +// aten::mv(Tensor self, Tensor vec) -> Tensor +inline at::Tensor Tensor::mv(const at::Tensor & vec) const { + return at::_ops::mv::call(const_cast(*this), vec); +} + +// aten::mvlgamma(Tensor self, int p) -> Tensor +inline at::Tensor Tensor::mvlgamma(int64_t p) const { + return at::_ops::mvlgamma::call(const_cast(*this), p); +} + +// aten::mvlgamma_(Tensor(a!) self, int p) -> Tensor(a!) +inline at::Tensor & Tensor::mvlgamma_(int64_t p) const { + return at::_ops::mvlgamma_::call(const_cast(*this), p); +} + +// aten::narrow_copy(Tensor self, int dim, SymInt start, SymInt length) -> Tensor +inline at::Tensor Tensor::narrow_copy(int64_t dim, int64_t start, int64_t length) const { + return at::_ops::narrow_copy::call(const_cast(*this), dim, start, length); +} + +// aten::narrow_copy(Tensor self, int dim, SymInt start, SymInt length) -> Tensor +inline at::Tensor Tensor::narrow_copy_symint(int64_t dim, c10::SymInt start, c10::SymInt length) const { + return at::_ops::narrow_copy::call(const_cast(*this), dim, start, length); +} + +// aten::narrow(Tensor(a) self, int dim, SymInt start, SymInt length) -> Tensor(a) +inline at::Tensor Tensor::narrow(int64_t dim, int64_t start, int64_t length) const { + return at::_ops::narrow::call(const_cast(*this), dim, start, length); +} + +// aten::narrow(Tensor(a) self, int dim, SymInt start, SymInt length) -> Tensor(a) +inline at::Tensor Tensor::narrow_symint(int64_t dim, c10::SymInt start, c10::SymInt length) const { + return at::_ops::narrow::call(const_cast(*this), dim, start, length); +} + +// aten::narrow.Tensor(Tensor(a) self, int dim, Tensor start, SymInt length) -> Tensor(a) +inline at::Tensor Tensor::narrow(int64_t dim, const at::Tensor & start, int64_t length) const { + return at::_ops::narrow_Tensor::call(const_cast(*this), dim, start, length); +} + +// aten::narrow.Tensor(Tensor(a) self, int dim, Tensor start, SymInt length) -> Tensor(a) +inline at::Tensor Tensor::narrow_symint(int64_t dim, const at::Tensor & start, c10::SymInt length) const { + return at::_ops::narrow_Tensor::call(const_cast(*this), dim, start, length); +} + +// aten::permute(Tensor(a) self, int[] dims) -> Tensor(a) +inline at::Tensor Tensor::permute(at::IntArrayRef dims) const { + return at::_ops::permute::call(const_cast(*this), dims); +} + +// aten::movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a) +inline at::Tensor Tensor::movedim(at::IntArrayRef source, at::IntArrayRef destination) const { + return at::_ops::movedim_intlist::call(const_cast(*this), source, destination); +} + +// aten::movedim.int(Tensor(a) self, int source, int destination) -> Tensor(a) +inline at::Tensor Tensor::movedim(int64_t source, int64_t destination) const { + return at::_ops::movedim_int::call(const_cast(*this), source, destination); +} + +// aten::moveaxis.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a) +inline at::Tensor Tensor::moveaxis(at::IntArrayRef source, at::IntArrayRef destination) const { + return at::_ops::moveaxis_intlist::call(const_cast(*this), source, destination); +} + +// aten::moveaxis.int(Tensor(a) self, int source, int destination) -> Tensor(a) +inline at::Tensor Tensor::moveaxis(int64_t source, int64_t destination) const { + return at::_ops::moveaxis_int::call(const_cast(*this), source, destination); +} + +// aten::numpy_T(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::numpy_T() const { + return at::_ops::numpy_T::call(const_cast(*this)); +} + +// aten::matrix_H(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::matrix_H() const { + return at::_ops::matrix_H::call(const_cast(*this)); +} + +// aten::mT(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::mT() const { + return at::_ops::mT::call(const_cast(*this)); +} + +// aten::mH(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::mH() const { + return at::_ops::mH::call(const_cast(*this)); +} + +// aten::adjoint(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::adjoint() const { + return at::_ops::adjoint::call(const_cast(*this)); +} + +// aten::is_pinned(Tensor self, Device? device=None) -> bool +inline bool Tensor::is_pinned(::std::optional device) const { + return at::_ops::is_pinned::call(const_cast(*this), device); +} + +// aten::pin_memory(Tensor(a) self, Device? device=None) -> Tensor(a) +inline at::Tensor Tensor::pin_memory(::std::optional device) const { + return at::_ops::pin_memory::call(const_cast(*this), device); +} + +// aten::pinverse(Tensor self, float rcond=1e-15) -> Tensor +inline at::Tensor Tensor::pinverse(double rcond) const { + return at::_ops::pinverse::call(const_cast(*this), rcond); +} + +// aten::rad2deg(Tensor self) -> Tensor +inline at::Tensor Tensor::rad2deg() const { + return at::_ops::rad2deg::call(const_cast(*this)); +} + +// aten::rad2deg_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::rad2deg_() const { + return at::_ops::rad2deg_::call(const_cast(*this)); +} + +// aten::deg2rad(Tensor self) -> Tensor +inline at::Tensor Tensor::deg2rad() const { + return at::_ops::deg2rad::call(const_cast(*this)); +} + +// aten::deg2rad_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::deg2rad_() const { + return at::_ops::deg2rad_::call(const_cast(*this)); +} + +// aten::ravel(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::ravel() const { + return at::_ops::ravel::call(const_cast(*this)); +} + +// aten::reciprocal(Tensor self) -> Tensor +inline at::Tensor Tensor::reciprocal() const { + return at::_ops::reciprocal::call(const_cast(*this)); +} + +// aten::reciprocal_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::reciprocal_() const { + return at::_ops::reciprocal_::call(const_cast(*this)); +} + +// aten::neg(Tensor self) -> Tensor +inline at::Tensor Tensor::neg() const { + return at::_ops::neg::call(const_cast(*this)); +} + +// aten::neg_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::neg_() const { + return at::_ops::neg_::call(const_cast(*this)); +} + +// aten::negative(Tensor self) -> Tensor +inline at::Tensor Tensor::negative() const { + return at::_ops::negative::call(const_cast(*this)); +} + +// aten::negative_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::negative_() const { + return at::_ops::negative_::call(const_cast(*this)); +} + +// aten::repeat(Tensor self, SymInt[] repeats) -> Tensor +inline at::Tensor Tensor::repeat(at::IntArrayRef repeats) const { + return at::_ops::repeat::call(const_cast(*this), c10::fromIntArrayRefSlow(repeats)); +} + +// aten::repeat(Tensor self, SymInt[] repeats) -> Tensor +inline at::Tensor Tensor::repeat_symint(c10::SymIntArrayRef repeats) const { + return at::_ops::repeat::call(const_cast(*this), repeats); +} + +// aten::repeat_interleave.self_Tensor(Tensor self, Tensor repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor +inline at::Tensor Tensor::repeat_interleave(const at::Tensor & repeats, ::std::optional dim, ::std::optional output_size) const { + return at::_ops::repeat_interleave_self_Tensor::call(const_cast(*this), repeats, dim, output_size.has_value() ? ::std::make_optional(c10::SymInt(*output_size)) : ::std::nullopt); +} + +// aten::repeat_interleave.self_Tensor(Tensor self, Tensor repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor +inline at::Tensor Tensor::repeat_interleave_symint(const at::Tensor & repeats, ::std::optional dim, ::std::optional output_size) const { + return at::_ops::repeat_interleave_self_Tensor::call(const_cast(*this), repeats, dim, output_size); +} + +// aten::repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor +inline at::Tensor Tensor::repeat_interleave(int64_t repeats, ::std::optional dim, ::std::optional output_size) const { + return at::_ops::repeat_interleave_self_int::call(const_cast(*this), repeats, dim, output_size.has_value() ? ::std::make_optional(c10::SymInt(*output_size)) : ::std::nullopt); +} + +// aten::repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor +inline at::Tensor Tensor::repeat_interleave_symint(c10::SymInt repeats, ::std::optional dim, ::std::optional output_size) const { + return at::_ops::repeat_interleave_self_int::call(const_cast(*this), repeats, dim, output_size); +} + +// aten::reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a) +inline at::Tensor Tensor::reshape(at::IntArrayRef shape) const { + return at::_ops::reshape::call(const_cast(*this), c10::fromIntArrayRefSlow(shape)); +} + +// aten::reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a) +inline at::Tensor Tensor::reshape_symint(c10::SymIntArrayRef shape) const { + return at::_ops::reshape::call(const_cast(*this), shape); +} + +// aten::_reshape_alias(Tensor(a) self, SymInt[] size, SymInt[] stride) -> Tensor(a) +inline at::Tensor Tensor::_reshape_alias(at::IntArrayRef size, at::IntArrayRef stride) const { + return at::_ops::_reshape_alias::call(const_cast(*this), c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride)); +} + +// aten::_reshape_alias(Tensor(a) self, SymInt[] size, SymInt[] stride) -> Tensor(a) +inline at::Tensor Tensor::_reshape_alias_symint(c10::SymIntArrayRef size, c10::SymIntArrayRef stride) const { + return at::_ops::_reshape_alias::call(const_cast(*this), size, stride); +} + +// aten::reshape_as(Tensor(a) self, Tensor other) -> Tensor(a) +inline at::Tensor Tensor::reshape_as(const at::Tensor & other) const { + return at::_ops::reshape_as::call(const_cast(*this), other); +} + +// aten::round(Tensor self) -> Tensor +inline at::Tensor Tensor::round() const { + return at::_ops::round::call(const_cast(*this)); +} + +// aten::round_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::round_() const { + return at::_ops::round_::call(const_cast(*this)); +} + +// aten::round.decimals(Tensor self, *, int decimals) -> Tensor +inline at::Tensor Tensor::round(int64_t decimals) const { + return at::_ops::round_decimals::call(const_cast(*this), decimals); +} + +// aten::round_.decimals(Tensor(a!) self, *, int decimals) -> Tensor(a!) +inline at::Tensor & Tensor::round_(int64_t decimals) const { + return at::_ops::round__decimals::call(const_cast(*this), decimals); +} + +// aten::relu(Tensor self) -> Tensor +inline at::Tensor Tensor::relu() const { + return at::_ops::relu::call(const_cast(*this)); +} + +// aten::relu_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::relu_() const { + return at::_ops::relu_::call(const_cast(*this)); +} + +// aten::prelu(Tensor self, Tensor weight) -> Tensor +inline at::Tensor Tensor::prelu(const at::Tensor & weight) const { + return at::_ops::prelu::call(const_cast(*this), weight); +} + +// aten::hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor +inline at::Tensor Tensor::hardshrink(const at::Scalar & lambd) const { + return at::_ops::hardshrink::call(const_cast(*this), lambd); +} + +// aten::hardshrink_backward(Tensor grad_out, Tensor self, Scalar lambd) -> Tensor +inline at::Tensor Tensor::hardshrink_backward(const at::Tensor & grad_out, const at::Scalar & lambd) const { + return at::_ops::hardshrink_backward::call(grad_out, const_cast(*this), lambd); +} + +// aten::rsqrt(Tensor self) -> Tensor +inline at::Tensor Tensor::rsqrt() const { + return at::_ops::rsqrt::call(const_cast(*this)); +} + +// aten::rsqrt_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::rsqrt_() const { + return at::_ops::rsqrt_::call(const_cast(*this)); +} + +// aten::select.Dimname(Tensor(a) self, Dimname dim, int index) -> Tensor(a) +inline at::Tensor Tensor::select(at::Dimname dim, int64_t index) const { + return at::_ops::select_Dimname::call(const_cast(*this), dim, index); +} + +// aten::select.int(Tensor(a) self, int dim, SymInt index) -> Tensor(a) +inline at::Tensor Tensor::select(int64_t dim, int64_t index) const { + return at::_ops::select_int::call(const_cast(*this), dim, index); +} + +// aten::select.int(Tensor(a) self, int dim, SymInt index) -> Tensor(a) +inline at::Tensor Tensor::select_symint(int64_t dim, c10::SymInt index) const { + return at::_ops::select_int::call(const_cast(*this), dim, index); +} + +// aten::sigmoid(Tensor self) -> Tensor +inline at::Tensor Tensor::sigmoid() const { + return at::_ops::sigmoid::call(const_cast(*this)); +} + +// aten::sigmoid_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::sigmoid_() const { + return at::_ops::sigmoid_::call(const_cast(*this)); +} + +// aten::logit(Tensor self, float? eps=None) -> Tensor +inline at::Tensor Tensor::logit(::std::optional eps) const { + return at::_ops::logit::call(const_cast(*this), eps); +} + +// aten::logit_(Tensor(a!) self, float? eps=None) -> Tensor(a!) +inline at::Tensor & Tensor::logit_(::std::optional eps) const { + return at::_ops::logit_::call(const_cast(*this), eps); +} + +// aten::sin(Tensor self) -> Tensor +inline at::Tensor Tensor::sin() const { + return at::_ops::sin::call(const_cast(*this)); +} + +// aten::sin_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::sin_() const { + return at::_ops::sin_::call(const_cast(*this)); +} + +// aten::sinc(Tensor self) -> Tensor +inline at::Tensor Tensor::sinc() const { + return at::_ops::sinc::call(const_cast(*this)); +} + +// aten::sinc_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::sinc_() const { + return at::_ops::sinc_::call(const_cast(*this)); +} + +// aten::sinh(Tensor self) -> Tensor +inline at::Tensor Tensor::sinh() const { + return at::_ops::sinh::call(const_cast(*this)); +} + +// aten::sinh_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::sinh_() const { + return at::_ops::sinh_::call(const_cast(*this)); +} + +// aten::detach(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::detach() const { + return at::_ops::detach::call(const_cast(*this)); +} + +// aten::detach_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::detach_() const { + return at::_ops::detach_::call(const_cast(*this)); +} + +// aten::size.Dimname(Tensor self, Dimname dim) -> int +inline int64_t Tensor::size(at::Dimname dim) const { + return at::_ops::size_Dimname::call(const_cast(*this), dim); +} + +// aten::slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) +inline at::Tensor Tensor::slice(int64_t dim, ::std::optional start, ::std::optional end, int64_t step) const { + return at::_ops::slice_Tensor::call(const_cast(*this), dim, start.has_value() ? ::std::make_optional(c10::SymInt(*start)) : ::std::nullopt, end.has_value() ? ::std::make_optional(c10::SymInt(*end)) : ::std::nullopt, step); +} + +// aten::slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) +inline at::Tensor Tensor::slice_symint(int64_t dim, ::std::optional start, ::std::optional end, c10::SymInt step) const { + return at::_ops::slice_Tensor::call(const_cast(*this), dim, start, end, step); +} + +// aten::slice_inverse(Tensor(a) self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) +inline at::Tensor Tensor::slice_inverse(const at::Tensor & src, int64_t dim, ::std::optional start, ::std::optional end, int64_t step) const { + return at::_ops::slice_inverse::call(const_cast(*this), src, dim, start.has_value() ? ::std::make_optional(c10::SymInt(*start)) : ::std::nullopt, end.has_value() ? ::std::make_optional(c10::SymInt(*end)) : ::std::nullopt, step); +} + +// aten::slice_inverse(Tensor(a) self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) +inline at::Tensor Tensor::slice_inverse_symint(const at::Tensor & src, int64_t dim, ::std::optional start, ::std::optional end, c10::SymInt step) const { + return at::_ops::slice_inverse::call(const_cast(*this), src, dim, start, end, step); +} + +// aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor +inline at::Tensor Tensor::slice_scatter(const at::Tensor & src, int64_t dim, ::std::optional start, ::std::optional end, int64_t step) const { + return at::_ops::slice_scatter::call(const_cast(*this), src, dim, start.has_value() ? ::std::make_optional(c10::SymInt(*start)) : ::std::nullopt, end.has_value() ? ::std::make_optional(c10::SymInt(*end)) : ::std::nullopt, step); +} + +// aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor +inline at::Tensor Tensor::slice_scatter_symint(const at::Tensor & src, int64_t dim, ::std::optional start, ::std::optional end, c10::SymInt step) const { + return at::_ops::slice_scatter::call(const_cast(*this), src, dim, start, end, step); +} + +// aten::select_scatter(Tensor self, Tensor src, int dim, SymInt index) -> Tensor +inline at::Tensor Tensor::select_scatter(const at::Tensor & src, int64_t dim, int64_t index) const { + return at::_ops::select_scatter::call(const_cast(*this), src, dim, index); +} + +// aten::select_scatter(Tensor self, Tensor src, int dim, SymInt index) -> Tensor +inline at::Tensor Tensor::select_scatter_symint(const at::Tensor & src, int64_t dim, c10::SymInt index) const { + return at::_ops::select_scatter::call(const_cast(*this), src, dim, index); +} + +// aten::diagonal_scatter(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1) -> Tensor +inline at::Tensor Tensor::diagonal_scatter(const at::Tensor & src, int64_t offset, int64_t dim1, int64_t dim2) const { + return at::_ops::diagonal_scatter::call(const_cast(*this), src, offset, dim1, dim2); +} + +// aten::as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor +inline at::Tensor Tensor::as_strided_scatter(const at::Tensor & src, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset) const { + return at::_ops::as_strided_scatter::call(const_cast(*this), src, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt); +} + +// aten::as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor +inline at::Tensor Tensor::as_strided_scatter_symint(const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset) const { + return at::_ops::as_strided_scatter::call(const_cast(*this), src, size, stride, storage_offset); +} + +// aten::smm(Tensor self, Tensor mat2) -> Tensor +inline at::Tensor Tensor::smm(const at::Tensor & mat2) const { + return at::_ops::smm::call(const_cast(*this), mat2); +} + +// aten::softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::softmax(int64_t dim, ::std::optional dtype) const { + return at::_ops::softmax_int::call(const_cast(*this), dim, dtype); +} + +// aten::softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::softmax(at::Dimname dim, ::std::optional dtype) const { + return at::_ops::softmax_Dimname::call(const_cast(*this), dim, dtype); +} + +// aten::unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[] +inline ::std::vector Tensor::unsafe_split(int64_t split_size, int64_t dim) const { + return at::_ops::unsafe_split_Tensor::call(const_cast(*this), split_size, dim); +} + +// aten::unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[] +inline ::std::vector Tensor::unsafe_split_symint(c10::SymInt split_size, int64_t dim) const { + return at::_ops::unsafe_split_Tensor::call(const_cast(*this), split_size, dim); +} + +// aten::split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[] +inline ::std::vector Tensor::split(int64_t split_size, int64_t dim) const { + return at::_ops::split_Tensor::call(const_cast(*this), split_size, dim); +} + +// aten::split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[] +inline ::std::vector Tensor::split_symint(c10::SymInt split_size, int64_t dim) const { + return at::_ops::split_Tensor::call(const_cast(*this), split_size, dim); +} + +// aten::split.sizes(Tensor(a -> *) self, SymInt[] split_size, int dim=0) -> Tensor(a)[] +inline ::std::vector Tensor::split(at::IntArrayRef split_size, int64_t dim) const { + return at::_ops::split_sizes::call(const_cast(*this), c10::fromIntArrayRefSlow(split_size), dim); +} + +// aten::split.sizes(Tensor(a -> *) self, SymInt[] split_size, int dim=0) -> Tensor(a)[] +inline ::std::vector Tensor::split_symint(c10::SymIntArrayRef split_size, int64_t dim) const { + return at::_ops::split_sizes::call(const_cast(*this), split_size, dim); +} + +// aten::unsafe_split_with_sizes(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[] +inline ::std::vector Tensor::unsafe_split_with_sizes(at::IntArrayRef split_sizes, int64_t dim) const { + return at::_ops::unsafe_split_with_sizes::call(const_cast(*this), c10::fromIntArrayRefSlow(split_sizes), dim); +} + +// aten::unsafe_split_with_sizes(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[] +inline ::std::vector Tensor::unsafe_split_with_sizes_symint(c10::SymIntArrayRef split_sizes, int64_t dim) const { + return at::_ops::unsafe_split_with_sizes::call(const_cast(*this), split_sizes, dim); +} + +// aten::split_with_sizes(Tensor(a -> *) self, SymInt[] split_sizes, int dim=0) -> Tensor(a)[] +inline ::std::vector Tensor::split_with_sizes(at::IntArrayRef split_sizes, int64_t dim) const { + return at::_ops::split_with_sizes::call(const_cast(*this), c10::fromIntArrayRefSlow(split_sizes), dim); +} + +// aten::split_with_sizes(Tensor(a -> *) self, SymInt[] split_sizes, int dim=0) -> Tensor(a)[] +inline ::std::vector Tensor::split_with_sizes_symint(c10::SymIntArrayRef split_sizes, int64_t dim) const { + return at::_ops::split_with_sizes::call(const_cast(*this), split_sizes, dim); +} + +// aten::hsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[] +inline ::std::vector Tensor::hsplit(int64_t sections) const { + return at::_ops::hsplit_int::call(const_cast(*this), sections); +} + +// aten::hsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[] +inline ::std::vector Tensor::hsplit(at::IntArrayRef indices) const { + return at::_ops::hsplit_array::call(const_cast(*this), indices); +} + +// aten::vsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[] +inline ::std::vector Tensor::vsplit(int64_t sections) const { + return at::_ops::vsplit_int::call(const_cast(*this), sections); +} + +// aten::vsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[] +inline ::std::vector Tensor::vsplit(at::IntArrayRef indices) const { + return at::_ops::vsplit_array::call(const_cast(*this), indices); +} + +// aten::dsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[] +inline ::std::vector Tensor::dsplit(int64_t sections) const { + return at::_ops::dsplit_int::call(const_cast(*this), sections); +} + +// aten::dsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[] +inline ::std::vector Tensor::dsplit(at::IntArrayRef indices) const { + return at::_ops::dsplit_array::call(const_cast(*this), indices); +} + +// aten::squeeze(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::squeeze() const { + return at::_ops::squeeze::call(const_cast(*this)); +} + +// aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a) +inline at::Tensor Tensor::squeeze(int64_t dim) const { + return at::_ops::squeeze_dim::call(const_cast(*this), dim); +} + +// aten::squeeze.dimname(Tensor(a) self, Dimname dim) -> Tensor(a) +inline at::Tensor Tensor::squeeze(at::Dimname dim) const { + return at::_ops::squeeze_dimname::call(const_cast(*this), dim); +} + +// aten::squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a) +inline at::Tensor Tensor::squeeze(at::IntArrayRef dim) const { + return at::_ops::squeeze_dims::call(const_cast(*this), dim); +} + +// aten::squeeze_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::squeeze_() const { + return at::_ops::squeeze_::call(const_cast(*this)); +} + +// aten::squeeze_.dim(Tensor(a!) self, int dim) -> Tensor(a!) +inline at::Tensor & Tensor::squeeze_(int64_t dim) const { + return at::_ops::squeeze__dim::call(const_cast(*this), dim); +} + +// aten::squeeze_.dims(Tensor(a!) self, int[] dim) -> Tensor(a!) +inline at::Tensor & Tensor::squeeze_(at::IntArrayRef dim) const { + return at::_ops::squeeze__dims::call(const_cast(*this), dim); +} + +// aten::squeeze_.dimname(Tensor(a!) self, Dimname dim) -> Tensor(a!) +inline at::Tensor & Tensor::squeeze_(at::Dimname dim) const { + return at::_ops::squeeze__dimname::call(const_cast(*this), dim); +} + +// aten::sspaddmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor +inline at::Tensor Tensor::sspaddmm(const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha) const { + return at::_ops::sspaddmm::call(const_cast(*this), mat1, mat2, beta, alpha); +} + +// aten::stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None, bool? align_to_window=None) -> Tensor +inline at::Tensor Tensor::stft(int64_t n_fft, ::std::optional hop_length, ::std::optional win_length, const ::std::optional & window, bool normalized, ::std::optional onesided, ::std::optional return_complex, ::std::optional align_to_window) const { + return at::_ops::stft::call(const_cast(*this), n_fft, hop_length, win_length, window, normalized, onesided, return_complex, align_to_window); +} + +// aten::stft.center(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, str pad_mode="reflect", bool normalized=False, bool? onesided=None, bool? return_complex=None, bool? align_to_window=None) -> Tensor +inline at::Tensor Tensor::stft(int64_t n_fft, ::std::optional hop_length, ::std::optional win_length, const ::std::optional & window, bool center, c10::string_view pad_mode, bool normalized, ::std::optional onesided, ::std::optional return_complex, ::std::optional align_to_window) const { + return at::_ops::stft_center::call(const_cast(*this), n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided, return_complex, align_to_window); +} + +// aten::istft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, bool normalized=False, bool? onesided=None, int? length=None, bool return_complex=False) -> Tensor +inline at::Tensor Tensor::istft(int64_t n_fft, ::std::optional hop_length, ::std::optional win_length, const ::std::optional & window, bool center, bool normalized, ::std::optional onesided, ::std::optional length, bool return_complex) const { + return at::_ops::istft::call(const_cast(*this), n_fft, hop_length, win_length, window, center, normalized, onesided, length, return_complex); +} + +// aten::stride.Dimname(Tensor self, Dimname dim) -> int +inline int64_t Tensor::stride(at::Dimname dim) const { + return at::_ops::stride_Dimname::call(const_cast(*this), dim); +} + +// aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::sum(::std::optional dtype) const { + return at::_ops::sum::call(const_cast(*this), dtype); +} + +// aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::sum(at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype) const { + return at::_ops::sum_dim_IntList::call(const_cast(*this), dim, keepdim, dtype); +} + +// aten::sum.dim_DimnameList(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::sum(at::DimnameList dim, bool keepdim, ::std::optional dtype) const { + return at::_ops::sum_dim_DimnameList::call(const_cast(*this), dim, keepdim, dtype); +} + +// aten::nansum(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::nansum(at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype) const { + return at::_ops::nansum::call(const_cast(*this), dim, keepdim, dtype); +} + +// aten::sum_to_size(Tensor self, SymInt[] size) -> Tensor +inline at::Tensor Tensor::sum_to_size(at::IntArrayRef size) const { + return at::_ops::sum_to_size::call(const_cast(*this), c10::fromIntArrayRefSlow(size)); +} + +// aten::sum_to_size(Tensor self, SymInt[] size) -> Tensor +inline at::Tensor Tensor::sum_to_size_symint(c10::SymIntArrayRef size) const { + return at::_ops::sum_to_size::call(const_cast(*this), size); +} + +// aten::sqrt(Tensor self) -> Tensor +inline at::Tensor Tensor::sqrt() const { + return at::_ops::sqrt::call(const_cast(*this)); +} + +// aten::sqrt_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::sqrt_() const { + return at::_ops::sqrt_::call(const_cast(*this)); +} + +// aten::square(Tensor self) -> Tensor +inline at::Tensor Tensor::square() const { + return at::_ops::square::call(const_cast(*this)); +} + +// aten::square_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::square_() const { + return at::_ops::square_::call(const_cast(*this)); +} + +// aten::std(Tensor self, bool unbiased=True) -> Tensor +inline at::Tensor Tensor::std(bool unbiased) const { + return at::_ops::std::call(const_cast(*this), unbiased); +} + +// aten::std.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::std(at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) const { + return at::_ops::std_dim::call(const_cast(*this), dim, unbiased, keepdim); +} + +// aten::std.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::std(at::OptionalIntArrayRef dim, const ::std::optional & correction, bool keepdim) const { + return at::_ops::std_correction::call(const_cast(*this), dim, correction, keepdim); +} + +// aten::std.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::std(at::DimnameList dim, bool unbiased, bool keepdim) const { + return at::_ops::std_names_dim::call(const_cast(*this), dim, unbiased, keepdim); +} + +// aten::std.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::std(at::DimnameList dim, const ::std::optional & correction, bool keepdim) const { + return at::_ops::std_correction_names::call(const_cast(*this), dim, correction, keepdim); +} + +// aten::prod(Tensor self, *, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::prod(::std::optional dtype) const { + return at::_ops::prod::call(const_cast(*this), dtype); +} + +// aten::prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::prod(int64_t dim, bool keepdim, ::std::optional dtype) const { + return at::_ops::prod_dim_int::call(const_cast(*this), dim, keepdim, dtype); +} + +// aten::prod.dim_Dimname(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::prod(at::Dimname dim, bool keepdim, ::std::optional dtype) const { + return at::_ops::prod_dim_Dimname::call(const_cast(*this), dim, keepdim, dtype); +} + +// aten::t(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::t() const { + return at::_ops::t::call(const_cast(*this)); +} + +// aten::t_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::t_() const { + return at::_ops::t_::call(const_cast(*this)); +} + +// aten::tan(Tensor self) -> Tensor +inline at::Tensor Tensor::tan() const { + return at::_ops::tan::call(const_cast(*this)); +} + +// aten::tan_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::tan_() const { + return at::_ops::tan_::call(const_cast(*this)); +} + +// aten::tanh(Tensor self) -> Tensor +inline at::Tensor Tensor::tanh() const { + return at::_ops::tanh::call(const_cast(*this)); +} + +// aten::tanh_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::tanh_() const { + return at::_ops::tanh_::call(const_cast(*this)); +} + +// aten::tile(Tensor self, SymInt[] dims) -> Tensor +inline at::Tensor Tensor::tile(at::IntArrayRef dims) const { + return at::_ops::tile::call(const_cast(*this), c10::fromIntArrayRefSlow(dims)); +} + +// aten::tile(Tensor self, SymInt[] dims) -> Tensor +inline at::Tensor Tensor::tile_symint(c10::SymIntArrayRef dims) const { + return at::_ops::tile::call(const_cast(*this), dims); +} + +// aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a) +inline at::Tensor Tensor::transpose(int64_t dim0, int64_t dim1) const { + return at::_ops::transpose_int::call(const_cast(*this), dim0, dim1); +} + +// aten::transpose.Dimname(Tensor(a) self, Dimname dim0, Dimname dim1) -> Tensor(a) +inline at::Tensor Tensor::transpose(at::Dimname dim0, at::Dimname dim1) const { + return at::_ops::transpose_Dimname::call(const_cast(*this), dim0, dim1); +} + +// aten::transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!) +inline at::Tensor & Tensor::transpose_(int64_t dim0, int64_t dim1) const { + return at::_ops::transpose_::call(const_cast(*this), dim0, dim1); +} + +// aten::flip(Tensor self, int[] dims) -> Tensor +inline at::Tensor Tensor::flip(at::IntArrayRef dims) const { + return at::_ops::flip::call(const_cast(*this), dims); +} + +// aten::fliplr(Tensor self) -> Tensor +inline at::Tensor Tensor::fliplr() const { + return at::_ops::fliplr::call(const_cast(*this)); +} + +// aten::flipud(Tensor self) -> Tensor +inline at::Tensor Tensor::flipud() const { + return at::_ops::flipud::call(const_cast(*this)); +} + +// aten::roll(Tensor self, SymInt[1] shifts, int[1] dims=[]) -> Tensor +inline at::Tensor Tensor::roll(at::IntArrayRef shifts, at::IntArrayRef dims) const { + return at::_ops::roll::call(const_cast(*this), c10::fromIntArrayRefSlow(shifts), dims); +} + +// aten::roll(Tensor self, SymInt[1] shifts, int[1] dims=[]) -> Tensor +inline at::Tensor Tensor::roll_symint(c10::SymIntArrayRef shifts, at::IntArrayRef dims) const { + return at::_ops::roll::call(const_cast(*this), shifts, dims); +} + +// aten::rot90(Tensor self, int k=1, int[] dims=[0,1]) -> Tensor +inline at::Tensor Tensor::rot90(int64_t k, at::IntArrayRef dims) const { + return at::_ops::rot90::call(const_cast(*this), k, dims); +} + +// aten::_nested_tensor_size(Tensor self) -> Tensor +inline at::Tensor Tensor::_nested_tensor_size() const { + return at::_ops::_nested_tensor_size::call(const_cast(*this)); +} + +// aten::_nested_tensor_strides(Tensor self) -> Tensor +inline at::Tensor Tensor::_nested_tensor_strides() const { + return at::_ops::_nested_tensor_strides::call(const_cast(*this)); +} + +// aten::_nested_tensor_storage_offsets(Tensor self) -> Tensor +inline at::Tensor Tensor::_nested_tensor_storage_offsets() const { + return at::_ops::_nested_tensor_storage_offsets::call(const_cast(*this)); +} + +// aten::trunc(Tensor self) -> Tensor +inline at::Tensor Tensor::trunc() const { + return at::_ops::trunc::call(const_cast(*this)); +} + +// aten::trunc_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::trunc_() const { + return at::_ops::trunc_::call(const_cast(*this)); +} + +// aten::fix(Tensor self) -> Tensor +inline at::Tensor Tensor::fix() const { + return at::_ops::fix::call(const_cast(*this)); +} + +// aten::fix_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::fix_() const { + return at::_ops::fix_::call(const_cast(*this)); +} + +// aten::type_as(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::type_as(const at::Tensor & other) const { + return at::_ops::type_as::call(const_cast(*this), other); +} + +// aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a) +inline at::Tensor Tensor::unsqueeze(int64_t dim) const { + return at::_ops::unsqueeze::call(const_cast(*this), dim); +} + +// aten::unsqueeze_(Tensor(a!) self, int dim) -> Tensor(a!) +inline at::Tensor & Tensor::unsqueeze_(int64_t dim) const { + return at::_ops::unsqueeze_::call(const_cast(*this), dim); +} + +// aten::var(Tensor self, bool unbiased=True) -> Tensor +inline at::Tensor Tensor::var(bool unbiased) const { + return at::_ops::var::call(const_cast(*this), unbiased); +} + +// aten::var.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::var(at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) const { + return at::_ops::var_dim::call(const_cast(*this), dim, unbiased, keepdim); +} + +// aten::var.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::var(at::OptionalIntArrayRef dim, const ::std::optional & correction, bool keepdim) const { + return at::_ops::var_correction::call(const_cast(*this), dim, correction, keepdim); +} + +// aten::var.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::var(at::DimnameList dim, bool unbiased, bool keepdim) const { + return at::_ops::var_names_dim::call(const_cast(*this), dim, unbiased, keepdim); +} + +// aten::var.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::var(at::DimnameList dim, const ::std::optional & correction, bool keepdim) const { + return at::_ops::var_correction_names::call(const_cast(*this), dim, correction, keepdim); +} + +// aten::view_as(Tensor(a) self, Tensor other) -> Tensor(a) +inline at::Tensor Tensor::view_as(const at::Tensor & other) const { + return at::_ops::view_as::call(const_cast(*this), other); +} + +// aten::where.self(Tensor condition, Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::where(const at::Tensor & condition, const at::Tensor & other) const { + return at::_ops::where_self::call(condition, const_cast(*this), other); +} + +// aten::where.ScalarOther(Tensor condition, Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::where(const at::Tensor & condition, const at::Scalar & other) const { + return at::_ops::where_ScalarOther::call(condition, const_cast(*this), other); +} + +// aten::norm.ScalarOpt_dtype(Tensor self, Scalar? p, *, ScalarType dtype) -> Tensor +inline at::Tensor Tensor::norm(const ::std::optional & p, at::ScalarType dtype) const { + return at::_ops::norm_ScalarOpt_dtype::call(const_cast(*this), p, dtype); +} + +// aten::norm.Scalar(Tensor self, Scalar p=2) -> Tensor +inline at::Tensor Tensor::norm(const at::Scalar & p) const { + return at::_ops::norm_Scalar::call(const_cast(*this), p); +} + +// aten::norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor +inline at::Tensor Tensor::norm(const ::std::optional & p, at::IntArrayRef dim, bool keepdim, at::ScalarType dtype) const { + return at::_ops::norm_ScalarOpt_dim_dtype::call(const_cast(*this), p, dim, keepdim, dtype); +} + +// aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::norm(const ::std::optional & p, at::IntArrayRef dim, bool keepdim) const { + return at::_ops::norm_ScalarOpt_dim::call(const_cast(*this), p, dim, keepdim); +} + +// aten::norm.names_ScalarOpt_dim_dtype(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor +inline at::Tensor Tensor::norm(const ::std::optional & p, at::DimnameList dim, bool keepdim, at::ScalarType dtype) const { + return at::_ops::norm_names_ScalarOpt_dim_dtype::call(const_cast(*this), p, dim, keepdim, dtype); +} + +// aten::norm.names_ScalarOpt_dim(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::norm(const ::std::optional & p, at::DimnameList dim, bool keepdim) const { + return at::_ops::norm_names_ScalarOpt_dim::call(const_cast(*this), p, dim, keepdim); +} + +// aten::frexp.Tensor(Tensor self) -> (Tensor mantissa, Tensor exponent) +inline ::std::tuple Tensor::frexp() const { + return at::_ops::frexp_Tensor::call(const_cast(*this)); +} + +// aten::clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor +inline at::Tensor Tensor::clone(::std::optional memory_format) const { + return at::_ops::clone::call(const_cast(*this), memory_format); +} + +// aten::positive(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::positive() const { + return at::_ops::positive::call(const_cast(*this)); +} + +// aten::resize_as_(Tensor(a!) self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor(a!) +inline const at::Tensor & Tensor::resize_as_(const at::Tensor & the_template, ::std::optional memory_format) const { + return at::_ops::resize_as_::call(const_cast(*this), the_template, memory_format); +} + +// aten::resize_as_sparse_(Tensor(a!) self, Tensor the_template) -> Tensor(a!) +inline const at::Tensor & Tensor::resize_as_sparse_(const at::Tensor & the_template) const { + return at::_ops::resize_as_sparse_::call(const_cast(*this), the_template); +} + +// aten::zero_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::zero_() const { + return at::_ops::zero_::call(const_cast(*this)); +} + +// aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor +inline at::Tensor Tensor::sub(const at::Tensor & other, const at::Scalar & alpha) const { + return at::_ops::sub_Tensor::call(const_cast(*this), other, alpha); +} + +// aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) +inline at::Tensor & Tensor::sub_(const at::Tensor & other, const at::Scalar & alpha) const { + return at::_ops::sub__Tensor::call(const_cast(*this), other, alpha); +} + +// aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor +inline at::Tensor Tensor::sub(const at::Scalar & other, const at::Scalar & alpha) const { + return at::_ops::sub_Scalar::call(const_cast(*this), other, alpha); +} + +// aten::sub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!) +inline at::Tensor & Tensor::sub_(const at::Scalar & other, const at::Scalar & alpha) const { + return at::_ops::sub__Scalar::call(const_cast(*this), other, alpha); +} + +// aten::subtract.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor +inline at::Tensor Tensor::subtract(const at::Tensor & other, const at::Scalar & alpha) const { + return at::_ops::subtract_Tensor::call(const_cast(*this), other, alpha); +} + +// aten::subtract_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) +inline at::Tensor & Tensor::subtract_(const at::Tensor & other, const at::Scalar & alpha) const { + return at::_ops::subtract__Tensor::call(const_cast(*this), other, alpha); +} + +// aten::subtract.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor +inline at::Tensor Tensor::subtract(const at::Scalar & other, const at::Scalar & alpha) const { + return at::_ops::subtract_Scalar::call(const_cast(*this), other, alpha); +} + +// aten::subtract_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!) +inline at::Tensor & Tensor::subtract_(const at::Scalar & other, const at::Scalar & alpha) const { + return at::_ops::subtract__Scalar::call(const_cast(*this), other, alpha); +} + +// aten::heaviside(Tensor self, Tensor values) -> Tensor +inline at::Tensor Tensor::heaviside(const at::Tensor & values) const { + return at::_ops::heaviside::call(const_cast(*this), values); +} + +// aten::heaviside_(Tensor(a!) self, Tensor values) -> Tensor(a!) +inline at::Tensor & Tensor::heaviside_(const at::Tensor & values) const { + return at::_ops::heaviside_::call(const_cast(*this), values); +} + +// aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor +inline at::Tensor Tensor::addmm(const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha) const { + return at::_ops::addmm::call(const_cast(*this), mat1, mat2, beta, alpha); +} + +// aten::addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) +inline at::Tensor & Tensor::addmm_(const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha) const { + return at::_ops::addmm_::call(const_cast(*this), mat1, mat2, beta, alpha); +} + +// aten::_addmm_activation(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False) -> Tensor +inline at::Tensor Tensor::_addmm_activation(const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, bool use_gelu) const { + return at::_ops::_addmm_activation::call(const_cast(*this), mat1, mat2, beta, alpha, use_gelu); +} + +// aten::sparse_resize_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!) +inline const at::Tensor & Tensor::sparse_resize_(at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) const { + return at::_ops::sparse_resize_::call(const_cast(*this), size, sparse_dim, dense_dim); +} + +// aten::sparse_resize_and_clear_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!) +inline const at::Tensor & Tensor::sparse_resize_and_clear_(at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) const { + return at::_ops::sparse_resize_and_clear_::call(const_cast(*this), size, sparse_dim, dense_dim); +} + +// aten::sparse_mask(Tensor self, Tensor mask) -> Tensor +inline at::Tensor Tensor::sparse_mask(const at::Tensor & mask) const { + return at::_ops::sparse_mask::call(const_cast(*this), mask); +} + +// aten::_sparse_mask_projection(Tensor self, Tensor mask, bool accumulate_matches=False) -> Tensor +inline at::Tensor Tensor::_sparse_mask_projection(const at::Tensor & mask, bool accumulate_matches) const { + return at::_ops::_sparse_mask_projection::call(const_cast(*this), mask, accumulate_matches); +} + +// aten::to_dense(Tensor self, ScalarType? dtype=None, *, bool? masked_grad=None) -> Tensor +inline at::Tensor Tensor::to_dense(::std::optional dtype, ::std::optional masked_grad) const { + return at::_ops::to_dense::call(const_cast(*this), dtype, masked_grad); +} + +// aten::_to_dense(Tensor self, ScalarType? dtype=None, bool? masked_grad=None) -> Tensor +inline at::Tensor Tensor::_to_dense(::std::optional dtype, ::std::optional masked_grad) const { + return at::_ops::_to_dense::call(const_cast(*this), dtype, masked_grad); +} + +// aten::sparse_dim(Tensor self) -> int +inline int64_t Tensor::sparse_dim() const { + return at::_ops::sparse_dim::call(const_cast(*this)); +} + +// aten::_dimI(Tensor self) -> int +inline int64_t Tensor::_dimI() const { + return at::_ops::_dimI::call(const_cast(*this)); +} + +// aten::dense_dim(Tensor self) -> int +inline int64_t Tensor::dense_dim() const { + return at::_ops::dense_dim::call(const_cast(*this)); +} + +// aten::_dimV(Tensor self) -> int +inline int64_t Tensor::_dimV() const { + return at::_ops::_dimV::call(const_cast(*this)); +} + +// aten::_nnz(Tensor self) -> int +inline int64_t Tensor::_nnz() const { + return at::_ops::_nnz::call(const_cast(*this)); +} + +// aten::coalesce(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::coalesce() const { + return at::_ops::coalesce::call(const_cast(*this)); +} + +// aten::is_coalesced(Tensor self) -> bool +inline bool Tensor::is_coalesced() const { + return at::_ops::is_coalesced::call(const_cast(*this)); +} + +// aten::_indices(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::_indices() const { + return at::_ops::_indices::call(const_cast(*this)); +} + +// aten::_values(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::_values() const { + return at::_ops::_values::call(const_cast(*this)); +} + +// aten::_coalesced_(Tensor(a!) self, bool coalesced) -> Tensor(a!) +inline at::Tensor & Tensor::_coalesced_(bool coalesced) const { + return at::_ops::_coalesced_::call(const_cast(*this), coalesced); +} + +// aten::indices(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::indices() const { + return at::_ops::indices::call(const_cast(*this)); +} + +// aten::values(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::values() const { + return at::_ops::values::call(const_cast(*this)); +} + +// aten::crow_indices(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::crow_indices() const { + return at::_ops::crow_indices::call(const_cast(*this)); +} + +// aten::col_indices(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::col_indices() const { + return at::_ops::col_indices::call(const_cast(*this)); +} + +// aten::ccol_indices(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::ccol_indices() const { + return at::_ops::ccol_indices::call(const_cast(*this)); +} + +// aten::row_indices(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::row_indices() const { + return at::_ops::row_indices::call(const_cast(*this)); +} + +// aten::unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[] +inline ::std::vector Tensor::unbind(int64_t dim) const { + return at::_ops::unbind_int::call(const_cast(*this), dim); +} + +// aten::unbind.Dimname(Tensor(a -> *) self, Dimname dim) -> Tensor(a)[] +inline ::std::vector Tensor::unbind(at::Dimname dim) const { + return at::_ops::unbind_Dimname::call(const_cast(*this), dim); +} + +// aten::to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor +inline at::Tensor Tensor::to_sparse(int64_t sparse_dim) const { + return at::_ops::to_sparse_sparse_dim::call(const_cast(*this), sparse_dim); +} + +// aten::_to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor +inline at::Tensor Tensor::_to_sparse(int64_t sparse_dim) const { + return at::_ops::_to_sparse_sparse_dim::call(const_cast(*this), sparse_dim); +} + +// aten::to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor +inline at::Tensor Tensor::to_sparse(::std::optional layout, at::OptionalIntArrayRef blocksize, ::std::optional dense_dim) const { + return at::_ops::to_sparse::call(const_cast(*this), layout, blocksize, dense_dim); +} + +// aten::_to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor +inline at::Tensor Tensor::_to_sparse(::std::optional layout, at::OptionalIntArrayRef blocksize, ::std::optional dense_dim) const { + return at::_ops::_to_sparse::call(const_cast(*this), layout, blocksize, dense_dim); +} + +// aten::to_sparse_csr(Tensor self, int? dense_dim=None) -> Tensor +inline at::Tensor Tensor::to_sparse_csr(::std::optional dense_dim) const { + return at::_ops::to_sparse_csr::call(const_cast(*this), dense_dim); +} + +// aten::_to_sparse_csr(Tensor self, int? dense_dim=None) -> Tensor +inline at::Tensor Tensor::_to_sparse_csr(::std::optional dense_dim) const { + return at::_ops::_to_sparse_csr::call(const_cast(*this), dense_dim); +} + +// aten::to_sparse_csc(Tensor self, int? dense_dim=None) -> Tensor +inline at::Tensor Tensor::to_sparse_csc(::std::optional dense_dim) const { + return at::_ops::to_sparse_csc::call(const_cast(*this), dense_dim); +} + +// aten::_to_sparse_csc(Tensor self, int? dense_dim=None) -> Tensor +inline at::Tensor Tensor::_to_sparse_csc(::std::optional dense_dim) const { + return at::_ops::_to_sparse_csc::call(const_cast(*this), dense_dim); +} + +// aten::to_sparse_bsr(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor +inline at::Tensor Tensor::to_sparse_bsr(at::IntArrayRef blocksize, ::std::optional dense_dim) const { + return at::_ops::to_sparse_bsr::call(const_cast(*this), blocksize, dense_dim); +} + +// aten::_to_sparse_bsr(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor +inline at::Tensor Tensor::_to_sparse_bsr(at::IntArrayRef blocksize, ::std::optional dense_dim) const { + return at::_ops::_to_sparse_bsr::call(const_cast(*this), blocksize, dense_dim); +} + +// aten::to_sparse_bsc(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor +inline at::Tensor Tensor::to_sparse_bsc(at::IntArrayRef blocksize, ::std::optional dense_dim) const { + return at::_ops::to_sparse_bsc::call(const_cast(*this), blocksize, dense_dim); +} + +// aten::_to_sparse_bsc(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor +inline at::Tensor Tensor::_to_sparse_bsc(at::IntArrayRef blocksize, ::std::optional dense_dim) const { + return at::_ops::_to_sparse_bsc::call(const_cast(*this), blocksize, dense_dim); +} + +// aten::to_mkldnn(Tensor self, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::to_mkldnn(::std::optional dtype) const { + return at::_ops::to_mkldnn::call(const_cast(*this), dtype); +} + +// aten::dequantize.self(Tensor self) -> Tensor +inline at::Tensor Tensor::dequantize() const { + return at::_ops::dequantize_self::call(const_cast(*this)); +} + +// aten::q_scale(Tensor self) -> float +inline double Tensor::q_scale() const { + return at::_ops::q_scale::call(const_cast(*this)); +} + +// aten::q_zero_point(Tensor self) -> int +inline int64_t Tensor::q_zero_point() const { + return at::_ops::q_zero_point::call(const_cast(*this)); +} + +// aten::q_per_channel_scales(Tensor self) -> Tensor +inline at::Tensor Tensor::q_per_channel_scales() const { + return at::_ops::q_per_channel_scales::call(const_cast(*this)); +} + +// aten::q_per_channel_zero_points(Tensor self) -> Tensor +inline at::Tensor Tensor::q_per_channel_zero_points() const { + return at::_ops::q_per_channel_zero_points::call(const_cast(*this)); +} + +// aten::q_per_channel_axis(Tensor self) -> int +inline int64_t Tensor::q_per_channel_axis() const { + return at::_ops::q_per_channel_axis::call(const_cast(*this)); +} + +// aten::int_repr(Tensor self) -> Tensor +inline at::Tensor Tensor::int_repr() const { + return at::_ops::int_repr::call(const_cast(*this)); +} + +// aten::qscheme(Tensor self) -> QScheme +inline at::QScheme Tensor::qscheme() const { + return at::_ops::qscheme::call(const_cast(*this)); +} + +// aten::_autocast_to_reduced_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled, ScalarType cuda_dtype, ScalarType cpu_dtype) -> Tensor(a) +inline at::Tensor Tensor::_autocast_to_reduced_precision(bool cuda_enabled, bool cpu_enabled, at::ScalarType cuda_dtype, at::ScalarType cpu_dtype) const { + return at::_ops::_autocast_to_reduced_precision::call(const_cast(*this), cuda_enabled, cpu_enabled, cuda_dtype, cpu_dtype); +} + +// aten::_autocast_to_full_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled) -> Tensor(a) +inline at::Tensor Tensor::_autocast_to_full_precision(bool cuda_enabled, bool cpu_enabled) const { + return at::_ops::_autocast_to_full_precision::call(const_cast(*this), cuda_enabled, cpu_enabled); +} + +// aten::to.dtype_layout(Tensor(a) self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a) +inline at::Tensor Tensor::to(at::TensorOptions options, bool non_blocking, bool copy, ::std::optional memory_format) const { + return at::_ops::to_dtype_layout::call(const_cast(*this), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), non_blocking, copy, c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); +} + +// aten::to.dtype_layout(Tensor(a) self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a) +inline at::Tensor Tensor::to(::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, bool non_blocking, bool copy, ::std::optional memory_format) const { + return at::_ops::to_dtype_layout::call(const_cast(*this), dtype, layout, device, pin_memory, non_blocking, copy, memory_format); +} + +// aten::to.device(Tensor(a) self, Device device, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a) +inline at::Tensor Tensor::to(at::Device device, at::ScalarType dtype, bool non_blocking, bool copy, ::std::optional memory_format) const { + return at::_ops::to_device::call(const_cast(*this), device, dtype, non_blocking, copy, memory_format); +} + +// aten::to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a) +inline at::Tensor Tensor::to(at::ScalarType dtype, bool non_blocking, bool copy, ::std::optional memory_format) const { + return at::_ops::to_dtype::call(const_cast(*this), dtype, non_blocking, copy, memory_format); +} + +// aten::to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a) +inline at::Tensor Tensor::to(const at::Tensor & other, bool non_blocking, bool copy, ::std::optional memory_format) const { + return at::_ops::to_other::call(const_cast(*this), other, non_blocking, copy, memory_format); +} + +// aten::item(Tensor self) -> Scalar +inline at::Scalar Tensor::item() const { + return at::_ops::item::call(const_cast(*this)); +} + +// aten::set_.source_Storage(Tensor(a!) self, Storage source) -> Tensor(a!) +inline at::Tensor & Tensor::set_(at::Storage source) const { + return at::_ops::set__source_Storage::call(const_cast(*this), source); +} + +// aten::set_.source_Storage_storage_offset(Tensor(a!) self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor(a!) +inline at::Tensor & Tensor::set_(at::Storage source, int64_t storage_offset, at::IntArrayRef size, at::IntArrayRef stride) const { + return at::_ops::set__source_Storage_storage_offset::call(const_cast(*this), source, storage_offset, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride)); +} + +// aten::set_.source_Storage_storage_offset(Tensor(a!) self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor(a!) +inline at::Tensor & Tensor::set__symint(at::Storage source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride) const { + return at::_ops::set__source_Storage_storage_offset::call(const_cast(*this), source, storage_offset, size, stride); +} + +// aten::set_.source_Tensor_storage_offset(Tensor(a!) self, Tensor source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor(a!) +inline at::Tensor & Tensor::set_(const at::Tensor & source, int64_t storage_offset, at::IntArrayRef size, at::IntArrayRef stride) const { + return at::_ops::set__source_Tensor_storage_offset::call(const_cast(*this), source, storage_offset, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride)); +} + +// aten::set_.source_Tensor_storage_offset(Tensor(a!) self, Tensor source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor(a!) +inline at::Tensor & Tensor::set__symint(const at::Tensor & source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride) const { + return at::_ops::set__source_Tensor_storage_offset::call(const_cast(*this), source, storage_offset, size, stride); +} + +// aten::set_.source_Tensor(Tensor(a!) self, Tensor source) -> Tensor(a!) +inline at::Tensor & Tensor::set_(const at::Tensor & source) const { + return at::_ops::set__source_Tensor::call(const_cast(*this), source); +} + +// aten::set_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::set_() const { + return at::_ops::set_::call(const_cast(*this)); +} + +// aten::is_set_to(Tensor self, Tensor tensor) -> bool +inline bool Tensor::is_set_to(const at::Tensor & tensor) const { + return at::_ops::is_set_to::call(const_cast(*this), tensor); +} + +// aten::masked_fill_.Scalar(Tensor(a!) self, Tensor mask, Scalar value) -> Tensor(a!) +inline at::Tensor & Tensor::masked_fill_(const at::Tensor & mask, const at::Scalar & value) const { + return at::_ops::masked_fill__Scalar::call(const_cast(*this), mask, value); +} + +// aten::masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor +inline at::Tensor Tensor::masked_fill(const at::Tensor & mask, const at::Scalar & value) const { + return at::_ops::masked_fill_Scalar::call(const_cast(*this), mask, value); +} + +// aten::masked_fill_.Tensor(Tensor(a!) self, Tensor mask, Tensor value) -> Tensor(a!) +inline at::Tensor & Tensor::masked_fill_(const at::Tensor & mask, const at::Tensor & value) const { + return at::_ops::masked_fill__Tensor::call(const_cast(*this), mask, value); +} + +// aten::masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor +inline at::Tensor Tensor::masked_fill(const at::Tensor & mask, const at::Tensor & value) const { + return at::_ops::masked_fill_Tensor::call(const_cast(*this), mask, value); +} + +// aten::masked_scatter_(Tensor(a!) self, Tensor mask, Tensor source) -> Tensor(a!) +inline at::Tensor & Tensor::masked_scatter_(const at::Tensor & mask, const at::Tensor & source) const { + return at::_ops::masked_scatter_::call(const_cast(*this), mask, source); +} + +// aten::masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor +inline at::Tensor Tensor::masked_scatter(const at::Tensor & mask, const at::Tensor & source) const { + return at::_ops::masked_scatter::call(const_cast(*this), mask, source); +} + +// aten::view(Tensor(a) self, SymInt[] size) -> Tensor(a) +inline at::Tensor Tensor::view(at::IntArrayRef size) const { + return at::_ops::view::call(const_cast(*this), c10::fromIntArrayRefSlow(size)); +} + +// aten::view(Tensor(a) self, SymInt[] size) -> Tensor(a) +inline at::Tensor Tensor::view_symint(c10::SymIntArrayRef size) const { + return at::_ops::view::call(const_cast(*this), size); +} + +// aten::view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a) +inline at::Tensor Tensor::view(at::ScalarType dtype) const { + return at::_ops::view_dtype::call(const_cast(*this), dtype); +} + +// aten::put_(Tensor(a!) self, Tensor index, Tensor source, bool accumulate=False) -> Tensor(a!) +inline at::Tensor & Tensor::put_(const at::Tensor & index, const at::Tensor & source, bool accumulate) const { + return at::_ops::put_::call(const_cast(*this), index, source, accumulate); +} + +// aten::put(Tensor self, Tensor index, Tensor source, bool accumulate=False) -> Tensor +inline at::Tensor Tensor::put(const at::Tensor & index, const at::Tensor & source, bool accumulate) const { + return at::_ops::put::call(const_cast(*this), index, source, accumulate); +} + +// aten::index_add_(Tensor(a!) self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor(a!) +inline at::Tensor & Tensor::index_add_(int64_t dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha) const { + return at::_ops::index_add_::call(const_cast(*this), dim, index, source, alpha); +} + +// aten::index_add(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor +inline at::Tensor Tensor::index_add(int64_t dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha) const { + return at::_ops::index_add::call(const_cast(*this), dim, index, source, alpha); +} + +// aten::index_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor +inline at::Tensor Tensor::index_add(at::Dimname dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha) const { + return at::_ops::index_add_dimname::call(const_cast(*this), dim, index, source, alpha); +} + +// aten::index_reduce_(Tensor(a!) self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True) -> Tensor(a!) +inline at::Tensor & Tensor::index_reduce_(int64_t dim, const at::Tensor & index, const at::Tensor & source, c10::string_view reduce, bool include_self) const { + return at::_ops::index_reduce_::call(const_cast(*this), dim, index, source, reduce, include_self); +} + +// aten::index_reduce(Tensor self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True) -> Tensor +inline at::Tensor Tensor::index_reduce(int64_t dim, const at::Tensor & index, const at::Tensor & source, c10::string_view reduce, bool include_self) const { + return at::_ops::index_reduce::call(const_cast(*this), dim, index, source, reduce, include_self); +} + +// aten::index_fill_.int_Scalar(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!) +inline at::Tensor & Tensor::index_fill_(int64_t dim, const at::Tensor & index, const at::Scalar & value) const { + return at::_ops::index_fill__int_Scalar::call(const_cast(*this), dim, index, value); +} + +// aten::index_fill.int_Scalar(Tensor self, int dim, Tensor index, Scalar value) -> Tensor +inline at::Tensor Tensor::index_fill(int64_t dim, const at::Tensor & index, const at::Scalar & value) const { + return at::_ops::index_fill_int_Scalar::call(const_cast(*this), dim, index, value); +} + +// aten::index_fill_.int_Tensor(Tensor(a!) self, int dim, Tensor index, Tensor value) -> Tensor(a!) +inline at::Tensor & Tensor::index_fill_(int64_t dim, const at::Tensor & index, const at::Tensor & value) const { + return at::_ops::index_fill__int_Tensor::call(const_cast(*this), dim, index, value); +} + +// aten::index_fill.int_Tensor(Tensor self, int dim, Tensor index, Tensor value) -> Tensor +inline at::Tensor Tensor::index_fill(int64_t dim, const at::Tensor & index, const at::Tensor & value) const { + return at::_ops::index_fill_int_Tensor::call(const_cast(*this), dim, index, value); +} + +// aten::index_fill_.Dimname_Scalar(Tensor(a!) self, Dimname dim, Tensor index, Scalar value) -> Tensor(a!) +inline at::Tensor & Tensor::index_fill_(at::Dimname dim, const at::Tensor & index, const at::Scalar & value) const { + return at::_ops::index_fill__Dimname_Scalar::call(const_cast(*this), dim, index, value); +} + +// aten::index_fill_.Dimname_Tensor(Tensor(a!) self, Dimname dim, Tensor index, Tensor value) -> Tensor(a!) +inline at::Tensor & Tensor::index_fill_(at::Dimname dim, const at::Tensor & index, const at::Tensor & value) const { + return at::_ops::index_fill__Dimname_Tensor::call(const_cast(*this), dim, index, value); +} + +// aten::index_fill.Dimname_Scalar(Tensor self, Dimname dim, Tensor index, Scalar value) -> Tensor +inline at::Tensor Tensor::index_fill(at::Dimname dim, const at::Tensor & index, const at::Scalar & value) const { + return at::_ops::index_fill_Dimname_Scalar::call(const_cast(*this), dim, index, value); +} + +// aten::index_fill.Dimname_Tensor(Tensor self, Dimname dim, Tensor index, Tensor value) -> Tensor +inline at::Tensor Tensor::index_fill(at::Dimname dim, const at::Tensor & index, const at::Tensor & value) const { + return at::_ops::index_fill_Dimname_Tensor::call(const_cast(*this), dim, index, value); +} + +// aten::scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor +inline at::Tensor Tensor::scatter(int64_t dim, const at::Tensor & index, const at::Tensor & src) const { + return at::_ops::scatter_src::call(const_cast(*this), dim, index, src); +} + +// aten::scatter_.src(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!) +inline at::Tensor & Tensor::scatter_(int64_t dim, const at::Tensor & index, const at::Tensor & src) const { + return at::_ops::scatter__src::call(const_cast(*this), dim, index, src); +} + +// aten::scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor +inline at::Tensor Tensor::scatter(int64_t dim, const at::Tensor & index, const at::Scalar & value) const { + return at::_ops::scatter_value::call(const_cast(*this), dim, index, value); +} + +// aten::scatter_.value(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!) +inline at::Tensor & Tensor::scatter_(int64_t dim, const at::Tensor & index, const at::Scalar & value) const { + return at::_ops::scatter__value::call(const_cast(*this), dim, index, value); +} + +// aten::scatter.reduce(Tensor self, int dim, Tensor index, Tensor src, *, str reduce) -> Tensor +inline at::Tensor Tensor::scatter(int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce) const { + return at::_ops::scatter_reduce::call(const_cast(*this), dim, index, src, reduce); +} + +// aten::scatter_.reduce(Tensor(a!) self, int dim, Tensor index, Tensor src, *, str reduce) -> Tensor(a!) +inline at::Tensor & Tensor::scatter_(int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce) const { + return at::_ops::scatter__reduce::call(const_cast(*this), dim, index, src, reduce); +} + +// aten::scatter.value_reduce(Tensor self, int dim, Tensor index, Scalar value, *, str reduce) -> Tensor +inline at::Tensor Tensor::scatter(int64_t dim, const at::Tensor & index, const at::Scalar & value, c10::string_view reduce) const { + return at::_ops::scatter_value_reduce::call(const_cast(*this), dim, index, value, reduce); +} + +// aten::scatter_.value_reduce(Tensor(a!) self, int dim, Tensor index, Scalar value, *, str reduce) -> Tensor(a!) +inline at::Tensor & Tensor::scatter_(int64_t dim, const at::Tensor & index, const at::Scalar & value, c10::string_view reduce) const { + return at::_ops::scatter__value_reduce::call(const_cast(*this), dim, index, value, reduce); +} + +// aten::scatter.dimname_src(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor +inline at::Tensor Tensor::scatter(at::Dimname dim, const at::Tensor & index, const at::Tensor & src) const { + return at::_ops::scatter_dimname_src::call(const_cast(*this), dim, index, src); +} + +// aten::scatter.dimname_value(Tensor self, Dimname dim, Tensor index, Scalar value) -> Tensor +inline at::Tensor Tensor::scatter(at::Dimname dim, const at::Tensor & index, const at::Scalar & value) const { + return at::_ops::scatter_dimname_value::call(const_cast(*this), dim, index, value); +} + +// aten::scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor +inline at::Tensor Tensor::scatter_add(int64_t dim, const at::Tensor & index, const at::Tensor & src) const { + return at::_ops::scatter_add::call(const_cast(*this), dim, index, src); +} + +// aten::scatter_add_(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!) +inline at::Tensor & Tensor::scatter_add_(int64_t dim, const at::Tensor & index, const at::Tensor & src) const { + return at::_ops::scatter_add_::call(const_cast(*this), dim, index, src); +} + +// aten::scatter_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor +inline at::Tensor Tensor::scatter_add(at::Dimname dim, const at::Tensor & index, const at::Tensor & src) const { + return at::_ops::scatter_add_dimname::call(const_cast(*this), dim, index, src); +} + +// aten::scatter_reduce.two(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor +inline at::Tensor Tensor::scatter_reduce(int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self) const { + return at::_ops::scatter_reduce_two::call(const_cast(*this), dim, index, src, reduce, include_self); +} + +// aten::scatter_reduce_.two(Tensor(a!) self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor(a!) +inline at::Tensor & Tensor::scatter_reduce_(int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self) const { + return at::_ops::scatter_reduce__two::call(const_cast(*this), dim, index, src, reduce, include_self); +} + +// aten::eq_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::eq_(const at::Scalar & other) const { + return at::_ops::eq__Scalar::call(const_cast(*this), other); +} + +// aten::eq_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::eq_(const at::Tensor & other) const { + return at::_ops::eq__Tensor::call(const_cast(*this), other); +} + +// aten::bitwise_and.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::bitwise_and(const at::Scalar & other) const { + return at::_ops::bitwise_and_Scalar::call(const_cast(*this), other); +} + +// aten::bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::bitwise_and(const at::Tensor & other) const { + return at::_ops::bitwise_and_Tensor::call(const_cast(*this), other); +} + +// aten::bitwise_and_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::bitwise_and_(const at::Scalar & other) const { + return at::_ops::bitwise_and__Scalar::call(const_cast(*this), other); +} + +// aten::bitwise_and_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::bitwise_and_(const at::Tensor & other) const { + return at::_ops::bitwise_and__Tensor::call(const_cast(*this), other); +} + +// aten::__and__.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::__and__(const at::Scalar & other) const { + return at::_ops::__and___Scalar::call(const_cast(*this), other); +} + +// aten::__and__.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::__and__(const at::Tensor & other) const { + return at::_ops::__and___Tensor::call(const_cast(*this), other); +} + +// aten::__iand__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::__iand__(const at::Scalar & other) const { + return at::_ops::__iand___Scalar::call(const_cast(*this), other); +} + +// aten::__iand__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::__iand__(const at::Tensor & other) const { + return at::_ops::__iand___Tensor::call(const_cast(*this), other); +} + +// aten::bitwise_or.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::bitwise_or(const at::Scalar & other) const { + return at::_ops::bitwise_or_Scalar::call(const_cast(*this), other); +} + +// aten::bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::bitwise_or(const at::Tensor & other) const { + return at::_ops::bitwise_or_Tensor::call(const_cast(*this), other); +} + +// aten::bitwise_or_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::bitwise_or_(const at::Scalar & other) const { + return at::_ops::bitwise_or__Scalar::call(const_cast(*this), other); +} + +// aten::bitwise_or_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::bitwise_or_(const at::Tensor & other) const { + return at::_ops::bitwise_or__Tensor::call(const_cast(*this), other); +} + +// aten::__or__.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::__or__(const at::Scalar & other) const { + return at::_ops::__or___Scalar::call(const_cast(*this), other); +} + +// aten::__or__.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::__or__(const at::Tensor & other) const { + return at::_ops::__or___Tensor::call(const_cast(*this), other); +} + +// aten::__ior__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::__ior__(const at::Scalar & other) const { + return at::_ops::__ior___Scalar::call(const_cast(*this), other); +} + +// aten::__ior__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::__ior__(const at::Tensor & other) const { + return at::_ops::__ior___Tensor::call(const_cast(*this), other); +} + +// aten::bitwise_xor.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::bitwise_xor(const at::Scalar & other) const { + return at::_ops::bitwise_xor_Scalar::call(const_cast(*this), other); +} + +// aten::bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::bitwise_xor(const at::Tensor & other) const { + return at::_ops::bitwise_xor_Tensor::call(const_cast(*this), other); +} + +// aten::bitwise_xor_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::bitwise_xor_(const at::Scalar & other) const { + return at::_ops::bitwise_xor__Scalar::call(const_cast(*this), other); +} + +// aten::bitwise_xor_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::bitwise_xor_(const at::Tensor & other) const { + return at::_ops::bitwise_xor__Tensor::call(const_cast(*this), other); +} + +// aten::__xor__.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::__xor__(const at::Scalar & other) const { + return at::_ops::__xor___Scalar::call(const_cast(*this), other); +} + +// aten::__xor__.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::__xor__(const at::Tensor & other) const { + return at::_ops::__xor___Tensor::call(const_cast(*this), other); +} + +// aten::__ixor__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::__ixor__(const at::Scalar & other) const { + return at::_ops::__ixor___Scalar::call(const_cast(*this), other); +} + +// aten::__ixor__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::__ixor__(const at::Tensor & other) const { + return at::_ops::__ixor___Tensor::call(const_cast(*this), other); +} + +// aten::__lshift__.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::__lshift__(const at::Scalar & other) const { + return at::_ops::__lshift___Scalar::call(const_cast(*this), other); +} + +// aten::__lshift__.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::__lshift__(const at::Tensor & other) const { + return at::_ops::__lshift___Tensor::call(const_cast(*this), other); +} + +// aten::__ilshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::__ilshift__(const at::Scalar & other) const { + return at::_ops::__ilshift___Scalar::call(const_cast(*this), other); +} + +// aten::__ilshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::__ilshift__(const at::Tensor & other) const { + return at::_ops::__ilshift___Tensor::call(const_cast(*this), other); +} + +// aten::bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::bitwise_left_shift(const at::Tensor & other) const { + return at::_ops::bitwise_left_shift_Tensor::call(const_cast(*this), other); +} + +// aten::bitwise_left_shift_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::bitwise_left_shift_(const at::Tensor & other) const { + return at::_ops::bitwise_left_shift__Tensor::call(const_cast(*this), other); +} + +// aten::bitwise_left_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::bitwise_left_shift(const at::Scalar & other) const { + return at::_ops::bitwise_left_shift_Tensor_Scalar::call(const_cast(*this), other); +} + +// aten::bitwise_left_shift_.Tensor_Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::bitwise_left_shift_(const at::Scalar & other) const { + return at::_ops::bitwise_left_shift__Tensor_Scalar::call(const_cast(*this), other); +} + +// aten::__rshift__.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::__rshift__(const at::Scalar & other) const { + return at::_ops::__rshift___Scalar::call(const_cast(*this), other); +} + +// aten::__rshift__.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::__rshift__(const at::Tensor & other) const { + return at::_ops::__rshift___Tensor::call(const_cast(*this), other); +} + +// aten::__irshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::__irshift__(const at::Scalar & other) const { + return at::_ops::__irshift___Scalar::call(const_cast(*this), other); +} + +// aten::__irshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::__irshift__(const at::Tensor & other) const { + return at::_ops::__irshift___Tensor::call(const_cast(*this), other); +} + +// aten::bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::bitwise_right_shift(const at::Tensor & other) const { + return at::_ops::bitwise_right_shift_Tensor::call(const_cast(*this), other); +} + +// aten::bitwise_right_shift_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::bitwise_right_shift_(const at::Tensor & other) const { + return at::_ops::bitwise_right_shift__Tensor::call(const_cast(*this), other); +} + +// aten::bitwise_right_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::bitwise_right_shift(const at::Scalar & other) const { + return at::_ops::bitwise_right_shift_Tensor_Scalar::call(const_cast(*this), other); +} + +// aten::bitwise_right_shift_.Tensor_Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::bitwise_right_shift_(const at::Scalar & other) const { + return at::_ops::bitwise_right_shift__Tensor_Scalar::call(const_cast(*this), other); +} + +// aten::tril_(Tensor(a!) self, int diagonal=0) -> Tensor(a!) +inline at::Tensor & Tensor::tril_(int64_t diagonal) const { + return at::_ops::tril_::call(const_cast(*this), diagonal); +} + +// aten::triu_(Tensor(a!) self, int diagonal=0) -> Tensor(a!) +inline at::Tensor & Tensor::triu_(int64_t diagonal) const { + return at::_ops::triu_::call(const_cast(*this), diagonal); +} + +// aten::digamma_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::digamma_() const { + return at::_ops::digamma_::call(const_cast(*this)); +} + +// aten::lerp_.Scalar(Tensor(a!) self, Tensor end, Scalar weight) -> Tensor(a!) +inline at::Tensor & Tensor::lerp_(const at::Tensor & end, const at::Scalar & weight) const { + return at::_ops::lerp__Scalar::call(const_cast(*this), end, weight); +} + +// aten::lerp_.Tensor(Tensor(a!) self, Tensor end, Tensor weight) -> Tensor(a!) +inline at::Tensor & Tensor::lerp_(const at::Tensor & end, const at::Tensor & weight) const { + return at::_ops::lerp__Tensor::call(const_cast(*this), end, weight); +} + +// aten::addbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) +inline at::Tensor & Tensor::addbmm_(const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha) const { + return at::_ops::addbmm_::call(const_cast(*this), batch1, batch2, beta, alpha); +} + +// aten::addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor +inline at::Tensor Tensor::addbmm(const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha) const { + return at::_ops::addbmm::call(const_cast(*this), batch1, batch2, beta, alpha); +} + +// aten::random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!) +inline at::Tensor & Tensor::random_(int64_t from, ::std::optional to, ::std::optional generator) const { + return at::_ops::random__from::call(const_cast(*this), from, to, generator); +} + +// aten::random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!) +inline at::Tensor & Tensor::random_(int64_t to, ::std::optional generator) const { + return at::_ops::random__to::call(const_cast(*this), to, generator); +} + +// aten::random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!) +inline at::Tensor & Tensor::random_(::std::optional generator) const { + return at::_ops::random_::call(const_cast(*this), generator); +} + +// aten::uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!) +inline at::Tensor & Tensor::uniform_(double from, double to, ::std::optional generator) const { + return at::_ops::uniform_::call(const_cast(*this), from, to, generator); +} + +// aten::cauchy_(Tensor(a!) self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor(a!) +inline at::Tensor & Tensor::cauchy_(double median, double sigma, ::std::optional generator) const { + return at::_ops::cauchy_::call(const_cast(*this), median, sigma, generator); +} + +// aten::log_normal_(Tensor(a!) self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor(a!) +inline at::Tensor & Tensor::log_normal_(double mean, double std, ::std::optional generator) const { + return at::_ops::log_normal_::call(const_cast(*this), mean, std, generator); +} + +// aten::exponential_(Tensor(a!) self, float lambd=1, *, Generator? generator=None) -> Tensor(a!) +inline at::Tensor & Tensor::exponential_(double lambd, ::std::optional generator) const { + return at::_ops::exponential_::call(const_cast(*this), lambd, generator); +} + +// aten::geometric_(Tensor(a!) self, float p, *, Generator? generator=None) -> Tensor(a!) +inline at::Tensor & Tensor::geometric_(double p, ::std::optional generator) const { + return at::_ops::geometric_::call(const_cast(*this), p, generator); +} + +// aten::diag(Tensor self, int diagonal=0) -> Tensor +inline at::Tensor Tensor::diag(int64_t diagonal) const { + return at::_ops::diag::call(const_cast(*this), diagonal); +} + +// aten::cross(Tensor self, Tensor other, int? dim=None) -> Tensor +inline at::Tensor Tensor::cross(const at::Tensor & other, ::std::optional dim) const { + return at::_ops::cross::call(const_cast(*this), other, dim); +} + +// aten::triu(Tensor self, int diagonal=0) -> Tensor +inline at::Tensor Tensor::triu(int64_t diagonal) const { + return at::_ops::triu::call(const_cast(*this), diagonal); +} + +// aten::tril(Tensor self, int diagonal=0) -> Tensor +inline at::Tensor Tensor::tril(int64_t diagonal) const { + return at::_ops::tril::call(const_cast(*this), diagonal); +} + +// aten::trace(Tensor self) -> Tensor +inline at::Tensor Tensor::trace() const { + return at::_ops::trace::call(const_cast(*this)); +} + +// aten::ne.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::ne(const at::Scalar & other) const { + return at::_ops::ne_Scalar::call(const_cast(*this), other); +} + +// aten::ne.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::ne(const at::Tensor & other) const { + return at::_ops::ne_Tensor::call(const_cast(*this), other); +} + +// aten::ne_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::ne_(const at::Scalar & other) const { + return at::_ops::ne__Scalar::call(const_cast(*this), other); +} + +// aten::ne_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::ne_(const at::Tensor & other) const { + return at::_ops::ne__Tensor::call(const_cast(*this), other); +} + +// aten::not_equal.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::not_equal(const at::Scalar & other) const { + return at::_ops::not_equal_Scalar::call(const_cast(*this), other); +} + +// aten::not_equal.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::not_equal(const at::Tensor & other) const { + return at::_ops::not_equal_Tensor::call(const_cast(*this), other); +} + +// aten::not_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::not_equal_(const at::Scalar & other) const { + return at::_ops::not_equal__Scalar::call(const_cast(*this), other); +} + +// aten::not_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::not_equal_(const at::Tensor & other) const { + return at::_ops::not_equal__Tensor::call(const_cast(*this), other); +} + +// aten::eq.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::eq(const at::Scalar & other) const { + return at::_ops::eq_Scalar::call(const_cast(*this), other); +} + +// aten::eq.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::eq(const at::Tensor & other) const { + return at::_ops::eq_Tensor::call(const_cast(*this), other); +} + +// aten::ge.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::ge(const at::Scalar & other) const { + return at::_ops::ge_Scalar::call(const_cast(*this), other); +} + +// aten::ge.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::ge(const at::Tensor & other) const { + return at::_ops::ge_Tensor::call(const_cast(*this), other); +} + +// aten::ge_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::ge_(const at::Scalar & other) const { + return at::_ops::ge__Scalar::call(const_cast(*this), other); +} + +// aten::ge_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::ge_(const at::Tensor & other) const { + return at::_ops::ge__Tensor::call(const_cast(*this), other); +} + +// aten::greater_equal.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::greater_equal(const at::Scalar & other) const { + return at::_ops::greater_equal_Scalar::call(const_cast(*this), other); +} + +// aten::greater_equal.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::greater_equal(const at::Tensor & other) const { + return at::_ops::greater_equal_Tensor::call(const_cast(*this), other); +} + +// aten::greater_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::greater_equal_(const at::Scalar & other) const { + return at::_ops::greater_equal__Scalar::call(const_cast(*this), other); +} + +// aten::greater_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::greater_equal_(const at::Tensor & other) const { + return at::_ops::greater_equal__Tensor::call(const_cast(*this), other); +} + +// aten::le.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::le(const at::Scalar & other) const { + return at::_ops::le_Scalar::call(const_cast(*this), other); +} + +// aten::le.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::le(const at::Tensor & other) const { + return at::_ops::le_Tensor::call(const_cast(*this), other); +} + +// aten::le_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::le_(const at::Scalar & other) const { + return at::_ops::le__Scalar::call(const_cast(*this), other); +} + +// aten::le_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::le_(const at::Tensor & other) const { + return at::_ops::le__Tensor::call(const_cast(*this), other); +} + +// aten::less_equal.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::less_equal(const at::Scalar & other) const { + return at::_ops::less_equal_Scalar::call(const_cast(*this), other); +} + +// aten::less_equal.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::less_equal(const at::Tensor & other) const { + return at::_ops::less_equal_Tensor::call(const_cast(*this), other); +} + +// aten::less_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::less_equal_(const at::Scalar & other) const { + return at::_ops::less_equal__Scalar::call(const_cast(*this), other); +} + +// aten::less_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::less_equal_(const at::Tensor & other) const { + return at::_ops::less_equal__Tensor::call(const_cast(*this), other); +} + +// aten::gt.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::gt(const at::Scalar & other) const { + return at::_ops::gt_Scalar::call(const_cast(*this), other); +} + +// aten::gt.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::gt(const at::Tensor & other) const { + return at::_ops::gt_Tensor::call(const_cast(*this), other); +} + +// aten::gt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::gt_(const at::Scalar & other) const { + return at::_ops::gt__Scalar::call(const_cast(*this), other); +} + +// aten::gt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::gt_(const at::Tensor & other) const { + return at::_ops::gt__Tensor::call(const_cast(*this), other); +} + +// aten::greater.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::greater(const at::Scalar & other) const { + return at::_ops::greater_Scalar::call(const_cast(*this), other); +} + +// aten::greater.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::greater(const at::Tensor & other) const { + return at::_ops::greater_Tensor::call(const_cast(*this), other); +} + +// aten::greater_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::greater_(const at::Scalar & other) const { + return at::_ops::greater__Scalar::call(const_cast(*this), other); +} + +// aten::greater_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::greater_(const at::Tensor & other) const { + return at::_ops::greater__Tensor::call(const_cast(*this), other); +} + +// aten::lt.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::lt(const at::Scalar & other) const { + return at::_ops::lt_Scalar::call(const_cast(*this), other); +} + +// aten::lt.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::lt(const at::Tensor & other) const { + return at::_ops::lt_Tensor::call(const_cast(*this), other); +} + +// aten::lt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::lt_(const at::Scalar & other) const { + return at::_ops::lt__Scalar::call(const_cast(*this), other); +} + +// aten::lt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::lt_(const at::Tensor & other) const { + return at::_ops::lt__Tensor::call(const_cast(*this), other); +} + +// aten::less.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::less(const at::Scalar & other) const { + return at::_ops::less_Scalar::call(const_cast(*this), other); +} + +// aten::less.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::less(const at::Tensor & other) const { + return at::_ops::less_Tensor::call(const_cast(*this), other); +} + +// aten::less_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::less_(const at::Scalar & other) const { + return at::_ops::less__Scalar::call(const_cast(*this), other); +} + +// aten::less_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::less_(const at::Tensor & other) const { + return at::_ops::less__Tensor::call(const_cast(*this), other); +} + +// aten::take(Tensor self, Tensor index) -> Tensor +inline at::Tensor Tensor::take(const at::Tensor & index) const { + return at::_ops::take::call(const_cast(*this), index); +} + +// aten::take_along_dim(Tensor self, Tensor indices, int? dim=None) -> Tensor +inline at::Tensor Tensor::take_along_dim(const at::Tensor & indices, ::std::optional dim) const { + return at::_ops::take_along_dim::call(const_cast(*this), indices, dim); +} + +// aten::index_select(Tensor self, int dim, Tensor index) -> Tensor +inline at::Tensor Tensor::index_select(int64_t dim, const at::Tensor & index) const { + return at::_ops::index_select::call(const_cast(*this), dim, index); +} + +// aten::index_select.dimname(Tensor self, Dimname dim, Tensor index) -> Tensor +inline at::Tensor Tensor::index_select(at::Dimname dim, const at::Tensor & index) const { + return at::_ops::index_select_dimname::call(const_cast(*this), dim, index); +} + +// aten::masked_select(Tensor self, Tensor mask) -> Tensor +inline at::Tensor Tensor::masked_select(const at::Tensor & mask) const { + return at::_ops::masked_select::call(const_cast(*this), mask); +} + +// aten::nonzero(Tensor self) -> Tensor +inline at::Tensor Tensor::nonzero() const { + return at::_ops::nonzero::call(const_cast(*this)); +} + +// aten::nonzero_static(Tensor self, *, SymInt size, int fill_value=-1) -> Tensor +inline at::Tensor Tensor::nonzero_static(int64_t size, int64_t fill_value) const { + return at::_ops::nonzero_static::call(const_cast(*this), size, fill_value); +} + +// aten::nonzero_static(Tensor self, *, SymInt size, int fill_value=-1) -> Tensor +inline at::Tensor Tensor::nonzero_static_symint(c10::SymInt size, int64_t fill_value) const { + return at::_ops::nonzero_static::call(const_cast(*this), size, fill_value); +} + +// aten::nonzero_numpy(Tensor self) -> Tensor[] +inline ::std::vector Tensor::nonzero_numpy() const { + return at::_ops::nonzero_numpy::call(const_cast(*this)); +} + +// aten::argwhere(Tensor self) -> Tensor +inline at::Tensor Tensor::argwhere() const { + return at::_ops::argwhere::call(const_cast(*this)); +} + +// aten::gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor +inline at::Tensor Tensor::gather(int64_t dim, const at::Tensor & index, bool sparse_grad) const { + return at::_ops::gather::call(const_cast(*this), dim, index, sparse_grad); +} + +// aten::gather.dimname(Tensor self, Dimname dim, Tensor index, *, bool sparse_grad=False) -> Tensor +inline at::Tensor Tensor::gather(at::Dimname dim, const at::Tensor & index, bool sparse_grad) const { + return at::_ops::gather_dimname::call(const_cast(*this), dim, index, sparse_grad); +} + +// aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor +inline at::Tensor Tensor::addcmul(const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value) const { + return at::_ops::addcmul::call(const_cast(*this), tensor1, tensor2, value); +} + +// aten::addcmul_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!) +inline at::Tensor & Tensor::addcmul_(const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value) const { + return at::_ops::addcmul_::call(const_cast(*this), tensor1, tensor2, value); +} + +// aten::addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor +inline at::Tensor Tensor::addcdiv(const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value) const { + return at::_ops::addcdiv::call(const_cast(*this), tensor1, tensor2, value); +} + +// aten::addcdiv_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!) +inline at::Tensor & Tensor::addcdiv_(const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value) const { + return at::_ops::addcdiv_::call(const_cast(*this), tensor1, tensor2, value); +} + +// aten::triangular_solve(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False) -> (Tensor solution, Tensor cloned_coefficient) +inline ::std::tuple Tensor::triangular_solve(const at::Tensor & A, bool upper, bool transpose, bool unitriangular) const { + return at::_ops::triangular_solve::call(const_cast(*this), A, upper, transpose, unitriangular); +} + +// aten::svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V) +inline ::std::tuple Tensor::svd(bool some, bool compute_uv) const { + return at::_ops::svd::call(const_cast(*this), some, compute_uv); +} + +// aten::swapaxes(Tensor(a) self, int axis0, int axis1) -> Tensor(a) +inline at::Tensor Tensor::swapaxes(int64_t axis0, int64_t axis1) const { + return at::_ops::swapaxes::call(const_cast(*this), axis0, axis1); +} + +// aten::swapaxes_(Tensor(a!) self, int axis0, int axis1) -> Tensor(a!) +inline at::Tensor & Tensor::swapaxes_(int64_t axis0, int64_t axis1) const { + return at::_ops::swapaxes_::call(const_cast(*this), axis0, axis1); +} + +// aten::swapdims(Tensor(a) self, int dim0, int dim1) -> Tensor(a) +inline at::Tensor Tensor::swapdims(int64_t dim0, int64_t dim1) const { + return at::_ops::swapdims::call(const_cast(*this), dim0, dim1); +} + +// aten::swapdims_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!) +inline at::Tensor & Tensor::swapdims_(int64_t dim0, int64_t dim1) const { + return at::_ops::swapdims_::call(const_cast(*this), dim0, dim1); +} + +// aten::cholesky(Tensor self, bool upper=False) -> Tensor +inline at::Tensor Tensor::cholesky(bool upper) const { + return at::_ops::cholesky::call(const_cast(*this), upper); +} + +// aten::cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor +inline at::Tensor Tensor::cholesky_solve(const at::Tensor & input2, bool upper) const { + return at::_ops::cholesky_solve::call(const_cast(*this), input2, upper); +} + +// aten::cholesky_inverse(Tensor self, bool upper=False) -> Tensor +inline at::Tensor Tensor::cholesky_inverse(bool upper) const { + return at::_ops::cholesky_inverse::call(const_cast(*this), upper); +} + +// aten::qr(Tensor self, bool some=True) -> (Tensor Q, Tensor R) +inline ::std::tuple Tensor::qr(bool some) const { + return at::_ops::qr::call(const_cast(*this), some); +} + +// aten::geqrf(Tensor self) -> (Tensor a, Tensor tau) +inline ::std::tuple Tensor::geqrf() const { + return at::_ops::geqrf::call(const_cast(*this)); +} + +// aten::orgqr(Tensor self, Tensor input2) -> Tensor +inline at::Tensor Tensor::orgqr(const at::Tensor & input2) const { + return at::_ops::orgqr::call(const_cast(*this), input2); +} + +// aten::ormqr(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False) -> Tensor +inline at::Tensor Tensor::ormqr(const at::Tensor & input2, const at::Tensor & input3, bool left, bool transpose) const { + return at::_ops::ormqr::call(const_cast(*this), input2, input3, left, transpose); +} + +// aten::lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor +inline at::Tensor Tensor::lu_solve(const at::Tensor & LU_data, const at::Tensor & LU_pivots) const { + return at::_ops::lu_solve::call(const_cast(*this), LU_data, LU_pivots); +} + +// aten::multinomial(Tensor self, SymInt num_samples, bool replacement=False, *, Generator? generator=None) -> Tensor +inline at::Tensor Tensor::multinomial(int64_t num_samples, bool replacement, ::std::optional generator) const { + return at::_ops::multinomial::call(const_cast(*this), num_samples, replacement, generator); +} + +// aten::multinomial(Tensor self, SymInt num_samples, bool replacement=False, *, Generator? generator=None) -> Tensor +inline at::Tensor Tensor::multinomial_symint(c10::SymInt num_samples, bool replacement, ::std::optional generator) const { + return at::_ops::multinomial::call(const_cast(*this), num_samples, replacement, generator); +} + +// aten::lgamma_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::lgamma_() const { + return at::_ops::lgamma_::call(const_cast(*this)); +} + +// aten::lgamma(Tensor self) -> Tensor +inline at::Tensor Tensor::lgamma() const { + return at::_ops::lgamma::call(const_cast(*this)); +} + +// aten::digamma(Tensor self) -> Tensor +inline at::Tensor Tensor::digamma() const { + return at::_ops::digamma::call(const_cast(*this)); +} + +// aten::polygamma(int n, Tensor self) -> Tensor +inline at::Tensor Tensor::polygamma(int64_t n) const { + return at::_ops::polygamma::call(n, const_cast(*this)); +} + +// aten::polygamma_(Tensor(a!) self, int n) -> Tensor(a!) +inline at::Tensor & Tensor::polygamma_(int64_t n) const { + return at::_ops::polygamma_::call(const_cast(*this), n); +} + +// aten::erfinv(Tensor self) -> Tensor +inline at::Tensor Tensor::erfinv() const { + return at::_ops::erfinv::call(const_cast(*this)); +} + +// aten::erfinv_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::erfinv_() const { + return at::_ops::erfinv_::call(const_cast(*this)); +} + +// aten::i0(Tensor self) -> Tensor +inline at::Tensor Tensor::i0() const { + return at::_ops::i0::call(const_cast(*this)); +} + +// aten::i0_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::i0_() const { + return at::_ops::i0_::call(const_cast(*this)); +} + +// aten::sign(Tensor self) -> Tensor +inline at::Tensor Tensor::sign() const { + return at::_ops::sign::call(const_cast(*this)); +} + +// aten::sign_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::sign_() const { + return at::_ops::sign_::call(const_cast(*this)); +} + +// aten::signbit(Tensor self) -> Tensor +inline at::Tensor Tensor::signbit() const { + return at::_ops::signbit::call(const_cast(*this)); +} + +// aten::dist(Tensor self, Tensor other, Scalar p=2) -> Tensor +inline at::Tensor Tensor::dist(const at::Tensor & other, const at::Scalar & p) const { + return at::_ops::dist::call(const_cast(*this), other, p); +} + +// aten::atan2_(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::atan2_(const at::Tensor & other) const { + return at::_ops::atan2_::call(const_cast(*this), other); +} + +// aten::atan2(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::atan2(const at::Tensor & other) const { + return at::_ops::atan2::call(const_cast(*this), other); +} + +// aten::arctan2(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::arctan2(const at::Tensor & other) const { + return at::_ops::arctan2::call(const_cast(*this), other); +} + +// aten::arctan2_(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::arctan2_(const at::Tensor & other) const { + return at::_ops::arctan2_::call(const_cast(*this), other); +} + +// aten::lerp.Scalar(Tensor self, Tensor end, Scalar weight) -> Tensor +inline at::Tensor Tensor::lerp(const at::Tensor & end, const at::Scalar & weight) const { + return at::_ops::lerp_Scalar::call(const_cast(*this), end, weight); +} + +// aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor +inline at::Tensor Tensor::lerp(const at::Tensor & end, const at::Tensor & weight) const { + return at::_ops::lerp_Tensor::call(const_cast(*this), end, weight); +} + +// aten::histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor +inline at::Tensor Tensor::histc(int64_t bins, const at::Scalar & min, const at::Scalar & max) const { + return at::_ops::histc::call(const_cast(*this), bins, min, max); +} + +// aten::histogram.bins_tensor(Tensor self, Tensor bins, *, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor bin_edges) +inline ::std::tuple Tensor::histogram(const at::Tensor & bins, const ::std::optional & weight, bool density) const { + return at::_ops::histogram_bins_tensor::call(const_cast(*this), bins, weight, density); +} + +// aten::histogram.bin_ct(Tensor self, int bins=100, *, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor bin_edges) +inline ::std::tuple Tensor::histogram(int64_t bins, ::std::optional> range, const ::std::optional & weight, bool density) const { + return at::_ops::histogram_bin_ct::call(const_cast(*this), bins, range, weight, density); +} + +// aten::fmod.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::fmod(const at::Scalar & other) const { + return at::_ops::fmod_Scalar::call(const_cast(*this), other); +} + +// aten::fmod_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::fmod_(const at::Scalar & other) const { + return at::_ops::fmod__Scalar::call(const_cast(*this), other); +} + +// aten::fmod.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::fmod(const at::Tensor & other) const { + return at::_ops::fmod_Tensor::call(const_cast(*this), other); +} + +// aten::fmod_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::fmod_(const at::Tensor & other) const { + return at::_ops::fmod__Tensor::call(const_cast(*this), other); +} + +// aten::hypot(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::hypot(const at::Tensor & other) const { + return at::_ops::hypot::call(const_cast(*this), other); +} + +// aten::hypot_(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::hypot_(const at::Tensor & other) const { + return at::_ops::hypot_::call(const_cast(*this), other); +} + +// aten::igamma(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::igamma(const at::Tensor & other) const { + return at::_ops::igamma::call(const_cast(*this), other); +} + +// aten::igamma_(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::igamma_(const at::Tensor & other) const { + return at::_ops::igamma_::call(const_cast(*this), other); +} + +// aten::igammac(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::igammac(const at::Tensor & other) const { + return at::_ops::igammac::call(const_cast(*this), other); +} + +// aten::igammac_(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::igammac_(const at::Tensor & other) const { + return at::_ops::igammac_::call(const_cast(*this), other); +} + +// aten::nextafter(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::nextafter(const at::Tensor & other) const { + return at::_ops::nextafter::call(const_cast(*this), other); +} + +// aten::nextafter_(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::nextafter_(const at::Tensor & other) const { + return at::_ops::nextafter_::call(const_cast(*this), other); +} + +// aten::remainder.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::remainder(const at::Scalar & other) const { + return at::_ops::remainder_Scalar::call(const_cast(*this), other); +} + +// aten::remainder_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::remainder_(const at::Scalar & other) const { + return at::_ops::remainder__Scalar::call(const_cast(*this), other); +} + +// aten::remainder.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::remainder(const at::Tensor & other) const { + return at::_ops::remainder_Tensor::call(const_cast(*this), other); +} + +// aten::remainder_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::remainder_(const at::Tensor & other) const { + return at::_ops::remainder__Tensor::call(const_cast(*this), other); +} + +// aten::min(Tensor self) -> Tensor +inline at::Tensor Tensor::min() const { + return at::_ops::min::call(const_cast(*this)); +} + +// aten::fmin(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::fmin(const at::Tensor & other) const { + return at::_ops::fmin::call(const_cast(*this), other); +} + +// aten::max(Tensor self) -> Tensor +inline at::Tensor Tensor::max() const { + return at::_ops::max::call(const_cast(*this)); +} + +// aten::fmax(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::fmax(const at::Tensor & other) const { + return at::_ops::fmax::call(const_cast(*this), other); +} + +// aten::maximum(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::maximum(const at::Tensor & other) const { + return at::_ops::maximum::call(const_cast(*this), other); +} + +// aten::max.other(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::max(const at::Tensor & other) const { + return at::_ops::max_other::call(const_cast(*this), other); +} + +// aten::minimum(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::minimum(const at::Tensor & other) const { + return at::_ops::minimum::call(const_cast(*this), other); +} + +// aten::min.other(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::min(const at::Tensor & other) const { + return at::_ops::min_other::call(const_cast(*this), other); +} + +// aten::quantile(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor +inline at::Tensor Tensor::quantile(const at::Tensor & q, ::std::optional dim, bool keepdim, c10::string_view interpolation) const { + return at::_ops::quantile::call(const_cast(*this), q, dim, keepdim, interpolation); +} + +// aten::quantile.scalar(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor +inline at::Tensor Tensor::quantile(double q, ::std::optional dim, bool keepdim, c10::string_view interpolation) const { + return at::_ops::quantile_scalar::call(const_cast(*this), q, dim, keepdim, interpolation); +} + +// aten::nanquantile(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor +inline at::Tensor Tensor::nanquantile(const at::Tensor & q, ::std::optional dim, bool keepdim, c10::string_view interpolation) const { + return at::_ops::nanquantile::call(const_cast(*this), q, dim, keepdim, interpolation); +} + +// aten::nanquantile.scalar(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor +inline at::Tensor Tensor::nanquantile(double q, ::std::optional dim, bool keepdim, c10::string_view interpolation) const { + return at::_ops::nanquantile_scalar::call(const_cast(*this), q, dim, keepdim, interpolation); +} + +// aten::sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::sort(int64_t dim, bool descending) const { + return at::_ops::sort::call(const_cast(*this), dim, descending); +} + +// aten::sort.stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::sort(::std::optional stable, int64_t dim, bool descending) const { + return at::_ops::sort_stable::call(const_cast(*this), stable, dim, descending); +} + +// aten::sort.dimname(Tensor self, Dimname dim, bool descending=False) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::sort(at::Dimname dim, bool descending) const { + return at::_ops::sort_dimname::call(const_cast(*this), dim, descending); +} + +// aten::sort.dimname_stable(Tensor self, *, bool? stable, Dimname dim, bool descending=False) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::sort(::std::optional stable, at::Dimname dim, bool descending) const { + return at::_ops::sort_dimname_stable::call(const_cast(*this), stable, dim, descending); +} + +// aten::msort(Tensor self) -> Tensor +inline at::Tensor Tensor::msort() const { + return at::_ops::msort::call(const_cast(*this)); +} + +// aten::argsort(Tensor self, int dim=-1, bool descending=False) -> Tensor +inline at::Tensor Tensor::argsort(int64_t dim, bool descending) const { + return at::_ops::argsort::call(const_cast(*this), dim, descending); +} + +// aten::argsort.stable(Tensor self, *, bool stable, int dim=-1, bool descending=False) -> Tensor +inline at::Tensor Tensor::argsort(bool stable, int64_t dim, bool descending) const { + return at::_ops::argsort_stable::call(const_cast(*this), stable, dim, descending); +} + +// aten::argsort.dimname(Tensor self, Dimname dim, bool descending=False) -> Tensor +inline at::Tensor Tensor::argsort(at::Dimname dim, bool descending) const { + return at::_ops::argsort_dimname::call(const_cast(*this), dim, descending); +} + +// aten::topk(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::topk(int64_t k, int64_t dim, bool largest, bool sorted) const { + return at::_ops::topk::call(const_cast(*this), k, dim, largest, sorted); +} + +// aten::topk(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::topk_symint(c10::SymInt k, int64_t dim, bool largest, bool sorted) const { + return at::_ops::topk::call(const_cast(*this), k, dim, largest, sorted); +} + +// aten::all(Tensor self) -> Tensor +inline at::Tensor Tensor::all() const { + return at::_ops::all::call(const_cast(*this)); +} + +// aten::any(Tensor self) -> Tensor +inline at::Tensor Tensor::any() const { + return at::_ops::any::call(const_cast(*this)); +} + +// aten::renorm(Tensor self, Scalar p, int dim, Scalar maxnorm) -> Tensor +inline at::Tensor Tensor::renorm(const at::Scalar & p, int64_t dim, const at::Scalar & maxnorm) const { + return at::_ops::renorm::call(const_cast(*this), p, dim, maxnorm); +} + +// aten::renorm_(Tensor(a!) self, Scalar p, int dim, Scalar maxnorm) -> Tensor(a!) +inline at::Tensor & Tensor::renorm_(const at::Scalar & p, int64_t dim, const at::Scalar & maxnorm) const { + return at::_ops::renorm_::call(const_cast(*this), p, dim, maxnorm); +} + +// aten::unfold(Tensor(a) self, int dimension, int size, int step) -> Tensor(a) +inline at::Tensor Tensor::unfold(int64_t dimension, int64_t size, int64_t step) const { + return at::_ops::unfold::call(const_cast(*this), dimension, size, step); +} + +// aten::equal(Tensor self, Tensor other) -> bool +inline bool Tensor::equal(const at::Tensor & other) const { + return at::_ops::equal::call(const_cast(*this), other); +} + +// aten::pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor +inline at::Tensor Tensor::pow(const at::Tensor & exponent) const { + return at::_ops::pow_Tensor_Tensor::call(const_cast(*this), exponent); +} + +// aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor +inline at::Tensor Tensor::pow(const at::Scalar & exponent) const { + return at::_ops::pow_Tensor_Scalar::call(const_cast(*this), exponent); +} + +// aten::pow_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!) +inline at::Tensor & Tensor::pow_(const at::Scalar & exponent) const { + return at::_ops::pow__Scalar::call(const_cast(*this), exponent); +} + +// aten::pow_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!) +inline at::Tensor & Tensor::pow_(const at::Tensor & exponent) const { + return at::_ops::pow__Tensor::call(const_cast(*this), exponent); +} + +// aten::float_power.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor +inline at::Tensor Tensor::float_power(const at::Tensor & exponent) const { + return at::_ops::float_power_Tensor_Tensor::call(const_cast(*this), exponent); +} + +// aten::float_power.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor +inline at::Tensor Tensor::float_power(const at::Scalar & exponent) const { + return at::_ops::float_power_Tensor_Scalar::call(const_cast(*this), exponent); +} + +// aten::float_power_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!) +inline at::Tensor & Tensor::float_power_(const at::Scalar & exponent) const { + return at::_ops::float_power__Scalar::call(const_cast(*this), exponent); +} + +// aten::float_power_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!) +inline at::Tensor & Tensor::float_power_(const at::Tensor & exponent) const { + return at::_ops::float_power__Tensor::call(const_cast(*this), exponent); +} + +// aten::normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!) +inline at::Tensor & Tensor::normal_(double mean, double std, ::std::optional generator) const { + return at::_ops::normal_::call(const_cast(*this), mean, std, generator); +} + +// aten::alias(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::alias() const { + return at::_ops::alias::call(const_cast(*this)); +} + +// aten::isfinite(Tensor self) -> Tensor +inline at::Tensor Tensor::isfinite() const { + return at::_ops::isfinite::call(const_cast(*this)); +} + +// aten::isinf(Tensor self) -> Tensor +inline at::Tensor Tensor::isinf() const { + return at::_ops::isinf::call(const_cast(*this)); +} + +// aten::record_stream(Tensor(a!) self, Stream s) -> () +inline void Tensor::record_stream(at::Stream s) const { + return at::_ops::record_stream::call(const_cast(*this), s); +} + +// aten::isposinf(Tensor self) -> Tensor +inline at::Tensor Tensor::isposinf() const { + return at::_ops::isposinf::call(const_cast(*this)); +} + +// aten::isneginf(Tensor self) -> Tensor +inline at::Tensor Tensor::isneginf() const { + return at::_ops::isneginf::call(const_cast(*this)); +} + +// aten::det(Tensor self) -> Tensor +inline at::Tensor Tensor::det() const { + return at::_ops::det::call(const_cast(*this)); +} + +// aten::slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet) +inline ::std::tuple Tensor::slogdet() const { + return at::_ops::slogdet::call(const_cast(*this)); +} + +// aten::logdet(Tensor self) -> Tensor +inline at::Tensor Tensor::logdet() const { + return at::_ops::logdet::call(const_cast(*this)); +} + +// aten::inverse(Tensor self) -> Tensor +inline at::Tensor Tensor::inverse() const { + return at::_ops::inverse::call(const_cast(*this)); +} + +// aten::inner(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::inner(const at::Tensor & other) const { + return at::_ops::inner::call(const_cast(*this), other); +} + +// aten::outer(Tensor self, Tensor vec2) -> Tensor +inline at::Tensor Tensor::outer(const at::Tensor & vec2) const { + return at::_ops::outer::call(const_cast(*this), vec2); +} + +// aten::ger(Tensor self, Tensor vec2) -> Tensor +inline at::Tensor Tensor::ger(const at::Tensor & vec2) const { + return at::_ops::ger::call(const_cast(*this), vec2); +} + +// aten::to_padded_tensor(Tensor self, float padding, SymInt[]? output_size=None) -> Tensor +inline at::Tensor Tensor::to_padded_tensor(double padding, at::OptionalIntArrayRef output_size) const { + return at::_ops::to_padded_tensor::call(const_cast(*this), padding, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt); +} + +// aten::to_padded_tensor(Tensor self, float padding, SymInt[]? output_size=None) -> Tensor +inline at::Tensor Tensor::to_padded_tensor_symint(double padding, at::OptionalSymIntArrayRef output_size) const { + return at::_ops::to_padded_tensor::call(const_cast(*this), padding, output_size); +} +} // namespace at + + +namespace c10 { +template <> +struct MaybeOwnedTraits { + using owned_type = at::Tensor; + using borrow_type = at::Tensor; + + static borrow_type createBorrow(const owned_type& from) { + // NOTE: this can be implemented without the special + // unsafe_borrow_t Tensor constructor as + // + // return borrow_type(c10::intrusive_ptr::reclaim(from.unsafeGetTensorImpl())); + // + // but that hurts inlining due to the nullptr check in the + // Tensor(c10::intrusive_ptr<...>) constructor. We already know + // that from.impl_ isn't null because from is a valid Tensor, so + // we needn't do the check again. (using __builtin_assume can + // avoid this, but wouldn't be portable to MSVC.) + return borrow_type(borrow_type::unsafe_borrow_t{}, from); + } + + static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) { + lhs.unsafeReleaseTensorImpl(); + // See above note: this can be implemented with public API + // similarly to createBorrow(), but that would hurt inlining. + lhs = borrow_type(borrow_type::unsafe_borrow_t{}, rhs); + } + + static void destroyBorrow(borrow_type& toDestroy) { + toDestroy.unsafeReleaseTensorImpl(); // "leak" it, but it was already +0. + } + + static const owned_type& referenceFromBorrow(const borrow_type& borrow) { + return borrow; + } + + static const owned_type* pointerFromBorrow(const borrow_type& borrow) { + return &borrow; + } + + static bool debugBorrowIsValid(const borrow_type& /*borrow*/) { + return true; + } +}; + +template <> +struct ExclusivelyOwnedTraits { + using repr_type = at::Tensor; + using pointer_type = at::Tensor*; + using const_pointer_type = const at::Tensor*; + + static repr_type nullRepr() { + return at::Tensor(); + } + + template + static repr_type createInPlace(Args&&... args) { + return at::Tensor(std::forward(args)...); + } + + static repr_type moveToRepr(at::Tensor&& x) { + return std::move(x); + } + + static void destroyOwned(at::Tensor& x) { + return ExclusivelyOwnedTraits::destroyOwned(x); + } + + static at::Tensor take(at::Tensor& x) { + return std::move(x); + } + + static pointer_type getImpl(repr_type& x) { + return &x; + } + + static const_pointer_type getImpl(const repr_type& x) { + return &x; + } +}; +} // namespace c10 + +namespace at { + +inline c10::MaybeOwned borrow_from_optional_tensor( + const std::optional& opt) { + return opt.has_value() + ? c10::MaybeOwned::borrowed(*opt) + : c10::MaybeOwned::owned(std::in_place); +} + +inline c10::MaybeOwned Tensor::expect_contiguous(MemoryFormat memory_format) const & { + if (is_contiguous(memory_format)) { + return c10::MaybeOwned::borrowed(*this); + } else { + return c10::MaybeOwned::owned(__dispatch_contiguous(memory_format)); + } +} +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/TorchDispatchUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/core/TorchDispatchUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..0e119863eca135f826bfb312436300aa731efcb4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/TorchDispatchUtils.h @@ -0,0 +1,17 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace at::impl { + +TORCH_API bool tensor_has_dispatch(const at::Tensor& t); +TORCH_API bool tensorlist_has_dispatch(at::ITensorListRef li); +TORCH_API bool tensorlist_has_dispatch( + const c10::List>& li); +using c10::impl::dispatch_mode_enabled; + +} // namespace at::impl diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/TransformationHelper.h b/phivenv/Lib/site-packages/torch/include/ATen/core/TransformationHelper.h new file mode 100644 index 0000000000000000000000000000000000000000..ea0ba9723e7174ce237d044e3f4564cfe023179e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/TransformationHelper.h @@ -0,0 +1,175 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { + +// Using DistAccumType in accumulate types for distributions. +// Note: Ideally we'd be using ATen/AccumulateType.h but looks +// like the there is some inconsistency in how accumulate types +// are mapped currently, e.g. for the cpu side, float is mapped +// to double. +template +struct DistAccumType { }; + +#if defined(__CUDACC__) || defined(__HIPCC__) +template <> struct DistAccumType { using type = float; }; +#endif +template <> struct DistAccumType { using type = float; }; +template <> struct DistAccumType { using type = float; }; +template <> struct DistAccumType { using type = float; }; +template <> struct DistAccumType { using type = double; }; + +template +using dist_acctype = typename DistAccumType::type; + +namespace transformation { + +/** + * A transformation function for `torch.Tensor.random_()`, when both `from` and `to` are specified. + * `range` is `to - from` + * `base` is `from` + */ +template +C10_HOST_DEVICE inline T uniform_int_from_to(V val, uint64_t range, int64_t base) { + return static_cast(static_cast((val % range) + base)); +} + +/** + * A transformation function for `torch.Tensor.random_()`, when `from=min_value(int64_t)` and to=None + */ +template +C10_HOST_DEVICE inline T uniform_int_full_range(V val) { + return static_cast(static_cast(val)); +} + +/** + * A transformation function for `torch.Tensor.random_()`, when used without specifying `from` and `to`. + * In order to prevent compiler warnings reported in GitHub issue 46391, T can't be float or double + * in this overloaded version + */ +template +C10_HOST_DEVICE inline std::enable_if_t), T>uniform_int(V val) { + if constexpr (std::is_same_v) { + return static_cast(val & 1); + } else if constexpr (std::is_same_v) { + return static_cast(val % (static_cast(std::numeric_limits::max()) + 1)); + } else if constexpr (std::is_same_v || std::is_same_v) { + return static_cast(val % static_cast((1ULL << std::numeric_limits::digits) + 1)); + } else if constexpr (std::is_integral_v) { + return static_cast(val % (static_cast(std::numeric_limits::max()) + 1)); + } else { + assert(false); + return 0; + } +} + +/** + * An overloaded transformation function for `torch.Tensor.random_()`, when used without specifying `from` and `to`, + * added to fix compiler warnings reported in GitHub issue 46391. T is either float or double in this version. + */ +template +C10_HOST_DEVICE inline std::enable_if_t, T>uniform_int(V val) { + return static_cast(val % static_cast((1ULL << std::numeric_limits::digits) + 1)); +} + +template +C10_HOST_DEVICE inline dist_acctype uniform_real(V val, T from, T to) { + constexpr auto MASK = static_cast((static_cast(1) << std::numeric_limits::digits) - 1); + constexpr auto DIVISOR = static_cast>(1) / (static_cast(1) << std::numeric_limits::digits); + dist_acctype x = (val & MASK) * DIVISOR; + return (x * (to - from) + from); +} + +/** + * Transforms normally distributed `val` with mean 0.0 and standard deviation 1.0 to + * normally distributed with `mean` and standard deviation `std`. + */ +template +C10_HOST_DEVICE inline T normal(T val, T mean, T std) { + return val * std + mean; +} + +/** + * Transforms uniformly distributed `val` between 0.0 and 1.0 to + * Cauchy distribution with location parameter `median` and scale parameter `sigma`. + */ +template +C10_HOST_DEVICE inline T cauchy(T val, T median, T sigma) { + // https://en.wikipedia.org/wiki/Cauchy_distribution#Cumulative_distribution_function + // __tanf overflows and returns `inf/-inf` when (val > 1 - eps) or (val < 0 + eps), + // thus we clip those values. + constexpr T eps = std::numeric_limits::epsilon(); + constexpr T one_minus_eps = 1 - eps; + constexpr T zero_plus_eps = 0 + eps; + val = (val > one_minus_eps ? one_minus_eps : val); + val = (val < zero_plus_eps ? zero_plus_eps : val); + return median + sigma * at::tan(c10::pi * (val - static_cast(0.5))); +} + +template <> +C10_HOST_DEVICE inline double cauchy(double val, double median, double sigma) { + // https://en.wikipedia.org/wiki/Cauchy_distribution#Cumulative_distribution_function + return median + sigma * at::tan(c10::pi * (val - static_cast(0.5))); +} + +/** + * Transforms uniformly distributed `val` between 0.0 and 1.0 to + * exponentially distributed with `lambda` parameter of the distribution. + */ +template +C10_HOST_DEVICE inline T exponential(T val, T lambda) { + // https://en.wikipedia.org/wiki/Exponential_distribution#Generating_exponential_variates + // Different implementations for CUDA and CPU to preserve original logic + // TODO: must be investigated and unified!!! + // https://github.com/pytorch/pytorch/issues/38662 +#if defined(__CUDACC__) || defined(__HIPCC__) + // BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/16706 + // curand_uniform has (0,1] bounds. log(1) is 0 and exponential excludes 0. + // we need log to be not 0, and not underflow when converted to half + // fast __logf approximation can underflow, so set log to -epsilon/2 for 1 or close to 1 args + auto log = val >= static_cast(1.) - std::numeric_limits::epsilon() / 2 + ? -std::numeric_limits::epsilon() / 2 + : at::log(val); + return static_cast(-1.0) / lambda * log; +#else + return static_cast(-1.0) / lambda * at::log1p(-val); +#endif +} + +/** + * Transforms uniformly distributed `val` between 0.0 and 1.0 to + * geometrically distributed with success probability `p`. + */ +template +C10_HOST_DEVICE inline T geometric(T val, T p) { + // https://en.wikipedia.org/wiki/Geometric_distribution#Related_distributions + return static_cast(::ceil(at::log(val) / at::log1p(-p))); +} + +/** + * Transforms normally distributed `val` to log-normally distributed. + */ +template +C10_HOST_DEVICE inline T log_normal(T val) { + // https://en.wikipedia.org/wiki/Log-normal_distribution#Mode,_median,_quantiles + return at::exp(val); +} + +/** + * Transforms uniformly distributed `val` between 0.0 and 1.0 to + * bernoulli distributed with success probability `p`. + */ +template +C10_HOST_DEVICE inline T bernoulli(T val, T p) { + return val < p; +} + +}} // namespace at::transformation diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/UndefinedTensorImpl.h b/phivenv/Lib/site-packages/torch/include/ATen/core/UndefinedTensorImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..49612392cc4f66224d30e9480522acca886fd293 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/UndefinedTensorImpl.h @@ -0,0 +1 @@ +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/UnsafeFromTH.h b/phivenv/Lib/site-packages/torch/include/ATen/core/UnsafeFromTH.h new file mode 100644 index 0000000000000000000000000000000000000000..a47ad1586d70587faf7dea99d50b20dbea3a344f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/UnsafeFromTH.h @@ -0,0 +1,21 @@ +#pragma once +#include + +namespace at { + +inline Tensor unsafeTensorFromTH(void * th_pointer, bool retain) { + auto tensor_impl = c10::intrusive_ptr::reclaim(static_cast(th_pointer)); + if (retain && tensor_impl.get() != UndefinedTensorImpl::singleton()) { + c10::raw::intrusive_ptr::incref(tensor_impl.get()); + } + return Tensor(std::move(tensor_impl)); +} + +inline Storage unsafeStorageFromTH(void * th_pointer, bool retain) { + if (retain && th_pointer) { + c10::raw::intrusive_ptr::incref(static_cast(th_pointer)); + } + return Storage(c10::intrusive_ptr::reclaim(static_cast(th_pointer))); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/VariableHooksInterface.h b/phivenv/Lib/site-packages/torch/include/ATen/core/VariableHooksInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..0d46d7beebb6bd94f0328bfaf9030b27846f7f8a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/VariableHooksInterface.h @@ -0,0 +1,83 @@ +#pragma once + +#include +#include + +// A little explanation about why this file exists at all. We have +// a few methods on Tensor class which require access to reified access to +// AutogradMeta. In open source, this isn't a big deal: we just access +// torch/csrc/autograd/variable.h from aten/src/ATen/core/Tensor.cpp and +// we can put the definitions inline. This is because everything gets balled +// into a single dynamic library in the end. +// +// However, inside our Facebook internal version of our build system, we +// have a split between aten and torch/csrc. So we cannot simply just +// cross this boundary. "Now wait," you might say, "Why don't we just +// merge the libraries inside Facebook". Well, the problem is that there +// are some downstream applications which are at binary size limit, and +// incorporating all of the extra code from libtorch would push them +// over (admarket/adreview/service:adreviewservice, see also +// https://github.com/pytorch/pytorch/pull/29299) So if you want to do that, +// we have to fix all of the services like this. +// +// I didn't want to block eliminating Tensor-Variable on this work, so I +// had to introduce another dynamic dispatch to get to the variable +// implementations (which live in torch/csrc/autograd/variable.cpp, FYI). +// +// I also considered using our existing dynamic dispatch mechanism, c10 +// dispatcher, to do this. However, (1) some of the functions on Tensor +// have weird signatures that are not supported by autograd, and (2) +// see this bug https://github.com/pytorch/pytorch/issues/30102 + +namespace torch::autograd { + +struct Node; + +} // namespace torch::autograd + +namespace at::impl { + +struct TORCH_API VariableHooksInterface { + virtual ~VariableHooksInterface() = default; + virtual TensorBase tensor_data(const TensorBase&) const = 0; + virtual TensorBase variable_data(const TensorBase&) const = 0; + virtual const std::shared_ptr& grad_fn( + const TensorBase&) const = 0; + virtual unsigned _register_hook( + const TensorBase&, + std::function hook) const = 0; + virtual void remove_hook(const TensorBase&, unsigned pos) const = 0; + virtual bool is_view(const TensorBase&) const = 0; + virtual const TensorBase& base(const TensorBase&) const = 0; + virtual const std::string& name(const TensorBase&) const = 0; + virtual bool is_leaf(const TensorBase&) const = 0; + virtual int64_t output_nr(const TensorBase&) const = 0; + virtual void set_data(const TensorBase&, const TensorBase&) const = 0; + virtual TensorBase data(const TensorBase&) const = 0; + virtual int64_t _version(const TensorBase&) const = 0; + virtual void retain_grad(const TensorBase&) const = 0; + virtual bool retains_grad(const TensorBase&) const = 0; + virtual void _backward( + const Tensor&, + TensorList, + const std::optional&, + std::optional, + bool) const = 0; + virtual void requires_grad_(const TensorBase&, bool) const = 0; + virtual void basic_autograd_not_implemented_fallback( + const c10::OperatorHandle& op, + c10::DispatchKeySet dispatch_keys, + torch::jit::Stack* stack) const = 0; +}; + +TORCH_API void SetVariableHooks(VariableHooksInterface* hooks); +TORCH_API VariableHooksInterface* GetVariableHooks(); +TORCH_API bool HasVariableHooks(); + +struct TORCH_API VariableHooksRegisterer { + explicit VariableHooksRegisterer(VariableHooksInterface* hooks) { + SetVariableHooks(hooks); + } +}; + +} // namespace at::impl diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/Variadic.h b/phivenv/Lib/site-packages/torch/include/ATen/core/Variadic.h new file mode 100644 index 0000000000000000000000000000000000000000..3f850501e340d8fc1489b6af12562fc971040a37 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/Variadic.h @@ -0,0 +1,92 @@ +#pragma once + +#include + +#include +#include + +namespace at { + +// This class allows you to write variadic functions which +// call a (possibly overloaded) function on each argument, +// in order. This is most commonly used in autogenerated code, +// where it is convenient to have a function that can uniformly +// take arguments of different types. If your arguments +// are homogenous consider using a std::initializer_list instead. +// +// For examples of this in use, see torch/csrc/utils/variadic.h +template +struct IterArgs { + template + inline F& apply() { + return self(); + } + + // NB: Use perfect forwarding here, otherwise we'll make value + // copies of all arguments! + template + inline F& apply(T&& arg, Args&&... args) { + self()(std::forward(arg)); + if (self().short_circuit()) { + return self(); + } else { + return apply(std::forward(args)...); + } + } + + // Here are some handy overloads which provide sensible + // defaults for container-like structures that one might + // be interested in recursing into. You can enable them + // by adding: + // + // using IterArgs::operator() + // + // to your struct. These are not enabled by default because + // you may be able to process these structures more efficiently + // than handling them one-by-one. + + template + void operator()(c10::IListRef args) { + for (const auto& arg : args) { + self()(arg); + if (self().short_circuit()) + return; + } + } + + template + void operator()(at::ArrayRef args) { + for (const auto& arg : args) { + self()(arg); + if (self().short_circuit()) + return; + } + } + + template + void operator()(const torch::List& args) { + for (const auto& arg : args) { + self()(arg); + if (self().short_circuit()) + return; + } + } + + // NB: we need to specify std::vector manually as C++ won't + // do an implicit conversion to make a template deduction go through. + template + void operator()(const std::vector& args) { + self()(at::ArrayRef{args}); + } + + constexpr bool short_circuit() const { + return false; + } + + private: + inline F& self() { + return *static_cast(this); + } +}; + +} // namespace torch diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/Vitals.h b/phivenv/Lib/site-packages/torch/include/ATen/core/Vitals.h new file mode 100644 index 0000000000000000000000000000000000000000..924005882d6bc83e2b27a4030733c866b92d80ab --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/Vitals.h @@ -0,0 +1,94 @@ +#pragma once +#include +#include +#include + +#include + +namespace at::vitals { + +TORCH_API bool torchVitalEnabled(); + +struct TORCH_API TorchVitalAttr { + // always initialized to empty + std::string value; + template + TorchVitalAttr& operator<<(const T& t) { + if (torchVitalEnabled()) { + std::stringstream ss; + ss << t; + value += ss.str(); + } + return *this; + } + + template + void write(const T& t, bool force) { + if (force || torchVitalEnabled()) { + std::stringstream ss; + ss << t; + value = ss.str(); + } + } +}; + +struct TORCH_API TorchVital { + std::string name; + std::unordered_map attrs; + + explicit TorchVital(std::string n) : name(std::move(n)) {} + TorchVital(const TorchVital&) = default; + TorchVital(TorchVital&&) = default; + TorchVital& operator=(const TorchVital&) = default; + TorchVital& operator=(TorchVital&&) = default; + TorchVital() = delete; + + TorchVitalAttr& create(const std::string& attr); + TorchVitalAttr& create(const std::string& attr, bool force); + friend std::ostream& operator<<(std::ostream& os, const TorchVital& dt); + + ~TorchVital(); +}; + +std::ostream& operator<<(std::ostream& os, TorchVital const& tv); + +// A way to access vitals by string names instead of by global reference. +// This enables access to vitals from the PythonAPI. +class TORCH_API APIVitals { + public: + bool vitals_enabled; + + // Set any vital sign that was added to the map. + bool setVital( + const std::string& vital_name, + const std::string& attr_name, + const std::string& value, + bool force = false); + std::string readVitals(); + + APIVitals(); + + // Ensure this stays a singleton + APIVitals(APIVitals const& other) = delete; + APIVitals(APIVitals&& other) = delete; + APIVitals& operator=(const APIVitals&) = delete; + APIVitals& operator=(APIVitals&&) = delete; + ~APIVitals() = default; + + private: + std::unordered_map name_map_; +}; + +extern TORCH_API APIVitals VitalsAPI; + +} // namespace at::vitals + +#define TORCH_VITAL_DECLARE(name) \ + TORCH_API at::vitals::TorchVital TorchVital_##name; + +#define TORCH_VITAL_DEFINE(name) \ + TORCH_API at::vitals::TorchVital TorchVital_##name(#name); + +#define TORCH_VITAL_BASE(name) TorchVital_##name + +#define TORCH_VITAL(name, attr) TORCH_VITAL_BASE(name).create(#attr) diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/boxing/BoxedKernel.h b/phivenv/Lib/site-packages/torch/include/ATen/core/boxing/BoxedKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..52bc1ccf37701bc91a28dbe0884bb4fe9f371ddf --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/boxing/BoxedKernel.h @@ -0,0 +1,213 @@ +#pragma once + +#include +#include +#include + +namespace c10 { + +struct IValue; +using Stack = std::vector; + +class OperatorHandle; +class KernelFunction; + +// This kernel implements the behavior of falling through to the next available +// registered dispatch key. The implementation of this function is FAST; it is +// no overhead to fallthrough to the next key. See cpp file for some more +// implementation notes; notably, this does NOT actually go through the +// boxing/unboxing codepath. +TORCH_API void fallthrough_kernel( + OperatorKernel*, + const OperatorHandle&, + DispatchKeySet, + Stack*); + +// Note [Ambiguity in AutogradOther kernel] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This error-reporting kernel is registered to the AutogradOther entry in the +// dispatch table when there is both a CompositeImplicitAutograd kernel and a +// backend kernel for ANY backend that maps to AutogradOther. To see why +// this is necessary in the AutogradOther case, it's helpful to first see +// why everything works out fine for a backend that has a reserved Autograd +// entry (see rule 2.2 in [Note] DispatchTable computation): +// +// CPU AutogradCPU +// reg? registers with... +// ------------------------------------------------- +// y Autograd registration takes precedence +// over CompositeImplicitAutograd. +// This is good, because the CPU specific backend +// implementation is more specialized and typically better; +// if we used the composite, we would bypass it. +// (NB: the Autograd key is guaranteed to exist because +// the autograd codegen requires it!) +// +// n CompositeImplicitAutograd takes precedence. +// This is also good, because the Autograd +// registration (if it exists) would try to redispatch +// to the (non-existent) CPU implementation; by +// using the composite, we ensure the operator +// actually works. +// +// As you can see, when we have a specific Autograd key (AutogradCPU), we can +// decide whether or not to use the CompositeImplicitAutograd kernel or the +// Autograd kernel based on whether or not the backend kernel exists. +// +// However, for AutogradOther (which is the catchall autograd kernel for +// everything that doesn't have a specific Autograd key), we can't do this +// trick because there isn't any unique backend to peek at to disambiguate; +// if there are some backends that have implementations they prefer Autograd, +// but unimplemented backends would prefer CompositeImplicitAutograd. Rather +// than arbitrarily pick one or the other, we just register a kernel that raises +// an error and let the user decide how to proceed. +TORCH_API void ambiguous_autogradother_kernel( + OperatorKernel*, + const OperatorHandle&, + DispatchKeySet, + Stack*); + +// Note [named_not_supported_kernel] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This kernel implements reporting an error message saying that named tensor is +// not supported. This kernel doesn't rely on the Stack, and so it is special +// cased in the dispatcher to be triggered before we attempt boxing (so we can +// give a good error message in cases when boxing is not supported). When +// boxing is universally supported this can be removed. +[[noreturn]] TORCH_API void named_not_supported_kernel( + OperatorKernel*, + const OperatorHandle&, + DispatchKeySet, + Stack*); + +/** + * BoxedKernel is similar to a std::function storing a boxed kernel. + */ +class TORCH_API BoxedKernel final { + public: + // This is how boxed kernels are actually stored + // + // Note [Plumbing Keys Through The Dispatcher] + // Benchmarks have shown that it is expensive for the dispatcher to read from + // thread-local storage (TLS) upon every dispatch call into order to compute + // which kernel to dispatch to. + // + // To mitigate this, we've updated the calling convention inside the + // dispatcher to expect every kernel that it stores to have a first argument + // of type DispatchKeySet. + // + // What are the invariants of the DispatchKeySet when it gets passed to a + // kernel? + // - All keys to the left of the current dispatch key have been masked out. + // (e.g. a Tracing kernel that takes in the DispatchKeySet will expect the + // highest bit to be DispatchKey::Tracer) + // - All other keys that dispatcher normally would have computed through TLS + + // global state + op arguments + // are still in the set. + // + // Kernels can then opt into using this keyset to save the dispatcher from + // doing repeated work during redispatches: recalculating the highest-priority + // dispatch key, which involves reading from TLS. Instead, the kernels that + // opt in will calculate an updated DispatchKeySet directly from the old one, + // and pass the updated set directly into the dispatcher upon redispatching. + // + // This is an opt-in mechanism: Kernels can automatically opt in by setting + // the first argument in their signature to be of type DispatchKeySet. See the + // kernels in VariableTypeEverything.cpp and TraceTypeEverything.cpp for + // examples. + // + // The mechanism for optionally passing that DispatchKeySet into the kernel + // lives in make_boxed_from_unboxed_functor.h. See Note [Plumbing Keys Through + // The Dispatcher 2] for details. + using InternalBoxedKernelFunction = + void(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*); + // This is the public API for how boxed kernels are defined + using BoxedKernelFunction = void(const OperatorHandle&, Stack*); + using BoxedKernelFunction_withDispatchKeys = + void(const OperatorHandle&, DispatchKeySet, Stack*); + + BoxedKernel(); + + // Fast path for dispatch to allow not touching the boxed kernel in + // the common case where unboxed is available. + bool isValid() const; + bool isFallthrough() const; + + /** + * Call the function with boxed arguments. + */ + void callBoxed( + const OperatorHandle& opHandle, + DispatchKeySet dispatchKeySet, + Stack* stack) const; + + /** + * Create a KernelFunction from a boxed function. + * + * Example: + * + * > void boxed_func(OperatorKernel*, Stack* stack) {...} + * > BoxedFunction func = BoxedKernel::makeFromFunction<&boxed_func>(); + */ + template + static BoxedKernel makeFromFunction(); + + /** + * TODO: This will only be useful if we write a backend fallback that plumbs + * dispatch keys (currently there are none) See Note [Plumbing Keys Through + * The Dispatcher] for details. + */ + template + static BoxedKernel makeFromFunction(); + + /** + * Create a KernelFunction from a boxed functor. + * + * Example: + * + * > class MyFunctor final : public c10::OperatorKernel { + * > public: + * > void operator()(const OperatorHandle&, DispatchKeySet, Stack*) {...} + * > }; + * > BoxedKernel func = + * BoxedKernel::makeFromFunctor(std::make_unique()); + */ + template + static BoxedKernel makeFromFunctor( + std::unique_ptr kernelFunctor); + + static BoxedKernel makeFallthrough(); + static BoxedKernel makeAmbiguousAutogradOther(); + static BoxedKernel makeNamedNotSupported(); + + private: + friend class KernelFunction; + + template + static void make_boxed_function( + OperatorKernel*, + const OperatorHandle& opHandle, + DispatchKeySet, + Stack* stack); + + template + static void make_boxed_function( + OperatorKernel*, + const OperatorHandle& opHandle, + DispatchKeySet, + Stack* stack); + + explicit BoxedKernel( + std::unique_ptr functor, + InternalBoxedKernelFunction* boxed_kernel_func); + + OperatorKernel* getFunctor() const; + InternalBoxedKernelFunction* getFnPtr() const; + + c10::intrusive_ptr functor_; + InternalBoxedKernelFunction* boxed_kernel_func_; +}; + +} // namespace c10 + +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/boxing/BoxedKernel_impl.h b/phivenv/Lib/site-packages/torch/include/ATen/core/boxing/BoxedKernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..7c5e455ed583c1816234c541de41e9285b3a0e01 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/boxing/BoxedKernel_impl.h @@ -0,0 +1,106 @@ +#pragma once + +namespace c10 { + +inline BoxedKernel::BoxedKernel() : functor_(), boxed_kernel_func_(nullptr) {} + +inline BoxedKernel::BoxedKernel( + std::unique_ptr functor, + InternalBoxedKernelFunction* boxed_kernel_func) + : functor_(std::move(functor)), boxed_kernel_func_(boxed_kernel_func) {} + +template +inline void BoxedKernel::make_boxed_function( + OperatorKernel*, + const OperatorHandle& opHandle, + DispatchKeySet, + Stack* stack) { + // Note that we're dropping the DispatchKeySet argument. + // See Note [Plumbing Keys Through The Dispatcher 2] for details. + func(opHandle, stack); +} + +template +inline void BoxedKernel::make_boxed_function( + OperatorKernel*, + const OperatorHandle& opHandle, + DispatchKeySet ks, + Stack* stack) { + // See Note [Plumbing Keys Through The Dispatcher 2] for details. + func(opHandle, ks, stack); +} + +inline bool BoxedKernel::isValid() const { + return boxed_kernel_func_ != nullptr; +} + +inline bool BoxedKernel::isFallthrough() const { + return boxed_kernel_func_ == &fallthrough_kernel; +} + +inline void BoxedKernel::callBoxed( + const OperatorHandle& opHandle, + DispatchKeySet dispatchKeySet, + Stack* stack) const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + boxed_kernel_func_ != nullptr, + "Tried to call BoxedKernel::callBoxed() on an uninitialized BoxedKernel."); + (*boxed_kernel_func_)(functor_.get(), opHandle, dispatchKeySet, stack); +} + +template +inline BoxedKernel BoxedKernel::makeFromFunction() { + return BoxedKernel( + nullptr, // no functor_ object + &make_boxed_function); +} + +template +inline BoxedKernel BoxedKernel::makeFromFunction() { + return BoxedKernel( + nullptr, // no functor_ object + &make_boxed_function); +} + +inline BoxedKernel BoxedKernel::makeFallthrough() { + return BoxedKernel( + nullptr, // no functor_ object + &fallthrough_kernel); +} + +inline BoxedKernel BoxedKernel::makeAmbiguousAutogradOther() { + return BoxedKernel( + nullptr, // no functor_ object + &ambiguous_autogradother_kernel); +} + +inline BoxedKernel BoxedKernel::makeNamedNotSupported() { + return BoxedKernel( + nullptr, // no functor_ object + &named_not_supported_kernel); +} + +template +inline BoxedKernel BoxedKernel::makeFromFunctor( + std::unique_ptr kernelFunctor) { + static_assert( + std::is_base_of_v, + "Tried to call BoxedKernel::makeFromFunctor, but the functor doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it."); + return BoxedKernel( + std::move(kernelFunctor), + [](OperatorKernel* kernel, + const OperatorHandle& op, + DispatchKeySet ks, + Stack* stack) { + (*static_cast(kernel))(op, ks, stack); + }); +} + +inline OperatorKernel* BoxedKernel::getFunctor() const { + return functor_.get(); +} +inline BoxedKernel::InternalBoxedKernelFunction* BoxedKernel::getFnPtr() const { + return boxed_kernel_func_; +} + +} // namespace c10 diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/boxing/KernelFunction.h b/phivenv/Lib/site-packages/torch/include/ATen/core/boxing/KernelFunction.h new file mode 100644 index 0000000000000000000000000000000000000000..2fbae99d396601179ed6af88a54e23e731cf7cb4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/boxing/KernelFunction.h @@ -0,0 +1,283 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +using Stack = torch::jit::Stack; // TODO Instead of this, move torch::jit::Stack + // to the c10 namespace. + +class OperatorHandle; +struct OperatorKernel; +class KernelFunction; + +template +using has_symint = std::disjunction< + std::is_same, + std::is_same, + std::is_same, + std::is_same, T>>; + +template +struct remove_symint { + using type = T; +}; + +template <> +struct remove_symint { + using type = int64_t; +}; + +template <> +struct remove_symint { + using type = OptionalIntArrayRef; +}; + +template <> +struct remove_symint { + using type = c10::IntArrayRef; +}; + +template <> +struct remove_symint> { + using type = std::optional; +}; + +template +struct maybe_keep_symint final {}; + +template +struct maybe_keep_symint { + using type = T; +}; + +template +struct maybe_keep_symint { + using type = typename remove_symint::type; +}; + +template +using fn_has_symint = typename guts::typelist::true_for_any_type< + has_symint, + typename guts::infer_function_traits::type::parameter_types>; + +template +struct fn_remove_symint; + +template +struct fn_remove_symint { + using type = Ret(typename remove_symint::type...); +}; + +/** + * KernelFunction is similar to std::function but stores a kernel function. + * You can create a KernelFunction from a boxed or unboxed + * function/functor/lambda and call it in a boxed or unboxed way. If the way it + * was created doesn't match the way it was called, it will do boxing or + * unboxing as necessary. + */ +class TORCH_API KernelFunction final { + public: + using InternalBoxedKernelFunction = BoxedKernel::InternalBoxedKernelFunction; + using BoxedKernelFunction = BoxedKernel::BoxedKernelFunction; + using BoxedKernelFunction_withDispatchKeys = + BoxedKernel::BoxedKernelFunction_withDispatchKeys; + + KernelFunction(); + + // Fast path for dispatch to allow not touching the boxed kernel in + // the common case where unboxed is available. + bool isValidUnboxed() const; + bool isValidSymUnboxed() const; + bool isValid() const; + bool isFallthrough() const; + + /** + * Call the function in a boxed way. + * If the kernel function was created with an unboxed function, + * this will call an unboxing wrapper which then calls into that + * unboxed function. + * + * Example: + * + * > void boxed_func(OperatorKernel*, Stack* stack) {...} + * > KernelFunction func = KernelFunction::makeFromBoxedFunction(&boxed_func); + * > Tensor result = func.callBoxed(stack); + * + * Or, with an unboxed implementation: + * + * > KernelFunction func = KernelFunction::makeFromUnboxedLambda( + * > [] (Tensor a, bool b) -> Tensor {...}); + * > Tensor result = func.callBoxed(stack); + */ + void callBoxed( + const OperatorHandle& opHandle, + DispatchKeySet dispatchKeySet, + Stack* stack) const; + + /** + * Call the function in an unboxed way. + * If the kernel function was created with a boxed function, + * this will box all inputs and then call into that boxed function. + * + * Note that this doesn't work for all types yet. + * + * Example: + * + * > KernelFunction func = KernelFunction::makeFromUnboxedLambda( + * > [] (Tensor a, bool b) -> Tensor {...}); + * > Tensor result = func.call(tensor1, true); + * + * Or, with a boxed implementation: + * + * > void boxed_func(OperatorKernel*, Stack* stack) {...} + * > KernelFunction func = KernelFunction::makeFromBoxedFunction(&boxed_func); + * > Tensor result = func.call(tensor1, true); + */ + template + Return call( + const OperatorHandle& opHandle, + DispatchKeySet dispatchKeySet, + Args... args) const; + + /** + * Create a KernelFunction from a BoxedKernel. + */ + static KernelFunction makeFromBoxedKernel(BoxedKernel boxed_fn); + + /** + * Create a KernelFunction from a boxed function. + * + * Example: + * + * > void boxed_func(OperatorKernel*, Stack* stack) {...} + * > KernelFunction func = + * KernelFunction::makeFromBoxedFunction<&boxed_func>(); + */ + template + static KernelFunction makeFromBoxedFunction(); + + /** + * TODO: This will only be useful if we write a backend fallback that plumbs + * dispatch keys (currently there are none) See Note [Plumbing Keys Through + * The Dispatcher] for details. + */ + template + static KernelFunction makeFromBoxedFunction(); + + /** + * Create a KernelFunction from an unboxed functor. + * + * Example: + * + * > class MyFunctor final : public c10::OperatorKernel { + * > public: + * > Tensor operator()(Tensor a, Tensor b) {...} + * > }; + * > KernelFunction func = + * KernelFunction::makeFromUnboxedFunctor(std::make_unique()); + */ + template + static KernelFunction makeFromUnboxedFunctor( + std::unique_ptr kernelFunctor); + + /** + * Create a KernelFunction from a boxed functor. + * + * Example: + * + * > class MyFunctor final : public c10::OperatorKernel { + * > public: + * > void operator()(const OperatorHandle&, DispatchKeySet, Stack*) {...} + * > }; + * > KernelFunction func = + * KernelFunction::makeFromBoxedFunctor(std::make_unique()); + */ + template + static KernelFunction makeFromBoxedFunctor( + std::unique_ptr kernelFunctor); + + /** + * Create a KernelFunction from an unboxed function. + * This is usually better than KernelFunction::makeFromUnboxedRuntimeFunction + * because knowing the function pointer as a template argument (i.e. at + * compile time) allows the compiler to inline the function into its + * unboxing wrapper and yields better performance when calling the function. + * + * Example: + * + * > Tensor unboxed_func(Tensor a, Tensor b) {...} + * > KernelFunction func = + * KernelFunction::makeFromUnboxedFunction(); + */ + template + static KernelFunction makeFromUnboxedFunction(FuncPtr); + + /** + * Create a KernelFunction from an unboxed function. + * KernelFunction::makeFromUnboxedFunction is usually a better choice than + * this if you know the function pointer at compile time, see doc comment + * there for an explanation. + * + * Example: + * + * > Tensor unboxed_func(Tensor a, Tensor b) {...} + * > KernelFunction func = + * KernelFunction::makeFromUnboxedRuntimeFunction(&unboxed_func); + */ + template + static KernelFunction makeFromUnboxedRuntimeFunction(FuncType* func); + + static KernelFunction makeFallthrough(); + static KernelFunction makeAmbiguousAutogradOther(); + static KernelFunction makeNamedNotSupported(); + + /** + * Create a KernelFunction from an unboxed lambda. + * + * Example: + * + * > KernelFunction func = KernelFunction::makeFromUnboxedLambda( + * > [] (Tensor a, bool b) -> Tensor {...}); + */ + template + static std::enable_if_t< + guts::is_stateless_lambda>::value, + KernelFunction> + makeFromUnboxedLambda(Lambda&& lambda); + template + static std::enable_if_t< + !guts::is_stateless_lambda>::value, + KernelFunction> + makeFromUnboxedLambda(Lambda&& lambda); + + std::string dumpState() const; + // For testing internal invariants only + bool _equalsBoxedAndUnboxed(const KernelFunction&) const; + + private: + explicit KernelFunction( + std::unique_ptr functor, + InternalBoxedKernelFunction* boxed_kernel_func, + void* unboxed_kernel_func, + void* sym_unboxed_kernel_func); + explicit KernelFunction( + BoxedKernel boxed_fn, + void* unboxed_kernel_func, + void* sym_unboxed_kernel_func); + + BoxedKernel boxed_kernel_func_; + void* unboxed_kernel_func_; + void* sym_unboxed_kernel_func_; +}; + +} // namespace c10 + +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/boxing/KernelFunction_impl.h b/phivenv/Lib/site-packages/torch/include/ATen/core/boxing/KernelFunction_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..9d8e0369044a7a76f44b5ae0edc88664878c7bf0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/boxing/KernelFunction_impl.h @@ -0,0 +1,320 @@ +#include +#include +#include +#include + +#include +#include + +namespace c10 { + +namespace detail { +template +std::enable_if_t< + !std::is_array_v && !std::is_array_v && + std::is_base_of_v, + std::unique_ptr> +make_unique_base(Args&&... args) { + return std::unique_ptr(new Child(std::forward(args)...)); +} +} // namespace detail + +inline KernelFunction::KernelFunction() + : boxed_kernel_func_(), + unboxed_kernel_func_(nullptr), + sym_unboxed_kernel_func_(nullptr) {} + +inline KernelFunction::KernelFunction( + std::unique_ptr functor, + InternalBoxedKernelFunction* boxed_kernel_func, + void* unboxed_kernel_func, + void* sym_unboxed_kernel_func = nullptr) + : boxed_kernel_func_(std::move(functor), boxed_kernel_func), + unboxed_kernel_func_(unboxed_kernel_func), + sym_unboxed_kernel_func_(sym_unboxed_kernel_func) {} + +inline KernelFunction::KernelFunction( + BoxedKernel boxed_fn, + void* unboxed_kernel_func, + void* sym_unboxed_kernel_func = nullptr) + : boxed_kernel_func_(std::move(boxed_fn)), + unboxed_kernel_func_(unboxed_kernel_func), + sym_unboxed_kernel_func_(sym_unboxed_kernel_func) {} + +inline bool KernelFunction::isValidUnboxed() const { + return unboxed_kernel_func_ != nullptr; +} + +inline bool KernelFunction::isValidSymUnboxed() const { + return sym_unboxed_kernel_func_ != nullptr; +} + +inline bool KernelFunction::isValid() const { + return boxed_kernel_func_.isValid(); +} + +inline bool KernelFunction::isFallthrough() const { + return boxed_kernel_func_.isFallthrough(); +} + +inline void KernelFunction::callBoxed( + const OperatorHandle& opHandle, + DispatchKeySet dispatchKeySet, + Stack* stack) const { + boxed_kernel_func_.callBoxed(opHandle, dispatchKeySet, stack); +} + +template +inline Return callUnboxedKernelFunction( + void* unboxed_kernel_func, + OperatorKernel* functor, + DispatchKeySet dispatchKeySet, + Args&&... args) { + using ActualSignature = Return(OperatorKernel*, DispatchKeySet, Args...); + ActualSignature* func = + reinterpret_cast(unboxed_kernel_func); + return (*func)(functor, dispatchKeySet, std::forward(args)...); +} + +// This template requires you to explicitly specify the argument you want to +// forward; it doesn't work if you try to deduce it +// NB: keep this in sync with cloneWithRealTypes in function_schema.cpp + +template +inline typename remove_symint::type unpackSymInt(T x) { + return x; +} + +template <> +inline typename remove_symint::type unpackSymInt(c10::SymInt x) { + return x.guard_int(__FILE__, __LINE__); +} + +template <> +inline typename remove_symint::type unpackSymInt( + c10::SymIntArrayRef x) { + return C10_AS_INTARRAYREF_SLOW(x); +} + +template <> +inline typename remove_symint>::type unpackSymInt( + std::optional x) { + return x.has_value() ? std::make_optional(x->guard_int(__FILE__, __LINE__)) + : std::nullopt; +} + +template <> +inline typename remove_symint::type unpackSymInt( + at::OptionalSymIntArrayRef x) { + return x.has_value() ? std::make_optional(C10_AS_INTARRAYREF_SLOW(*x)) + : std::nullopt; +} + +template +C10_ALWAYS_INLINE Return KernelFunction::call( + const OperatorHandle& opHandle, + DispatchKeySet dispatchKeySet, + Args... args) const { + // note: Args above is intentionally not Args&&. We don't want perfect + // forwarding, which would require Args to be deduced, but instead we + // want callers to explicitly specify the Args. + + if constexpr (std::disjunction_v...>) { + if (sym_unboxed_kernel_func_ != nullptr) { + auto* functor = boxed_kernel_func_.getFunctor(); + return callUnboxedKernelFunction( + sym_unboxed_kernel_func_, + functor, + dispatchKeySet, + std::forward(args)...); + } + + if (unboxed_kernel_func_ != nullptr) { + auto* functor = boxed_kernel_func_.getFunctor(); + return callUnboxedKernelFunction< + Return, + typename remove_symint::type...>( + unboxed_kernel_func_, + functor, + dispatchKeySet, + unpackSymInt(args)...); + } + } else { + if (C10_LIKELY(unboxed_kernel_func_ != nullptr)) { + auto* functor = boxed_kernel_func_.getFunctor(); + return callUnboxedKernelFunction( + unboxed_kernel_func_, + functor, + dispatchKeySet, + std::forward(args)...); + } + } + + return impl::BoxedKernelWrapper::call( + boxed_kernel_func_, + opHandle, + dispatchKeySet, + std::forward(args)...); +} + +inline KernelFunction KernelFunction::makeFromBoxedKernel( + BoxedKernel boxed_fn) { + return KernelFunction( + std::move(boxed_fn), nullptr); // no unboxed function pointer +} + +template +inline KernelFunction KernelFunction::makeFromBoxedFunction() { + return KernelFunction::makeFromBoxedKernel( + BoxedKernel::makeFromFunction()); +} + +template +inline KernelFunction KernelFunction::makeFromBoxedFunction() { + return KernelFunction::makeFromBoxedKernel( + BoxedKernel::makeFromFunction()); +} + +inline KernelFunction KernelFunction::makeFallthrough() { + return KernelFunction::makeFromBoxedKernel(BoxedKernel::makeFallthrough()); +} + +inline KernelFunction KernelFunction::makeAmbiguousAutogradOther() { + return KernelFunction::makeFromBoxedKernel( + BoxedKernel::makeAmbiguousAutogradOther()); +} + +inline KernelFunction KernelFunction::makeNamedNotSupported() { + return KernelFunction::makeFromBoxedKernel( + BoxedKernel::makeNamedNotSupported()); +} + +template +inline KernelFunction KernelFunction::makeFromUnboxedFunctor( + std::unique_ptr kernelFunctor) { +#ifndef NDEBUG + // This assertion is costly for build time so it's debug-gated. + static_assert( + guts::is_functor::value, + "Tried to call KernelFunction::makeFromUnboxedFunctor but the argument is not a functor."); +#endif + static_assert( + std::is_base_of_v, + "Tried to call KernelFunction::makeFromUnboxedFunctor, but the functor doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it."); + + auto* unboxed_fn = &impl::wrap_kernel_functor_unboxed::call; + void* void_unboxed_fn = reinterpret_cast(unboxed_fn); + bool is_symint = fn_has_symint::value; + return KernelFunction( + std::move(kernelFunctor), + &impl::make_boxed_from_unboxed_functor:: + call, + is_symint ? nullptr : void_unboxed_fn, + is_symint ? void_unboxed_fn : nullptr); +} + +template +inline KernelFunction KernelFunction::makeFromBoxedFunctor( + std::unique_ptr kernelFunctor) { + return KernelFunction::makeFromBoxedKernel( + BoxedKernel::makeFromFunctor(std::move(kernelFunctor))); +} + +template +inline KernelFunction KernelFunction::makeFromUnboxedFunction( + FuncPtr func_ptr) { + static_assert( + is_compile_time_function_pointer::value, + "Tried to call KernelFunction::makeFromUnboxedFunction with an invalid parameter. It must be a function pointer created with TORCH_FN."); + static_assert( + !std::is_same_v, + "Tried to call KernelFunction::makeFromUnboxedFunction with a boxed function pointer. Please use KernelFunction::makeFromBoxedFunction instead."); +#if defined(__GNUC__) && defined(__SANITIZE_ADDRESS__) && !defined(__CUDACC__) + TORCH_INTERNAL_ASSERT( + FuncPtr::func_ptr() != nullptr, "Kernel function cannot be nullptr"); +#else + static_assert( + FuncPtr::func_ptr() != nullptr, "Kernel function cannot be nullptr"); +#endif + +#if !defined(C10_MOBILE) + (void)func_ptr; // Suppress unused variable warning + return makeFromUnboxedFunctor< + AllowLegacyTypes, + typename impl::WrapFunctionIntoFunctor::type>( + detail::make_unique_base< + OperatorKernel, + typename impl::WrapFunctionIntoFunctor::type>()); +#else + // On mobile, we rather want to optimize for binary size than for performance, + // so let's not inline the kernel into the wrapper but use + // makeFromUnboxedRuntimeFunction instead. + return makeFromUnboxedRuntimeFunction(func_ptr.func_ptr()); +#endif +} + +template +inline KernelFunction KernelFunction::makeFromUnboxedRuntimeFunction( + FuncType* func) { + static_assert( + guts::is_function_type::value, + "Tried to call KernelFunction::makeFromUnboxedRuntimeFunction with a non-function type."); + static_assert( + !std::is_same_v, + "Tried to call KernelFunction::makeFromUnboxedRuntimeFunction with a boxed function pointer. Please use KernelFunction::makeFromBoxedFunction instead."); + TORCH_INTERNAL_ASSERT(func != nullptr, "Kernel function cannot be nullptr"); + + return makeFromUnboxedFunctor< + AllowLegacyTypes, + impl::WrapFunctionIntoRuntimeFunctor>>( + detail::make_unique_base< + OperatorKernel, + impl::WrapFunctionIntoRuntimeFunctor>>(func)); +} + +template +inline std::enable_if_t< + guts::is_stateless_lambda>::value, + KernelFunction> +KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) { + static_assert( + guts::is_functor>::value, + "Tried to call KernelFunction::makeFromUnboxedLambda with a non-lambda type."); + +#if !defined(C10_MOBILE) + return makeFromUnboxedFunctor< + AllowLegacyTypes, + impl::WrapFunctionIntoRuntimeFunctor>>( + detail::make_unique_base< + OperatorKernel, + impl::WrapFunctionIntoRuntimeFunctor>>( + std::forward(lambda))); +#else + // On mobile, we rather want to optimize for binary size than for performance, + // so let's not inline the kernel into the wrapper but use + // makeFromUnboxedRuntimeFunction instead. + using FuncType = + typename guts::infer_function_traits_t>::func_type; + return makeFromUnboxedRuntimeFunction(lambda); +#endif +} + +template +inline std::enable_if_t< + !guts::is_stateless_lambda>::value, + KernelFunction> +KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) { + static_assert( + guts::is_functor>::value, + "Tried to call KernelFunction::makeFromUnboxedLambda with a non-lambda type."); + + return makeFromUnboxedFunctor< + AllowLegacyTypes, + impl::WrapFunctionIntoRuntimeFunctor>>( + detail::make_unique_base< + OperatorKernel, + impl::WrapFunctionIntoRuntimeFunctor>>( + std::forward(lambda))); +} + +} // namespace c10 diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/boxing/OperatorKernel.h b/phivenv/Lib/site-packages/torch/include/ATen/core/boxing/OperatorKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..6083e45043362b461f1eb20e74602e32a6f46b00 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/boxing/OperatorKernel.h @@ -0,0 +1,27 @@ +#pragma once +#include + +namespace c10 { + +/** + * Inherit from OperatorKernel to implement a c10 kernel. + * + * Example: + * > namespace { + * > class my_kernel_cpu final : public c10::OperatorKernel { + * > public: + * > Tensor operator()(Tensor a, Tensor b) {...} + * > }; + * > } + * + * The kernel class is allowed to have members but these are equivalent + * to global variables. The kernel implementation is responsible for + * preventing race conditions on them. + * + * See below for how to register this kernel with PyTorch. + */ +struct TORCH_API OperatorKernel : public c10::intrusive_ptr_target { + ~OperatorKernel() override = default; +}; + +} // namespace c10 diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h b/phivenv/Lib/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h new file mode 100644 index 0000000000000000000000000000000000000000..f2d192bc844270ffdbca8486e48482350710e8ff --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h @@ -0,0 +1,38 @@ +#pragma once + +#include + +namespace c10::impl { +namespace detail { +template +class WrapFunctionIntoFunctor_ {}; +template +class WrapFunctionIntoFunctor_< + FuncPtr, + ReturnType, + guts::typelist::typelist> + final : public c10::OperatorKernel { + public: + C10_ALWAYS_INLINE decltype(auto) operator()(Parameters... args) { + return (*FuncPtr::func_ptr())(std::forward(args)...); + } +}; +} // namespace detail + +// WrapFunctionIntoFunctor: Wraps a compile time function pointer into a kernel +// functor. Since it is a compile time function pointer, many compilers can +// inline it into the wrapper and you don't get any performance overhead for +// wrapping. +template +struct WrapFunctionIntoFunctor final { + static_assert( + c10::is_compile_time_function_pointer::value, + "WrapFunctionIntoFunctor can only wrap functions created with TORCH_FN."); + using type = detail::WrapFunctionIntoFunctor_< + FuncPtr, + typename guts::function_traits::return_type, + typename guts::function_traits< + typename FuncPtr::FuncType>::parameter_types>; +}; + +} // namespace c10::impl diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h b/phivenv/Lib/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h new file mode 100644 index 0000000000000000000000000000000000000000..4fc1c0904de391f4ae411eff04b5c8d604952cdf --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h @@ -0,0 +1,41 @@ +#pragma once + +#include + +namespace c10::impl { + +namespace detail { +template +class WrapFunctionIntoRuntimeFunctor_ {}; +template +class WrapFunctionIntoRuntimeFunctor_< + FuncType, + ReturnType, + guts::typelist::typelist> + final : public c10::OperatorKernel { + public: + template + explicit WrapFunctionIntoRuntimeFunctor_(FuncType_&& kernel_func) + : kernel_func_(std::forward(kernel_func)) {} + + decltype(auto) operator()(Parameters... args) { + return kernel_func_(std::forward(args)...); + } + + private: + FuncType kernel_func_; +}; +} // namespace detail + +// WrapFunctionIntoRuntimeFunctor: Wraps any runtime functor into a functor that +// inherits from c10::OperatorKernel, so it can be used as a c10 kernel. +// This can, for example, be used for lambdas, functors or even function +// pointers. In the case of function pointers, since it is a runtime function +// pointer, there is an overhead for calling it whenever the kernel is invoked. +template +using WrapFunctionIntoRuntimeFunctor = detail::WrapFunctionIntoRuntimeFunctor_< + FuncType, + typename guts::infer_function_traits_t::return_type, + typename guts::infer_function_traits_t::parameter_types>; + +} // namespace c10::impl diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/boxing/impl/boxing.h b/phivenv/Lib/site-packages/torch/include/ATen/core/boxing/impl/boxing.h new file mode 100644 index 0000000000000000000000000000000000000000..6e234c1b65e2a3d0d9c47e5eb81feece876f822e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/boxing/impl/boxing.h @@ -0,0 +1,410 @@ +#pragma once + +// This file contains boxing (not unboxing) logic, +// i.e. how to make a vector from a set of concrete arguments. + +#include +#include +#include + +#include + +#include +#include + +namespace c10::impl { + +// +// utils +// + +// is_mutable_tensor_ref +template +struct is_mutable_tensor_ref : std::false_type {}; +template <> +struct is_mutable_tensor_ref : std::true_type {}; + +// is_tuple_of_mutable_tensor_refs +// +template +struct is_tuple_of_mutable_tensor_refs : std::false_type {}; + +template +struct is_tuple_of_mutable_tensor_refs< + T, + std::enable_if_t::value, void>> + : guts::typelist:: + all> {}; + +// has_ivalue_to tests the presence/absence of instance method +// IValue::to() +// +template +struct has_ivalue_to : std::false_type {}; + +template +struct ivalue_to_helper { + using type = decltype(std::declval().template to()); +}; +template +using ivalue_to_helper_t = typename ivalue_to_helper::type; + +template +struct has_ivalue_to>> : std::true_type {}; + +// +// boxing predicates +// + +// A boxable arg type is one that IValue has a constructor for. +template +using can_box = std::disjunction< + std::is_constructible>, + // TensorOptions are not directly constructible into IValue, + // but torch::jit::push knows how to handle them + std::is_same>>; + +template +using can_box_all = std::conjunction...>; + +// an unboxable result is one that can be extracted from an IValue +template +using can_unbox = std::conjunction< + std::disjunction< + has_ivalue_to, + // void returns are ok + std::is_same>, + std::negation>>; + +// +// boxArgs - utility for pushing unboxed args onto IValue stack +// +template +torch::jit::Stack boxArgs(Args... args) { + // TODO Reuse stack vector instead of allocating? + torch::jit::Stack stack; + stack.reserve(sizeof...(Args)); + torch::jit::push(stack, std::forward(args)...); + return stack; +} + +template +inline constexpr size_t boxed_size_one() { + static_assert( + !std::is_same_v, c10::TensorOptions>, + "need to patch this path to support TensorOptions passed by reference"); + return 1; +} + +// torch::jit::push pushes 4 values for a TensorOptions; this needs to +// be kept in sync. +template <> +inline constexpr size_t boxed_size_one() { + return 4; +} + +// NOTE: this could probably be simplified with C++17 fold expressions. +template +struct BoxedSize : std::integral_constant {}; +template +struct BoxedSize + : std::integral_constant< + size_t, + boxed_size_one() + BoxedSize::value> {}; + +template +static inline constexpr size_t boxed_size() { + return BoxedSize::value; +} + +template +C10_ALWAYS_INLINE_UNLESS_MOBILE void boxToStack(IValue*& dest, T& arg) { + new (dest++) IValue(arg); +} + +C10_ALWAYS_INLINE_UNLESS_MOBILE void boxToStack( + IValue*& dest, + c10::TensorOptions options) { + new (dest++) IValue(c10::typeMetaToScalarType(options.dtype())); + new (dest++) IValue(options.layout()); + new (dest++) IValue(options.device()); + new (dest++) IValue(options.pinned_memory()); +} + +inline void boxArgsToStack(IValue*&) {} + +template +C10_ALWAYS_INLINE_UNLESS_MOBILE void boxArgsToStack( + IValue*& dest, + T& arg, + Args&... args) { + boxToStack(dest, arg); + boxArgsToStack(dest, args...); +} + +// +// PopResult is a helper class whose specializations handle popping single and +// multiple return values, respectively. +// +template +struct PopResult final { + static Result call(Stack& stack) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + stack.size() == 1, + "Boxed kernel was expected to return one value on the stack, ", + "but instead pushed ", + stack.size(), + " values."); + return std::move(stack[0]).to(); + } +}; + +template +struct PopResult> final { + using Result = std::tuple; + + static Result call(Stack& stack) { + // for tuple return types, boxed kernel has pushed multiple values onto the + // stack + constexpr int RetCount = sizeof...(Types); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + stack.size() == RetCount, + "Boxed kernel was expected to return ", + RetCount, + " values on the stack, ", + "but instead pushed ", + stack.size(), + " values."); + return pop_to_tuple_impl(stack, std::make_index_sequence()); + } + + private: + // note: this has been moved into its own helper only to avoid a parse error + // on `indices` otherwise. I'm sure there's an incantation that slips it past + // the parser but eh + template + static Result pop_to_tuple_impl( + Stack& stack, + std::index_sequence) { + return std::make_tuple((std::move(stack[indices]).template to())...); + } +}; + +// +// BoxedKernelWrapper +// +// For a given function type FT, BoxedKernelWrapper implements +// a `call` method that +// - takes a boxed kernel and unboxed arguments as specified by FT, +// - calls `boxArgs` to box the arguments +// - calls the boxed kernel +// - unboxes and returns the result +// +// The partial specializations below handle various cases: in +// particular, not all types appearing in op signatures are supported, +// and ops returning references have nonstandard wrapper implementations. +// + +// 1. The base specialization of BoxedKernelWrapper should never be +// instantiated. A "no call method defined on BoxedKernelWrapper" compile error +// means that an op signature has failed to trigger any of the partial +// specializations that follow this one. +// +template +struct BoxedKernelWrapper { + // The reason we're not just doing straight up static_assert(false, ...) here: + // Basically, the way to make sure a static_assert only fires if a template + // is actually instantiated (rather than every time the file is parsed) is to + // use template parameters in the expression, e.g. FuncType here. However, + // since `sizeof(FuncType) != sizeof(FuncType)` is always false, this has the + // same effect. + static_assert( + sizeof(FuncType) != sizeof(FuncType), + "Function signature contains one or more unsupported parameter and/or return types. " + "Look for a nearby error like " + "\"'call' is not a member of 'c10::impl::BoxedKernelWrapper<(your function type), void>'\" " + "- (your function type) is the unsupported signature."); +}; + +// +// 2. Supported signatures, other than those involving non-const Tensor refs - +// i.e., "functional" ops. +// + +template +struct BoxedKernelWrapper< + Result(Args...), + std::enable_if_t< + can_box_all::value && can_unbox::value && + !is_tuple_of_mutable_tensor_refs::value, + void>> { + static Result call( + const BoxedKernel& boxed_kernel_func, + const OperatorHandle& opHandle, + DispatchKeySet dispatchKeySet, + Args... args) { + torch::jit::Stack stack = boxArgs(std::forward(args)...); + boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack); + + if constexpr (!std::is_same_v) { + // op has pushed one or more values onto the stack. + return PopResult::call(stack); + } else { + // op returns void, boxed kernel has pushed nothing onto stack. + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + stack.empty(), + "Boxed kernel was expected to return no values on the stack, ", + "but instead returned ", + stack.size(), + " values."); + } + } +}; + +// +// 3. in-place ops take a single non-const Tensor reference +// as their first argument, and return it. +// +// Note: all signatures matching this pattern are assumed to be for such ops. +// Because of this, the generated BoxedKernelWrapper specializations simply +// return the in-place argument. +// + +template +struct BoxedKernelWrapper< + at::Tensor&(at::Tensor&, OtherArgs...), + std::enable_if_t::value, void>> { + static at::Tensor& call( + const BoxedKernel& boxed_kernel_func, + const OperatorHandle& opHandle, + DispatchKeySet dispatchKeySet, + at::Tensor& outArg, + OtherArgs... otherArgs) { + torch::jit::Stack stack = boxArgs( + outArg, std::forward(otherArgs)...); + boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + stack.size() == 1, + "Boxed kernel was expected to return a single value on the stack, ", + "but instead returned ", + stack.size(), + " values."); + + return outArg; + } +}; + +// +// 3.5. In-process migration to make in-place ops take and return +// const references instead. +template +struct BoxedKernelWrapper< + const at::Tensor&(const at::Tensor&, OtherArgs...), + std::enable_if_t::value, void>> { + static const at::Tensor& call( + const BoxedKernel& boxed_kernel_func, + const OperatorHandle& opHandle, + DispatchKeySet dispatchKeySet, + const at::Tensor& outArg, + OtherArgs... otherArgs) { + torch::jit::Stack stack = boxArgs(outArg, otherArgs...); + boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + stack.size() == 1, + "Boxed kernel was expected to return a single value on the stack, ", + "but instead returned ", + stack.size(), + " values."); + + return outArg; + } +}; + +// +// 4. out of place ops that take a single non-const Tensor reference as their +// final argument, and also return it. +// +// Note: all signatures matching this pattern are assumed to be for such ops. +// This assumption permits the generated BoxedKernelWrapper specializations to +// simply return out arguments. +// +template +struct BoxedKernelWrapper< + at::Tensor&(FirstArg, RestArgs...), + std::enable_if_t< + can_box_all::value + // this skips over in-place kernels with a non-const Tensor + // arg at the front, so those can unambiguously trigger the + // preceding specialization. + && !is_mutable_tensor_ref::value, + void>> { + static at::Tensor& call( + const BoxedKernel& boxed_kernel_func, + const OperatorHandle& opHandle, + DispatchKeySet dispatchKeySet, + FirstArg firstArg, + RestArgs... restArgs) { + torch::jit::Stack stack = boxArgs( + std::forward(firstArg), std::forward(restArgs)...); + boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + stack.size() == 1, + "Boxed kernel was expected to return a single value on the stack, ", + "but instead returned ", + stack.size(), + " values."); + + // reusing restArgs after it has been forwarded here is ok because we know + // that the last element is of type `Tensor&`. + return std::get( + std::tuple{restArgs...}); + } +}; + +// +// 5. out of place ops that take multiple non-const Tensor references as their +// final arguments, and return them in a std::tuple. +// +// Note: all signatures matching this pattern are assumed to be for such ops. +// This assumption permits the generated BoxedKernelWrapper specializations to +// simply return the out arguments. +// +template +struct BoxedKernelWrapper< + Result(Args...), + std::enable_if_t< + can_box_all::value && + is_tuple_of_mutable_tensor_refs::value, + void>> { + static Result call( + const BoxedKernel& boxed_kernel_func, + const OperatorHandle& opHandle, + DispatchKeySet dispatchKeySet, + Args... args) { + using ArgTuple = std::tuple; + constexpr int RetCount = std::tuple_size(); + + torch::jit::Stack stack = boxArgs(std::forward(args)...); + boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + stack.size() == RetCount, + "Boxed kernel was expected to return ", + RetCount, + " values on the stack, ", + "but instead returned ", + stack.size(), + " values."); + + // reusing args after it has been forwarded here is ok because we know + // that the last RetCount elements are of type `Tensor&`. + auto result = guts::tuple_take( + ArgTuple{std::forward(args)...}); + static_assert( + std::is_same_v, + "The parameter list of an op returning a tuple of Tensor references " + "must end with an equal number of Tensor reference parameters."); + return result; + } +}; + +} // namespace c10::impl diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h b/phivenv/Lib/site-packages/torch/include/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h new file mode 100644 index 0000000000000000000000000000000000000000..66d7cad8d960f7cef7303ad404de99839d711703 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h @@ -0,0 +1,785 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace c10 { + +using Stack = torch::jit::Stack; // TODO Instead of this, move torch::jit::Stack + // to the c10 namespace. +class OperatorHandle; + +/* + * [Note: Argument forwarding in the dispatcher] + * + * The dispatcher uses a somewhat unusual way to forward arguments through + * several layers of wrapper functions. This can be confusing because an + * experienced C++ programmer would look at this and think "oh this is supposed + * to be forwarding a universal reference but the && is missing. This is a + * bug.". It is not a bug. The common way in C++ to forward arguments is to use + * universal references: + * + * > template void func(T&& arg) { func2(std::forward(arg)); } + * + * but that relies on inferring the correct reference type (i.e. value vs & vs + * &&) from the argument. In our case, we cannot rely on the argument as + * supplied by the caller, because that could infer a different reference type + * than was used in the kernel function. The correct reference type is dictated + * by the kernel signature and must be identical since we cast function pointers + * through void* pointers and mismatches would be UB. So we need a forwarding + * pattern that determines the reference type to use by looking at the + * explicitly supplied operator signature, not by looking at the argument we're + * calling it with. + * + * What does std::forward do, exactly? + * ------------------------------------ + * std::forward(t) is a way to cast t to the reference type supplied in T. + * Let's assume decay_t == U and T is either U or some reference of U. + * - std::forward(t) will return U&, no matter what kind of reference t is. + * - std::forward(t) will return U&&, no matter what kind of reference t + * is. + * - std::forward(t) will return U&& (not U!), no matter what kind of + * reference t is. + * + * For universal references, that means that in the following function + * > template void func(T&& arg) { func2(std::forward(arg)); } + * + * - when called with arg being a rvalue reference or non-reference value, T + * gets inferred to be a non-reference U, and std::forward(t) will return + * U&&, correctly moving the argument. + * - when called with arg behind a lvalue reference, T gets inferred to be U& + * because that's the only way to match the signature (in C++, a type that is + * (T&)&& will collapse to T&). That means std::forward(t) will return U& and + * the value will not be moved but passed on as a lvalue reference. + * + * How do we use that? + * ------------------------------------ + * But std::forward can also be used outside of the common "universal + * forwarding" pattern to change reference types. So instead of following the + * common C++ pattern, we notice what std::forward() actually does, and that + * is it takes a value and changes its reference to the type of reference passed + * in as T. If we don't infer T but explicitly specify it, we can use this to + * forward based on an explicitly specified reference type instead of the + * inferred argument type. + * + * This is why many of the dispatcher functions look like + * > template func(T t) { func2(std::forward(t)); } + * instead of the common + * > template func(T&& t) { func2(std::forward(t)); } + * + * and are expected to be called by explicitly specifying the template + * parameters in a way that matches the expected operator signature at each call + * site. + */ + +namespace impl { +// supported_primitive_arg_types defines which primitive types we allow in +// kernel functions as arguments or returns. +// Additionally, we support lists, dicts and optionals containing these types. +using supported_primitive_arg_types = guts::typelist::typelist< + int64_t, + double, + bool, + std::string_view, + at::Tensor, + at::Scalar, + c10::QScheme, + c10::ScalarType, + c10::Device, + c10::DeviceIndex, + c10::Layout, + c10::MemoryFormat, + at::Dimname>; + +// We have an unboxed functor in hand that takes C++ arguments, and +// we're building a boxed functor wrapper for it that takes IValues. +// So "outside" is boxed and "inside" is unboxed. +// +// So a valid input type is one that our boxed functor wrapper can +// unbox from an IValue into a C++ value. +// +// Whereas a valid output type is one that our wrapper can recieve +// as a C++ value from the unboxed functor, and box into an IValue. + +// +// assert_is_valid_input_type +// checks that T can be unboxed from an IValue into a C++ value. +// + +template +struct assert_is_valid_input_type { + assert_is_valid_input_type() { + if constexpr (guts::typelist::contains:: + value) { + /* everything is ok, this is a primitive type */ + } else { + /* otherwise this must be an instance of a valid custom class, since it + can only have been created via IValue(x), which ensures this. */ + } + } +}; + +template +struct assert_is_valid_input_type, AllowDeprecatedTypes> + : assert_is_valid_input_type {}; + +template +struct TypeCheckHelper; + +template +struct TypeCheckHelper {}; + +template +struct TypeCheckHelper + : TypeCheckHelper { + assert_is_valid_input_type check; +}; + +template +struct assert_is_valid_input_type< + std::tuple, + AllowDeprecatedTypes> + : TypeCheckHelper {}; + +template +struct assert_is_valid_input_type, AllowDeprecatedTypes> + : assert_is_valid_input_type { + static_assert( + guts::typelist::contains::value, + "You tried to register a kernel with an unsupported input type: Dict where Key is invalid. We only support int64_t, double, bool, and string."); +}; + +template +struct assert_is_valid_input_type< + std::unordered_map, + AllowDeprecatedTypes> + : assert_is_valid_input_type { + static_assert( + AllowDeprecatedTypes, + "You tried to register a kernel with an unsupported input type: std::unordered_map. Please use Dict instead."); + static_assert( + guts::typelist::contains::value, + "You tried to register a kernel with an unsupported input type: std::unordered_map where Key is invalid. We only support int64_t, double, bool, and string."); +}; + +template +struct assert_is_valid_input_type, AllowDeprecatedTypes> + : assert_is_valid_input_type { + static_assert( + !std::is_same_v, + "You tried to register a kernel with an unsupported input type: List. Please use List, List or Tensor instead."); +}; + +template +struct assert_is_valid_input_type, AllowDeprecatedTypes> + : assert_is_valid_input_type { + static_assert( + !std::is_same_v, + "You tried to register a kernel with an unsupported input type: ArrayRef. Please use List, List or Tensor instead."); +}; + +template +struct assert_is_valid_input_type< + c10::OptionalArrayRef, + AllowDeprecatedTypes> + : assert_is_valid_input_type { + static_assert( + !std::is_same_v, + "You tried to register a kernel with an unsupported input type: OptionalArrayRef. Please use List, List or Tensor instead."); +}; + +template +struct assert_is_valid_input_type, AllowDeprecatedTypes> + : assert_is_valid_input_type { + static_assert( + !std::is_same_v, + "You tried to register a kernel with an unsupported input type: std::array. Please use std::array instead."); +}; + +template +struct assert_is_valid_input_type< + T, + AllowDeprecatedTypes, + std::enable_if_t>> { + // There is no reason to support float when we have double. Keep the API lean. + static_assert( + guts::false_t::value, + "You tried to register a kernel with an unsupported input type: float. Please use double instead; you should use `double` in the C++ function signature and `float` in the schema string."); +}; +template +struct assert_is_valid_input_type< + T, + AllowDeprecatedTypes, + std::enable_if_t>> { + static_assert( + guts::false_t::value, + "You tried to register a kernel with an unsupported input type: const char*. Please use std::string_view instead."); +}; +template +struct assert_is_valid_input_type< + T, + AllowDeprecatedTypes, + std::enable_if_t, T>>> { + static_assert( + guts::false_t::value, + "You tried to register a kernel with an unsupported input type: vector. Please use List instead."); +}; +template +struct assert_is_valid_input_type< + T, + AllowDeprecatedTypes, + std::enable_if_t< + std::is_integral_v && + !guts::typelist::contains::value>> { + static_assert( + guts::false_t::value, + "You tried to register a kernel with an unsupported integral input type. Please use int64_t instead; you should use `int64_t` in the C++ function signature and `int` in the schema string."); +}; +template +struct assert_is_valid_input_type< + T, + AllowDeprecatedTypes, + std::enable_if_t>> { + static_assert( + guts::false_t::value, + "You tried to register a kernel taking c10::SymInt by reference. Please accept it by value instead."); +}; + +// TODO: it probably would be good to tighten this up quite a bit more with +// an explicit list for everything + +// +// assert_is_valid_output_type +// + +template +struct assert_is_valid_output_type { + assert_is_valid_output_type() { + if constexpr (guts::typelist::contains:: + value) { + /* everything is ok, this is a primitive type */ + } else { + /* otherwise T is verified to be a registered custom class in the IValue + constructor, so no benefit in double-checking here */ + } + } +}; + +template +struct assert_is_valid_output_type, AllowDeprecatedTypes> + : assert_is_valid_output_type {}; + +template +struct assert_is_valid_output_type< + c10::OptionalArrayRef, + AllowDeprecatedTypes> + : assert_is_valid_output_type {}; + +template +struct assert_is_valid_output_type, AllowDeprecatedTypes> + : assert_is_valid_output_type { + static_assert( + guts::typelist::contains::value, + "You tried to register a kernel with an unsupported output type: Dict where Key is invalid. We only support int64_t, double, bool, and string."); + static_assert( + !std::is_same_v, + "You tried to register a kernel with an unsupported output type: Dict. Please use Dict or Dict."); +}; + +template +struct assert_is_valid_output_type< + std::unordered_map, + AllowDeprecatedTypes> + : assert_is_valid_output_type { + static_assert( + AllowDeprecatedTypes, + "You tried to register a kernel with an unsupported output type: std::unordered_map. Please use Dict instead."); + static_assert( + guts::typelist::contains::value, + "You tried to register a kernel with an unsupported output type: std::unordered_map where Key is invalid. We only support int64_t, double, bool, and string."); + static_assert( + !std::is_same_v, + "You tried to register a kernel with an unsupported output type: std::unordered_map. Please use Dict or Dict."); +}; + +template +struct assert_is_valid_output_type, AllowDeprecatedTypes> + : assert_is_valid_output_type { + static_assert( + !std::is_same_v, + "You tried to register a kernel with an unsupported output type: List. Please use List, List or Tensor instead."); +}; + +template +struct assert_is_valid_output_type, AllowDeprecatedTypes> + : assert_is_valid_output_type { + static_assert( + !std::is_same_v, + "You tried to register a kernel with an unsupported output type: std::vector. Please use List, List or Tensor instead."); + // TODO static_assert(AllowDeprecatedTypes, "You tried to register a kernel + // with an unsupported output type: std::vector. Please use List + // instead."); +}; + +template +struct assert_is_valid_output_type, AllowDeprecatedTypes> + : assert_is_valid_output_type { + static_assert( + !std::is_same_v, + "You tried to register a kernel with an unsupported output type: std::array. Please use std::array instead."); +}; + +// The following specialisations of assert_is_valid_output_type are technically +// not necessary since we would hit the base case and show an error message +// there if they didn't exist, but we can show a better error message +// in some common error scenarios. +template +struct assert_is_valid_output_type< + T, + AllowDeprecatedTypes, + std::enable_if_t>> { + // There is no reason to support float when we have double. Keep the API lean. + static_assert( + guts::false_t::value, + "You tried to register a kernel with an unsupported output type: float. Please use double instead; you should use `double` in the C++ function signature and `float` in the schema string."); +}; +template +struct assert_is_valid_output_type< + T, + AllowDeprecatedTypes, + std::enable_if_t>> { + static_assert( + guts::false_t::value, + "You tried to register a kernel with an unsupported output type: const char*. Please use std::string_view instead."); +}; +template +struct assert_is_valid_output_type< + T, + AllowDeprecatedTypes, + std::enable_if_t, T>>> { + static_assert( + guts::false_t::value, + "You tried to register a kernel with an unsupported output type: vector. Please use List instead."); +}; +template +struct assert_is_valid_output_type< + T, + AllowDeprecatedTypes, + std::enable_if_t< + std::is_integral_v && + !guts::typelist::contains::value>> { + static_assert( + guts::false_t::value, + "You tried to register a kernel with an unsupported integral output type. Please use int64_t instead; you should use `int64_t` in the C++ function signature and `int` in the schema string."); +}; + +// ivalue_to_arg + +template +struct decay_if_not_tensor final { + using type = std::decay_t; +}; + +template <> +struct decay_if_not_tensor final { + using type = at::Tensor&; +}; + +template <> +struct decay_if_not_tensor final { + using type = const at::Tensor&; +}; + +template +struct ivalue_to_arg final { + static decltype(auto) call(IValue& v) { + assert_is_valid_input_type(); + return std::move(v).to(); + } +}; + +// The following two specializations take advantage of specialized +// `toTensor()` overloads on IValue to avoid copying. +template +struct ivalue_to_arg final { + // We cannot use the default implementation if they asked for a + // `at::Tensor&` because it moves from the IValue, so it can't get + // an lvalue reference. + static at::Tensor& call(IValue& v) { + // Tensor& is valid, don't bother asserting + return v.toTensor(); + } +}; + +template +struct ivalue_to_arg final { + // We should not use the default implementation if they asked for + // a `const at::Tensor&` because it moves from the IValue and they + // didn't ask for that. + static const at::Tensor& call(IValue& v) { + // const Tensor& is valid, don't bother asserting + return v.toTensor(); + } +}; + +template +struct ivalue_to_arg final { + static List call(IValue& v) { + return v.toTensorList(); + } +}; + +template +struct ivalue_to_arg, AllowDeprecatedTypes> final { + // If an argument is ArrayRef, convert the IValue to a std::vector and + // pass that to the operator. std::vector is implicitly convertible to + // ArrayRef. + static std::vector call(IValue& v) { + return ivalue_to_arg, AllowDeprecatedTypes>::call(v); + } +}; +template +struct ivalue_to_arg final { + static std::vector call(IValue& v) { + if (v.isIntList()) { + std::vector r; + auto src = v.toIntList(); + std::transform( + src.begin(), src.end(), std::back_inserter(r), [](int64_t i) { + return c10::SymInt(i); + }); + return r; + } else { + return ivalue_to_arg, AllowDeprecatedTypes>:: + call(v); + } + } +}; +template +struct ivalue_to_arg, AllowDeprecatedTypes> + final { + static OptionalArray call(IValue& v) { + if (v.isIntList()) { + std::vector r; + auto src = v.toIntList(); + std::transform( + src.begin(), src.end(), std::back_inserter(r), [](int64_t i) { + return c10::SymInt(i); + }); + return OptionalArray(std::move(r)); + } else { + return std::move(v).to>(); + } + } +}; +template +struct ivalue_to_arg>, AllowDeprecatedTypes> final { + // If an argument is std::optional>, convert the IValue to an + // std::optional> and pass that to the operator. + // OptionalArray is basically a std::optional> but + // implicitly convertible to std::optional>. + static OptionalArray call(IValue& v) { + return ivalue_to_arg, AllowDeprecatedTypes>::call(v); + } +}; + +template +struct ivalue_to_arg, AllowDeprecatedTypes> final { + // If an argument is OptionalArrayRef, convert the IValue to an + // std::optional> and pass that to the operator. + // OptionalArray is basically a std::optional> but + // implicitly convertible to OptionalArrayRef + static OptionalArray call(IValue& v) { + return ivalue_to_arg, AllowDeprecatedTypes>::call(v); + } +}; + +// return_to_ivalue +template +struct return_to_ivalue final {}; + +template +struct return_to_ivalue< + T, + AllowDeprecatedTypes, + std::enable_if_t>> + final { + static IValue call(T&& v) { + assert_is_valid_output_type(); + return c10::ivalue::from(std::move(v)); + } + static IValue copy(const T& v) { + assert_is_valid_output_type(); + return IValue(v); + } +}; + +// Special case to allow kernels to return `Tensor&`. +// TODO Delete this once kernels don't do that anymore +template +struct return_to_ivalue final { + static IValue call(at::Tensor& v) { + return c10::ivalue::from(v); + } + static IValue copy(at::Tensor& v) { + return IValue(v); + } +}; + +// wrap_kernel_functor_unboxed_ + +template +struct wrap_kernel_functor_unboxed_ final {}; + +// This specialization is for kernels with a first argument that is NOT of type +// DispatchKeySet This includes kernels with 0 arguments. +template +struct wrap_kernel_functor_unboxed_< + KernelFunctor, + ReturnType(ParameterTypes...)> + final { + static_assert( + std::is_same_v< + ReturnType, + typename guts::infer_function_traits_t::return_type>, + "Return type mismatch"); + static_assert( + std::is_same_v< + guts::typelist::typelist, + typename guts::infer_function_traits_t< + KernelFunctor>::parameter_types>, + "Parameter types mismatch"); + + // See [Note: Argument forwarding in the dispatcher] for why ParameterTypes + // doesn't use && + static ReturnType call( + OperatorKernel* functor, + DispatchKeySet, + ParameterTypes... args) { + KernelFunctor* functor_ = static_cast(functor); + // Note [Plumbing Keys Through The Dispatcher 2] + // See Note [Plumbing Keys Through The Dispatcher] for the background. + // This functor explicitly takes in a dispatchKeySet and drops it on the + // floor- it does not forward it to the registered kernel. + // + // This is due to the calling convention within the dispatcher, which + // expects all registered kernels to have a first argument of type + // DispatchKeySet. + // This is not the case for pretty much all manually written kernels, + // however- this functor serves to separate the calling convention of the + // dispatcher from the calling convention of manually written kernels. + return (*functor_)(std::forward(args)...); + } +}; + +// This specialization is for kernels with a first argument of type +// DispatchKeySet +template +struct wrap_kernel_functor_unboxed_< + KernelFunctor, + ReturnType(DispatchKeySet, ParameterTypes...)> + final { + static_assert( + std::is_same_v< + ReturnType, + typename guts::infer_function_traits_t::return_type>, + "Return type mismatch"); + static_assert( + std::is_same_v< + guts::typelist::typelist, + typename guts::infer_function_traits_t< + KernelFunctor>::parameter_types>, + "Parameter types mismatch"); + + // See [Note: Argument forwarding in the dispatcher] for why ParameterTypes + // doesn't use && + static ReturnType call( + OperatorKernel* functor, + DispatchKeySet dispatchKeySet, + ParameterTypes... args) { + KernelFunctor* functor_ = static_cast(functor); + // We're explicitly taking in a dispatchKeySet and forwarding it to the + // registered kernel. See Note [Plumbing Keys Through The Dispatcher 2] for + // details. + return (*functor_)(dispatchKeySet, std::forward(args)...); + } +}; + +template +using wrap_kernel_functor_unboxed = wrap_kernel_functor_unboxed_< + KernelFunctor, + typename guts::infer_function_traits_t::func_type>; + +// call_functor_with_args_from_stack + +template < + class Functor, + bool AllowDeprecatedTypes, + size_t... ivalue_arg_indices, + typename... ArgTypes> +std::decay_t::return_type> +call_functor_with_args_from_stack_( + OperatorKernel* functor, + DispatchKeySet dispatchKeySet, + Stack* stack, + std::index_sequence, + guts::typelist::typelist*) { + (void)(stack); // when sizeof...(ivalue_arg_indices) == 0, this argument would + // be unused and we have to silence the compiler warning. + + // We're explicitly filtering out DispatchKeySet from the argument list. + // Some kernels take a DispatchKeySet as their first argument in order to + // plumb keys through the dispatcher. We don't want to expose the + // DispatchKeySet type to jit, so we don't include this argument on the stack. + // See Note [Plumbing Keys Through The Dispatcher] for the background. + return wrap_kernel_functor_unboxed::call( + functor, + dispatchKeySet, + ivalue_to_arg< + typename decay_if_not_tensor::type, + AllowDeprecatedTypes>:: + call(torch::jit::peek( + *stack, ivalue_arg_indices, sizeof...(ivalue_arg_indices)))...); +} + +template +std::decay_t::return_type> +call_functor_with_args_from_stack( + OperatorKernel* functor, + DispatchKeySet dispatchKeySet, + Stack* stack) { + // We're explicitly filtering out DispatchKeySet from the argument list. + // Some kernels take a DispatchKeySet as their first argument in order to + // plumb keys through the dispatcher. We don't want to expose the + // DispatchKeySet type to jit, so we don't include this argument on the stack. + // See Note [Plumbing Keys Through The Dispatcher] for the background. + using ArgTypes = typename c10::remove_DispatchKeySet_arg_from_func< + Functor>::parameter_types; + constexpr size_t num_ivalue_args = guts::typelist::size::value; + return call_functor_with_args_from_stack_( + functor, + dispatchKeySet, + stack, + std::make_index_sequence(), + static_cast(nullptr)); +} + +// push_outputs + +template +struct push_outputs final { + // Contrary to [Note: Argument forwarding in the dispatcher], we use + // OutputType&& here to avoid one extra call to the move constructor in this + // case. This is still not a universal reference though because OutputType is + // an explicitly specified class template parameter. + static void call(OutputType&& output, Stack* stack) { + torch::jit::push( + *stack, + return_to_ivalue::call( + std::forward(output))); + } + static void copy(const OutputType& output, Stack* stack) { + torch::jit::push( + *stack, + return_to_ivalue::copy(output)); + } +}; +template +struct push_outputs, AllowDeprecatedTypes> final { + static void call(std::tuple&& output, Stack* stack) { + call_( + std::move(output), + stack, + std::make_index_sequence()); + } + static void copy(const std::tuple& output, Stack* stack) { + copy_(output, stack, std::make_index_sequence()); + } + + private: + template + static void call_( + std::tuple&& output, + Stack* stack, + std::index_sequence) { + torch::jit::push( + *stack, + return_to_ivalue::call( + std::forward(std::get(output)))...); + } + template + static void copy_( + const std::tuple& output, + Stack* stack, + std::index_sequence) { + torch::jit::push( + *stack, + return_to_ivalue::copy( + std::get(output))...); + } +}; +template +struct push_outputs final { + static void call(int /*dummy*/, Stack* /*stack*/) {} + static void copy(int /*dummy*/, Stack* /*stack*/) {} +}; + +// make_boxed_from_unboxed_functor + +template +struct make_boxed_from_unboxed_functor final { + static_assert( + std::is_base_of_v, + "Tried to register a kernel functor using the kernel() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it."); + + static void call( + OperatorKernel* functor, + const OperatorHandle&, + DispatchKeySet dispatchKeySet, + Stack* stack) { + using ReturnType = + typename guts::infer_function_traits_t::return_type; + // We're explicitly filtering out DispatchKeySet from the argument list. + // Some kernels take a DispatchKeySet as their first argument in order to + // plumb keys through the dispatcher. We don't want to expose the + // DispatchKeySet type to jit, so we don't include this argument on the + // stack. See Note [Plumbing Keys Through The Dispatcher] for the + // background. + using ArgTypes = typename c10::remove_DispatchKeySet_arg_from_func< + KernelFunctor>::parameter_types; + constexpr bool has_outputs = !std::is_same_v; + constexpr size_t num_inputs = guts::typelist::size::value; + if constexpr (has_outputs) { + // Decay ReturnType to ReturnType_ so that if a reference gets returned, + // we actually store it by value and don't get a dangling reference. This + // is only required because some kernels still return `Tensor&`. [Note: + // VC++ and 'std': ambiguous symbol] + using ReturnType_ = ::std::decay_t; + ReturnType_ output = call_functor_with_args_from_stack< + KernelFunctor, + AllowDeprecatedTypes>(functor, dispatchKeySet, stack); + torch::jit::drop(*stack, num_inputs); + // See note [ VC++ and 'std': ambiguous symbol] + push_outputs::call( + ::std::move(output), stack); + } else { + call_functor_with_args_from_stack( + functor, dispatchKeySet, stack); + torch::jit::drop(*stack, num_inputs); + } + } +}; +} // namespace impl + +} // namespace c10 + +namespace torch { +using OperatorKernel = c10::OperatorKernel; +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/boxing/impl/test_helpers.h b/phivenv/Lib/site-packages/torch/include/ATen/core/boxing/impl/test_helpers.h new file mode 100644 index 0000000000000000000000000000000000000000..df8f3f969f5ac34c32c17be9522311558fb7c936 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/boxing/impl/test_helpers.h @@ -0,0 +1,140 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include + +template +inline std::vector makeStack(Inputs&&... inputs) { + return {std::forward(inputs)...}; +} + +inline at::Tensor dummyTensor( + c10::DispatchKeySet ks, + bool requires_grad = false) { + auto* allocator = c10::GetCPUAllocator(); + int64_t nelements = 1; + auto dtype = caffe2::TypeMeta::Make(); + int64_t size_bytes = nelements * dtype.itemsize(); + auto storage_impl = c10::make_intrusive( + c10::StorageImpl::use_byte_size_t(), + size_bytes, + allocator->allocate(size_bytes), + allocator, + /*resizable=*/true); + at::Tensor t = + at::detail::make_tensor(storage_impl, ks, dtype); + // TODO: We add this to simulate the ideal case where we only have Autograd + // backend keys + // on Tensor when it requires grad. But currently Autograd keys are + // added in TensorImpl constructor by default. + if (!requires_grad) { + t.unsafeGetTensorImpl()->remove_autograd_key(); + } + return t; +} + +inline at::Tensor dummyTensor( + c10::DispatchKey dispatch_key, + bool requires_grad = false) { + return dummyTensor(c10::DispatchKeySet(dispatch_key), requires_grad); +} + +template +inline std::vector callOp( + const c10::OperatorHandle& op, + Args... args) { + auto stack = makeStack(std::forward(args)...); + op.callBoxed(&stack); + return stack; +} + +template +inline Result callOpUnboxed(const c10::OperatorHandle& op, Args... args) { + return op.typed().call(std::forward(args)...); +} + +template +inline Result callOpUnboxedWithDispatchKey( + const c10::OperatorHandle& op, + c10::DispatchKey dispatchKey, + Args... args) { + return op.typed().callWithDispatchKey( + dispatchKey, std::forward(args)...); +} + +template +inline Result callOpUnboxedWithPrecomputedDispatchKeySet( + const c10::OperatorHandle& op, + c10::DispatchKeySet ks, + Args... args) { + return op.typed().redispatch( + ks, std::forward(args)...); +} + +inline void expectDoesntFindKernel( + const char* op_name, + c10::DispatchKey dispatch_key) { + auto op = c10::Dispatcher::singleton().findSchema({op_name, ""}); + EXPECT_ANY_THROW(callOp(*op, dummyTensor(dispatch_key), 5);); +} + +inline void expectDoesntFindOperator(const char* op_name) { + auto op = c10::Dispatcher::singleton().findSchema({op_name, ""}); + EXPECT_FALSE(op.has_value()); +} + +template +inline void expectThrows(Functor&& functor, const char* expectMessageContains) { + try { + std::forward(functor)(); + } catch (const Exception& e) { + EXPECT_THAT(e.what(), testing::HasSubstr(expectMessageContains)); + return; + } + ADD_FAILURE() << "Expected to throw exception containing \"" + << expectMessageContains << "\" but didn't throw"; +} + +template +void expectListEquals(c10::ArrayRef expected, std::array actual) { + EXPECT_EQ(expected.size(), actual.size()); + for (const auto i : c10::irange(expected.size())) { + EXPECT_EQ(expected[i], actual[i]); + } +} + +template +void expectListEquals(c10::ArrayRef expected, c10::ArrayRef actual) { + EXPECT_EQ(expected.size(), actual.size()); + for (const auto i : c10::irange(expected.size())) { + EXPECT_EQ(expected[i], actual[i]); + } +} + +template +void expectListEquals(c10::ArrayRef expected, c10::List actual) { + EXPECT_EQ(expected.size(), actual.size()); + for (const auto i : c10::irange(expected.size())) { + EXPECT_EQ(expected[i], actual.get(i)); + } +} + +template +void expectListEquals(c10::ArrayRef expected, std::vector actual) { + EXPECT_EQ(expected.size(), actual.size()); + for (const auto i : c10::irange(expected.size())) { + EXPECT_EQ(expected[i], actual[i]); + } +} + +// NB: This is not really sound, but all of the type sets constructed here +// are singletons so it's fine +static inline c10::DispatchKey extractDispatchKey(const at::Tensor& t) { + return legacyExtractDispatchKey(t.key_set()); +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/dispatch/CppSignature.h b/phivenv/Lib/site-packages/torch/include/ATen/core/dispatch/CppSignature.h new file mode 100644 index 0000000000000000000000000000000000000000..be10dc4231ad61a1b7a4044fbe9e472deb80c150 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/dispatch/CppSignature.h @@ -0,0 +1,67 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace c10::impl { + +// A CppSignature object holds RTTI information about a C++ function signature +// at runtime and can compare them or get a debug-printable name. +class TORCH_API CppSignature final { + public: + CppSignature(const CppSignature&) = default; + CppSignature(CppSignature&&) noexcept = default; + CppSignature& operator=(const CppSignature&) = default; + CppSignature& operator=(CppSignature&&) noexcept = default; + + template + static CppSignature make() { + // Normalize functors, lambdas, function pointers, etc. into the plain + // function type The first argument of the schema might be of type + // DispatchKeySet, in which case we remove it. We do this to guarantee that + // all CppSignature's for an operator will match, even if they're registered + // with different calling conventions. + // See Note [Plumbing Keys Through The Dispatcher] + using decayed_function_type = + typename c10::remove_DispatchKeySet_arg_from_func< + std::decay_t>::func_type; + + return CppSignature(std::type_index(typeid(decayed_function_type))); + } + + std::string name() const { + return c10::demangle(signature_.name()); + } + + friend bool operator==(const CppSignature& lhs, const CppSignature& rhs) { + if (lhs.signature_ == rhs.signature_) { + return true; + } + // Without RTLD_GLOBAL, the type_index comparison could yield false because + // they point to different instances of the RTTI data, but the types would + // still be the same. Let's check for that case too. + // Note that there still is a case where this might not work, i.e. when + // linking libraries of different compilers together, they might have + // different ways to serialize a type name. That, together with a missing + // RTLD_GLOBAL, would still fail this. + if (0 == strcmp(lhs.signature_.name(), rhs.signature_.name())) { + return true; + } + + return false; + } + + private: + explicit CppSignature(std::type_index signature) + : signature_(std::move(signature)) {} + std::type_index signature_; +}; + +inline bool operator!=(const CppSignature& lhs, const CppSignature& rhs) { + return !(lhs == rhs); +} + +} // namespace c10::impl diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/dispatch/DispatchKeyExtractor.h b/phivenv/Lib/site-packages/torch/include/ATen/core/dispatch/DispatchKeyExtractor.h new file mode 100644 index 0000000000000000000000000000000000000000..62f05c8f56d764fcb6d6c09f7b6ebb62511efa13 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/dispatch/DispatchKeyExtractor.h @@ -0,0 +1,279 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +namespace impl { + +// Take a DispatchKeySet for a Tensor and determine what the actual dispatch +// DispatchKey should be, taking into account TLS, and skipping backends which +// fall through. +// +// Unlike Tensor::key_set(), the value of this on a tensor can change depending +// on TLS. +// +// NB: If there is no valid dispatch key, this will return Undefined +inline DispatchKeySet computeDispatchKeySet( + DispatchKeySet ks, + // The key mask lets us eliminate (by zero entries) keys which should not + // be considered for dispatch. There are two cases when we use this: + // + // - If an operator's dispatch table contains a fallthrough entry, we + // should bypass it entirely when finding the key + // - If a user invokes with redispatch, the mask lets us + // zero out the key the user asked us to stop. + // + // These excluded backends are NOT tracked in the TLS, but must be applied + // AFTER TLS (since the backend may have been introduced for consideration + // by the included TLS), which is why you have to pass them in to this + // function (as opposed to just applying it to the input 'ks'). + DispatchKeySet key_mask) { + c10::impl::LocalDispatchKeySet local = + c10::impl::tls_local_dispatch_key_set(); + // TODO: It's a bit irritating that we have to do logical ORs here, it would + // be nice to only do one. Can always_included be folded into the TLS? Well, + // it's a bit troublesome, because fastpath TLS access requires the type of + // the TLS in question to be zero-initialized, so you don't actually win + // anything in that case. + return (((ks | local.included_) - local.excluded_) & key_mask); +} + +} // namespace impl + +namespace detail { +// A small gadget to extract the DispatchKeySet from types which are known +// to have it. Used to extract dispatch keys from unboxed calls. +struct MultiDispatchKeySet : at::IterArgs { + DispatchKeySet ts; + void operator()(const at::Tensor& x) { + ts = ts | x.key_set(); + } + void operator()(const std::optional& x) { + if (x.has_value()) { + ts = ts | x->key_set(); + } + } + void operator()(at::ArrayRef xs) { + for (const auto& x : xs) { + ts = ts | x.key_set(); + } + } + // Tensor?[] translates to this case. + void operator()(const c10::List>& xs) { + for (std::optional x : xs) { + if (x.has_value()) { + ts = ts | x.value().key_set(); + } + } + } + // Structured Tensor[] translates to this case + void operator()(const at::ITensorListRef& xs) { + for (const auto& x : xs) { + ts = ts | x.key_set(); + } + } + [[noreturn]] void operator()(at::ArrayRef>) { + // Just checking that the handling of Tensor?[] didn't change. + TORCH_INTERNAL_ASSERT(false); + } + void operator()(const at::Generator& gen) { + if (gen.defined()) { + ts = ts | gen.key_set(); + } + } + void operator()(const std::optional& gen) { + if (gen.has_value() && gen->defined()) { + ts = ts | gen->key_set(); + } + } + template + void operator()(const T&) { + // do nothing + } +}; + +// NB: take by const reference (Don't do universal forwarding here! You +// don't want to move into this function!) +template +DispatchKeySet multi_dispatch_key_set(const Args&... args) { + return MultiDispatchKeySet().apply(args...).ts; +} +} // namespace detail + +/** + * An instance of DispatchKeyExtractor knows how to get a dispatch key given + * a list of arguments for an operator call. + * + * The instance is specific for a certain operator as: + * - In boxed dispatch, different operators have different ways to extract + * the dispatch key (e.g. different numbers of arguments), and we precompute + * the stack locations we should look at; and + * - In all dispatch, some backends should be excluded from dispatch because + * they have been registered as fallthrough. The set of excluded backends + * varies from operator, as some operators may have overridden the + * fallthrough with custom behavior. + * + * Note - this should maintain identical impl to the py dispatcher key + * extraction logic at pytorch/torch/dispatcher.py + */ +struct TORCH_API DispatchKeyExtractor final { + public: + static DispatchKeyExtractor make(const FunctionSchema& schema) { + return DispatchKeyExtractor(makeBitsetForDispatchArgs(schema)); + } + + static DispatchKeyExtractor makeUninitialized() { + return DispatchKeyExtractor(c10::utils::bitset()); + } + + void registerSchema(const FunctionSchema& schema) { + TORCH_INTERNAL_ASSERT(dispatch_arg_indices_reverse_.is_entirely_unset()); + dispatch_arg_indices_reverse_ = makeBitsetForDispatchArgs(schema); + } + void deregisterSchema() { + dispatch_arg_indices_reverse_ = c10::utils::bitset(); + } + + DispatchKeySet getDispatchKeySetBoxed(const torch::jit::Stack* stack) const { + DispatchKeySet ks; + dispatch_arg_indices_reverse_.for_each_set_bit([&](size_t + reverse_arg_index) { + const auto& ivalue = torch::jit::peek(*stack, 0, reverse_arg_index + 1); + if (C10_LIKELY(ivalue.isTensor())) { + // NB: Take care not to introduce a refcount bump (there's + // no safe toTensorRef method, alas) + ks = ks | ivalue.unsafeToTensorImpl()->key_set(); + } else if (C10_UNLIKELY(ivalue.isTensorList())) { + // NB: use toListRef as it doesn't induce refcount bumps + // (toTensorListRef is not a thing) + for (const auto& nv : ivalue.toListRef()) { + auto* tensor = nv.unsafeToTensorImpl(); + ks = ks | tensor->key_set(); + } + } + // Tensor?[] translates to a c10::List so we need to peek inside + else if (C10_UNLIKELY(ivalue.isList())) { + for (const auto& elt : ivalue.toListRef()) { + if (elt.isTensor()) { + ks = ks | elt.toTensor().key_set(); + } + } + } + }); + // Keys that are fallthrough should be skipped + if (requiresBitsetPerBackend_) { + c10::impl::LocalDispatchKeySet tls = + c10::impl::tls_local_dispatch_key_set(); + auto backend_idx = + ((ks | tls.included_) - tls.excluded_).getBackendIndex(); + return impl::computeDispatchKeySet( + ks, nonFallthroughKeysPerBackend_[backend_idx]); + } else { + return impl::computeDispatchKeySet(ks, nonFallthroughKeys_); + } + } + + template + DispatchKeySet getDispatchKeySetUnboxed(const Args&... args) const { + auto ks = detail::multi_dispatch_key_set(args...); + // Keys that are fallthrough should be skipped + if (requiresBitsetPerBackend_) { + c10::impl::LocalDispatchKeySet tls = + c10::impl::tls_local_dispatch_key_set(); + auto backend_idx = + ((ks | tls.included_) - tls.excluded_).getBackendIndex(); + return impl::computeDispatchKeySet( + ks, nonFallthroughKeysPerBackend_[backend_idx]); + } else { + return impl::computeDispatchKeySet(ks, nonFallthroughKeys_); + } + } + + void setOperatorHasFallthroughForKey(DispatchKey k, bool has_fallthrough); + + std::string dumpState() const; + void checkInvariants(const FunctionSchema& schema) const; + + private: + static bool isDispatchType(const Type& type) { + // Checking isSubtypeOf on a DynamicType heap-allocates a + // DynamicType version of the argument if it's not a DynamicType + // already, and this has measurable overhead during startup. +#ifdef C10_MOBILE + struct CachedTypes { + DynamicTypePtr listOfTensors; + DynamicTypePtr listOfOptionalTensors; + DynamicTypePtr optionalOfTensor; + }; + static const CachedTypes ct = { + DynamicType::create(*ListType::ofTensors()), + DynamicType::create(*ListType::ofOptionalTensors()), + DynamicType::create(*OptionalType::ofTensor())}; + return type.isSubtypeOf(c10::TypeFactory::get()) || + type.isSubtypeOf(ct.listOfTensors) || + type.isSubtypeOf(ct.listOfOptionalTensors) || + type.isSubtypeOf(ct.optionalOfTensor); +#else // C10_MOBILE + return type.isSubtypeOf(*TensorType::get()) || + type.isSubtypeOf(*ListType::ofTensors()) || + type.isSubtypeOf(*ListType::ofOptionalTensors()) || + type.isSubtypeOf(*OptionalType::ofTensor()); +#endif // C10_MOBILE + } + static c10::utils::bitset makeBitsetForDispatchArgs( + const FunctionSchema& schema) { + TORCH_CHECK( + schema.arguments().size() <= c10::utils::bitset::NUM_BITS(), + "The function schema has ", + schema.arguments().size(), + " arguments but this PyTorch build only supports ", + c10::utils::bitset::NUM_BITS()); + c10::utils::bitset dispatch_arg_indices_reverse; + for (const auto index : c10::irange(schema.arguments().size())) { + if (isDispatchType(*schema.arguments()[index].type())) { + dispatch_arg_indices_reverse.set(schema.arguments().size() - 1 - index); + } + } + return dispatch_arg_indices_reverse; + } + + explicit DispatchKeyExtractor(c10::utils::bitset dispatch_arg_indices_reverse) + : dispatch_arg_indices_reverse_(dispatch_arg_indices_reverse), + nonFallthroughKeys_(DispatchKeySet::FULL) { + for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size())) { + nonFallthroughKeysPerBackend_[i] = DispatchKeySet::FULL; + } + } + + // this is a bitset that has ones for each argument index which has to be + // considered for dispatch. This avoids having to iterate over the stack + // to find all the tensors. The bits are stored in reverse order, i.e. + // dispatch_arg_indices_reverse_[i] == true, then the i-th argument from + // the top of the stack (i.e. the i-th last argument of the function) + // is relevant for dispatch. + // dispatch_arg_indices_reverse_ is allowed to have zero bits set; that just + // means you must do the fallthrough + c10::utils::bitset dispatch_arg_indices_reverse_; + + // Set of functionality keys for which the operator does NOT have fallthrough + // kernel. + DispatchKeySet nonFallthroughKeys_; + // Set of functionality keys for which the operator does NOT have fallthrough + // kernel, defined PER BACKEND. This is only needed if we know that the + // operator has a different set of fallthroughs defined for some backends. + std::array nonFallthroughKeysPerBackend_; + // Flag to tell us if we can use the single set of nonFallthroughKeys_ (fast + // path), or if we need to fall back to the slower path and check + // nonFallthroughKeysPerBackend_ + bool requiresBitsetPerBackend_{false}; +}; + +} // namespace c10 diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/dispatch/Dispatcher.h b/phivenv/Lib/site-packages/torch/include/ATen/core/dispatch/Dispatcher.h new file mode 100644 index 0000000000000000000000000000000000000000..599a32972046284b3f90ce0ad6512d88b47d2f0a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/dispatch/Dispatcher.h @@ -0,0 +1,937 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#ifndef NDEBUG +#include +#endif + +namespace c10 { + +TORCH_API bool show_dispatch_trace(); +TORCH_API void dispatch_trace_nesting_incr(); +TORCH_API void dispatch_trace_nesting_decr(); +TORCH_API int64_t dispatch_trace_nesting_value(); + +struct DispatchTraceNestingGuard { + DispatchTraceNestingGuard() { + dispatch_trace_nesting_incr(); + } + ~DispatchTraceNestingGuard() { + dispatch_trace_nesting_decr(); + } +}; + +class TORCH_API OperatorHandle; +template +class TypedOperatorHandle; + +/** + * Implement this interface and register your instance with the dispatcher + * to get notified when operators are registered or deregistered with + * the dispatcher. + * + * NB: registration events only occur when a 'def' occurs; we don't trigger + * on 'impl' or 'fallback' calls. + */ +class TORCH_API OpRegistrationListener { + public: + virtual ~OpRegistrationListener(); + + virtual void onOperatorRegistered(const OperatorHandle& op) = 0; + virtual void onOperatorDeregistered(const OperatorHandle& op) = 0; +}; + +namespace detail { +class RegistrationListenerList; +} +class SchemaRegistrationHandleRAII; + +/** + * Top-level dispatch interface for dispatching via the dynamic dispatcher. + * Most end users shouldn't use this directly; if you're trying to register + * ops look in op_registration + */ +class TORCH_API Dispatcher final { + private: + // For direct access to backend fallback information + friend class impl::OperatorEntry; + + struct OperatorDef final { + explicit OperatorDef(OperatorName&& op_name) : op(std::move(op_name)) {} + + impl::OperatorEntry op; + + // These refer to the number of outstanding RegistrationHandleRAII + // for this operator. def_count reflects only def() registrations + // (in the new world, this should only ever be 1, but old style + // registrations may register the schema multiple times, which + // will increase this count). def_and_impl_count reflects the number + // of combined def() and impl() registrations. When the last def() gets + // unregistered, we must immediately call the Deregistered listeners, but we + // must not actually delete the handle as there are other outstanding RAII + // destructors which will try to destruct and they had better still have a + // working operator handle in this case + size_t def_count = 0; + size_t def_and_impl_count = 0; + }; + friend class OperatorHandle; + template + friend class TypedOperatorHandle; + + struct Guard final { + Guard() : alive(true), mutex() {} + std::atomic alive; + std::mutex mutex; + }; + + public: + ~Dispatcher(); + + // Implementation note: this class abstracts over the fact that we have + // per-operator dispatch tables. This could be easily adjusted to have a + // single global hash table. + static Dispatcher& realSingleton(); + + C10_ALWAYS_INLINE static Dispatcher& singleton() { +#if !defined C10_MOBILE + // Implemented inline so that steady-state code needn't incur + // function-call overhead. We can't just inline `realSingleton` + // because the function-local static would get duplicated across + // all DSOs that include & use this header, leading to multiple + // singleton instances. + static Dispatcher& s = realSingleton(); + return s; +#else + // For C10_MOBILE, we should never inline a static function that + // has a static member, since the generated code calls + // __cxa_guard_acquire and __cxa_guard_release which help + // implement exactly once semantics for the initialization of the + // static Dispatcher& s above (for the non-mobile case). That + // additional code when duplicated across all operator stubs + // for every backend results in a lot of additional code + // being generated by the compiler. + return realSingleton(); +#endif + } + + // ------------------------------------------------------------------------ + // + // Accessing operators by schema + // + // ------------------------------------------------------------------------ + + /** + * Looks for an operator schema with the given name and overload name + * and returns it if it is registered WITH A SCHEMA. + * Returns nullopt otherwise. + */ + std::optional findSchema(const OperatorName& operator_name); + + /** + * Variant of findSchema that results in less code generated at the call site. + * It (1) takes const char* pointer rather than OperatorName (so we skip + * generating std::string constructor calls at the call site), and (2) + * it raises an exception if the operator is not found (so we skip + * generating exception raising code at the call site) + * + * Irritatingly, we still have to generate the handful of instructions + * for dealing with an exception being thrown during static initialization + * (e.g. __cxa_guard_abort). If we could annotate this method noexcept we + * could avoid this code too, but as the name of the function suggests, + * it does throw exceptions. + */ + OperatorHandle findSchemaOrThrow(const char* name, const char* overload_name); + + // Like findSchema, but also returns OperatorHandle even if there is no schema + std::optional findOp(const OperatorName& operator_name); + + // Returns a list of all operator names present in the operatorLookupTable_ + const std::vector getAllOpNames(); + + // Returns a list of all operator names present in the operatorLookupTable_ + // for a given dispatch key + const std::vector getAllOpNamesForDispatchKey(DispatchKey k); + + // ------------------------------------------------------------------------ + // + // Invoking operators + // + // ------------------------------------------------------------------------ + + template + Return call(const TypedOperatorHandle& op, Args... args) + const; + + template + static Return callWithDispatchKeySlowPath( + const TypedOperatorHandle& op, + at::StepCallbacks& stepCallbacks, + DispatchKeySet dispatchKeySet, + const KernelFunction& kernel, + Args... args); + + // Like call, but intended for use in a redispatch in kernels that have + // explicitly performed the DispatchKey update calculatulation. This will take + // the DispatchKeySet completely as is and dispatch to the kernel of the + // corresponding highest priority key in the set. Note that this version of + // redispatch treats the inputted DispatchKeySet *as is*, and does NOT mask + // out the highest priority key. See Note [Plumbing Keys Through The + // Dispatcher] + template + Return redispatch( + const TypedOperatorHandle& op, + DispatchKeySet currentDispatchKeySet, + Args... args) const; + + // Invoke an operator via the boxed calling convention using an IValue stack + void callBoxed(const OperatorHandle& op, Stack* stack) const; + void callBoxedForDispatchKey( + const OperatorHandle& op, + DispatchKey dk, + Stack* stack) const; + + // TODO: This will only be useful if we write a backend fallback that plumbs + // dispatch keys (currently there are none) See Note [Plumbing Keys Through + // The Dispatcher] + void redispatchBoxed( + const OperatorHandle& op, + DispatchKeySet dispatchKeySet, + Stack* stack) const; + + bool hasBackendFallbackForDispatchKey(DispatchKey dk) { + auto dispatch_ix = getDispatchTableIndexForDispatchKey(dk); + if (dispatch_ix < 0) + return false; + return backendFallbackKernels_[dispatch_ix].kernel.isValid(); + } + + // Used by torchdeploy/multipy for multiple interpreters racing. + void waitForDef(const FunctionSchema& schema); + void waitForImpl( + const OperatorName& op_name, + std::optional dispatch_key); + + // ------------------------------------------------------------------------ + // + // Performing registrations (NON user public; use op_registration) + // + // ------------------------------------------------------------------------ + + /** + * Register a new operator schema. + * + * If a schema with the same operator name and overload name already exists, + * this function will check that both schemas are exactly identical. + */ + RegistrationHandleRAII registerDef( + FunctionSchema schema, + std::string debug, + std::vector tags = {}); + + /** + * Register a kernel to the dispatch table for an operator. + * If dispatch_key is nullopt, then this registers a fallback kernel. + * + * @return A RAII object that manages the lifetime of the registration. + * Once that object is destructed, the kernel will be deregistered. + */ + // NB: steals the inferred function schema, as we may need to hold on to + // it for a bit until the real schema turns up + RegistrationHandleRAII registerImpl( + OperatorName op_name, + std::optional dispatch_key, + KernelFunction kernel, + std::optional cpp_signature, + std::unique_ptr inferred_function_schema, + std::string debug); + + /** + * Given an operator, tells the Dispatcher that we have implemented a fake + * impl for this op in the given Python module. Call this a "pystub". + */ + RegistrationHandleRAII registerPythonModule( + const OperatorName& op_name, + const char* pymodule, + const char* context); + + /** + * Given an operator, throws if we have a pystub. + */ + void throwIfHasPythonModule(OperatorName op_name); + + std::optional> getPyStub( + OperatorName op_name); + + /** + * Register a new operator by name. + */ + RegistrationHandleRAII registerName(OperatorName op_name); + + /** + * Register a fallback kernel for a backend. + * If an operator is called but there is no concrete kernel for the dispatch + * key of the given operator arguments, it will check if there is such a + * fallback kernel for the given dispatch key and, if yes, call that one. + */ + RegistrationHandleRAII registerFallback( + DispatchKey dispatch_key, + KernelFunction kernel, + std::string debug); + + /** + * Use to register whenever we had a TORCH_LIBRARY declaration in the frontend + * API. These invocations are only permitted once per program, so we raise + * an error if this is called again for the same namespace. + */ + RegistrationHandleRAII registerLibrary(std::string ns, std::string debug); + + // ------------------------------------------------------------------------ + // + // Listeners on registrations + // + // ------------------------------------------------------------------------ + + /** + * Add a listener that gets called whenever a new op is registered or an + * existing op is deregistered. Immediately after registering, this listener + * gets called for all previously registered ops, so it can be used to keep + * track of ops registered with this dispatcher. + */ + RegistrationHandleRAII addRegistrationListener( + std::unique_ptr listener); + + void checkInvariants() const; + + // + // ------------------------------------------------------------------------ + // + // Assertions + // + // ------------------------------------------------------------------------ + + /** + * For testing purposes. + * Returns a list of all operators that were created through calls to + * registerImpl(), without any corresponding calls to registerDef(). After + * static initialization is done this is almost certainly a bug, as the + * created OperatorHandle won't have any schema associated with it and users + * calling the op through the dispatcher won't be able to access it + * + * Note that we cannot enforce this invariant "as we go" during static + * initialization, due to undefined static initialization order- we have no + * guarantees over the order in which .def() and .impl() calls are registered + * in the dispatcher at static initialization time. So this function should + * only be called after static initialization. + */ + std::vector findDanglingImpls() const; + + /** + * Useful for inspecting global Dispatcher registration state. + * Returns the names of all operators with a kernel registered for the + * specified DispatchKey. If no DispatchKey is specified, it returns all + * registered operators. + */ + std::vector getRegistrationsForDispatchKey( + std::optional k) const; + + private: + Dispatcher(); + + static int64_t sequenceNumberForRunningRecordFunction( + DispatchKey dispatchKey, + DispatchKeySet dispatchKeySet); + static void runRecordFunction( + at::RecordFunction& guard, + at::RecordFunction::schema_ref_t schema_ref, + DispatchKey dispatchKey, + DispatchKeySet dispatchKeySet); + static void runRecordFunction( + at::RecordFunction& guard, + at::RecordFunction::schema_ref_t schema_ref, + DispatchKey dispatchKey, + DispatchKeySet dispatchKeySet, + c10::ArrayRef args); + +#ifdef FBCODE_CAFFE2 + static bool profilingOperatorEvents(); + static void fireOpStartUSDT(at::RecordFunction::schema_ref_t schema_ref); + static void fireOpEndUSDT(at::RecordFunction::schema_ref_t schema_ref); +#endif // FBCODE_CAFFE2 + + OperatorHandle findOrRegisterSchema_(FunctionSchema&& schema); + OperatorHandle findOrRegisterName_(const OperatorName& op_name); + + void deregisterDef_(const OperatorHandle& op, const OperatorName& op_name); + void deregisterImpl_( + const OperatorHandle& op, + const OperatorName& op_name, + std::optional dispatch_key, + impl::OperatorEntry::AnnotatedKernelContainerIterator kernel_handle); + void deregisterName_(const OperatorHandle& op, const OperatorName& op_name); + void deregisterFallback_(DispatchKey dispatchKey); + void deregisterLibrary_(const std::string& ns); + void cleanup(const OperatorHandle& op, const OperatorName& op_name); + void checkSchemaCompatibility( + const OperatorHandle& op, + const FunctionSchema& schema, + const std::string& debug); + + std::list operators_; +#if !defined(C10_MOBILE) + LeftRight> + operatorLookupTable_; +#else + RWSafeLeftRightWrapper> + operatorLookupTable_; +#endif + // Map from namespace to debug string (saying, e.g., where the library was + // defined) + ska::flat_hash_map libraries_; + + std::array + backendFallbackKernels_; + + std::unique_ptr listeners_; + + // This condition variable gets notified whenever we add a new def/impl to the + // dispatch table. This is primarily used by multipy/torchdeploy, when + // we have multiple interpreters trying to register to the dispatch table. + // In this situation, whenever the non-primary interpreter would have tried + // to register to the dispatch table, instead it will check to see if the + // expected registration has already been made, and if it hasn't, wait on + // this condition variable to see if it was just racing with the primary + // interpreter. + // + // We expect it to be rare for there to be any waiters on this condition + // variable. This is mostly just to help give better diagnostics if + // something goes horribly wrong + std::condition_variable cond_var_; + + // Protect concurrent access to the dispatcher. We store this in a + // `shared_ptr` as we return callbacks that call back into dispatcher methods, + // and we need to be able to handle and guard against the event when the + // `Dispatcher` has been destroyed before the callbacks fire. + std::shared_ptr guard_; +}; + +/** + * This is a handle to an operator schema registered with the dispatcher. + * This handle can be used to register kernels with the dispatcher or + * to lookup a kernel for a certain set of arguments. + */ +class TORCH_API OperatorHandle { + template + friend struct std::hash; + + public: + OperatorHandle(OperatorHandle&&) noexcept = default; + OperatorHandle& operator=(OperatorHandle&&) noexcept = default; + OperatorHandle(const OperatorHandle&) = default; + OperatorHandle& operator=(const OperatorHandle&) = default; + // NOLINTNEXTLINE(performance-trivially-destructible) + ~OperatorHandle(); + + const OperatorName& operator_name() const { + return operatorDef_->op.operator_name(); + } + + bool hasSchema() const { + return operatorDef_->op.hasSchema(); + } + + const FunctionSchema& schema() const { + return operatorDef_->op.schema(); + } + + const std::string& debug() const { + return operatorDef_->op.debug(); + } + + std::string dumpState() const { + return operatorDef_->op.dumpState(); + } + + bool hasKernelForDispatchKey(DispatchKey k) const { + return operatorDef_->op.hasKernelForDispatchKey(k); + } + + bool isKernelFallthroughKernel(DispatchKey k) const { + return operatorDef_->op.kernelForDispatchKey(k).isFallthrough(); + } + + bool hasKernelForAnyDispatchKey(DispatchKeySet k) const { + return operatorDef_->op.hasKernelForAnyDispatchKey(k); + } + + bool hasComputedKernelForDispatchKey(DispatchKey k) const { + return operatorDef_->op.hasComputedKernelForDispatchKey(k); + } + + std::string dumpComputedTable() const { + return operatorDef_->op.dumpComputedTable(); + } + + void checkInvariants() const { + return operatorDef_->op.checkInvariants(); + } + + c10::ArrayRef getTags() const { + return operatorDef_->op.getTags(); + } + + void setReportErrorCallback_(std::unique_ptr callback) { + operatorDef_->op.setReportErrorCallback_(std::move(callback)); + } + + bool hasTag(const at::Tag& tag) const { + for (const auto& tag_ : getTags()) { + if (tag == tag_) { + return true; + } + } + return false; + } + + template + TypedOperatorHandle typed() const { + // NB: This assert is not 100% sound: you can retrieve a typed() operator + // handle prior to ANY C++ signature being registered on the operator + // and the check will say everything is OK (at which point you can then + // smuggle in a kernel that is typed incorrectly). For everything + // in core library this won't happen, because all the static registrations + // will be done by the time a typed() handle is acquired. +#if !defined C10_MOBILE + operatorDef_->op.assertSignatureIsCorrect(); + if (fn_has_symint::value) { + operatorDef_->op.assertSignatureIsCorrect< + typename fn_remove_symint::type>(); + } +#endif + return TypedOperatorHandle(operatorIterator_); + } + + void callBoxed(Stack* stack) const { + c10::Dispatcher::singleton().callBoxed(*this, stack); + } + + void callBoxed(Stack& stack) const { + callBoxed(&stack); + } + + void callBoxedForDispatchKey(DispatchKey dk, Stack& stack) const { + c10::Dispatcher::singleton().callBoxedForDispatchKey(*this, dk, &stack); + } + + void redispatchBoxed(DispatchKeySet ks, Stack* stack) const { + c10::Dispatcher::singleton().redispatchBoxed(*this, ks, stack); + } + + template + PyObject* getPythonOp( + c10::impl::PyInterpreter* self_interpreter, + F slow_accessor) const { + return operatorDef_->op.getPythonOp(self_interpreter, slow_accessor); + } + + bool operator==(const OperatorHandle& other) const { + return operatorDef_ == other.operatorDef_; + } + + bool operator!=(const OperatorHandle& other) const { + return operatorDef_ != other.operatorDef_; + } + + private: + explicit OperatorHandle( + std::list::iterator operatorIterator) + : operatorDef_(&*operatorIterator), operatorIterator_(operatorIterator) {} + friend class Dispatcher; + template + friend class TypedOperatorHandle; + + // Storing a direct pointer to the OperatorDef even though we + // already have the iterator saves an instruction in the critical + // dispatch path. The iterator is effectively a + // pointer-to-std::list-node, and (at least in libstdc++'s + // implementation) the element is at an offset 16 bytes from that, + // because the prev/next pointers come first in the list node + // struct. So, an add instruction would be necessary to convert from the + // iterator to an OperatorDef*. + Dispatcher::OperatorDef* operatorDef_; + + // We need to store this iterator in order to make + // Dispatcher::cleanup() fast -- it runs a lot on program + // termination (and presuambly library unloading). + std::list::iterator operatorIterator_; +}; + +/** + * This is a handle to an operator schema registered with the dispatcher. + * It holds the same information as an OperatorHandle, but it is templated + * on the operator arguments and allows calling the operator in an + * unboxed way. + */ +template +class TypedOperatorHandle final { + static_assert( + guts::false_t(), + "FuncType in OperatorHandle::typed was not a valid function type"); +}; +template +class TypedOperatorHandle final : public OperatorHandle { + public: + TypedOperatorHandle(TypedOperatorHandle&&) noexcept = default; + TypedOperatorHandle& operator=(TypedOperatorHandle&&) noexcept = default; + TypedOperatorHandle(const TypedOperatorHandle&) = default; + TypedOperatorHandle& operator=(const TypedOperatorHandle&) = default; + + // See [Note: Argument forwarding in the dispatcher] for why Args doesn't use + // && + C10_ALWAYS_INLINE Return call(Args... args) const { + return c10::Dispatcher::singleton().call( + *this, std::forward(args)...); + } + + // See [Note: Argument forwarding in the dispatcher] for why Args doesn't use + // && + C10_ALWAYS_INLINE Return + redispatch(DispatchKeySet currentDispatchKeySet, Args... args) const { + return c10::Dispatcher::singleton().redispatch( + *this, currentDispatchKeySet, std::forward(args)...); + } + + private: + explicit TypedOperatorHandle( + std::list::iterator operatorIterator) + : OperatorHandle(operatorIterator) {} + friend class OperatorHandle; +}; + +namespace detail { +template +inline void unused_arg_(const Args&...) {} + +// CaptureKernelCall is intended to capture return values from Dispatcher +// unboxed kernel calls. A record function may request to get outputs from the +// kernel calls. For boxed kernels, it's straightforward, the returned values +// are in the stack object. The stack can be passed to record functions. For +// unboxed kernels, we need to handle different kinds of return values, cache +// them temporarily, then release the values for the actual function call +// return. +template +struct CaptureKernelCall { + template + CaptureKernelCall( + const F& kernel, + const TypedOperatorHandle& op, + const DispatchKeySet& dispatchKeySet, + Args&&... args) + // Calls the kernel and capture the result in output_. + : output_{kernel.template call( + op, + dispatchKeySet, + std::forward(args)...)} {} + // Wraps the return values in a Stack. + Stack getOutputs() { + Stack stack; + impl::push_outputs::copy(output_, &stack); + return stack; + } + // Since we are returning the output_, we don't expect the output_ to be used + // afterward. Copy elision and RVO do not apply to class data members. Using + // move semantic to avoid copies when possible. + ReturnType release() && { + return std::move(output_); + } + + private: + ReturnType output_; +}; + +// Handle the lvalue reference differently since it should not be moved. +template <> +inline at::Tensor& CaptureKernelCall::release() && { + return output_; +} + +// Handle case where the kernel returns void. +template <> +struct CaptureKernelCall { + template + CaptureKernelCall( + const F& kernel, + const TypedOperatorHandle& op, + const DispatchKeySet& dispatchKeySet, + Args&&... args) { + // Calling the kernel and no need to capture void. + kernel.template call( + op, dispatchKeySet, std::forward(args)...); + } + Stack getOutputs() { + return Stack(); + } + void release() && {} +}; + +TORCH_API void _print_dispatch_trace( + const std::string& label, + const std::string& op_name, + const DispatchKeySet& dispatchKeySet); + +} // namespace detail + +// See [Note: Argument forwarding in the dispatcher] for why Args doesn't use && +template +inline Return Dispatcher::callWithDispatchKeySlowPath( + const TypedOperatorHandle& op, + at::StepCallbacks& stepCallbacks, + DispatchKeySet dispatchKeySet, + const KernelFunction& kernel, + Args... args) { + // If callbacks need inputs, we box the arguments and pass them to the guard. + // Note: For perf reasons we wouldn't want to prematurely box the arguments. + at::RecordFunction guard(std::move(stepCallbacks)); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(op.operatorDef_->op.isObserved()); + auto dispatchKey = dispatchKeySet.highestPriorityTypeId(); + auto& schema = op.schema(); + auto schema_ref = std::reference_wrapper(schema); + constexpr auto num_boxed_args = impl::boxed_size(); + if constexpr (num_boxed_args != 0) { + if (guard.needsInputs()) { + // If we used std::array here, we would + // have to spend time default constructing the IValues in + // boxedArgs. aligned_storage has no such requirement. + // NOLINTNEXTLINE(*array*) + alignas(IValue) std::byte boxedArgs[num_boxed_args * sizeof(IValue)]; + // For debugging only; could be removed (but the compiler will do + // that for us and it's nice to have the extra assurance of + // correctness from our debug builds). + IValue* boxedArgsPtr = reinterpret_cast(boxedArgs); + impl::boxArgsToStack(boxedArgsPtr, args...); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + reinterpret_cast(boxedArgsPtr) == + boxedArgs + num_boxed_args * sizeof(IValue)); + // I don't *think* we need std::launder here, because IValue has + // no subclasses and no const or reference fields. + runRecordFunction( + guard, + schema_ref, + dispatchKey, + dispatchKeySet, + c10::ArrayRef( + reinterpret_cast(boxedArgs), num_boxed_args)); + boxedArgsPtr = reinterpret_cast(boxedArgs); + for (size_t ii = 0; ii < num_boxed_args; ++ii) { + (boxedArgsPtr + ii)->~IValue(); + } + } else { + runRecordFunction(guard, schema_ref, dispatchKey, dispatchKeySet); + } + } else { + runRecordFunction(guard, schema_ref, dispatchKey, dispatchKeySet); + } + + if (C10_UNLIKELY(guard.needsOutputs())) { + // Calls the kernel and capture the output temporarily to pass to + // RecordFunction. + detail::CaptureKernelCall captureKernelCall( + kernel, op, dispatchKeySet, std::forward(args)...); + guard.setOutputs(captureKernelCall.getOutputs()); + // Releases the captured output to return to caller. + return std::move(captureKernelCall).release(); + } + + // keeping the guard alive while executing the kernel + return kernel.template call( + op, dispatchKeySet, std::forward(args)...); +} + +// See [Note: Argument forwarding in the dispatcher] for why Args doesn't use && +template +C10_ALWAYS_INLINE_UNLESS_MOBILE Return Dispatcher::call( + const TypedOperatorHandle& op, + Args... args) const { + auto dispatchKeySet = + op.operatorDef_->op.dispatchKeyExtractor() + .template getDispatchKeySetUnboxed(args...); +#if defined(HAS_TORCH_SHOW_DISPATCH_TRACE) || !defined(NDEBUG) + DispatchTraceNestingGuard debug_guard; + if (show_dispatch_trace()) { + detail::_print_dispatch_trace( + "[call]", toString(op.operator_name()), dispatchKeySet); + } +#endif + const KernelFunction& kernel = op.operatorDef_->op.lookup(dispatchKeySet); +#ifndef PYTORCH_DISABLE_PER_OP_PROFILING + auto step_callbacks = + at::getStepCallbacksUnlessEmpty(at::RecordScope::FUNCTION); + if (C10_UNLIKELY( + step_callbacks.has_value() && op.operatorDef_->op.isObserved())) { + return callWithDispatchKeySlowPath( + op, + *step_callbacks, + dispatchKeySet, + kernel, + std::forward(args)...); + } +#endif // PYTORCH_DISABLE_PER_OP_PROFILING + +#ifdef FBCODE_CAFFE2 + if (profilingOperatorEvents()) { + struct FireOpRAII { + FireOpRAII(at::RecordFunction::schema_ref_t schema_ref) + : schema_ref_(schema_ref) { + fireOpStartUSDT(schema_ref); + } + ~FireOpRAII() { + fireOpEndUSDT(schema_ref_); + } + at::RecordFunction::schema_ref_t schema_ref_; + } event(op.schema()); + return kernel.template call( + op, dispatchKeySet, std::forward(args)...); + } else { + return kernel.template call( + op, dispatchKeySet, std::forward(args)...); + } +#else + return kernel.template call( + op, dispatchKeySet, std::forward(args)...); +#endif // FBCODE_CAFFE2 +} + +// See [Note: Argument forwarding in the dispatcher] for why Args doesn't use && +template +inline Return Dispatcher::redispatch( + const TypedOperatorHandle& op, + DispatchKeySet currentDispatchKeySet, + Args... args) const { + // do not use RecordFunction on redispatch +#if defined(HAS_TORCH_SHOW_DISPATCH_TRACE) || !defined(NDEBUG) + DispatchTraceNestingGuard debug_guard; + if (show_dispatch_trace()) { + detail::_print_dispatch_trace( + "[redispatch]", toString(op.operator_name()), currentDispatchKeySet); + } +#endif + const KernelFunction& kernel = + op.operatorDef_->op.lookup(currentDispatchKeySet); + return kernel.template call( + op, currentDispatchKeySet, std::forward(args)...); +} + +inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) + const { + // note: this doesn't need the mutex because write operations on the list keep + // iterators intact. + const auto& entry = op.operatorDef_->op; + auto dispatchKeySet = + entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack); +#if defined(HAS_TORCH_SHOW_DISPATCH_TRACE) || !defined(NDEBUG) + DispatchTraceNestingGuard debug_guard; + if (show_dispatch_trace()) { + detail::_print_dispatch_trace( + "[callBoxed]", toString(op.operator_name()), dispatchKeySet); + } +#endif + const auto& kernel = entry.lookup(dispatchKeySet); +#ifndef PYTORCH_DISABLE_PER_OP_PROFILING + auto step_callbacks = + at::getStepCallbacksUnlessEmpty(at::RecordScope::FUNCTION); + if (C10_UNLIKELY(step_callbacks.has_value() && entry.isObserved())) { + at::RecordFunction guard(std::move(*step_callbacks)); + auto dispatchKey = dispatchKeySet.highestPriorityTypeId(); + auto& schema = op.schema(); + auto schema_ref = std::reference_wrapper(schema); + guard.needsInputs() + ? runRecordFunction( + guard, + schema_ref, + dispatchKey, + dispatchKeySet, + c10::ArrayRef(stack->data(), stack->size())) + : runRecordFunction(guard, schema_ref, dispatchKey, dispatchKeySet); + + // keeping the guard alive while executing the kernel + kernel.callBoxed(op, dispatchKeySet, stack); + + if (C10_UNLIKELY(guard.needsOutputs())) { + guard.setOutputs(*stack); + } + return; + } +#endif // PYTORCH_DISABLE_PER_OP_PROFILING + kernel.callBoxed(op, dispatchKeySet, stack); +} + +// NB: this doesn't count as a "true" dispatcher jump, so no instrumentation +inline void Dispatcher::callBoxedForDispatchKey( + const OperatorHandle& op, + DispatchKey dk, + Stack* stack) const { + // note: this doesn't need the mutex because write operations on the list keep + // iterators intact. + const auto& entry = op.operatorDef_->op; + // We still compute this as we're obligated to pass it on to the internal + // kernel, if it is a boxed fallback + auto dispatchKeySet = + entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack); + const auto& kernel = ([&]() { + if (op.hasKernelForDispatchKey(dk)) { + return entry.kernelForDispatchKey(dk); + } else { + auto idx = getDispatchTableIndexForDispatchKey(dk); + TORCH_INTERNAL_ASSERT(idx >= 0); + return backendFallbackKernels_[idx].kernel; + } + })(); + kernel.callBoxed(op, dispatchKeySet, stack); +} + +inline void Dispatcher::redispatchBoxed( + const OperatorHandle& op, + DispatchKeySet dispatchKeySet, + Stack* stack) const { + // note: this doesn't need the mutex because write operations on the list keep + // iterators intact. + const auto& entry = op.operatorDef_->op; +#if defined(HAS_TORCH_SHOW_DISPATCH_TRACE) || !defined(NDEBUG) + DispatchTraceNestingGuard debug_guard; + if (show_dispatch_trace()) { + detail::_print_dispatch_trace( + "[redispatchBoxed]", toString(op.operator_name()), dispatchKeySet); + } +#endif + const auto& kernel = entry.lookup(dispatchKeySet); + return kernel.callBoxed(op, dispatchKeySet, stack); +} + +} // namespace c10 + +namespace std { + +template <> +struct hash { + size_t operator()(const c10::OperatorHandle& op) const noexcept { + return std::hash{}(static_cast(op.operatorDef_)); + } +}; + +} // namespace std diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/dispatch/ObservedOperators.h b/phivenv/Lib/site-packages/torch/include/ATen/core/dispatch/ObservedOperators.h new file mode 100644 index 0000000000000000000000000000000000000000..ef2efd55af04ee5c9d2bb01683a31d4688435ccd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/dispatch/ObservedOperators.h @@ -0,0 +1,17 @@ +#pragma once + +#include +#include +#include + +namespace c10 { + +struct TORCH_API ObservedOperators { + ObservedOperators() = delete; + + static bool isObserved(const OperatorName& name); + + static std::unordered_set& getUnobservedOperatorList(); +}; + +} // namespace c10 diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/dispatch/OperatorEntry.h b/phivenv/Lib/site-packages/torch/include/ATen/core/dispatch/OperatorEntry.h new file mode 100644 index 0000000000000000000000000000000000000000..e3c39781ece4e74323c1e671f77230551e543ff3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/dispatch/OperatorEntry.h @@ -0,0 +1,335 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include + +#ifdef C10_MOBILE +#define C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY +#endif + +namespace c10 { + +class Dispatcher; + +namespace impl { + +// This data structure represents a kernel that was registered to us from a +// user. Unlike KernelFunction, AnnotatedKernel contains some extra metadata +// about the kernel that isn't necessary for actual dispatching (this is why +// we don't put AnnotatedKernel in the actual DispatchTable), but is useful for +// giving good error messages. +struct AnnotatedKernel final { + AnnotatedKernel( + KernelFunction k, + std::unique_ptr s, + std::string d) + : kernel(std::move(k)), + inferred_function_schema(std::move(s)), + debug(std::move(d)) {} + AnnotatedKernel() = default; + KernelFunction kernel; + std::unique_ptr inferred_function_schema; + // A little debug string to help us identify the kernel in question. + // Most importantly it records the TORCH_LIBRARY block that did the + // registration. + std::string debug; +}; + +// This data structure represents operator schema, with metadata specifying +// where the registration of this schema occurred +struct AnnotatedSchema final { + AnnotatedSchema(FunctionSchema s, std::string d) + : schema(std::move(s)), debug(std::move(d)) {} + FunctionSchema schema; + std::string debug; +}; + +// Internal data structure that records information about a specific operator. +// It's not part of the public API; typically, users will interact with +// OperatorHandle instead. +// +// Concurrent writes to OperatorEntry are protected by the GLOBAL Dispatcher +// lock (this is important because some methods in OperatorEntry access +// dispatcher state) +class TORCH_API OperatorEntry final { + public: + explicit OperatorEntry(OperatorName&& operator_name); + + OperatorEntry(const OperatorEntry&) = delete; + OperatorEntry(OperatorEntry&&) noexcept = delete; + OperatorEntry& operator=(const OperatorEntry&) = delete; + OperatorEntry& operator=(OperatorEntry&&) noexcept = delete; + + const FunctionSchema& schema() const { + TORCH_INTERNAL_ASSERT( + schema_.has_value(), + "Tried to access the schema for ", + name_, + " which doesn't have a schema registered yet"); + return schema_->schema; + } + const std::string& debug() const { + TORCH_INTERNAL_ASSERT(schema_.has_value()); + return schema_->debug; + } + bool hasSchema() const { + return schema_.has_value(); + } + + bool isObserved() const { + return is_observed_; + } + + // We may allocate an OperatorEntry for an operator even when we don't + // have a schema. When we receive the schema registration, we post + // facto register a schema. + // + // NB: registerSchema/deregisterSchema are not idempotent; if you + // attempt to register a schema when one is already present or vice + // versa that is an error. (Refcounting for the registrations is + // handled in the OperatorHandle in Dispatcher) + void registerSchema( + FunctionSchema&&, + std::string&& debug, + std::vector tags = {}); + void deregisterSchema(); + + const OperatorName& operator_name() const { + return name_; + } + +#ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY + using AnnotatedKernelContainer = std::array; +#else + using AnnotatedKernelContainer = std::list; +#endif + using AnnotatedKernelContainerIterator = AnnotatedKernelContainer::iterator; + + // Why are kernels and fallback asymmetric? It has to do with ownership. + // Kernels and the computed dispatch tables for them are canonically + // owned by OperatorEntry, but backend fallbacks are specified once + // and apply for all operators, so they should be owned by Dispatcher. + // However, the registration of a backend fallback affects the + // state of the computed dispatch table, so when a backend fallback + // is updated, we need to update the operator tables too. Thus, + // registerKernel is the mechanism by which we give kernels to + // operator entry to own (and update dispatch table), but we only + // need a non-owning mechanism to update fallback. + + // Precondition: Dispatcher::mutex_ is held + // Postcondition: caller is responsible for disposing of the kernel + AnnotatedKernelContainerIterator registerKernel( + const Dispatcher& dispatcher, + std::optional dispatch_key, + KernelFunction kernel, + std::optional cpp_signature, + std::unique_ptr inferred_function_schema, + std::string debug); + + // Precondition: Dispatcher::mutex_ is held + void deregisterKernel_( + const Dispatcher& dispatcher, + std::optional dispatch_key, + AnnotatedKernelContainerIterator kernel); + + // Precondition: Dispatcher::mutex_ is held + void updateFallback(const Dispatcher& dispatcher, DispatchKey dispatch_key); + + // Precondition: Dispatcher::mutex_ is held + void updateSchemaAliasAnalysis(AliasAnalysisKind a) { + TORCH_INTERNAL_ASSERT(schema_.has_value()); + schema_->schema.setAliasAnalysis(a); + } + + std::string dumpComputedTable() const; + std::string dumpState() const; + void checkInvariants() const; + + const DispatchKeyExtractor& dispatchKeyExtractor() const { + return dispatchKeyExtractor_; + } + + // Asserts that the given FuncType is correct for calling this operator in an + // unboxed way. + template + inline void assertSignatureIsCorrect() { + assertSignatureIsCorrect( + CppSignature::make(), fn_has_symint::value); + } + + void assertSignatureIsCorrect( + const CppSignature& call_signature, + bool has_symint) const; + + [[noreturn]] void reportError(DispatchKey dispatchKey) const; + + const KernelFunction& lookup(DispatchKeySet ks) const { + const auto idx = ks.getDispatchTableIndexForDispatchKeySet(); + if (C10_UNLIKELY(idx == -1)) { + reportError(ks.highestPriorityTypeId()); + } + const auto& kernel = dispatchTable_[idx]; + // A valid kernel *always* has a boxed kernel and *may* have an + // unboxed kernel. However, we typically do unboxed calls in at:: + // APIs, where the kernel 1) will very likely be valid and 2) + // should have an unboxed kernel. Checking the unboxed kernel + // first will allow us to avoid touching the boxed kernel at all + // in the common case. + if (C10_UNLIKELY(!kernel.isValidUnboxed())) { + if (!kernel.isValid()) { + reportError(ks.highestPriorityTypeId()); + } + } + return kernel; + } + + std::string listAllDispatchKeys() const; + + // Returns true if kernel_ has entry for any key in ks. + // + // Invariant: There are no alias keys in the passed-in dispatch key set. + // Note [No Alias Keys in DispatchKeySet] + // Alias keys should be checked using `hasKernelForDispatchKey` + // Alias keys shouldn't go inside of a DispatchKeySet, since they can + // technically have a value > 63 (causing overflow). + bool hasKernelForAnyDispatchKey(DispatchKeySet ks) const; + // Returns true if kernel_ has entry for a particular key. + bool hasKernelForDispatchKey(DispatchKey k) const; + // Retrieves the kernel entry at a particular key. Symmetric with + // hasKernelForDispatchKey. To get the AnnotatedKernel, see + // getKernelForDispatchKey (private) + const KernelFunction& kernelForDispatchKey(DispatchKey k) const; + // Returns true if the "computed table" has an entry for a particular key. + bool hasComputedKernelForDispatchKey(DispatchKey k) const; + // Returns all the operator tags added at the time of registration + const std::vector& getTags() const; + void setReportErrorCallback_(std::unique_ptr callback); + + template + PyObject* getPythonOp(PyInterpreter* self_interpreter, F slow_accessor) + const { + return py_cache_.ptr_or(self_interpreter, slow_accessor); + } + + private: + OperatorName name_; + std::optional schema_; +#ifndef C10_MOBILE + std::vector tags_; +#endif + std::array dispatchTable_; + DispatchKeyExtractor dispatchKeyExtractor_; + // Pointer to the torch.ops.ns.op.overload object for speed + c10::PyHandleCache py_cache_; + + // kernels_ stores all registered kernels for the corresponding dispatch key + // and catchAllKernels_ stores the catch-all kernels. + // If an operator library gets loaded that overwrites an already existing + // kernel, both kernels will be in that list but only the newer one will be in + // dispatchTable. If any of the kernels go away (say the library gets + // unloaded), we remove the kernel from this list and update the + // dispatchTable if necessary. + // Kernels in the list are ordered by registration time descendingly, + // newer registrations are before older registrations. + // We do not combine dispatchTable and kernels into one hash map because + // kernels is a larger data structure and accessed quite infrequently + // while dispatchTable is accessed often and should be kept small to fit + // into CPU caches. + // Invariants: + // - dispatchTable[dispatch_key] == kernels_[dispatch_key].front() + // - dispatchTable[dispatch_key] does not exist if and only if + // kernels_[dispatch_key] does not exist + // - If kernels_[dispatch_key] exists, then it has elements. + // It is never an empty list. + // + // Why do we do that? + // ----- + // We mostly do this to enable Jupyter notebooks where a cell registering + // a kernel could be executed multiple times and the later execution + // should overwrite the earlier one. Note that this still fails when the + // function schema changed between the executions, but it works as long + // as the function schema didn't change. A better solution would be to + // unload the old extension library from the Jupyter cell when the cell is + // re-executed and then only allow one kernel here, i.e. error if a kernel + // is already registered, but that's a lot of effort to implement and + // currently not high-pri. + ska::flat_hash_map< + DispatchKey, +#ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY + // On mobile, we needn't worry about Jupyter notebooks. + std::array +#else + std::list +#endif + > + kernels_; + + const AnnotatedKernel& missingKernel() const; + const AnnotatedKernel& ambiguousAutogradOtherKernel() const; + + // cpp_signature_ stores function signature if any of + // the kernels was created in a way that allowed us to know the function + // signature (i.e. by supplying an unboxed C++ kernel function). + // If this is set, it will be used to check that future kernel + // registrations match and it will be used in unboxed function calls + // to verify their arguments against the known function signature. + struct CppSignatureWithDebug { + CppSignature signature; + std::string debug; + std::optional dispatch_key; + }; + std::optional cpp_signature_; + std::optional sym_cpp_signature_; + + // A Python custom error handler for OperatorEntry::reportError + std::unique_ptr report_error_callback_; + + // Whether this operator needs to be observed with RecordFunction + const bool is_observed_; + + [[noreturn]] void reportSignatureError( + const CppSignature& call_signature, + const CppSignatureWithDebug& saved_signature) const; + const KernelFunction& computeDispatchTableEntry( + const c10::Dispatcher& dispatcher, + DispatchKey dispatch_key) const; + std::pair + computeDispatchTableEntryWithDebug( + const c10::Dispatcher& dispatcher, + DispatchKey dispatch_key) const; + // This function re-establishes the invariant that dispatchTable + // contains the front element from the kernels list for a given runtime + // dispatch key. + void updateDispatchTableEntry_( + const c10::Dispatcher& dispatcher, + DispatchKey dispatch_key); + // Like above, but also handles alias dispatch keys. + void updateDispatchTable_( + const c10::Dispatcher& dispatcher, + DispatchKey dispatch_key); + // Like above, but for ALL entries in the dispatch table. + void updateDispatchTableFull_(const c10::Dispatcher& dispatcher); + // Retrieves a pointer to AnnotatedKernel at + // kernels_.at(dispatch_key).front(). + const AnnotatedKernel* getKernelForDispatchKey( + DispatchKey dispatch_key) const; +}; + +} // namespace impl +} // namespace c10 diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/dispatch/OperatorOptions.h b/phivenv/Lib/site-packages/torch/include/ATen/core/dispatch/OperatorOptions.h new file mode 100644 index 0000000000000000000000000000000000000000..22661a720f92bccfd74803ae70be8c5c576287e0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/dispatch/OperatorOptions.h @@ -0,0 +1,30 @@ +#pragma once + +#include + +namespace c10 { + +enum class AliasAnalysisKind : uint8_t { + INTERNAL_SPECIAL_CASE, + CONSERVATIVE, // The most conservative alias analysis type, assumes + // side-effects. This is the default analysis. + FROM_SCHEMA, + PURE_FUNCTION +}; + +#if !defined(_MSC_VER) +constexpr // Our current MSVC version has a bug that doesn't allow this to be + // constexpr. +#endif + inline const char* + toString(AliasAnalysisKind aliasAnalysisKind) { + return (aliasAnalysisKind == AliasAnalysisKind::CONSERVATIVE) ? "CONSERVATIVE" + : (aliasAnalysisKind == AliasAnalysisKind::FROM_SCHEMA) ? "FROM_SCHEMA" + : (aliasAnalysisKind == AliasAnalysisKind::PURE_FUNCTION) + ? "PURE_FUNCTION" + : (aliasAnalysisKind == AliasAnalysisKind::INTERNAL_SPECIAL_CASE) + ? "INTERNAL_SPECIAL_CASE" + : "UNKNOWN"; +} + +} // namespace c10 diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/dispatch/RegistrationHandleRAII.h b/phivenv/Lib/site-packages/torch/include/ATen/core/dispatch/RegistrationHandleRAII.h new file mode 100644 index 0000000000000000000000000000000000000000..1ef8a9224ba8e72b768826d23e228fe12aef65b8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/dispatch/RegistrationHandleRAII.h @@ -0,0 +1,36 @@ +#pragma once + +#include + +namespace c10 { + +class RegistrationHandleRAII final { + public: + explicit RegistrationHandleRAII(std::function onDestruction) + : onDestruction_(std::move(onDestruction)) {} + + ~RegistrationHandleRAII() { + if (onDestruction_) { + onDestruction_(); + } + } + + RegistrationHandleRAII(const RegistrationHandleRAII&) = delete; + RegistrationHandleRAII& operator=(const RegistrationHandleRAII&) = delete; + + RegistrationHandleRAII(RegistrationHandleRAII&& rhs) noexcept + : onDestruction_(std::move(rhs.onDestruction_)) { + rhs.onDestruction_ = nullptr; + } + + RegistrationHandleRAII& operator=(RegistrationHandleRAII&& rhs) noexcept { + onDestruction_ = std::move(rhs.onDestruction_); + rhs.onDestruction_ = nullptr; + return *this; + } + + private: + std::function onDestruction_; +}; + +} // namespace c10 diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/dynamic_type.h b/phivenv/Lib/site-packages/torch/include/ATen/core/dynamic_type.h new file mode 100644 index 0000000000000000000000000000000000000000..61f51821c08acaa30001fbc5cd8aa5b39f9f32a1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/dynamic_type.h @@ -0,0 +1,246 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace c10 { + +using DynamicTypeBits = std::uint32_t; +#define DYNAMIC_TYPE_BIT(x) (1u << x) + +constexpr DynamicTypeBits kDynamicCovariantTypeBit = DYNAMIC_TYPE_BIT(31); +constexpr DynamicTypeBits kDynamicAnyTypeBit = DYNAMIC_TYPE_BIT(30); + +constexpr DynamicTypeBits kDynamicNoneTypeBit = DYNAMIC_TYPE_BIT(1); +constexpr DynamicTypeBits kDynamicIntTypeBit = DYNAMIC_TYPE_BIT(3); +constexpr DynamicTypeBits kDynamicFloatTypeBit = DYNAMIC_TYPE_BIT(4); +constexpr DynamicTypeBits kDynamicComplexTypeBit = DYNAMIC_TYPE_BIT(5); +constexpr DynamicTypeBits kDynamicListTypeBit = DYNAMIC_TYPE_BIT(7); +constexpr DynamicTypeBits kDynamicTupleTypeBit = DYNAMIC_TYPE_BIT(8); +constexpr DynamicTypeBits kDynamicClassTypeBit = DYNAMIC_TYPE_BIT(10); + +#define FORALL_DYNAMIC_TYPES(_) \ + _(Tensor, DYNAMIC_TYPE_BIT(0), 1) \ + _(None, kDynamicNoneTypeBit, 1) \ + _(Bool, DYNAMIC_TYPE_BIT(2), 1) \ + _(Int, kDynamicIntTypeBit, 1) \ + _(Float, kDynamicFloatTypeBit, 1) \ + _(Complex, kDynamicComplexTypeBit, 1) \ + _(Number, \ + (kDynamicIntTypeBit | kDynamicFloatTypeBit | kDynamicComplexTypeBit), \ + 1) \ + _(String, DYNAMIC_TYPE_BIT(6), 1) \ + _(List, kDynamicListTypeBit, 0) \ + _(Tuple, (kDynamicTupleTypeBit | kDynamicCovariantTypeBit), 0) \ + _(Dict, DYNAMIC_TYPE_BIT(9), 0) \ + _(Class, kDynamicClassTypeBit, 0) \ + _(Optional, \ + (DYNAMIC_TYPE_BIT(11) | kDynamicNoneTypeBit | kDynamicCovariantTypeBit), \ + 0) \ + _(AnyList, (kDynamicListTypeBit | kDynamicAnyTypeBit), 1) \ + _(AnyTuple, \ + (kDynamicTupleTypeBit | kDynamicCovariantTypeBit | kDynamicAnyTypeBit), \ + 1) \ + _(DeviceObj, DYNAMIC_TYPE_BIT(12), 1) \ + _(StreamObj, DYNAMIC_TYPE_BIT(13), 1) \ + _(Capsule, DYNAMIC_TYPE_BIT(14), 1) \ + _(Generator, DYNAMIC_TYPE_BIT(15), 1) \ + _(Storage, DYNAMIC_TYPE_BIT(16), 1) \ + _(Var, DYNAMIC_TYPE_BIT(17), 0) \ + _(AnyClass, (kDynamicClassTypeBit | kDynamicAnyTypeBit), 1) \ + _(QScheme, DYNAMIC_TYPE_BIT(18), 1) \ + _(Quantizer, DYNAMIC_TYPE_BIT(19), 1) \ + _(AnyEnum, DYNAMIC_TYPE_BIT(20), 1) \ + _(RRef, DYNAMIC_TYPE_BIT(21), 0) \ + _(Future, DYNAMIC_TYPE_BIT(22), 0) \ + _(Await, DYNAMIC_TYPE_BIT(23), 0) \ + _(Any, 0xffffffff, 1) + +#define FORALL_DYNAMIC_TYPES_FAKE(_) \ + _(ScalarType, kDynamicIntTypeBit, 1) \ + _(Layout, kDynamicIntTypeBit, 1) \ + _(SymInt, kDynamicIntTypeBit, 1) \ + _(MemoryFormat, kDynamicIntTypeBit, 1) + +#define FORWARD_DECL_TYPE(NAME, _, __) struct NAME ## Type; + FORALL_DYNAMIC_TYPES(FORWARD_DECL_TYPE) + FORALL_DYNAMIC_TYPES_FAKE(FORWARD_DECL_TYPE) +#undef FORWARD_DECL_TYPE + +class DynamicType; +using DynamicTypePtr = std::shared_ptr; + +/** + * DynamicType is designed as a low dependency type system for TorchScript. The + * existing JIT types are used for both compilation and runtime, which makes + * sense for server contexts because we often compile and run the model in + * the same process, however this doesn't hold for mobile devices where we + * always compiles a model ahead of time, therefore there will be dependencies + * which are not needed, but built with mobile runtime causing binary size + * bloat, by design. Every basic type like Int, Bool or String will bring their + * vtable, typeinfo, constructor, destructor and even more data from their + * specializations for STL types to the binary causing a long tail bloat. + * + * The core problem is about the complexity to implement and maintain a single + * type system for both analysis and execution purposes. Although they should + * have the exactly same semantics, in practice implement a unified abstraction + * adds conceptual and representational overhead for both sides of the world. + * + * To address the issues, DynamicType implements a minimal subset of JIT types + * and uses a generic algorithm to test all subtyping relations. To achieve + * this, we assign each dynamic type a single integer tag to represent its + * semantics. More specifically, a dynamic type is defined as a set of "control + * bits" and "data bits", where control bits describe the special behavior when + * testing a type and data bits map to identity of each nominal type. We use bit + * operations to perform all the tests. + * + * For example, a "covariant bit" is a control bit used to describe if a type + * is covariant, right now the most used one is tuple type, and in addition to + * the control bit, tuple type's data bit is the 8th bit from the LSB. Control + * bits start from MSB and data bits start from LSB. + * + * If two types are equal, then they are subtype of each other, also if the bits + * from one type tag is subset of the other tag, it automatically becomes a + * subtype of the other. This simplifies the subtyping logic a lot, and over the + * long term it is possible to adopt this scheme on the server side as well. + * Special cases can be added but they generally should not take too much code + * size. + * + * DynamicType may or may not inherit from c10::Type because it's not the core + * requirement of DynamicType to interface with existing JIT types, but we might + * want to inherit from c10::Type to reduce the migration cost. + */ +class DynamicType : public SharedType { + using ClassTypePtr = std::shared_ptr; + + /** + * A implementation detail to support NamedTuple. + */ + struct LabeledDynamicType { + std::optional label; + DynamicTypePtr ty; + explicit LabeledDynamicType(DynamicTypePtr t) : ty(std::move(t)) {} + + bool equals(const LabeledDynamicType& other) const; + bool isSubtypeOf(const LabeledDynamicType& other) const; + }; + + public: + // TODO Change Ptr to DynamicTypePtr when all migrations are done. + using Ptr = TypePtr; + using ElementType = DynamicType; + ~DynamicType() override; + + struct Arguments { + Arguments() = default; + Arguments(c10::ArrayRef); + Arguments(const std::vector&, c10::ArrayRef); + std::vector elems; + }; + + enum class Tag : DynamicTypeBits { +#define DYNAMIC_TYPE_ITEM(NAME, VAL, _) NAME = VAL, + FORALL_DYNAMIC_TYPES(DYNAMIC_TYPE_ITEM) + FORALL_DYNAMIC_TYPES_FAKE(DYNAMIC_TYPE_ITEM) +#undef DYNAMIC_TYPE_ITEM + }; + + bool equals(const Type& rhs) const override; + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override; + std::string str() const override; + static const TypeKind Kind = TypeKind::DynamicType; + static TORCH_API DynamicTypePtr create(Type& ty); + + explicit DynamicType(Tag, Arguments); + explicit DynamicType(Tag, std::string_view, Arguments); + + DynamicType(DynamicType&& other) = delete; + DynamicType(const DynamicType&) = delete; + DynamicType& operator=(const DynamicType&) = delete; + DynamicType& operator=(DynamicType&&) = delete; + + TypePtr containedType(size_t) const override; + size_t containedTypeSize() const override; + Tag tag() const { + return tag_; + } + const std::optional& name() const { + return name_; + } + const Arguments& arguments() const { + return arguments_; + } + TORCH_API TypeKind dynamicKind() const; + + // Should be used only on the server side to restore static type information. +#ifndef C10_MOBILE + TORCH_API +#endif + TypePtr fallback() const; + + private: + bool symmetric() const override { + return false; + } + friend struct Type; + // NOTE: Here we are using SingletonOrSharedTypePtr to mean + // "original-type-because-it-was-actually-a-DynamicType or shared". + static SingletonOrSharedTypePtr create(const Type& ty); + DynamicType(const Type& other); + bool equals(const DynamicType& other) const; + + template + bool compareArguments(const DynamicType& other, const F& f) const { + if (arguments_.elems.size() != other.arguments_.elems.size()) { + return false; + } + for (size_t i = 0; i < arguments_.elems.size(); i++) { + if (!f(arguments_.elems[i], other.arguments_.elems[i])) { + return false; + } + } + return true; + } + + Tag tag_; + std::optional name_; + union { + Arguments arguments_; + ClassTypePtr class_; + }; +}; + +template +struct DynamicTypeTrait { + C10_NOINLINE static auto tagValue() { + TORCH_CHECK(false); + return DynamicType::Tag::Any; + } +}; + +namespace detail { +C10_NOINLINE DynamicTypePtr makeBaseType(DynamicType::Tag tag); +} + +#define DYNAMIC_TYPE_TAG_VALUE(NAME, _, IS_BASE_TYPE) \ + template <> \ + struct TORCH_API DynamicTypeTrait { \ + C10_ERASE static auto tagValue() { \ + return DynamicType::Tag::NAME; \ + } \ + static constexpr bool isBaseType = IS_BASE_TYPE; \ + template \ + static std::enable_if_t getBaseType() { \ + static auto type = detail::makeBaseType(tagValue()); \ + return type; \ + } \ + }; // namespace c10 +FORALL_DYNAMIC_TYPES(DYNAMIC_TYPE_TAG_VALUE) +FORALL_DYNAMIC_TYPES_FAKE(DYNAMIC_TYPE_TAG_VALUE) +#undef DYNAMIC_TYPE_TAG_VALUE + +} // namespace c10 diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/enum_tag.h b/phivenv/Lib/site-packages/torch/include/ATen/core/enum_tag.h new file mode 100644 index 0000000000000000000000000000000000000000..eeded58a4ee32ae35a4c7c33590fc3bac8bcf2a3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/enum_tag.h @@ -0,0 +1,25 @@ +#pragma once + +// @generated by torchgen/gen.py from enum_tag.h + +namespace at { + // Enum of valid tags obtained from the entries in tags.yaml + enum class Tag { + core, + cudagraph_unsafe, + data_dependent_output, + dynamic_output_shape, + flexible_layout, + generated, + inplace_view, + maybe_aliasing_or_mutating, + needs_contiguous_strides, + needs_exact_strides, + needs_fixed_stride_order, + nondeterministic_bitwise, + nondeterministic_seeded, + pointwise, + pt2_compliant_tag, + view_copy + }; +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/enum_type.h b/phivenv/Lib/site-packages/torch/include/ATen/core/enum_type.h new file mode 100644 index 0000000000000000000000000000000000000000..c3b9fd924bfacb208d1373c5ae289be77d551e71 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/enum_type.h @@ -0,0 +1,102 @@ +#pragma once + +#include + +#include + +namespace c10 { + +struct EnumType; +using EnumTypePtr = std::shared_ptr; +using EnumNameValue = std::pair; +struct TORCH_API EnumType : public NamedType { + friend struct Type; + static const TypeKind Kind = TypeKind::EnumType; + + static EnumTypePtr create( + const c10::QualifiedName& qualified_class_name, + TypePtr value, + std::vector enum_names_values, + std::weak_ptr<::torch::jit::CompilationUnit> cu) { + switch (value->kind()) { + case TypeKind::IntType: + case TypeKind::FloatType: + case TypeKind::StringType: + return EnumTypePtr(new EnumType( + qualified_class_name, + std::move(value), + std::move(enum_names_values), + std::move(cu))); + default: + TORCH_CHECK( + false, + "Cannot create Enum with value type '", + value->str(), + "', only int, float and string are supported"); + } + } + + std::string str() const override { + return "Enum<" + annotation_str() + ">"; + } + + std::string repr_str() const override { + return str(); + } + + const TypePtr& getValueType() const { + return value_type_; + } + + bool equals(const Type& rhs) const override { + if (auto* enum_rhs = rhs.castRaw()) { + return name().has_value() && name() == enum_rhs->name() && + *getValueType() == *(enum_rhs->getValueType()) && + this->compilation_unit() == enum_rhs->compilation_unit(); + } + return false; + } + + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override; + + std::shared_ptr compilation_unit() + const { + auto cu = cu_.lock(); + return cu; + } + + const QualifiedName& qualifiedClassName() const { + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + return name().value(); + } + + at::ArrayRef containedTypes() const override { + return value_type_; + } + + const at::ArrayRef enumNamesValues() const { + return enum_names_values_; + } + + private: + EnumType( + c10::QualifiedName qualified_class_name, + TypePtr value_type, + std::vector enum_names_values, + std::weak_ptr cu) + : NamedType(TypeKind::EnumType, std::move(qualified_class_name)), + value_type_(std::move(value_type)), + enum_names_values_(std::move(enum_names_values)), + cu_(std::move(cu)) {} + + std::string annotation_str_impl( + [[maybe_unused]] const TypePrinter& printer = nullptr) const override { + return qualifiedClassName().qualifiedName(); + } + + TypePtr value_type_; + std::vector enum_names_values_; + std::weak_ptr<::torch::jit::CompilationUnit> cu_; +}; + +} // namespace c10 diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/function.h b/phivenv/Lib/site-packages/torch/include/ATen/core/function.h new file mode 100644 index 0000000000000000000000000000000000000000..9f7bbe8b6e8470049c6f9b7e94a9054b1ae07892 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/function.h @@ -0,0 +1,114 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace c10 { +struct FunctionSchema; +} + +namespace at { +TORCH_API void launch(std::function func); +} + +namespace torch::jit { + +struct Graph; +struct Code; + +namespace mobile { +struct Code; +} + +using Stack = std::vector; +using Kwargs = std::unordered_map; +struct RecursiveMethodCallError : public std::exception {}; +using TaskLauncher = std::function)>; + +TORCH_API void preoptimizeGraph( + std::shared_ptr& graph, + bool disable_autocast = false); + +// A Function is a pure Graph with no implicit `self` object bound. +// It contains schema information and the executor that manages the +// execution of the function. Method is a wrapper around an +// underlying Function that also provides a `self` object. +struct TORCH_API Function { + Function() = default; + Function(const Function&) = default; + Function& operator=(const Function&) = default; + Function(Function&&) noexcept = default; + Function& operator=(Function&&) noexcept = default; + virtual std::string_view doc_string() const { + static constexpr std::string_view no_doc_string; + return no_doc_string; + } + + virtual bool isGraphFunction() const { + return false; + } + + virtual void run(Stack& stack) = 0; + + virtual c10::intrusive_ptr runAsync( + Stack& /*stack*/, + // NOLINTNEXTLINE(performance-unnecessary-value-param) + [[maybe_unused]] TaskLauncher taskLauncher = at::launch) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false); + return {}; + } + + at::IValue operator()(Stack stack, const Kwargs& kwargs = Kwargs()) { + getSchema().checkAndNormalizeInputs(stack, kwargs); + run(stack); + return stack.front(); + } + + virtual const c10::QualifiedName& qualname() const = 0; + + const std::string& name() const { + return qualname().name(); + } + + // if this isn't yet defined, run its method_creator function + virtual void ensure_defined() = 0; + + virtual const c10::FunctionSchema& getSchema() const = 0; + + virtual size_t num_inputs() const = 0; + + virtual Function& setSchema(c10::FunctionSchema schema) = 0; + + // call() defines how different interpreter implementations interacts with + // Function objects. Basically interpreters need to provide a callback to + // communicate to Functions what to do if provided a Code object. + // Alternatively we could design the signature to return an optional Code + // object, but that requires special handling the null case in interpreter + // and the fallback behavior is not well defined by interpreter but rather + // Function themselves, so a callback approach is more reasonable than + // returning values. + // If call() returns true, then callback completes successfully, otherwise + // call() returns false. + + // Overload for server interpreter, a bailout size is needed for graph + // executor. + virtual bool call( + Stack&, + std::optional, + c10::function_ref) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false); + return false; + } + + // Overload for mobile interpreter. + virtual bool call(Stack&, c10::function_ref) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false); + return false; + } + + virtual ~Function() = default; +}; +} // namespace torch::jit diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/function_schema.h b/phivenv/Lib/site-packages/torch/include/ATen/core/function_schema.h new file mode 100644 index 0000000000000000000000000000000000000000..5b9479f074deeb14f176ecb9e84f2152aa4b48a2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/function_schema.h @@ -0,0 +1,690 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +// schema as used in the compiler for resolving function calls and reporting +// errors. These objects should be constructed from C10 schema once those +// are available. + +struct Argument; +struct FunctionSchema; + +using AliasTypeSet = std::vector; + +bool operator==(const Argument& lhs, const Argument& rhs); + +struct TORCH_API Argument { + Argument( + std::string name = "", + const TypePtr& type = nullptr, + std::optional N = std::nullopt, + std::optional default_value = std::nullopt, + bool kwarg_only = false, + std::optional alias_info = std::nullopt) + : Argument(std::move(name), type, type, N, std::move(default_value), kwarg_only, std::move(alias_info)) {} + + Argument( + std::string name, + TypePtr fake_type, + TypePtr real_type, + std::optional N = std::nullopt, + std::optional default_value = std::nullopt, + bool kwarg_only = false, + std::optional alias_info = std::nullopt) + : name_(std::move(name)), + type_(fake_type ? std::move(fake_type) : TensorType::get()), + real_type_(real_type ? std::move(real_type) : type_), + N_(N), + default_value_(std::move(default_value)), + alias_info_(alias_info ? std::make_unique(std::move(*alias_info)) : nullptr), + kwarg_only_(kwarg_only) { + // this is an softly-enforced invariant for out arguments. + bool is_alias = alias_info_ != nullptr && alias_info_->isWrite(); + is_out_ = kwarg_only_ && is_alias; + } + + Argument(Argument&& rhs) noexcept = default; + + Argument(const Argument& rhs) + : name_(rhs.name_), + type_(rhs.type_), + real_type_(rhs.real_type_), + N_(rhs.N_), + default_value_(rhs.default_value_), + alias_info_(rhs.alias_info_ ? std::make_unique(*rhs.alias_info_) : nullptr), + kwarg_only_(rhs.kwarg_only_), + is_out_(rhs.is_out_) {} + + Argument& operator=(Argument&& rhs) = default; + + Argument& operator=(const Argument& rhs) { + if (this != &rhs) { + name_ = rhs.name_; + type_ = rhs.type_; + real_type_ = rhs.real_type_; + N_ = rhs.N_; + default_value_ = rhs.default_value_; + alias_info_ = rhs.alias_info_ ? std::make_unique(*rhs.alias_info_) : nullptr; + kwarg_only_ = rhs.kwarg_only_; + is_out_ = rhs.is_out_; + } + return *this; + } + ~Argument() = default; + + const std::string& name() const { + return name_; + } + const TypePtr& type() const { + return type_; + } + // if type() is non-null, this is guaranteed to be non-null (if no real + // type was provided, this takes on type()'s value) + const TypePtr& real_type() const { + return real_type_; + } + const std::optional& N() const { + return N_; + } + const std::optional& default_value() const { + return default_value_; + } + bool kwarg_only() const { + return kwarg_only_; + } + + bool is_out() const { + return is_out_; + } + + [[nodiscard]] const AliasInfo* alias_info() const { + return alias_info_.get(); + } + + bool is_inferred_type() const { + bool is_inferred_type = false; + TORCH_INTERNAL_ASSERT(type_); + if (auto pt = type_->cast()) { + if (pt->isInferredType()) { + is_inferred_type = true; + } + } + return is_inferred_type; + } + + std::string formatTypeMismatchMsg(const std::string& actual_type) const { + std::string inferred_type_hint; + if (is_inferred_type()) { + inferred_type_hint = c10::str( + "Inferred '", + name(), + "' to be of type 'Tensor' ", + "because it was not annotated with an explicit type.\n"); + } + return c10::str( + "Expected a value of type '", + type()->repr_str(), + "' for argument '", + name(), + "' but instead found type '", + actual_type, + "'.\n", + inferred_type_hint); + } + + Argument cloneWithType(const TypePtr& new_type) const { + return Argument( + name_, + new_type, + N_, + default_value_, + kwarg_only_, + alias_info_ ? std::optional(*alias_info_) : std::nullopt); + } + + // this function checks whether this Argument is backward compatible with + // the old one. we consider the following cases are backward compatible: + // 1) two arguments are equal + // 2) this arg's type should be subtype of old + // 3) this arg must provide the same default value if old arg has one, + bool isBackwardCompatibleWith( + const Argument& old, + std::ostream* why_not=nullptr) const; + + // this function checks whether this Argument is forward compatible with + // the old one. we consider the following cases are forward compatible: + // 1) two arguments are equal + // 2) this arg's type should be subtype of old + // 3) this arg must provide the same default value if old arg has one, + bool isForwardCompatibleWith( + const Argument& old, + std::ostream* why_not = nullptr) const; + + private: + std::string name_; + TypePtr type_; + TypePtr real_type_; // this is ScalarType, not int, e.g. + // for list types, an optional statically known length for the list + // e.g. for int[3]: type = ListType::ofInts(), N = 3 + // If present, this will allow scalars to be broadcast to this length to + // become a list. + std::optional N_; + + std::optional default_value_; + // AliasInfo is huge, so let's only allocate memory for it if + // necessary (which it isn't during schema parsing on startup, to + // give a pertinent example). + std::unique_ptr alias_info_; + // is this only specifiable as a keyword argument? + bool kwarg_only_; + // marks if the argument is out variant of the schema + bool is_out_; +}; + +inline bool operator==(const Argument& lhs, const Argument& rhs) { + return lhs.name() == rhs.name() + && *lhs.type() == *rhs.type() + && lhs.N() == rhs.N() + && lhs.default_value() == rhs.default_value() + && lhs.kwarg_only() == rhs.kwarg_only() + && (lhs.alias_info() == rhs.alias_info() + || (lhs.alias_info() != nullptr && rhs.alias_info() != nullptr + && *lhs.alias_info() == *rhs.alias_info())); +} + +inline bool operator!=(const Argument& lhs, const Argument& rhs) { + return !(lhs == rhs); +} + +enum struct TORCH_API SchemaArgType { input, output }; + +/** + * struct SchemaArgument + * + * Structure used to represent arguments or returns for a schema. + */ +struct TORCH_API SchemaArgument { + SchemaArgType type; + size_t index; + SchemaArgument(SchemaArgType tpe, size_t idx) : type(tpe), index(idx) {} + bool operator==(const SchemaArgument& rhs) const { + return type == rhs.type && index == rhs.index; + } +}; + +bool operator==(const FunctionSchema& lhs, const FunctionSchema& rhs); + +struct TORCH_API FunctionSchema { + FunctionSchema( + std::string name, + std::string overload_name, + std::vector arguments, + std::vector returns, + bool is_vararg = false, + bool is_varret = false) + : name_({std::move(name), std::move(overload_name)}), + arguments_(std::move(arguments)), + returns_(std::move(returns)), + is_vararg_(is_vararg), + is_varret_(is_varret) { + checkSchema(); + } + + FunctionSchema( + Symbol name, + std::string overload_name, + std::vector arguments, + std::vector returns, + bool is_vararg = false, + bool is_varret = false) + : FunctionSchema( + name.toQualString(), + std::move(overload_name), + std::move(arguments), + std::move(returns), + is_vararg, + is_varret) { + checkSchema(); + } + + // Checks whether this schema is backward compatible with the old one. + // The following conditions must be true: + // [Function structure] The new schema's name, overload-name, varargs, and + // return arity are the same. + // [Output Narrowing] The new schema's output type must be the same class + // or inherit from the old schema's output type. + // [Argument count] The new schema must have at least as many arguments as + // the old schema (considering the list of positional and kwargs). + // [Arg Compatibility] Every argument in the old schema has a corresponding + // argument in the new schema that: + // * is at the same position. + // * has the same name. + // * is either positional, or kwarg and the old argument was kwarg. + // * has the same type, or the old argument's type inherits from the + // new argument's type. + // [Default Values] Every new argument must have a default value. + // E.g. + // OK f_new(a, b, c=1) => f_old(a, b) + // NOK f_new(a, c=1, *, b) => f_old(a, *, b) + // OK f_new(a, b, *, c) => f_old(a, *, b, c) + // NOK f_new(a, *, b, c) -> f_old(a, b, *, c) + // NOK f_new(a, *, c, b) => f_old(a, *, b, c) + // OK f_new(a, *, b, c, d=1) => f_old(a, *, b, c) + bool isBackwardCompatibleWith( + const FunctionSchema& old, + std::ostream* why_not = nullptr) const; + + // Checks whether this schema is forward compatible with the old one. + // The following conditions must be true: + // [Function structure] The new schema's name, overload-name, varargs, and + // return arity are the same. + // [Output Narrowing] The new schema's output type must be the same class + // or inherit from the old schema's output type. + // [Arg Compatibility] Every argument in the old schema has a corresponding + // argument in the new schema that: + // * is at the same position. + // * has the same name. + // * is either positional, or kwarg and the old argument was kwarg. + // * has the same type, or the old argument's type inherits from the + // new argument's type. + // [Default Values] Every new argument must have a default value. + // Each default value type should NOT be a container type. + // [Positioning] All defaults arguments MUST go after either old + // default arguments or the end of positional arguments + // and right BEFORE all out arguments + bool isForwardCompatibleWith( + const FunctionSchema& old, + std::ostringstream& why_not) const; + + private: + OperatorName name_; + std::vector arguments_; + std::vector returns_; + // if true then this schema takes an arbitrary number of additional arguments + // after the argument specified in arguments + // currently this is used primarily to represent 'primitive' operators whose + // arguments are not checked by schema + bool is_vararg_; + bool is_varret_; + + // if no alias information is directly specified, what kind of "default" + // alias information should we infer? + // NB: due to alias analysis kind merging, this may be nullopt. Eventually + // this should always be set no matter what + std::optional alias_kind_; + + template + void checkArg(const IValue& value, const Argument& argument, std::optional pos) const; + + void checkSchema() const { + bool seen_default_arg = false; + for (const auto& arg : arguments()) { + if (arg.default_value()) { + seen_default_arg = true; + } else { + // we have historically serialized broadcasting lists wo/default values, + // so to not break BC allow lists here + if (arg.type()->kind() == ListType::Kind) { + continue; + } + TORCH_INTERNAL_ASSERT( + !seen_default_arg || arg.kwarg_only(), + "Non-default positional argument follows default argument. Parameter ", + arg.name(), + " in ", + *this); + } + } + } + + public: + + void dump() const; + + const OperatorName& operator_name() const { + return name_; + } + const std::string& name() const { + return name_.name; + } + const std::string& overload_name() const { + return name_.overload_name; + } + const std::vector& arguments() const { + return arguments_; + } + const std::vector& returns() const { + return returns_; + } + bool is_vararg() const { + return is_vararg_; + } + bool is_varret() const { + return is_varret_; + } + bool is_aliasing(const c10::SchemaArgument &argument) const { + TORCH_INTERNAL_ASSERT( + argument.index < getCorrectList(argument.type).size(), + "Invalid index for schema."); + const AliasInfo* aliasInfo = getCorrectList(argument.type)[argument.index].alias_info(); + return aliasInfo; + } + bool is_mutable() const { + return std::any_of( + arguments_.cbegin(), arguments_.cend(), [](const Argument& arg) { + const AliasInfo* aliasInfo = arg.alias_info(); + return aliasInfo && aliasInfo->isWrite(); + }); + } + bool is_mutable(const c10::SchemaArgument &argument) const { + TORCH_INTERNAL_ASSERT( + argument.index < getCorrectList(argument.type).size(), + "Invalid index for schema."); + const AliasInfo* aliasInfo = getCorrectList(argument.type)[argument.index].alias_info(); + return aliasInfo && aliasInfo->isWrite(); + } + bool is_mutable(std::string_view name) const { + std::optional index = argumentIndexWithName(name); + TORCH_INTERNAL_ASSERT( + index.has_value(), "Schema has no argument named ", name); + + return is_mutable({c10::SchemaArgType::input, static_cast(*index)}); + } + + // Returns whether lhs and rhs may alias directly. + // This does not account for cases where lhs or rhs are a container that + // may contain elements that alias the other argument. + // FunctionSchema::may_contain_alias will include that functionality. + bool may_alias(const SchemaArgument& lhs, const SchemaArgument& rhs) const; + + // Returns whether lhs and rhs may alias directly or whether lhs/rhs are a container + // that may contain elements that alias the other argument. + // bidirectional = false only returns whether lhs may contain an alias of rhs + // while bidirectional = true returns both directions. + bool may_contain_alias(const SchemaArgument& lhs, const SchemaArgument& rhs, bool bidirectional = true) const; + + // Returns whether the two AliasTypeSets contain any similarities + // ie: whether the two type sets can alias. + bool canAliasTypeSetsAlias(const std::optional &lhs, const std::optional &rhs) const; + + // Recursively Finds all contained types within the AliasTypeSet. + std::optional getAliasTypeSetContainedTypes(const std::optional &aliasTypeSet) const; + + // Similar to mapTypeToAliasTypeSet defined in alias_analysis.cpp. + // Used to map types to a type such that all types that can alias will be mapped to the same type. + // For example, calling this method on 'Optional[List[int]]' is the same as calling this method + // on 'List[int]'. + std::optional mapTypeToAliasTypeSet(const TypePtr& type) const; + + // Returns either arguments() or returns() depending on the SchemaArgType + // output => returns(), input => arguments() + const std::vector& getCorrectList(SchemaArgType type) const; + + std::optional argumentIndexWithName(std::string_view name) const { + for (const auto i : c10::irange(arguments().size())) { + if(name == arguments()[i].name()) + return i; + } + return std::nullopt; + } + FunctionSchema cloneWithName(std::string name, std::string overload_name) const { + return FunctionSchema( + std::move(name), + std::move(overload_name), + arguments(), + returns(), + is_vararg(), + is_varret() + ); + } + FunctionSchema cloneWithArguments(std::vector new_arguments) const { + return FunctionSchema( + name(), + overload_name(), + std::move(new_arguments), + returns(), + is_vararg(), + is_varret()); + } + FunctionSchema cloneWithReturns(std::vector new_returns) const { + return FunctionSchema( + name(), + overload_name(), + arguments(), + std::move(new_returns), + is_vararg(), + is_varret()); + } + + std::string formatTypeMismatchMsg( + const Argument& expected, + const std::string& actual_type, + std::optional position = std::nullopt, + std::optional value = std::nullopt) const; + + FunctionSchema cloneWithRemappedTypes( + const std::function type_map) const; + + FunctionSchema cloneWithRealTypes(bool with_symint=true) const; + + // Check that inputs have the correct types and appends any missing default + // values. + template + void checkAndNormalizeInputs( + std::vector& inputs, + const std::unordered_map& kwargs = + std::unordered_map{}) const; + + std::string findErrorInKwargs(const std::vector& kwargs) const; + + bool hasAnyAliasInfo() const { + for (const auto& arg : arguments_) { + if (arg.alias_info() != nullptr) { + return true; + } + } + for (const auto& ret : returns_) { + if (ret.alias_info() != nullptr) { + return true; + } + } + return false; + } + + + // TODO remove the mutation here + bool isDefaultAliasAnalysisKind() const { + return !alias_kind_; + } + AliasAnalysisKind aliasAnalysis() const { + return alias_kind_.value_or(AliasAnalysisKind::CONSERVATIVE); + } + void setAliasAnalysis(AliasAnalysisKind v) { + alias_kind_ = v; + } + + std::optional getNamespace() const { + return name_.getNamespace(); + } + + // Returns true if we successfully set the namespace (as there + // was none set, and false otherwise) + bool setNamespaceIfNotSet(const char* ns) { + return name_.setNamespaceIfNotSet(ns); + } + + // can a function with this schema be substituted for a function of rhs's + // schema and have the program typecheck? + // as_method - if true, treat this schema as a method and ignore + // the first argument, which will be the object in both cases + bool isSubtypeOf(const FunctionSchema& rhs, bool as_method, std::ostream* why_not=nullptr) const; +}; + +inline bool operator==(const FunctionSchema& lhs, const FunctionSchema& rhs) { + return lhs.name() == rhs.name() + && lhs.overload_name() == rhs.overload_name() + && lhs.arguments() == rhs.arguments() + && lhs.returns() == rhs.returns() + && lhs.is_vararg() == rhs.is_vararg() + && lhs.is_varret() == rhs.is_varret(); +} + +inline bool operator!=(const FunctionSchema& lhs, const FunctionSchema& rhs) { + return !(lhs == rhs); +} + +// print out Argument, which is compatible with FunctionSchema parser +// full format: Type(alias)? name=default_value +inline std::ostream& operator<<(std::ostream& out, const Argument& arg) { + + // for adjusting the ? position. + // in schema, we have Tensor?(a!) input, and t(a!)?. + // however, t?(a!) doesn't work with schema parser. + // so we always use Type(alias)? format + // real_type versus fake_type: in order to be compatible with FunctionSchema + // parser, printing an argument with either MemoryFormat or Layout type should + // give us the original schema string, hence printing out real_type. + auto type = arg.real_type(); + bool is_opt = type->kind() == OptionalType::Kind; + auto unopt_type = is_opt ? type->castRaw()->getElementType() : type; + + if (unopt_type->kind() == ListType::Kind) { + // sized lists get size N from arg, not type + auto list = unopt_type->cast(); + out << list->getElementType()->str(); + if (arg.alias_info() && !arg.alias_info()->containedTypes().empty()){ + out << arg.alias_info()->containedTypes()[0]; + } + std::string N; + if (arg.N()) { + N = std::to_string(*arg.N()); + } + out << "[" << N << "]"; + } else { + out << unopt_type->str(); + } + + // print alias info if it has beforeSets. + if (arg.alias_info() && !arg.alias_info()->beforeSets().empty()) { + out << *arg.alias_info(); + } + + if (is_opt) { + out << "?"; + } + + if (!arg.name().empty()) { + out << " " << arg.name(); + } + + if (arg.default_value()) { + out << "="; + if ((type->kind() == c10::TypeKind::StringType || + unopt_type->kind() == c10::TypeKind::StringType) && + arg.default_value().value().isString()) { + printQuotedString(out, arg.default_value().value().toStringRef()); + } else if (type->kind() == TypeKind::ListType && type->castRaw()->getElementType()->kind() == c10::TypeKind::IntType) { + // We want to faithfully replicate JIT schema. + // in native_functions.yaml defaults for int arrays with a single value always look like + // int[2] stride=1 + // instead of + // int[2] stride=[1, 1] + auto default_val = arg.default_value().value().toIntList(); + if (default_val.size() > 1) { + auto all_defaults_the_same = true; + for (const auto i : c10::irange(1, default_val.size())) { + if (default_val[0] != default_val[i]) all_defaults_the_same = false; + } + if (all_defaults_the_same) { + out << default_val[0]; + } else { + out << arg.default_value().value(); + } + } else { + out << arg.default_value().value(); + } + } else { + out << arg.default_value().value(); + } + } + + return out; +} + +TORCH_API std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema); + +inline std::string toString(const FunctionSchema& schema) { + std::ostringstream str; + str << schema; + return str.str(); +} + +} // namespace c10 + +namespace std { +template<> + struct hash { + size_t operator()(const c10::SchemaArgument& arg) const + { + return c10::hash_combine(std::hash()(arg.index), std::hash()(static_cast(arg.type))); + } + }; +template<> + struct hash { + size_t operator()(const c10::Argument& arg) const + { + auto hash = std::hash{}(arg.name()); + auto type_hash = std::hash{}(arg.type()); + auto kwarg_only_hash = std::hash{}(arg.kwarg_only()); + hash = c10::hash_combine(hash, type_hash); + hash = c10::hash_combine(hash, kwarg_only_hash); + // hashing optional fields if they exist + if (arg.default_value().has_value()) { + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + auto default_value_hash = c10::hash{}(*arg.default_value()); + hash = c10::hash_combine(hash, default_value_hash); + } + if (arg.N().has_value()) { + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + auto N_hash = std::hash{}(*arg.N()); + hash = c10::hash_combine(hash, N_hash); + } + if (arg.alias_info()) { + auto alias_info_hash = std::hash{}(*arg.alias_info()); + hash = c10::hash_combine(hash, alias_info_hash); + } + return hash; + } + }; +template<> + struct hash { + size_t operator()(const c10::FunctionSchema& schema) const + { + auto hash = std::hash{}(schema.operator_name()); + auto args_hash = c10::hash>{}(schema.arguments()); + auto returns_hash = c10::hash>{}(schema.returns()); + auto is_vararg_hash = std::hash{}(schema.is_vararg()); + auto is_varret_hash = std::hash{}(schema.is_varret()); + hash = c10::hash_combine(hash, args_hash); + hash = c10::hash_combine(hash, returns_hash); + hash = c10::hash_combine(hash, is_vararg_hash); + hash = c10::hash_combine(hash, is_varret_hash); + return hash; + } + }; +} // namespace std + + +#include // IWYU pragma: keep diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/function_schema_inl.h b/phivenv/Lib/site-packages/torch/include/ATen/core/function_schema_inl.h new file mode 100644 index 0000000000000000000000000000000000000000..cbd96401ce1d57bc228e70237efdcebb80978276 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/function_schema_inl.h @@ -0,0 +1,78 @@ +#pragma once +#include +#include + +namespace c10 { + +template +inline void FunctionSchema::checkArg( + const IValue& value, + const Argument& argument, + std::optional pos) const { + if (value.isTensor() && argument.type() == TensorType::get()) { + // Fast-path for the common case + return; + } + if (value.isGenericDict() && value.toGenericDict().empty()) { + return; + } + if (!value.type()->isSubtypeOf(*argument.type())) { + TORCH_CHECK( + false, + formatTypeMismatchMsg( + argument, value.type()->repr_str(), pos)); + } +} + +template +inline void FunctionSchema::checkAndNormalizeInputs( + std::vector& inputs, + const std::unordered_map& kwargs) const { + // Do we have more inputs than the schema accepts? + TORCH_CHECK( + inputs.size() <= arguments().size(), + "Expected at most ", + arguments().size(), + " argument(s) for operator '", + name(), + "', but received ", + inputs.size(), + " argument(s). Declaration: ", + *this); + + size_t consumed_kwargs = 0; + for (const auto pos : c10::irange(arguments().size())) { + const auto& argument = arguments()[pos]; + if (pos < inputs.size()) { + checkArg(inputs[pos], argument, pos); + continue; + } + auto it = kwargs.find(argument.name()); + if (it != kwargs.end()) { + checkArg(it->second, argument, std::nullopt); + inputs.push_back(it->second); + consumed_kwargs++; + continue; + } + if (argument.default_value()) { + inputs.push_back(*argument.default_value()); + continue; + } + TORCH_CHECK(false, + name(), + "() is missing value for argument '", + argument.name(), + "'. Declaration: ", + *this); + } + if (consumed_kwargs != kwargs.size()) { + std::vector names; + names.reserve(kwargs.size()); + for(const auto& k : kwargs) { + names.emplace_back(k.first); + } + TORCH_CHECK(false, findErrorInKwargs(names)); + } +} + +} // namespace c10 diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/functional.h b/phivenv/Lib/site-packages/torch/include/ATen/core/functional.h new file mode 100644 index 0000000000000000000000000000000000000000..df1c209b55f33c4f2d1bcb3c6bd933031c8e3f7d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/functional.h @@ -0,0 +1,54 @@ +#pragma once + +#include +#include + +namespace c10 { + +// The passed in function must take T by value (T), or by +// const reference (const T&); taking T by non-const reference +// will result in an error like: +// +// error: no type named 'type' in 'class std::invoke_result' +// +// No explicit template parameters are required. + +// Overload for explicit function and ArrayRef +template +inline auto fmap(const T& inputs, const F& fn) -> std::vector { + std::vector r; + r.reserve(inputs.size()); + for(const auto & input : inputs) + r.push_back(fn(input)); + return r; +} + +// C++ forbids taking an address of a constructor, so here's a workaround... +// Overload for constructor (R) application +template +inline std::vector fmap(const T& inputs) { + std::vector r; + r.reserve(inputs.size()); + for(auto & input : inputs) + r.push_back(R(input)); + return r; +} + +template +inline std::vector filter(at::ArrayRef inputs, const F& fn) { + std::vector r; + r.reserve(inputs.size()); + for(auto & input : inputs) { + if (fn(input)) { + r.push_back(input); + } + } + return r; +} + +template +inline std::vector filter(const std::vector& inputs, const F& fn) { + return filter(static_cast>(inputs), fn); +} + +} // namespace c10 diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/grad_mode.h b/phivenv/Lib/site-packages/torch/include/ATen/core/grad_mode.h new file mode 100644 index 0000000000000000000000000000000000000000..5e7dc5b0ad1ca9ca11f325cb6c5985ffa9815efc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/grad_mode.h @@ -0,0 +1,10 @@ +#pragma once + +#include +#include + +namespace at { + using GradMode = c10::GradMode; + using AutoGradMode = c10::AutoGradMode; + using NoGradGuard = c10::NoGradGuard; +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/interned_strings.h b/phivenv/Lib/site-packages/torch/include/ATen/core/interned_strings.h new file mode 100644 index 0000000000000000000000000000000000000000..5e18c518ab5f64687360c914eff8e205b698511b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/interned_strings.h @@ -0,0 +1,355 @@ +#pragma once + +#include + +#include +#include + +namespace c10 { + +#define FORALL_NS_SYMBOLS(_) \ + _(namespaces, prim) \ + _(namespaces, prims) \ + _(namespaces, nvprims) \ + _(namespaces, aten) \ + _(namespaces, cuda) \ + _(namespaces, onnx) \ + _(namespaces, attr) \ + _(namespaces, scope) \ + _(namespaces, user) \ + _(namespaces, _caffe2) \ + _(namespaces, dimname) \ + _(namespaces, namespaces) \ + _(prim, Assign) \ + _(prim, BroadcastingChunk) \ + _(prim, BroadcastSizes) \ + _(prim, ReductionSizes) \ + _(prim, Constant) \ + _(prim, ChunkSizes) \ + _(prim, ConstantMKLDNNTensor) \ + _(prim, BroadcastMKLDNNTensors) \ + _(prim, MKLDNNGroup) \ + _(prim, MKLDNNHardSwish) \ + _(prim, MKLDNNHardSigmoid) \ + _(prim, MKLDNNHardTanh) \ + _(prim, MKLDNNClamp) \ + _(prim, StaticRuntimeCopyOuts) \ + _(prim, Drop) \ + _(prim, Eval) \ + _(prim, Expand) /* onnx */ \ + _(prim, FusionGroup) \ + _(prim, CudaFusionGroup) \ + _(prim, CudaFusionGuard) \ + _(prim, oneDNNFusionGroup) \ + _(prim, oneDNNFusionGuard) \ + _(prim, FunctionalGraph) \ + _(prim, add_optional) \ + _(prim, view_copy) \ + _(prim, permute_copy) \ + _(prim, reshape_copy) \ + _(prim, squeeze_copy) \ + _(prim, t_copy) \ + _(prim, transpose_copy) \ + _(prim, unsqueeze_copy) \ + _(prim, flatten_copy) \ + _(prim, expand_copy) \ + _(prim, expand_as_copy) \ + _(prim, DifferentiableGraph) \ + _(prim, TensorExprGroup) \ + _(prim, TensorExprDynamicGroup) \ + _(prim, StaticSubgraph) \ + _(prim, If) \ + _(prim, Jump) /* debug */ \ + _(prim, JumpNZ) /* debug */ \ + _(prim, JumpZ) /* debug */ \ + _(prim, Load) \ + _(prim, Loop) \ + _(prim, Param) \ + _(prim, PackPadded) /* onnx */ \ + _(prim, PadPacked) /* onnx */ \ + _(prim, Placeholder) /* debug */ \ + _(prim, Print) \ + _(prim, EmptyListLiteral) \ + _(prim, LegacyTypedConstructor) \ + _(prim, PythonOp) \ + _(prim, IgnoredPythonOp) \ + _(prim, Reverse) \ + _(prim, Return) \ + _(prim, ReturnStmt) \ + _(prim, BreakStmt) \ + _(prim, ContinueStmt) \ + _(prim, ComprehensionScope) \ + _(prim, Store) \ + _(prim, AutogradZero) \ + _(prim, AutogradAnyNonZero) \ + _(prim, AutogradAllNonZero) \ + _(prim, AutogradAllZero) \ + _(prim, Starred) \ + _(prim, TupleConstruct) \ + _(prim, TupleUnpack) \ + _(prim, TupleIndex) \ + _(prim, TupleSlice) \ + _(prim, ListConstruct) \ + _(prim, ListUnpack) \ + _(prim, DictConstruct) \ + _(prim, ModuleContainerIndex) \ + _(prim, EnumName) \ + _(prim, EnumValue) \ + _(prim, StringIndex) \ + _(prim, NumToTensor) \ + _(prim, Uninitialized) \ + _(prim, VarConcat) \ + _(prim, VarStack) \ + _(prim, With) \ + _(prim, Enter) \ + _(prim, Exit) \ + _(prim, IfThenElse) \ + _(aten, Bool) \ + _(aten, Int) \ + _(aten, FloatImplicit) \ + _(aten, ComplexImplicit) \ + _(aten, IntImplicit) \ + _(aten, ScalarImplicit) \ + _(aten, Float) \ + _(aten, Complex) \ + _(aten, str) \ + _(aten, Delete) \ + _(prim, device) \ + _(prim, dtype) \ + _(prim, layout) \ + _(prim, id) \ + _(prim, requires_grad) \ + _(prim, MakeTestTensor) /* test */ \ + _(prim, AutogradAdd) \ + _(prim, GradOf) \ + _(aten, grad) \ + _(aten, backward) \ + _(prim, Guard) \ + _(prim, BailOut) \ + _(prim, TypeCheck) \ + _(prim, RequiresGradCheck) \ + _(prim, FallbackGraph) \ + _(prim, FusedConcat) \ + _(prim, ConstantChunk) \ + _(prim, MMTreeReduce) \ + _(prim, MMBatchSide) \ + _(prim, list) \ + _(prim, dict) \ + _(prim, min) \ + _(prim, max) \ + _(prim, abs) \ + _(aten, divmod) \ + _(prim, zip) \ + _(prim, enumerate) \ + _(prim, range) \ + _(prim, rangelist) \ + _(prim, isinstance) \ + _(prim, tolist) \ + _(prim, unchecked_cast) \ + _(aten, _grad_sum_to_size) \ + _(aten, _size_if_not_equal) \ + _(aten, _ncf_unsqueeze) \ + _(aten, warn) \ + _(aten, sorted) \ + _(aten, floordiv) \ + _(aten, __range_length) \ + _(aten, __derive_index) \ + _(aten, __round_to_zero_floordiv) \ + _(aten, is_scripting) \ + _(aten, _unwrap_optional) \ + _(prim, fork) \ + _(prim, awaitable) \ + _(prim, forkClosure) \ + _(prim, awaitableClosure) \ + _(prim, awaitable_nowait) \ + _(prim, awaitable_wait) \ + _(prim, RaiseException) \ + _(prim, Closure) \ + _(prim, CreateObject) \ + _(prim, SetAttr) \ + _(prim, GetAttr) \ + _(prim, HasAttr) \ + _(prim, profile) \ + _(prim, profile_ivalue) \ + _(prim, AddStatValue) \ + _(prim, TimePoint) \ + _(prim, CallFunction) \ + _(prim, CallMethod) \ + _(prim, LoopContinuation) \ + _(prim, annotate) \ + _(prim, TracedModuleForward) \ + _(prim, TracedFork) \ + _(prim, TracedAttr) \ + _(prim, rpc_async) \ + _(prim, rpc_sync) \ + _(prim, rpc_remote) \ + _(prim, is_cuda) \ + _(aten, append) \ + _(aten, as_tensor) \ + _(aten, adaptive_avg_pool2d_backward) \ + _(aten, dim) \ + _(aten, format) \ + _(aten, percentFormat) \ + _(aten, __not__) \ + _(aten, __is__) \ + _(aten, __isnot__) \ + _(aten, _ger) \ + _(aten, __getitem__) \ + _(aten, _set_item) \ + _(aten, manual_seed) \ + _(aten, device) \ + _(aten, hash) \ + _(aten, len) \ + _(aten, list) \ + _(aten, dict) \ + _(aten, wait) \ + _(aten, save) \ + _(aten, keys) \ + _(aten, ord) \ + _(aten, chr) \ + _(aten, hex) \ + _(aten, oct) \ + _(aten, clear) \ + _(aten, setdefault) \ + _(aten, bin) \ + _(aten, pop) \ + _(aten, insert) \ + _(aten, tensor) \ + _(prim, unchecked_unwrap_optional) \ + _(aten, __contains__) \ + _(prim, BailoutTemplate) \ + _(prim, grad) \ + _(cuda, _set_device) \ + _(cuda, set_stream) \ + _(cuda, _current_device) \ + _(cuda, synchronize) \ + _(aten, has_torch_function) \ + _(aten, is_autocast_enabled) \ + _(aten, is_autocast_cpu_enabled) \ + _(aten, is_autocast_xla_enabled) \ + _(aten, get_autocast_dtype) \ + _(aten, is_autocast_mps_enabled) \ + FORALL_ATEN_BASE_SYMBOLS(_) \ + _(onnx, Add) \ + _(onnx, Concat) \ + _(onnx, Constant) \ + _(onnx, ConstantFill) \ + _(onnx, Div) \ + _(onnx, GRU) \ + _(onnx, Gather) \ + _(onnx, Gemm) \ + _(onnx, LSTM) \ + _(onnx, MatMul) \ + _(onnx, Min) \ + _(onnx, Max) \ + _(onnx, Mul) \ + _(onnx, Pow) \ + _(onnx, RNN) \ + _(onnx, Shape) \ + _(onnx, Size) \ + _(onnx, Slice) \ + _(onnx, Softmax) \ + _(onnx, Squeeze) \ + _(onnx, Sub) \ + _(onnx, Transpose) \ + _(onnx, Unsqueeze) \ + _(onnx, Loop) \ + _(onnx, If) \ + _(onnx, Reshape) \ + _(onnx, Expand) \ + _(onnx, Equal) \ + _(onnx, Greater) \ + _(onnx, GreaterOrEqual) \ + _(onnx, Less) \ + _(onnx, LessOrEqual) \ + _(onnx, Not) \ + _(aten, ATen) \ + _(onnx, Split) \ + _(onnx, ConstantOfShape) \ + _(onnx, Cast) \ + _(onnx, Mod) \ + _(onnx, Sqrt) \ + _(onnx, SplitToSequence) \ + _(onnx, SequenceAt) \ + _(onnx, SequenceConstruct) \ + _(onnx, SequenceEmpty) \ + _(onnx, SequenceInsert) \ + _(onnx, SequenceErase) \ + _(onnx, ConcatFromSequence) \ + _(onnx, Identity) \ + _(onnx, SoftmaxCrossEntropyLoss) \ + _(onnx, NegativeLogLikelihoodLoss) \ + _(onnx, LogSoftmax) \ + _(onnx, ReduceL1) \ + _(onnx, ReduceL2) \ + _(onnx, Conv) \ + _(onnx, BatchNormalization) \ + _(onnx, ReduceMean) \ + _(onnx, ReduceProd) \ + _(onnx, Relu) \ + _(onnx, Neg) \ + _(onnx, NonZero) \ + _(onnx, Range) \ + _(onnx, Tile) \ + _(onnx, Where) \ + _(onnx, Optional) \ + _(onnx, OptionalGetElement) \ + _(onnx, OptionalHasElement) \ + FORALL_ATTR_BASE_SYMBOLS(_) \ + _(attr, Subgraph) \ + _(attr, ReverseSubgraph) \ + _(attr, f_real_outputs) \ + _(attr, df_input_vjps) \ + _(attr, df_input_captured_inputs) \ + _(attr, df_input_captured_outputs) \ + _(attr, df_output_vjps) \ + _(attr, axes) \ + _(attr, symbolic_shape_inputs) \ + _(attr, allow_stack_outputs) \ + _(attr, striding_inputs_desc) \ + _(attr, striding_outputs_desc) \ + _(attr, broadcast) \ + _(attr, direction) \ + _(attr, ends) \ + _(attr, inplace) \ + _(attr, input_as_shape) \ + _(attr, is_zero) \ + _(attr, num_none) \ + _(attr, num_present) \ + _(attr, perm) \ + _(attr, starts) \ + _(attr, profiled_type) \ + _(attr, transA) \ + _(attr, transB) \ + _(attr, name) \ + _(attr, module) \ + _(attr, beg) \ + _(attr, idx) \ + _(attr, split) \ + _(attr, slot) \ + _(attr, kinds) \ + _(attr, types) \ + _(attr, scope) \ + _(attr, keepdims) \ + _(attr, cache_id) \ + _(attr, new_axis) \ + _(attr, warn_id) \ + _(attr, output_layouts) \ + _(attr, allowzero) \ + _(attr, seen_none) \ + _(attr, overload_name) \ + _(attr, node_stack_idx) + +enum class _keys : unique_t { + #define DEFINE_KEY(ns, s) ns##_##s, + FORALL_NS_SYMBOLS(DEFINE_KEY) + #undef DEFINE_KEY + num_symbols +}; + +#define DEFINE_SYMBOL(ns, s) \ + namespace ns { constexpr Symbol s(static_cast(_keys::ns##_##s)); } +FORALL_NS_SYMBOLS(DEFINE_SYMBOL) +#undef DEFINE_SYMBOL + +} // namespace c10 diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/interned_strings_class.h b/phivenv/Lib/site-packages/torch/include/ATen/core/interned_strings_class.h new file mode 100644 index 0000000000000000000000000000000000000000..73bfa71d9c27fdbe5e1459e2a93f6be40cf65a86 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/interned_strings_class.h @@ -0,0 +1,32 @@ +#include +#include +#include +#include +#include +#include + +namespace c10 { + +struct TORCH_API InternedStrings { + InternedStrings(); + Symbol symbol(const std::string& s); + std::pair string(Symbol sym); + Symbol ns(Symbol sym); + + private: + // prereq - holding mutex_ + Symbol _symbol(const std::string& s); + std::pair customString(Symbol sym); + std::unordered_map string_to_sym_; + + struct SymbolInfo { + Symbol ns; + std::string qual_name; + std::string unqual_name; + }; + std::vector sym_to_info_; + + std::mutex mutex_; +}; + +} // namespace c10 diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/ivalue.h b/phivenv/Lib/site-packages/torch/include/ATen/core/ivalue.h new file mode 100644 index 0000000000000000000000000000000000000000..6794e863bd3a1918c1f8e7d212de0c2f77da5450 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/ivalue.h @@ -0,0 +1,1589 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +class TORCH_API CustomClassHolder : public c10::intrusive_ptr_target {}; +namespace jit { +using ::torch::CustomClassHolder; +struct Function; +struct CompilationUnit; +struct Module; +} // namespace jit +} // namespace torch +namespace c10 { +template +class Dict; +template +class List; +template +class IListRef; +struct IValue; +struct ClassType; +struct Type; +class RRefInterface; + +struct ClassType; +using ClassTypePtr = std::shared_ptr; + +TORCH_API bool _fastEqualsForContainer(const IValue& lhs, const IValue& rhs); + +TORCH_API torch::jit::Function* checkObjectSortSchema( + const c10::ClassTypePtr& t, + std::stringstream& why_not); + +// A comparator that checks ordering of two IValues of same type. +typedef std::function IValueComparator; + +TORCH_API IValueComparator getLessThanComparator(const IValue& v); +TORCH_API IValueComparator getGreaterThanComparator(const IValue& v); + +namespace ivalue { +struct Tuple; +struct Future; +struct Await; +struct ConstantString; +struct GenericDict; +struct Object; +struct PyObjectHolder; +struct EnumHolder; +// We need a ComplexHolder because currently the payloads in the Union +// only take 64 bits. Since ComplexDouble takes up 128 bits, and is too big +// to fit in the IValue directly, we indirect complex numbers through an +// intrusive pointer to ComplexHolder (which contains a c10::complex). +struct ComplexHolder : c10::intrusive_ptr_target { + public: + template + ComplexHolder(c10::complex c) { + val = convert>(c); + } + ComplexHolder() = default; + c10::complex val; +}; + +// Similar to ComplexHolder, for StreamData3 +struct StreamData3Holder : c10::intrusive_ptr_target { + public: + StreamData3Holder(struct c10::StreamData3 d) : val(d) {} + StreamData3Holder() = delete; + struct c10::StreamData3 val; +}; + +} // namespace ivalue + +// This is an owning wrapper for a std::optional> +// that can be implicitly converted to a (non-owning) std::optional>. +// Its purpose is to be used in generated code to keep the vector alive +// either until the end of a statement (as a temporary), or as a saved arg +// in autograd. +template +struct OptionalArray { + std::optional> list; + + OptionalArray() = default; + OptionalArray(std::vector val) : list(std::move(val)) {} + + // Used when saving an argument for the backwards pass. + OptionalArray& operator=(std::optional> ref) { + if (ref) { + list = std::vector(ref->begin(), ref->end()); + } else { + list = std::nullopt; + } + return *this; + } + + // Used when saving an argument for the backwards pass. + OptionalArray& operator=(c10::OptionalArrayRef ref) { + if (ref) { + list = std::vector(ref->begin(), ref->end()); + } else { + list = std::nullopt; + } + return *this; + } + + operator std::optional>() { + if (!list) { + return std::nullopt; + } + return *list; + } + + operator c10::OptionalArrayRef() { + if (!list) { + return std::nullopt; + } + return *list; + } +}; + +// Capsule is an internal implementation detail of custom C++ classes. We +// define it as an owning wrapper for +// c10::intrusive_ptr This wrapper is here to serve as +// an abstraction of the type erased custom class object pointer. It also allow +// pybind11 to treat this as a standalone class to register as a separate type +// caster, instead of a custom pointer holder which the pointer holder type +// caster try to "unwrap" it automatically. +struct Capsule { + c10::intrusive_ptr obj_ptr; + explicit Capsule(c10::intrusive_ptr ptr) + : obj_ptr(std::move(ptr)) {} +}; + +// IValue is the generic tagged union used by the interpreter to hold +// all value types. +// It is a 16-byte object with an 8-byte payload and an 8-byte tag. +// The tag is currently 4 bytes to determine the type, and 1 byte +// to mark whether that type is a subtype of c10::intrusive_ptr_target and needs +// retain/release calls. + +#define TORCH_FORALL_TAGS(_) \ + _(None) \ + _(Tensor) \ + _(Storage) \ + _(Double) \ + _(ComplexDouble) \ + _(Int) \ + _(SymInt) \ + _(SymFloat) \ + _(SymBool) \ + _(Bool) \ + _(Tuple) \ + _(String) \ + _(Blob) \ + _(GenericList) \ + _(GenericDict) \ + _(Future) \ + _(Await) \ + _(Device) \ + _(Stream) \ + _(Object) \ + _(PyObject) \ + _(Uninitialized) \ + _(Capsule) \ + _(RRef) \ + _(Quantizer) \ + _(Generator) \ + _(Enum) + +// [doxygen private] +// These methods are not actually private but we don't want to document them, so +// they are marked `@private`, which hides them on the doxygen documentation for +// this page. + +/// IValue (Interpreter Value) is a tagged union over the types +/// supported by the TorchScript interpreter. IValues contain their +/// values as an `IValue::Payload`, which holds primitive types +/// (`int64_t`, `bool`, `double`, `Device`) and `Tensor` as values, +/// and all other types as a `c10::intrusive_ptr`. In order to +/// optimize performance of the destructor and related operations by +/// making the `Tensor` and `c10::intrusive_ptr` paths generate the +/// same code, we represent a null `c10::intrusive_ptr` as +/// `UndefinedTensorImpl::singleton()`, *not* `nullptr`. +/// +/// IValues are used as inputs to and outputs from the TorchScript interpreter. +/// To retrieve the value contained within an IValue, use the `.toX()` methods, +/// where `X` is the type you are trying to get. Note that neither the `.toX()` +/// methods nor the templated `.to` functions do any kind of casting, they +/// only unwrap the contained value. For example: +/// +/// \rst +/// .. code-block:: cpp +/// +/// // Make the IValue +/// torch::IValue my_ivalue(26); +/// std::cout << my_ivalue << "\n"; +/// +/// // Unwrap the IValue +/// int64_t my_int = my_ivalue.toInt(); +/// std::cout << my_int << "\n"; +/// +/// // This will throw an error! +/// // `my_ivalue` is tagged as an int and cannot be used as another type +/// torch::Tensor my_tensor = my_ivalue.toTensor(); +/// \endrst +struct TORCH_API IValue final { + IValue(const IValue& rhs) : IValue(rhs.payload, rhs.tag) { + if (isIntrusivePtr() && + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) { + c10::raw::intrusive_ptr::incref(payload.u.as_intrusive_ptr); + } + } + + IValue(IValue&& rhs) noexcept : tag(rhs.tag) { + moveFrom(std::move(rhs)); + } + + /// @private [doxygen private] + ~IValue() { + destroy(); + } + + C10_ALWAYS_INLINE IValue& operator=(IValue&& rhs) & noexcept { + if (&rhs == this) { + return *this; + } + + destroy(); + moveFrom(std::move(rhs)); + return *this; + } + + IValue& operator=(IValue const& rhs) & { + *this = IValue(rhs); + return *this; + } + + void dump() const; + + /** + * Equality comparison. The semantics are the same as Python's `==`: + * 1. Numerical types are compared by value. + * 2. Tensors compute element-wise equality, returning a BoolTensor (see: + * `torch.eq()`) + * 3. Strings are compared by value. + * 4. Sequence types (list, tuple) are compared lexicographically by + * comparing their elements. Different sequence types never compare equal. + * 5. Mappings (dict) must have equal (key, value) pairs. + * 6. If not listed above, the default behavior for is to test identity + * equality (e.g. pointer equality). + * + * Why does this return an IValue instead of a bool? Because in PyTorch, + * `tensor1 == tensor2` returns a `BoolTensor`, not a bool. + * + * NOTE: we (like Python) assume that identity equality implies value equality + * for efficiency. + * TODO: need to support customizing equality + */ + IValue equals(const IValue& rhs) const; + /** + * This implements the same semantics as `bool(lhs == rhs)` in Python. which + * is the same as `equals()` except for Tensor types. + */ + TORCH_API friend bool operator==(const IValue& lhs, const IValue& rhs); + TORCH_API friend bool operator!=(const IValue& lhs, const IValue& rhs); + + /** + * Identity comparison. Checks if `this` is the same object as `rhs`. The + * semantics are the same as Python's `is` operator. + * + * NOTE: Like in Python, this operation is poorly defined for primitive types + * like numbers and strings. Prefer to use `==` unless you really want to + * check identity equality. + */ + bool is(const IValue& rhs) const; + + /** + * Hashing for IValues. Returns an IValue-boxed int. + * + * Some notes: + * - Like eager, Tensors are hashed by looking at the pointer. This is not + * strictly correct because two value-equal tensors with different tensor + * pointers will hash differently, but we choose to reproduce the eager + * semantics. + * - Hashing is not defined on all built-in IValue types (e.g. list and + * dict), following Python. Calling `hash()` on these types will throw. + */ + IValue hash() const { + return (int64_t)IValue::hash(*this); + } + // This is defined because `c10::hash` dispatches to a function of this + // signature. See the member function `hash()`. + static size_t hash(const IValue& iv); + + /** + * @private [doxygen private] + * [container equality] + * This is an equality implementation that assumes objects with the same + * identity equal themselves, for efficiency reasons. We primarily have this + * for consistency, because Python does the same thing. This actually + * provokes user-visible changes in behavior due to quirks in torch: + * [tensor1] == [tensor1] -> True (because container equality will first + * compare identity) [tensor1] == [tensor1_copy] -> RuntimeError: + * Boolean value of Tensor with more than one value is ambiguous + */ + TORCH_API friend bool _fastEqualsForContainer( + const IValue& lhs, + const IValue& rhs); + + private: + static bool isAliasOf(const at::Tensor& a, const at::Tensor& b) { + if (a.is_sparse()) { + return isAliasOf(a._values(), b) || isAliasOf(a._indices(), b); + } + if (b.is_sparse()) { + return isAliasOf(a, b._values()) || isAliasOf(a, b._indices()); + } + if (a.is_sparse_csr()) { + return isAliasOf(a.values(), b) || isAliasOf(a.crow_indices(), b) || + isAliasOf(a.col_indices(), b); + } + if (b.is_sparse_csr()) { + return isAliasOf(a, b.values()) || isAliasOf(a, b.crow_indices()) || + isAliasOf(a, b.col_indices()); + } + + // Opaque tensors such as the ones constructed by the MKL-DNN backend + // don't have storage so we just compare their TensorImpls. + // TODO: Find way to expose alias info for opaque tensors. + if (!a.has_storage() || !b.has_storage()) { + return a.unsafeGetTensorImpl() == b.unsafeGetTensorImpl(); + } + + return a.is_alias_of(b); + } + + template + bool isListOf() const; + + public: + /// @private [doxygen private] + bool isAliasOf(const IValue& rhs) const { + if (this->tag != rhs.tag) { + // Trivially don't alias if the type is different + return false; + } + + // Tensors should be compared based on internal storage + if (this->isTensor()) { + return isAliasOf(this->toTensor(), rhs.toTensor()); + } + + if (!isIntrusivePtr()) { + // Primitive types don't alias anything + return false; + } + + AT_ASSERT(rhs.isIntrusivePtr()); + + // Other types can be compared by their ptr value + return this->payload.u.as_intrusive_ptr == rhs.payload.u.as_intrusive_ptr; + } + + /// @private [doxygen private] + size_t use_count() const noexcept { + if (isTensor()) { + return payload.as_tensor.use_count(); + } + + if (!isIntrusivePtrLegacyBehavior()) { + return 1; + } + + if (payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()) { + return 0; + } + return c10::raw::intrusive_ptr::use_count(payload.u.as_intrusive_ptr); + } + + /// @private [doxygen private] + void swap(IValue& rhs) noexcept { + if (isTensor() && rhs.isTensor()) { + std::swap(payload.as_tensor, rhs.payload.as_tensor); + } else if (isTensor()) { + at::Tensor t = std::move(payload.as_tensor); + // As far as I can tell, omitting the usual explicit destructor call + // is not UB in and of itself, and it's a slight perf win. The + // destructor is a no-op, because the moved-from Tensor is + // effectively an intrusive_ptr in the null state, so we don't need + // the behavior for correctness reasons either. Leaving this + // explanatory comment, including commented-out destructor call, to + // make this abundantly clear. + // + // payload.as_tensor.~Tensor(); + payload.u = rhs.payload.u; + new (&rhs.payload.as_tensor) at::Tensor(std::move(t)); + } else if (rhs.isTensor()) { + rhs.swap(*this); + return; + } else { + std::swap(payload.u, rhs.payload.u); + } + std::swap(tag, rhs.tag); + } + + // Accessors for subtypes are arranged together below + // While some of these accessors could be generated through templates, + // we prefer to write them manually for clarity + + IValue(at::TensorBase t) : tag(Tag::Tensor) { + new (&payload.as_tensor) at::Tensor(std::move(t)); + } + bool isTensor() const { + return Tag::Tensor == tag; + } + + private: + // Outlined error path so that toTensor() can be inlined. + [[noreturn]] void reportToTensorTypeError() const; + + public: + at::Tensor toTensor() &&; + at::Tensor& toTensor() &; + const at::Tensor& toTensor() const&; + at::TensorImpl* unsafeToTensorImpl() const { + TORCH_INTERNAL_ASSERT(isTensor()); + return payload.as_tensor.unsafeGetTensorImpl(); + } + + IValue(at::Storage s) : tag(Tag::Storage) { + payload.u.as_intrusive_ptr = + null_to_undefined_tensor(s.unsafeReleaseStorageImpl()); + } + bool isStorage() const { + return Tag::Storage == tag; + } + c10::Storage toStorage() &&; + c10::Storage toStorage() const&; + + const IValue& toIValue() const { + return *this; + } + IValue& toIValue() { + return *this; + } + + /// @private [doxygen private] + IValue(intrusive_ptr blob) : tag(Tag::Blob) { + // TODO (after Tensor merge) If we pass in a Blob holding a Tensor, extract + // and store it as a Tensor instead. + payload.u.as_intrusive_ptr = null_to_undefined_tensor(blob.release()); + } + + /// @private [doxygen private] + bool isBlob() const { + return Tag::Blob == tag; + } + + /// @private [doxygen private] + c10::intrusive_ptr toBlob() &&; + + /// @private [doxygen private] + c10::intrusive_ptr toBlob() const&; + + // Capsule. No new callsites of these APIs should + // be introduced. + static inline IValue make_capsule( + intrusive_ptr blob); + bool isCapsule() const { + return Tag::Capsule == tag; + } + c10::intrusive_ptr toCapsule() &&; + c10::intrusive_ptr toCapsule() const&; + + // Custom C++ classes + template < + typename T, + std::enable_if_t, int> = 0> + IValue(intrusive_ptr custom_class); + bool isCustomClass() const; + template + c10::intrusive_ptr toCustomClass() &&; + template + c10::intrusive_ptr toCustomClass() const&; + + // Tuple + IValue(c10::intrusive_ptr v); + + template < + typename... Args, + std::enable_if_t< + !std::disjunction_v< + std::is_lvalue_reference..., + std::negation>...>, + std::nullptr_t> = nullptr> + IValue(const std::tuple& t); + template < + typename... Args, + std::enable_if_t< + !std::disjunction_v< + std::is_lvalue_reference..., + std::negation>...>, + std::nullptr_t> = nullptr> + IValue(std::tuple&& t); + bool isTuple() const { + return Tag::Tuple == tag; + } + c10::intrusive_ptr toTuple() &&; + c10::intrusive_ptr toTuple() const&; + [[nodiscard]] ivalue::Tuple& toTupleRef() const; + + // Double + IValue(double d) : tag(Tag::Double) { + payload.u.as_double = d; + } + bool isDouble() const { + return Tag::Double == tag; + } + double toDouble() const { + if (isDouble()) { + return payload.u.as_double; + } else if (isSymFloat()) { + return toSymFloat().guard_float(__FILE__, __LINE__); + } else { + TORCH_INTERNAL_ASSERT(0, "expected double"); + } + } + + // ComplexDouble + template + IValue(c10::complex c); + bool isComplexDouble() const { + return Tag::ComplexDouble == tag; + } + c10::complex toComplexDouble() const; + + // Future + IValue(c10::intrusive_ptr v); + bool isFuture() const { + return Tag::Future == tag; + } + c10::intrusive_ptr toFuture() &&; + c10::intrusive_ptr toFuture() const&; + + IValue(c10::intrusive_ptr v); + bool isAwait() const { + return Tag::Await == tag; + } + c10::intrusive_ptr toAwait() &&; + c10::intrusive_ptr toAwait() const&; + + // RRef + IValue(c10::intrusive_ptr v); + bool isRRef() const { + return Tag::RRef == tag; + } + c10::intrusive_ptr toRRef() &&; + c10::intrusive_ptr toRRef() const&; + + // Quantizer + IValue(c10::intrusive_ptr v); + bool isQuantizer() const { + return Tag::Quantizer == tag; + } + c10::intrusive_ptr toQuantizer() &&; + c10::intrusive_ptr toQuantizer() const&; + + // Int + IValue(int64_t i) : tag(Tag::Int) { + payload.u.as_int = i; + } + + IValue(const c10::SymInt& i) { + if (auto mi = i.maybe_as_int()) { + tag = Tag::Int; + payload.u.as_int = *mi; + } else { + tag = Tag::SymInt; + payload.u.as_intrusive_ptr = i.toSymNode().release(); + } + } + + bool isSymInt() const { + return Tag::SymInt == tag; + } + + c10::SymInt toSymInt() &&; + c10::SymInt toSymInt() const&; + + IValue(const c10::SymFloat& i) { + if (i.is_symbolic()) { + tag = Tag::SymFloat; + payload.u.as_intrusive_ptr = i.toSymNodeImpl().release(); + } else { + tag = Tag::Double; + payload.u.as_double = i.as_float_unchecked(); + } + } + + bool isSymFloat() const { + return Tag::SymFloat == tag; + } + + c10::SymFloat toSymFloat() &&; + c10::SymFloat toSymFloat() const&; + + IValue(const c10::SymBool& i) { + if (auto mi = i.maybe_as_bool()) { + tag = Tag::Bool; + payload.u.as_int = *mi; + } else { + tag = Tag::SymBool; + payload.u.as_intrusive_ptr = i.toSymNodeImpl().release(); + } + } + + bool isSymBool() const { + return Tag::SymBool == tag; + } + + c10::SymBool toSymBool() &&; + c10::SymBool toSymBool() const&; + + // allow you to pass literals (3, 4) without ambiguity + IValue(int32_t i) : IValue(static_cast(i)) {} + + bool isInt() const { + return Tag::Int == tag; + } + + int64_t toInt() const { + if (isInt()) { + return payload.u.as_int; + } else if (isSymInt()) { + return toSymInt().guard_int(__FILE__, __LINE__); + } else { + TORCH_INTERNAL_ASSERT(0, "expected int"); + } + } + + // Bool + IValue(bool b) : tag(Tag::Bool) { +#if defined(__clang__) && defined(__x86_64__) + // Initializing entire payload stops valgrind's from reporting + // "jump or move depends on uninitialised value" in IValue copy constructor + // See https://github.com/pytorch/pytorch/issues/37117 + payload.u.as_int = b; +#else + payload.u.as_bool = b; +#endif + } + bool isBool() const { + return Tag::Bool == tag; + } + bool toBool() const { + if (isBool()) { + return payload.u.as_bool; + } else if (isSymBool()) { + return toSymBool().guard_bool(__FILE__, __LINE__); + } else { + TORCH_INTERNAL_ASSERT(0, "expected bool"); + } + } + + // IntList + bool isIntList() const; + bool isSymIntList() const; + c10::List toIntList() &&; + c10::List toIntList() const&; + std::vector toIntVector() const; + c10::List toSymIntList() &&; + c10::List toSymIntList() const&; + std::vector toSymIntVector() const; + at::DimVector toDimVector() const; + + // ConstantString + IValue(c10::intrusive_ptr v); + IValue(std::string v); + IValue(const char* v) : IValue(std::string(v)) {} + IValue(std::string_view v) : IValue(std::string(v)){} + bool isString() const { + return Tag::String == tag; + } + c10::intrusive_ptr toString() &&; + c10::intrusive_ptr toString() const&; + const std::string& toStringRef() const; + std::optional> toOptionalStringRef() + const; + std::string_view toStringView() const; + + // DoubleList + bool isDoubleList() const; + c10::List toDoubleList() &&; + c10::List toDoubleList() const&; + std::vector toDoubleVector() const; + + // ComplexDoubleList + bool isComplexDoubleList() const; + c10::List> toComplexDoubleList() &&; + c10::List> toComplexDoubleList() const&; + std::vector> toComplexDoubleVector() const; + + // BoolList + bool isBoolList() const; + c10::List toBoolList() &&; + c10::List toBoolList() const&; + + // TensorList + bool isTensorList() const; + c10::List toTensorList() &&; + c10::List toTensorList() const&; + std::vector toTensorVector() const; + + // OptionalTensorList + bool isOptionalTensorList() const; + c10::List> toOptionalTensorList() &&; + c10::List> toOptionalTensorList() const&; + std::vector> toOptionalTensorVector() const; + + // GenericList + IValue(c10::List v); + bool isList() const { + return Tag::GenericList == tag; + } + c10::List toList() &&; + c10::List toList() const&; + c10::ArrayRef toListRef() const; + + // Some template constructors of IValue calls another constructor recursively. + // This SFINAEs the called constructor exists. + template + using enable_if_ivalue_constructible = + std::enable_if_t, std::nullptr_t>; + + // The rule for lists is more complicated; the generic constructor is only + // acceptable if your element isn't SymInt. If you do have a SymInt element, + // then you must also, at construction time, check if you can decay the list + // into an int list (this is MANDATORY, as at a use site we may expect + // toIntList to work even if at the call site you had a SymIntArrayRef + // argument). In practice, only SymIntArrayRef is used this way, so we + // didn't bother making it work for the other constructors, we just make sure + // they're not selectable. + template + using enable_if_list_is_ivalue_constructible = std::enable_if_t< + std::is_constructible_v && !std::is_same_v, + std::nullptr_t>; + + template = nullptr> + IValue(c10::List&& v); + template = nullptr> + IValue(const c10::List& v); + template = nullptr> + IValue(at::ArrayRef v); + template = nullptr> + IValue(const std::vector& v); + template = nullptr> + IValue(std::vector&& v); + template + IValue(std::array v); + + // Manual constructors for lists of symints, which decay to int list if + // possible. To avoid ambiguous overload situations, we template them + // to prevent implicit conversions + template + using enable_if_symint = + std::enable_if_t, std::nullptr_t>; + + template = nullptr> + IValue(at::ArrayRef v); + template = nullptr> + IValue(at::OptionalArrayRef v); + template = nullptr> + IValue(const std::vector& v); + template = nullptr> + IValue(std::vector&& v); + + template + using enable_if_ilist_is_ivalue_constructible = std::enable_if_t< + std::is_constructible_v && + std::is_constructible_v::boxed_type> && + !std::is_same_v, + std::nullptr_t>; + + template = nullptr> + IValue(c10::IListRef v); + + // GenericDict + IValue(c10::Dict v); + bool isGenericDict() const { + return Tag::GenericDict == tag; + } + c10::Dict toGenericDict() &&; + c10::Dict toGenericDict() const&; + + template + IValue(c10::Dict v); + + template + /// \cond + /// DOXYGEN_CANNOT_HANDLE_CONSTRUCTORS_WITH_MACROS_SO_EXCLUDE_THIS_LINE_FROM_DOXYGEN + C10_DEPRECATED_MESSAGE( + "IValues based on std::unordered_map are slow and deprecated. Please use c10::Dict instead.") + /// \endcond + IValue(std::unordered_map v); + + template = nullptr> + IValue(std::optional v); + template = nullptr> + IValue(c10::OptionalArrayRef v); + IValue(std::nullopt_t); + + // ClassType + IValue(c10::intrusive_ptr v); + bool isObject() const { + return tag == Tag::Object; + } + c10::intrusive_ptr toObject() &&; + c10::intrusive_ptr toObject() const&; + ivalue::Object& toObjectRef() const; + + torch::jit::Module toModule() const; + bool isModule() const; + + // PyObject + IValue(c10::intrusive_ptr v); + bool isPyObject() const { + return tag == Tag::PyObject; + } + c10::intrusive_ptr toPyObjectHolder() &&; + c10::intrusive_ptr toPyObjectHolder() const&; + PyObject* toPyObject() const; + + // Enum + explicit IValue(c10::intrusive_ptr v); + bool isEnum() const { + return tag == Tag::Enum; + } + c10::intrusive_ptr toEnumHolder() &&; + c10::intrusive_ptr toEnumHolder() const&; + + // None + IValue() = default; + bool isNone() const { + return Tag::None == tag; + } + std::string toNone() const { + AT_ASSERT(isNone()); + return "None"; + } + + static IValue uninitialized() { + auto i = IValue(); + i.tag = Tag::Uninitialized; + return i; + } + + // Scalar, which gets encoded as either an Int, a Double or a ComplexDouble + IValue(const at::Scalar& s) : IValue() { + // NB: do the symbolic versions first, as isFloatingPoint is true + // for both SymFloat and double + if (s.isSymInt()) { + tag = Tag::SymInt; + payload.u.as_intrusive_ptr = s.toSymInt().toSymNode().release(); + } else if (s.isSymFloat()) { + tag = Tag::SymFloat; + payload.u.as_intrusive_ptr = s.toSymFloat().toSymNodeImpl().release(); + } else if (s.isSymBool()) { + tag = Tag::SymBool; + payload.u.as_intrusive_ptr = s.toSymBool().toSymNodeImpl().release(); + } else if (s.isFloatingPoint()) { + tag = Tag::Double; + payload.u.as_double = s.toDouble(); + } else if (s.isComplex()) { + *this = s.toComplexDouble(); + } else if (s.isBoolean()) { + tag = Tag::Bool; + payload.u.as_bool = s.toBool(); + } else { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + s.isIntegral(false), "Unknown type in Scalar"); + tag = Tag::Int; + payload.u.as_int = s.toLong(); + } + } + + bool isScalar() const { + return isDouble() || isInt() || isComplexDouble() || isBool() || + isSymInt() || isSymFloat() || isSymBool(); + } + + at::Scalar toScalar() const { + if (isDouble()) + return toDouble(); + else if (isInt()) + return toInt(); + else if (isComplexDouble()) + return toComplexDouble(); + else if (isBool()) + return toBool(); + else if (isSymInt()) + return toSymInt(); + else if (isSymFloat()) + return toSymFloat(); + else if (isSymBool()) + return toSymBool(); + TORCH_CHECK(false, "IValue is not a Scalar"); + } + + // Device + IValue(c10::Device d) : tag(Tag::Device) { + payload.u.as_device.type = d.type(); + payload.u.as_device.index = d.index(); + } + bool isDevice() const { + return Tag::Device == tag; + } + c10::Device toDevice() const { + AT_ASSERT(isDevice()); + return c10::Device(payload.u.as_device.type, payload.u.as_device.index); + } + + // Stream + IValue(c10::Stream s) : tag(Tag::Stream) { + auto v = c10::make_intrusive(s.pack3()); + payload.u.as_intrusive_ptr = v.release(); + } + c10::Stream toStream() &&; + c10::Stream toStream() const&; + bool isStream() const { + return Tag::Stream == tag; + } + + // ScalarType + IValue(ScalarType t) + : IValue(static_cast>(t)) {} + at::ScalarType toScalarType() const { + return static_cast(toInt()); + } + + // Layout + IValue(Layout l) : IValue(static_cast>(l)) {} + at::Layout toLayout() const { + return static_cast(toInt()); + } + + // MemoryFormat + IValue(MemoryFormat m) + : IValue(static_cast>(m)) {} + at::MemoryFormat toMemoryFormat() const { + return static_cast(toInt()); + } + + // QScheme + IValue(at::QScheme qscheme) : tag(Tag::Int) { + payload.u.as_int = static_cast(qscheme); + } + + at::QScheme toQScheme() const { + return static_cast(toInt()); + } + + // Dimname + IValue(at::Dimname dimname) : IValue(dimname.symbol().toQualString()) {} + + at::Dimname toDimname() const { + return at::Dimname::fromSymbol(Symbol::fromQualString(toStringRef())); + } + + // Generator + IValue(at::Generator g) : tag(Tag::Generator) { + payload.u.as_intrusive_ptr = + null_to_undefined_tensor(g.unsafeReleaseGeneratorImpl()); + } + bool isGenerator() const { + return Tag::Generator == tag; + } + at::Generator toGenerator() &&; + at::Generator toGenerator() const&; + + // for debugging + std::string tagKind() const { + switch (tag) { +#define DEFINE_CASE(x) \ + case Tag::x: \ + return #x; + TORCH_FORALL_TAGS(DEFINE_CASE) +#undef DEFINE_CASE + } + return "InvalidTag(" + std::to_string(static_cast(tag)) + ")"; + } + + // generic v.to() implementations + // that can be used in special functions like pop/push + // that use template meta-programming. + // prefer the directly named methods when you can, + // since they are simpler to understand + + // Note: if you get linker errors saying one of these is missing, + // change it to ... && = delete; and you will see better error messages for + // why However, we cannot commit this because some compiler versions barf on + // it. + template + T to() &&; + template + typename c10::detail::ivalue_to_const_ref_overload_return::type to() + const&; + + // ToOptional: convert a IValue to the Optional obj that accepts both T and + // None + template + std::optional toOptional(); + template + std::optional toOptional() const; + + /// @private [doxygen private] + /// this is a shallow comparison of two IValues to test the object identity + bool isSameIdentity(const IValue& rhs) const; + + // Computes the "official" string representation of an IValue. This produces a + // TorchScript expression that can be used to recreate an IValue with the same + // value (e.g. when we are printing constants in the serializer). + // + // Callers can use `customFormatter` to override how `repr()` prints out an + // IValue. This is useful if you have some other environment where you can + // look up values, and you want to print a reference to that environment (like + // the serializer's constant table). + // + // repr() is not necessarily defined on all objects! + std::ostream& repr( + std::ostream& stream, + std::function customFormatter) + const; + + // Computes an "informal" string representation of an IValue. This should be + // used for debugging, or servicing `print()`-like functions. + // This is different from `repr()` in that there is no expectation that we can + // exactly reconstruct an IValue from the output; feel free to use a + // concise/pretty form + TORCH_API friend std::ostream& operator<<(std::ostream& out, const IValue& v); + + bool isPtrType() const { + if (isTensor()) { + return payload.as_tensor.defined(); + } + return isIntrusivePtrLegacyBehavior(); + } + + /// @private [doxygen private] + const void* internalToPointer() const { + TORCH_INTERNAL_ASSERT( + isPtrType(), "Can only call internalToPointer() for pointer types"); + if (isTensor()) { + return payload.as_tensor.unsafeGetTensorImpl(); + } else { + return payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton() + ? payload.u.as_intrusive_ptr + : nullptr; + } + } + + template + TypePtr type() const; + + // Detect aliased tensors. + struct HashAliasedIValue { + size_t hashTensor(const at::Tensor& ten) const { + if (ten.is_sparse()) { + // COO sparse tensors have a "values" tensor and an "indices" tensor + // so this will detect overlap of sparse tensors that share a values + // tensor, but not sparse tensors that share an indices tensor. + return hashTensor(ten._values()); + } else if (ten.is_sparse_csr()) { + // COO sparse tensors have a "values" tensor and an "indices" tensor + // so this will detect overlap of sparse tensors that share a values + // tensor, but not sparse tensors that share an indices tensor. + return hashTensor(ten.values()); + } else if (!ten.has_storage()) { + // Opaque tensors such as the ones constructed by the MKL-DNN backend + // don't have storage so we just use their TensorImpls. + // TODO: Find way to expose alias info for opaque tensors. + return reinterpret_cast(ten.unsafeGetTensorImpl()); + } else { + return reinterpret_cast(ten.storage().unsafeGetStorageImpl()); + } + } + size_t operator()(const IValue& val) const { + if (val.isTensor()) { + return hashTensor(val.toTensor()); + } + // If it is not a Tensor, then two mutable IValues alias each other only + // if they are the same pointer. + return val.payload.u.as_int; + } + }; + + struct CompAliasedIValues { + bool operator()(const IValue& lhs, const IValue& rhs) const { + return lhs.isAliasOf(rhs); + } + }; + + using HashAliasedIValues = + std::unordered_set; + using HashAliasedIValueMap = + std::unordered_map; + + struct HashIdentityIValue { + size_t operator()(const IValue& val) const { + return val.payload.u.as_int; + } + }; + + struct CompIdentityIValues { + bool operator()(const IValue& lhs, const IValue& rhs) const { + return lhs.is(rhs); + } + }; + + using HashIdentityIValues = + std::unordered_set; + using HashIdentityIValueMap = + std::unordered_map; + + // Chechs if this and rhs has a subvalues in common. + // [t1,t2] and [t2, t3] returns true. + bool overlaps(const IValue& rhs) const; + + // Inserts all subvalues of this in subValues. + void getSubValues(HashAliasedIValues& subValues) const; + + // Apply visitor to every subvalue. + // TODO: There are several places that recurse over IValue. This is fragile. + // This visitor should be used to recurse over ivalues. + void visit(const std::function& visitor) const; + IValue deepcopy(std::optional device = std::nullopt) const; + IValue deepcopy( + HashIdentityIValueMap& memo, + std::optional device = std::nullopt) const; + + private: + static c10::intrusive_ptr_target* null_to_undefined_tensor( + c10::intrusive_ptr_target* p) { + return p ? p + : static_cast( + c10::UndefinedTensorImpl::singleton()); + } + + static bool ptrEqual(const IValue& lhs, const IValue& rhs); + // NOTE: IValue tags are intentionally private. In the future we may encode + // this value different (e.g. using NaN boxing), and this would make it more + // costly to determine the tag for all types vs just determining if something + // is a particular type. Instead we want clients to use the `isX` methods when + // possible. If for performance reasons you really, absolutely, must have a jump + // table, then we can revisit this. + enum class Tag : uint32_t { +#define DEFINE_TAG(x) x, + TORCH_FORALL_TAGS(DEFINE_TAG) +#undef DEFINE_TAG + }; + +#define COUNT_TAG(x) 1 + + static constexpr auto kNumTags = TORCH_FORALL_TAGS(COUNT_TAG) 0; +#undef COUNT_TAG + + template < + class T, + class NullType = c10::detail::intrusive_target_default_null_type> + c10::intrusive_ptr moveToIntrusivePtr(); + template < + typename T, + class NullType = c10::detail::intrusive_target_default_null_type> + c10::intrusive_ptr toIntrusivePtr() const; + + void destroy() { + // We carefully construct this call to both 1) avoid UB by using + // the "wrong" one of as_tensor and as_intrusive_ptr and 2) enable + // the compiler to generate the same code for each case. It is + // surprisingly difficult to get this right. + if (isTensor() || isIntrusivePtr()) { + c10::intrusive_ptr_target* p = isTensor() + ? payload.as_tensor.unsafeGetTensorImpl() + : payload.u.as_intrusive_ptr; + c10::intrusive_ptr:: + reclaim(p); + // No need to make this destructor call! + // payload.as_tensor.~Tensor(); + } + } + + // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) + C10_ALWAYS_INLINE void moveFrom(IValue&& rhs) noexcept { + if (rhs.isTensor()) { + new (&payload.as_tensor) at::Tensor(std::move(rhs.payload.as_tensor)); + // As far as I can tell, omitting the usual explicit destructor call + // is not UB in and of itself, and it's a slight perf win. The + // destructor is a no-op, because the moved-from Tensor is + // effectively an intrusive_ptr in the null state, so we don't need + // the behavior for correctness reasons either. Leaving this + // explanatory comment, including commented-out destructor call, to + // make this abundantly clear. + // + // rhs.payload.as_tensor.~Tensor(); + } else { + payload.u = rhs.payload.u; + } + tag = rhs.tag; + rhs.clearToNone(); + } + + void clearToNone() noexcept { + payload.u.as_int = 0; + tag = Tag::None; + } + + private: + // This is the source of truth for isIntrusivePtr; edit results here + // as needed and isIntrusivePtr will pick them up. + // NOLINTBEGIN(bugprone-branch-clone) + static constexpr bool isIntrusivePtrConstexpr(Tag tag) { + switch (tag) { + case Tag::None: + return false; + case Tag::Tensor: + return false; + case Tag::Storage: + return true; + case Tag::Generator: + return true; + case Tag::Double: + return false; + case Tag::ComplexDouble: + return true; + case Tag::Int: + return false; + case Tag::SymInt: + return true; + case Tag::SymFloat: + return true; + case Tag::SymBool: + return true; + case Tag::Bool: + return false; + case Tag::Tuple: + return true; + case Tag::String: + return true; + case Tag::Blob: + return true; + case Tag::GenericList: + return true; + case Tag::GenericDict: + return true; + case Tag::Future: + return true; + case Tag::Await: + return true; + case Tag::Device: + return false; + case Tag::Stream: + return true; + case Tag::Object: + return true; + case Tag::PyObject: + return true; + case Tag::Uninitialized: + return false; + case Tag::Capsule: + return true; + case Tag::RRef: + return true; + case Tag::Quantizer: + return true; + case Tag::Enum: + return true; + } + return false; + } + // NOLINTEND(bugprone-branch-clone) + + public: + // Don't edit this just to add results for new tags; edit + // isIntrusivePtrConstexpr above. + bool isIntrusivePtr() const { + // Implementation NOTE: the switch in isIntrusivePtrConstexpr + // above is the previous production implementation of this + // function. We observed that, at least on x86_64, the generated + // instruction sequence was a similar bit vector test to what we + // have manually implemented below, except that there was an extra + // "bounds check" branch confirming, essentially, that `tag < + // kNumTags` and providing a consistent result in that case. We + // don't care about the result if tag is out of bounds, so we'd + // like to eliminate that comparison and branch; manually + // implementing this function as a bit test is the simplest way I + // could find to accomplish that elimination. + static constexpr uint32_t kTruthTableBitVector = +#define TRUTH_TABLE_ENTRY(tag) \ + (uint32_t(isIntrusivePtrConstexpr(Tag::tag)) << uint32_t(Tag::tag)) | + TORCH_FORALL_TAGS(TRUTH_TABLE_ENTRY) +#undef TRUTH_TABLE_ENTRY + 0; + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + static_cast(tag) < kNumTags, + "unexpected tag ", + static_cast(tag)); + return kTruthTableBitVector & (1 << (uint32_t(tag) % 32)); + } + + // Storage and Generator were treated specially when + // is_intrusive_ptr was stored as explicit state. This getter + // preserves the old behavior for use with WeakIValue for now. + bool isIntrusivePtrLegacyBehavior() const { + if (tag == Tag::Storage || tag == Tag::Generator) { + return payload.u.as_intrusive_ptr != + c10::UndefinedTensorImpl::singleton(); + } else { + return isIntrusivePtr(); + } + } + + union Payload { + // [TriviallyCopyablePayload] + // We use a nested union here so that we can make the copy easy + // and efficient in the non-tensor (i.e., trivially copyable) + // case. Specifically, we do not have to do a switch-on-tag to + // figure out which union member to assign; we can just use + // TriviallyCopyablePayload::operator=. + union TriviallyCopyablePayload { + TriviallyCopyablePayload() : as_int(0) {} + int64_t as_int; + double as_double; + bool as_bool; + // Invariant: never nullptr; null state is represented as + // c10::UndefinedTensorImpl::singleton() for consistency of + // representation with Tensor. + c10::intrusive_ptr_target* as_intrusive_ptr; + struct { + c10::DeviceType type; + DeviceIndex index; + } as_device; + } u; + static_assert(std::is_trivially_copyable_v); + at::Tensor as_tensor; + Payload() : u() {} + Payload(const Payload&) = delete; + Payload(Payload&&) = delete; + Payload& operator=(const Payload&) = delete; + Payload& operator=(Payload&&) = delete; + // NOLINTNEXTLINE(modernize-use-equals-default) + ~Payload() {} + }; + + IValue(const Payload& p, Tag t) : tag(t) { + if (isTensor()) { + new (&payload.as_tensor) at::Tensor(p.as_tensor); + } else { + payload.u = p.u; + } + } + + template + struct TagType {}; + + friend MaybeOwnedTraits; + + Payload payload; + Tag tag{IValue::Tag::None}; + friend struct WeakIValue; +}; + +struct TORCH_API WeakIValue final { + WeakIValue() = default; + + WeakIValue(const WeakIValue& rhs) + : payload(rhs.payload), + tag(rhs.tag), + is_intrusive_ptr(rhs.is_intrusive_ptr) { + if (is_intrusive_ptr && + payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) { + c10::raw::weak_intrusive_ptr::incref(payload.as_intrusive_ptr); + } + } + WeakIValue(const IValue& rhs) + : tag(rhs.tag), is_intrusive_ptr(rhs.isIntrusivePtrLegacyBehavior()) { + if (rhs.isTensor()) { + payload.as_intrusive_ptr = rhs.unsafeToTensorImpl(); + is_intrusive_ptr = true; + } else { + payload = rhs.payload.u; + } + if (is_intrusive_ptr) { + if (payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) { + c10::raw::weak_intrusive_ptr::incref(payload.as_intrusive_ptr); + } + } + } + WeakIValue(WeakIValue&& rhs) noexcept : WeakIValue() { + swap(rhs); + } + ~WeakIValue() { + if (is_intrusive_ptr && + payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) { + c10::raw::weak_intrusive_ptr::decref(payload.as_intrusive_ptr); + } + } + WeakIValue& operator=(WeakIValue&& rhs) & noexcept { + WeakIValue(std::move(rhs)).swap(*this); // this also sets rhs to None + return *this; + } + WeakIValue& operator=(WeakIValue const& rhs) & { + WeakIValue(rhs).swap(*this); + return *this; + } + void swap(WeakIValue& rhs) noexcept { + std::swap(payload, rhs.payload); + std::swap(is_intrusive_ptr, rhs.is_intrusive_ptr); + std::swap(tag, rhs.tag); + } + + bool isSameIdentity(const WeakIValue& rhs) const { + return payload.as_int == rhs.payload.as_int && tag == rhs.tag && + is_intrusive_ptr == rhs.is_intrusive_ptr; + } + + IValue lock() const { + if (!is_intrusive_ptr) { + IValue::Payload newPayload; + newPayload.u = payload; + return IValue(newPayload, tag); + } + if (IValue::Tag::Tensor == tag) { + auto temp = + c10::weak_intrusive_ptr:: + reclaim(static_cast(payload.as_intrusive_ptr)); + c10::intrusive_ptr ip( + temp.lock()); + temp.release(); + if (!ip) { + return IValue(); + } else { + return IValue(at::Tensor(std::move(ip))); + } + } else { + auto temp = c10::weak_intrusive_ptr::reclaim( + payload.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton() + ? nullptr + : payload.as_intrusive_ptr); + IValue::Payload pl; + pl.u.as_intrusive_ptr = temp.lock().release(); + temp.release(); + if (!pl.u.as_intrusive_ptr) { + return IValue(); + } else { + return IValue(pl, tag); + } + } + } + + size_t use_count() const noexcept { + if (!is_intrusive_ptr) { + return 1; + } + auto temp = c10::weak_intrusive_ptr< + c10::intrusive_ptr_target, + c10::UndefinedTensorImpl>::reclaim(payload.as_intrusive_ptr); + size_t result = temp.use_count(); + temp.release(); + return result; + } + + size_t weak_use_count() const noexcept { + if (!is_intrusive_ptr) { + return 1; + } + auto temp = c10::weak_intrusive_ptr< + c10::intrusive_ptr_target, + c10::UndefinedTensorImpl>::reclaim(payload.as_intrusive_ptr); + size_t result = temp.weak_use_count(); + temp.release(); + return result; + } + size_t hash() const { + return payload.as_int; + } + + private: + using Payload = IValue::Payload::TriviallyCopyablePayload; + Payload payload; + IValue::Tag tag{IValue::Tag::None}; + bool is_intrusive_ptr{false}; +}; + +// An owning pointer to a type. When the type is class type, it requires a pair +// of shared_ptrs to the class type and its owning CU, so that the class type is +// guaranteed to stay alive as long as we hold this object. +struct TORCH_API StrongTypePtr { + StrongTypePtr(std::shared_ptr cu, TypePtr type); + + std::shared_ptr cu_; + TypePtr type_; +}; + +// [Constant Object Weak CompilationUnit Reference] +// A non owning pointer to a type. When a class get inserted as a constant +// into a graph, if we used a strong pointer we would have a circular reference +// from Object -> CompilationUnit and CompilationUnit -> Graph (which owns the +// Constant Object) +struct TORCH_API WeakTypePtr { + WeakTypePtr(std::weak_ptr cu, TypePtr type); + + std::weak_ptr cu_; + TypePtr type_; +}; + +// internal build errors with std::variant :/ +struct WeakOrStrongCompilationUnit { + explicit WeakOrStrongCompilationUnit( + std::shared_ptr shared_cu) + : strong_ptr_(std::move(shared_cu)), weak_ptr_(std::nullopt) {} + + explicit WeakOrStrongCompilationUnit( + std::weak_ptr weak_cu) + : strong_ptr_(std::nullopt), weak_ptr_(std::move(weak_cu)) {} + + std::shared_ptr getStrongRefOrThrow() const { + TORCH_INTERNAL_ASSERT(strong_ptr_.has_value()); + return *strong_ptr_; + } + + std::weak_ptr getWeakRefOrThrow() const { + TORCH_INTERNAL_ASSERT(weak_ptr_.has_value()); + return *weak_ptr_; + } + + bool holdingStrongRef() const { + return strong_ptr_.has_value(); + } + + bool holdingEmptyStrongRef() const { + return strong_ptr_ == nullptr; + } + + std::optional> strong_ptr_; + std::optional> weak_ptr_; +}; + +// An Object will hold a non-owning Compilation Unit reference if it is a +// Constant in the graph and a Owning reference otherwise +struct TORCH_API WeakOrStrongTypePtr { + explicit WeakOrStrongTypePtr(WeakTypePtr weak) + : cu_(WeakOrStrongCompilationUnit(std::move(weak.cu_))), + type_(std::move(weak.type_)) {} + explicit WeakOrStrongTypePtr(StrongTypePtr strong) + : cu_(WeakOrStrongCompilationUnit(std::move(strong.cu_))), + type_(std::move(strong.type_)) {} + explicit WeakOrStrongTypePtr(WeakOrStrongCompilationUnit cu, TypePtr type) + : cu_(std::move(cu)), type_(std::move(type)) {} + WeakTypePtr asWeakTypePtr() const; + + WeakOrStrongCompilationUnit cu_; + TypePtr type_; + + bool holds_strong_ref() const { + return cu_.holdingStrongRef(); + } + + bool holds_empty_strong_ref() const { + return cu_.holdingEmptyStrongRef(); + } +}; + +} // namespace c10 + +#include // IWYU pragma: keep diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/ivalue_inl.h b/phivenv/Lib/site-packages/torch/include/ATen/core/ivalue_inl.h new file mode 100644 index 0000000000000000000000000000000000000000..5e442105363d56ddf8290636e0875441d4cc082d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/ivalue_inl.h @@ -0,0 +1,2569 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { +struct Function; +struct CompilationUnit; +} // namespace jit +TORCH_API bool isCustomClass(const c10::IValue& v); +} // namespace torch +namespace c10 { +struct IValue; +struct ClassType; +struct TupleType; +struct EnumType; +struct InferredType; + +// For custom class __init__ registration, we need to pass in a function +// that looks like this: [](IValue x, args...) + +// However, make_boxed_from_unboxed_functor.h automatically sets the input types +// of the function by introspecting the types of the functor (which is IValue in +// this case). However, we need the type it binds to be Foo. + +// Instead, we pass in a lambda [](ivalue_holder x, args...) from +// which getTypePtr can recover the original class pointer. + +template +struct tagged_capsule { + IValue ivalue; +}; + +template +c10::intrusive_ptr IValue::moveToIntrusivePtr() { + auto t = c10::intrusive_ptr::reclaim( + payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton() + ? NullType::singleton() + : static_cast(payload.u.as_intrusive_ptr)); + clearToNone(); + return t; +} +template +c10::intrusive_ptr IValue::toIntrusivePtr() const { + if (payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()) { + return c10::intrusive_ptr(); + } + c10::raw::intrusive_ptr::incref(payload.u.as_intrusive_ptr); + return c10::intrusive_ptr::reclaim( + static_cast(payload.u.as_intrusive_ptr)); +} + +template +intrusive_ptr static_intrusive_pointer_cast(intrusive_ptr r) { + return intrusive_ptr::reclaim(static_cast(r.release())); +} + +template +intrusive_ptr dynamic_intrusive_pointer_cast(intrusive_ptr r) { + return intrusive_ptr::reclaim(dynamic_cast(r.release())); +} + +inline c10::intrusive_ptr IValue::toFuture() && { + AT_ASSERT(isFuture(), "Expected Future but got ", tagKind()); + return moveToIntrusivePtr(); +} +inline c10::intrusive_ptr IValue::toFuture() const& { + AT_ASSERT(isFuture(), "Expected Future but got ", tagKind()); + return toIntrusivePtr(); +} +inline c10::intrusive_ptr IValue::toAwait() && { + AT_ASSERT(isAwait(), "Expected Await but got ", tagKind()); + return moveToIntrusivePtr(); +} +inline c10::intrusive_ptr IValue::toAwait() const& { + AT_ASSERT(isAwait(), "Expected Await but got ", tagKind()); + return toIntrusivePtr(); +} +inline c10::intrusive_ptr IValue::toRRef() && { + AT_ASSERT(isRRef(), "Expected RRef but got ", tagKind()); + return moveToIntrusivePtr(); +} +inline c10::intrusive_ptr IValue::toRRef() const& { + AT_ASSERT(isRRef(), "Expected RRef but got ", tagKind()); + return toIntrusivePtr(); +} +inline c10::intrusive_ptr IValue::toQuantizer() && { + AT_ASSERT(isQuantizer(), "Expected Quantizer but got ", tagKind()); + return moveToIntrusivePtr(); +} +inline c10::intrusive_ptr IValue::toQuantizer() const& { + AT_ASSERT(isQuantizer(), "Expected Quantizer but got ", tagKind()); + return toIntrusivePtr(); +} +inline c10::intrusive_ptr IValue::toString() && { + AT_ASSERT(isString(), "Expected String but got ", tagKind()); + return moveToIntrusivePtr(); +} +inline c10::intrusive_ptr IValue::toString() const& { + AT_ASSERT(isString(), "Expected String but got ", tagKind()); + return toIntrusivePtr(); +} +inline c10::intrusive_ptr IValue::toObject() && { + AT_ASSERT(isObject(), "Expected Object but got ", tagKind()); + return moveToIntrusivePtr(); +} +inline c10::intrusive_ptr IValue::toObject() const& { + AT_ASSERT(isObject(), "Expected Object but got ", tagKind()); + return toIntrusivePtr(); +} +inline c10::intrusive_ptr IValue:: + toPyObjectHolder() && { + TORCH_INTERNAL_ASSERT(isPyObject(), "Expected PyObject but got ", tagKind()); + return moveToIntrusivePtr(); +} +inline c10::intrusive_ptr IValue::toPyObjectHolder() + const& { + TORCH_INTERNAL_ASSERT(isPyObject(), "Expected PyObject but got ", tagKind()); + return toIntrusivePtr(); +} +inline c10::intrusive_ptr IValue::toEnumHolder() && { + TORCH_INTERNAL_ASSERT(isEnum(), "Expected Enum but got ", tagKind()); + return moveToIntrusivePtr(); +} +inline c10::intrusive_ptr IValue::toEnumHolder() const& { + TORCH_INTERNAL_ASSERT(isEnum(), "Expected Enum but got ", tagKind()); + return toIntrusivePtr(); +} +inline c10::complex IValue::toComplexDouble() const { + TORCH_INTERNAL_ASSERT(isComplexDouble(), "Expected ComplexDouble but got ", tagKind()); + auto ptr = toIntrusivePtr(); + return (*ptr).val; +} +inline at::Tensor IValue::toTensor() && { + if (C10_UNLIKELY(!isTensor())) { + reportToTensorTypeError(); + } + auto result = std::move(payload.as_tensor); + // As far as I can tell, omitting the usual explicit destructor call + // is not UB in and of itself, and it's a slight perf win. The + // destructor is a no-op, because the moved-from Tensor is + // effectively an intrusive_ptr in the null state, so we don't need + // the behavior for correctness reasons either. Leaving this + // explanatory comment, including commented-out destructor call, to + // make this abundantly clear. + // + // payload.as_tensor.~Tensor(); + clearToNone(); + return result; +} +inline at::Tensor& IValue::toTensor() & { + if (C10_UNLIKELY(!isTensor())) { + reportToTensorTypeError(); + } + return payload.as_tensor; +} +inline const at::Tensor& IValue::toTensor() const& { + if (C10_UNLIKELY(!isTensor())) { + reportToTensorTypeError(); + } + return payload.as_tensor; +} +inline c10::Storage IValue::toStorage() && { + AT_ASSERT(isStorage(), "Expected Storage but got ", tagKind()); + return c10::Storage( + moveToIntrusivePtr()); +} +inline c10::Storage IValue::toStorage() const& { + AT_ASSERT(isStorage(), "Expected Storage but got ", tagKind()); + return c10::Storage(toIntrusivePtr()); +} +inline c10::Stream IValue::toStream() && { + AT_ASSERT(isStream(), "Expected Stream but got ", tagKind()); + auto ptr = toIntrusivePtr(); + return c10::Stream::unpack3((*ptr).val.stream_id, + (*ptr).val.device_index, + (*ptr).val.device_type); +} +inline c10::Stream IValue::toStream() const& { + AT_ASSERT(isStream(), "Expected Stream but got ", tagKind()); + auto ptr = toIntrusivePtr(); + return c10::Stream::unpack3((*ptr).val.stream_id, + (*ptr).val.device_index, + (*ptr).val.device_type); +} +inline c10::intrusive_ptr IValue::toBlob() && { + AT_ASSERT(isBlob(), "Expected Blob but got ", tagKind()); + return moveToIntrusivePtr(); +} +inline c10::intrusive_ptr IValue::toBlob() const& { + AT_ASSERT(isBlob(), "Expected Blob but got ", tagKind()); + return toIntrusivePtr(); + ; +} +inline c10::intrusive_ptr IValue::toCapsule() && { + TORCH_INTERNAL_ASSERT(isCapsule()); + return moveToIntrusivePtr(); +} +inline c10::intrusive_ptr IValue::toCapsule() const& { + TORCH_INTERNAL_ASSERT(isCapsule()); + return toIntrusivePtr(); +} +inline at::Generator IValue::toGenerator() && { + AT_ASSERT(isGenerator(), "Expected Generator but got ", tagKind()); + return at::Generator(moveToIntrusivePtr()); +} +inline at::Generator IValue::toGenerator() const& { + AT_ASSERT(isGenerator(), "Expected Generator but got ", tagKind()); + return at::Generator(toIntrusivePtr()); +} +inline c10::SymInt IValue::toSymInt() && { + AT_ASSERT(isSymInt() || isInt(), "Expected SymInt or int but got ", tagKind()); + if (isSymInt()) { + return c10::SymInt(moveToIntrusivePtr()); + } else { + return c10::SymInt(payload.u.as_int); + } +} +inline c10::SymInt IValue::toSymInt() const& { + AT_ASSERT(isSymInt() || isInt(), "Expected SymInt or int but got ", tagKind()); + if (isSymInt()) { + return c10::SymInt(toIntrusivePtr()); + } else { + return c10::SymInt(payload.u.as_int); + } +} +inline c10::SymFloat IValue::toSymFloat() && { + AT_ASSERT(isSymFloat() || isDouble(), "Expected SymFloat or double but got ", tagKind()); + if (isSymFloat()) { + return c10::SymFloat(moveToIntrusivePtr()); + } else { + return c10::SymFloat(payload.u.as_double); + } +} +inline c10::SymFloat IValue::toSymFloat() const& { + AT_ASSERT(isSymFloat() || isDouble(), "Expected SymFloat or double but got ", tagKind()); + if (isSymFloat()) { + return c10::SymFloat(toIntrusivePtr()); + } else { + return c10::SymFloat(payload.u.as_double); + } +} +inline c10::SymBool IValue::toSymBool() && { + AT_ASSERT(isSymBool() || isBool(), "Expected SymBool or boolean but got ", tagKind()); + if (isSymBool()) { + return c10::SymBool(moveToIntrusivePtr()); + } else { + return c10::SymBool(payload.u.as_bool); + } +} + +inline c10::SymBool IValue::toSymBool() const& { + AT_ASSERT(isSymBool() || isBool(), "Expected SymBool or boolean but got ", tagKind()); + if (isSymBool()) { + return c10::SymBool(toIntrusivePtr()); + } else { + return c10::SymBool(payload.u.as_bool); + } +} + +namespace ivalue { + +void TORCH_API +checkCustomClassType(const ClassType* expected_type, const Type* actual_type); + +template +using Shared = c10::intrusive_ptr; + +// string +struct TORCH_API ConstantString final : c10::intrusive_ptr_target { + private: + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const std::string str_; + + public: + ConstantString(std::string str) : str_(std::move(str)) {} + ConstantString(std::string_view str) : str_(std::string(str)) {} + static c10::intrusive_ptr create(std::string str_); + static c10::intrusive_ptr create(std::string_view str_); + static c10::intrusive_ptr create(const char* str_); + + const std::string& string() const { + return str_; + } + std::string_view string_view() const { + return str_; + } + + operator const std::string&() const { + return string(); + } + TORCH_API friend std::ostream& operator<<( + std::ostream& out, + const ConstantString& v); +}; + +struct Future; + +struct TORCH_API TupleElements { + private: + size_t inlineSize_; + // We represent TupleElements this way to save doing a heap + // allocation in the common (at least for unpickling) case where we + // have only 3 elements. We have our own union instead of + // c10::SmallVector because c10::SmallVector always + // stores the begin/end/capacity pointers, which would be a waste of + // space in our use case. + union { + std::vector elementsVector_; + // Don't want to declare a std::array because the convenient + // iteration and size members are a footgun in this case -- the + // actual size of the array may be smaller than 3! + // NOLINTNEXTLINE(*c-arrays*) + IValue elementsInline_[3]; + }; + + void destroyInline() { + for (const auto ii : c10::irange(inlineSize_)) { + elementsInline_[ii].~IValue(); + } + } + public: + + using iterator = IValue*; + using const_iterator = const IValue*; + + TupleElements() : inlineSize_(0) { + new (&elementsVector_) std::vector(); + } + + explicit TupleElements(std::vector elements) + : inlineSize_(0), elementsVector_(std::move(elements)) {} + + explicit TupleElements(c10::ArrayRef elements) + : inlineSize_(elements.size() <= 3 ? elements.size() : 0) { + switch (inlineSize_) { + case 3: + new (&elementsInline_[2]) IValue(elements[2]); + [[fallthrough]]; + case 2: + new (&elementsInline_[1]) IValue(elements[1]); + [[fallthrough]]; + case 1: + new (&elementsInline_[0]) IValue(elements[0]); + break; + case 0: + new (&elementsVector_) std::vector(elements.begin(), elements.end()); + break; + } + } + + explicit TupleElements(IValue&& e1) + : inlineSize_(1) { + new (&elementsInline_[0]) IValue(std::move(e1)); + } + + explicit TupleElements(IValue&& e1, IValue&& e2) + : inlineSize_(2) { + new (&elementsInline_[0]) IValue(std::move(e1)); + new (&elementsInline_[1]) IValue(std::move(e2)); + } + + explicit TupleElements(IValue&& e1, IValue&& e2, IValue&& e3) + : inlineSize_(3) { + new (&elementsInline_[0]) IValue(std::move(e1)); + new (&elementsInline_[1]) IValue(std::move(e2)); + new (&elementsInline_[2]) IValue(std::move(e3)); + } + + ~TupleElements() { + if (inlineSize_) { + destroyInline(); + } else { + elementsVector_.~vector(); + } + } + + // It would be nice to make this noncopyable to prevent people from + // writing code like `auto output = + // forward(...).toTupleRef().elements()` (which does refcount bumps on + // each element, unlike the more efficient but verbose + // ``` + // auto outputIntrusivePtr = forward(...).toTuple(); + // const auto& output = outputIntrusivePtr->elements(); + // ``` + // ), but there is simply an overwhelming amount of code that does + // it the inefficient way. + // See also operator std::vector below. + TupleElements(const TupleElements& rhs) + : inlineSize_(rhs.inlineSize_) { + if (rhs.inlineSize_) { + for (const auto ii : c10::irange(inlineSize_)) { + new (&elementsInline_[ii]) IValue(rhs.elementsInline_[ii]); + } + } else { + new (&elementsVector_) std::vector(rhs.elementsVector_); + } + } + + TupleElements& operator=(const TupleElements& rhs) { + if (inlineSize_) { + if (rhs.inlineSize_) { + for (const auto ii : c10::irange(std::min(inlineSize_, rhs.inlineSize_))) { + elementsInline_[ii] = rhs.elementsInline_[ii]; + } + if (rhs.inlineSize_ > inlineSize_) { + for (const auto ii : c10::irange(inlineSize_, rhs.inlineSize_)) { + new (&elementsInline_[ii]) IValue(rhs.elementsInline_[ii]); + } + } else { + for (const auto ii : c10::irange(rhs.inlineSize_, inlineSize_)) { + elementsInline_[ii].~IValue(); + } + } + } else { + destroyInline(); + new (&elementsVector_) std::vector(rhs.elementsVector_); + } + } else { + if (rhs.inlineSize_) { + elementsVector_.~vector(); + for (const auto ii : c10::irange(rhs.inlineSize_)) { + new (&elementsInline_[ii]) IValue(rhs.elementsInline_[ii]); + } + } else { + elementsVector_ = rhs.elementsVector_; + } + } + inlineSize_ = rhs.inlineSize_; + return *this; + } + + TupleElements(TupleElements&& rhs) noexcept + : inlineSize_(rhs.inlineSize_) { + if (inlineSize_) { + for (const auto ii : c10::irange(inlineSize_)) { + new (&elementsInline_[ii]) IValue(std::move(rhs.elementsInline_[ii])); + } + } else { + new (&elementsVector_) std::vector(std::move(rhs.elementsVector_)); + } + } + + TupleElements& operator=(TupleElements&& rhs) noexcept { + if (inlineSize_) { + if (rhs.inlineSize_) { + for (const auto ii : c10::irange(std::min(inlineSize_, rhs.inlineSize_))) { + elementsInline_[ii] = std::move(rhs.elementsInline_[ii]); + } + if (rhs.inlineSize_ > inlineSize_) { + for (const auto ii : c10::irange(inlineSize_, rhs.inlineSize_)) { + new (&elementsInline_[ii]) IValue(std::move(rhs.elementsInline_[ii])); + } + } else { + for (const auto ii : c10::irange(rhs.inlineSize_, inlineSize_)) { + elementsInline_[ii].~IValue(); + } + } + } else { + destroyInline(); + new (&elementsVector_) std::vector(std::move(rhs.elementsVector_)); + } + } else { + if (rhs.inlineSize_) { + elementsVector_.~vector(); + for (const auto ii : c10::irange(rhs.inlineSize_)) { + new (&elementsInline_[ii]) IValue(std::move(rhs.elementsInline_[ii])); + } + } else { + elementsVector_ = std::move(rhs.elementsVector_); + } + } + inlineSize_ = rhs.inlineSize_; + return *this; + } + + [[nodiscard]] c10::ArrayRef asArrayRef() const { + if (inlineSize_) { + return c10::ArrayRef(elementsInline_, inlineSize_); + } else { + return elementsVector_; + } + } + + // Mimic implicit conversion from std::vector to ArrayRef. + operator c10::ArrayRef() const { + return asArrayRef(); + } + + static size_t hash(const TupleElements& v) { + return c10::hash>()(v.asArrayRef()); + } + + void setContents(std::vector&& contents) { + if (inlineSize_) { + destroyInline(); + new (&elementsVector_) std::vector(std::move(contents)); + inlineSize_ = 0; + } else { + elementsVector_ = std::move(contents); + } + } + + [[nodiscard]] bool empty() const { + return inlineSize_ ? false : elementsVector_.empty(); + } + + [[nodiscard]] size_t size() const { + return inlineSize_ ? inlineSize_ : elementsVector_.size(); + } + + [[nodiscard]] IValue& operator[](size_t idx) { + if (inlineSize_) { + return elementsInline_[idx]; + } else { + return elementsVector_[idx]; + } + } + + [[nodiscard]] const IValue& operator[](size_t idx) const { + if (inlineSize_) { + return elementsInline_[idx]; + } else { + return elementsVector_[idx]; + } + } + + [[nodiscard]] IValue& at(size_t idx) { + if (inlineSize_) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inlineSize_ <= 3); + TORCH_CHECK(idx < inlineSize_, "TupleElements: invalid index Index = ", idx, "; Length = ", inlineSize_); + return elementsInline_[idx]; + } else { + return elementsVector_.at(idx); + } + } + + [[nodiscard]] const IValue& at(size_t idx) const { + if (inlineSize_) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inlineSize_ <= 3); + TORCH_CHECK(idx < inlineSize_, "TupleElements: invalid index Index = ", idx, "; Length = ", inlineSize_); + return elementsInline_[idx]; + } else { + TORCH_CHECK(idx < elementsVector_.size(), "TupleElements: invalid index Index = ", idx, "; Length = ", elementsVector_.size()); + return elementsVector_.at(idx); + } + } + + [[nodiscard]] iterator begin() { + if (inlineSize_) { + return elementsInline_; + } else { + return elementsVector_.data(); + } + } + + [[nodiscard]] iterator end() { + if (inlineSize_) { + return elementsInline_ + inlineSize_; + } else { + return elementsVector_.data() + elementsVector_.size(); + } + } + + [[nodiscard]] const_iterator begin() const { + if (inlineSize_) { + return elementsInline_; + } else { + return elementsVector_.data(); + } + } + + [[nodiscard]] const_iterator end() const { + if (inlineSize_) { + return elementsInline_ + inlineSize_; + } else { + return elementsVector_.data() + elementsVector_.size(); + } + } + + [[nodiscard]] const_iterator cbegin() const { + return begin(); + } + + [[nodiscard]] const_iterator cend() const { + return end(); + } + + [[nodiscard]] std::vector vec() const& { + return asArrayRef().vec(); + } + + [[nodiscard]] IValue& back() { + return *(end() - 1); + } + + [[nodiscard]] const IValue& back() const { + return *(end() - 1); + } + + [[nodiscard]] std::vector vec() && { + std::vector result; + result.reserve(size()); + for (auto&& iv : *this) { + result.push_back(std::move(iv)); + } + return result; + } + + // More compatibility shims for the overwhelming amount of code that + // likes to copy tuple elements into a vector; see comment above the + // copy constructor. + operator std::vector() const & { + return vec(); + } + + operator std::vector() && { + return vec(); + } +}; + +template +struct TupleTypeFactory {}; + +template <> +struct TORCH_API TupleTypeFactory { + static TupleTypePtr create(std::vector types) { + return TupleType::create(std::move(types)); + } + static TupleTypePtr fallback(const Type& type); +}; + +template <> +struct TORCH_API TupleTypeFactory { + static DynamicTypePtr create(const std::vector& elemTypes); + static DynamicTypePtr fallback(const Type&); +}; + +struct TORCH_API Tuple : c10::intrusive_ptr_target { + private: + TupleElements elements_; + mutable c10::TypePtr type_; // lazily computed for unnamed tuples + + public: + // named tuples have additional type information, so we + // directly create them tagged + static c10::intrusive_ptr createNamed( + std::vector elements_, + c10::TypePtr type_) { + return c10::make_intrusive(std::move(elements_), std::move(type_)); + } + + static c10::intrusive_ptr createNamed( + TupleElements elements_, + std::shared_ptr type_) { + return c10::make_intrusive(std::move(elements_), std::move(type_)); + } + + static c10::intrusive_ptr createNamed( + std::initializer_list elements_, + std::shared_ptr type_) { + return createNamed(TupleElements(c10::ArrayRef(elements_)), std::move(type_)); + } + + // MSVC apparently can't disambiguate the other two overloads of + // create when passed an initializer_list without this. + static c10::intrusive_ptr create(std::initializer_list elements_) { + return create(c10::ArrayRef(elements_)); + } + + static c10::intrusive_ptr create(std::vector elements_) { + return c10::make_intrusive(std::move(elements_)); + } + + static c10::intrusive_ptr create(TupleElements elements_) { + return c10::make_intrusive(std::move(elements_)); + } + + static c10::intrusive_ptr create(c10::ArrayRef elements_) { + return create(TupleElements(elements_)); + } + + static c10::intrusive_ptr create(IValue e1) { + return c10::make_intrusive(std::move(e1)); + } + + static c10::intrusive_ptr create(IValue e1, IValue e2) { + return c10::make_intrusive(std::move(e1), std::move(e2)); + } + + static c10::intrusive_ptr create(IValue e1, IValue e2, IValue e3) { + return c10::make_intrusive(std::move(e1), std::move(e2), std::move(e3)); + } + + private: + // Workaround inability to use `>` operator in template argument list. + template + static constexpr bool hasMoreThanThreeArgs() { + return sizeof...(Args) > 3; + } + + public: + template + static c10::intrusive_ptr create(Args&&... elements_) { + switch (sizeof...(Args)) { + case 1: + case 2: + case 3: + return create(IValue(std::forward(elements_))...); + default: + return create( + std::vector{IValue(std::forward(elements_))...}); + } + } + + // Again, it would be nice to make this noncopyable, but there's a + // lot of extant code that copies Tuples. + // Tuple(const Tuple& rhs) = delete; + + const TupleElements& elements() const& { + return elements_; + } + + TupleElements elements() && { + return std::move(elements_); + } + + void setElements(std::vector&& elements) { + elements_.setContents(std::move(elements)); + } + + void setElements(TupleElements&& elements) { + elements_ = std::move(elements); + } + + void unsafeSetElement(size_t idx, const IValue& element) { + elements_[idx] = element; + } + + void unsafeSetElement(size_t idx, IValue&& element) { + elements_[idx] = std::move(element); + } + + size_t size() const { + return elements_.size(); + } + + template + std::shared_ptr type() const { + if (!type_) { + type_ = TupleTypeFactory::create(fmap(elements(), [&](const IValue& v) { + return v.type(); + })); + } + if (auto t = type_->cast()) { + return t; + } + return TupleTypeFactory::fallback(*type_); + } + + static size_t hash(const Tuple& t) { + return c10::get_hash(t.elements()); + } + + TORCH_API friend bool operator==( + const ivalue::Tuple& lhs, + const ivalue::Tuple& rhs); + + private: + // NOTE: If we try to avoid the overloads without + // `std::shared_ptr type` by defaulting it to nullptr, we + // end up having to call (part of) the shared_ptr destructor for + // `type` even though we should know statically it won't do + // anything. + explicit Tuple(std::vector elements) + : elements_(std::move(elements)){} + + explicit Tuple(std::vector elements, c10::TypePtr type) + : elements_(std::move(elements)), type_(std::move(type)) {} + + explicit Tuple(TupleElements&& elements) + : elements_(std::move(elements)) {} + + explicit Tuple(TupleElements&& elements, std::shared_ptr type) + : elements_(std::move(elements)), type_(std::move(type)) {} + + explicit Tuple(IValue&& e1) + : elements_(std::move(e1)) {} + + explicit Tuple(IValue&& e1, std::shared_ptr type) + : elements_(std::move(e1)), type_(std::move(type)) {} + + explicit Tuple(IValue&& e1, IValue&& e2) + : elements_(std::move(e1), std::move(e2)) {} + + explicit Tuple(IValue&& e1, IValue&& e2, std::shared_ptr type) + : elements_(std::move(e1), std::move(e2)), type_(std::move(type)) {} + + explicit Tuple(IValue&& e1, IValue&& e2, IValue&& e3) + : elements_(std::move(e1), std::move(e2), std::move(e3)) {} + + explicit Tuple(IValue&& e1, IValue&& e2, IValue&& e3, std::shared_ptr type) + : elements_(std::move(e1), std::move(e2), std::move(e3)), type_(std::move(type)) {} + + friend class c10::intrusive_ptr; +}; + +struct Object; +struct PyObjectHolder; +struct EnumHolder; +} // namespace ivalue + +// Future +struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target { + private: + // Keep this private in order to force users to go through make_intrusive and + // thus prevent creating a Future that's not held by an intrusive_ptr. + explicit Future(TypePtr type, std::vector devices={}) + : type_(std::move(type)), + impl_(getTypeOfDevices(devices)), + devices_(sortAndDeduplicateDevices(impl_, std::move(devices))) {} + + friend c10::intrusive_ptr; + + struct FutureCallback { + std::function callback; + bool uses_future; // whether the Future& passed in is actually used + + template + FutureCallback(T callback, bool uses_future) + : callback(std::move(callback)), uses_future(uses_future) {} + }; + + public: + Future(const Future&) = delete; + Future(Future&&) = delete; + Future& operator=(const Future&) = delete; + Future& operator=(Future&&) = delete; + + // Destructor + // Explicitly destroy events under device guard, otherwise it can lead to + // extra context being created on device 0. Reason: python garbage collector + // calls this destructor, but python GC does not have a device context, so a + // "default" one (usually on device 0) could be created when we go down the + // line of event destroy. + ~Future() override { + while (!events_.empty()) { + c10::OptionalDeviceGuard deviceGuard(events_.back().device()); + events_.pop_back(); + } + } + + struct TORCH_API FutureError final : public std::exception { + explicit FutureError(std::string&& error_msg_) + : error_msg(std::move(error_msg_)) {} + + FutureError() = default; + + const char* what() const noexcept override { + return error_msg.c_str(); + } + + std::string error_msg; + }; + + /** + * Wait on the future until it completes. + */ + void wait() { + std::unique_lock lock(mutex_); + finished_cv_.wait(lock, [&]() -> bool { return completed_; }); + synchronizeWithCurrentStreams(); + } + + /** + * Wait on the future until it completes and throw an + * exception if an error exists. + */ + void waitAndThrow() { + wait(); + + if (eptr_) { + std::rethrow_exception(eptr_); + } + } + + /** + * Explicitly mark the future as completed with the output value. Optionally, + * the storages for all tensors in IValue can be passed as well. The DataPtrs + * of these storages are used to synchronize CUDA streams. If storages isn't + * given we will attempt to extract it from the value, if we need to (this + * happens if a non-empty set of devices was given to the constructor). Thus + * one only needs to provide storages when 1) they cannot be extracted through + * IValue::getSubValues() or through pickling in case of Python object; or + * when 2) customized storage extraction is more efficient. + */ + using WeakStorage = c10::weak_intrusive_ptr; + void markCompleted( + IValue value, + std::optional> storages = std::nullopt) { + // Start by performing all steps that can throw, before setting any field. + // Do this before even acquiring the mutex, because extractStorages might + // acquire the GIL, which could lead to a lock inversion with our mutex. + // See https://github.com/pytorch/pytorch/issues/58239. + std::vector actualStorages; + std::vector usedDevices; + try { + // FIXME We should always extract DataPtrs, in order to catch the case of + // users using CUDA values but forgetting to set devices, which currently + // leads to a silent synchronization/correctness issue. However, as this + // might worsen perf in CPU-only cases, we should only do so after careful + // benchmarks. + if (impl_.type() != c10::kCPU) { + actualStorages = + storages.has_value() ? std::move(*storages) : extractStorages(value); + usedDevices = getDevicesOfStorages(impl_, actualStorages); + ensureIsSubsetOfDevices(usedDevices, devices_); + } + } catch (const std::exception&) { + setError(std::current_exception()); + return; + } + + std::unique_lock lock(mutex_); + TORCH_CHECK( + !completed(), + "Attempting to mark a completed Future as complete again. Note that " + "a Future can only be marked completed once."); + + // Only set value_ and completed_ flag once all checks and preparation steps + // have returned successfully to allow for proper error propagation. + value_ = std::move(value); + completed_ = true; + + currentDevice_ = impl_.getDevice(); + storages_ = std::move(actualStorages); + for (const c10::Device& device : usedDevices) { + c10::Event event(impl_.type()); + event.record(impl_.getStream(device)); + events_.push_back(std::move(event)); + } + + std::vector cbs; + cbs.swap(callbacks_); + lock.unlock(); + + finished_cv_.notify_all(); + for (const auto& callback : cbs) { + invokeCallback(callback.callback, callback.uses_future); + } + } + + void markCompleted() { + markCompleted(IValue{}); + } + + void setError(std::exception_ptr eptr) { + std::unique_lock lock(mutex_); + setErrorInternal(std::move(eptr), lock); + } + + void setErrorIfNeeded(std::exception_ptr eptr) { + std::unique_lock lock(mutex_); + if (completed_) { + // This should be rare and shouldn't cause log spew. Its important to + // log errors and thats why we have this log here. + std::string msg = c10::str( + "Skipping setting following error on the Future since " + "it is already marked completed (this is not necessarily " + "an error):\n", + tryRetrieveErrorMessageInternal(std::move(eptr))); + if (eptr_) { + msg += c10::str( + ", \nOriginal exception:\n", + tryRetrieveErrorMessageInternal(eptr_)); + } + LOG(INFO) << msg; + return; + } else { + setErrorInternal(std::move(eptr), lock); + } + } + + // Get the result of the current future. + IValue value() { + std::unique_lock lock(mutex_); + AT_ASSERT(completed()); + if (eptr_) { + std::rethrow_exception(eptr_); + } + return value_; + } + + // This accessor should only be used if we know that the future is + // completed() with no error. + const IValue& constValue() const { + std::unique_lock lock(mutex_); + AT_ASSERT(completed()); + TORCH_INTERNAL_ASSERT( + !eptr_, + "value() accessor should only be used when future is not completed with ", + "an error, but future had the following error: ", + tryRetrieveErrorMessageInternal(eptr_) + ); + return value_; + } + + // This accessor should only be used if we know that the future is + // completed() with no error. + const std::vector& storages() const { + std::unique_lock lock(mutex_); + AT_ASSERT(completed()); + AT_ASSERT(!eptr_); + return storages_; + } + + /** + * Add a callback to the future. + * The callbacks will be executed once the future completes. + * If the future has already completed, + * this function will execute the callback immediately. + */ + template + void addCallback(T callback, bool uses_future = true) { + static_assert( + std::is_invocable_r_v, + "The callback must have signature void(Future&)"); + + std::unique_lock lock(mutex_); + if (completed()) { + lock.unlock(); + invokeCallback(callback, uses_future); + return; + } + callbacks_.emplace_back(std::move(callback), uses_future); + } + + /** + * Add a callback to the future, and return another Future to hold the return + * value of the callback. This is necessary when the callback provider needs + * to know for sure when the callback has finished. + */ + template + c10::intrusive_ptr then(T callback, TypePtr type) { + using IValueWithStorages = std::tuple>; + static_assert( + std::disjunction_v< + std::is_invocable_r, + std::is_invocable_r>, + "The callback must have signature IValue(Future&) or " + "std::tuple>(Future&)"); + + auto childFut = createInstance(::std::move(type)); + addCallback([childFut, + cb = std::move(callback)](Future& parentFut) { + try { + if constexpr (::std::is_convertible_v, IValueWithStorages>) { + auto [ivalue, storages] = cb(parentFut); + childFut->markCompleted(::std::move(ivalue), ::std::move(storages)); + } else { + childFut->markCompleted(cb(parentFut)); + } + } catch (std::exception&) { + childFut->setError(std::current_exception()); + } + }); + return childFut; + } + + template + c10::intrusive_ptr thenAsync(T callback, TypePtr type) { + static_assert( + std::is_invocable_r_v, T, Future&>, + "The callback must have signature c10::intrusive_ptr(Future&)"); + + auto childFut = createInstance(std::move(type)); + addCallback( + [childFut, cb = std::move(callback)](Future& parentFut) mutable { + c10::intrusive_ptr intermediateFut; + try { + intermediateFut = cb(parentFut); + } catch (std::exception&) { + childFut->setError(std::current_exception()); + return; + } + intermediateFut->addCallback( + [childFut = std::move(childFut)](Future& intermediateFut) { + if (intermediateFut.hasError()) { + childFut->setError(intermediateFut.exception_ptr()); + } else { + childFut->markCompleted( + intermediateFut.value(), intermediateFut.storages()); + } + }); + }); + return childFut; + } + + // Tries to retrieve the error message from std::exception_ptr. + std::string tryRetrieveErrorMessage() const { + TORCH_CHECK(hasError(), "No error present on the future."); + std::unique_lock lock(mutex_); + return tryRetrieveErrorMessageInternal(eptr_); + } + + // Check if the current future has completed + bool completed() const { + return completed_; + } + + bool hasValue() const { + std::unique_lock lock(mutex_); + return completed_ && !eptr_; + } + + bool hasError() const { + std::unique_lock lock(mutex_); + return eptr_ ? true : false; + } + + std::exception_ptr exception_ptr() const { + std::unique_lock lock(mutex_); + return eptr_; + } + + TORCH_API friend std::ostream& operator<<( + std::ostream& out, + const Future& v); + + const TypePtr& elementType() const { + return type_; + } + + const std::vector& devices() const { + return devices_; + } + + // This method should be used when one intends to manually create a child + // future, for example when implementing a customized version of then(). + c10::intrusive_ptr createInstance(at::TypePtr type) { + return c10::make_intrusive(std::move(type), devices_); + } + + private: + + // This method should always be used when invoking a callback (regardless of + // how/when that happens) as it will ensure that the proper "environment" is + // set up before running the callback, as in, it will set up the CUDA streams, + // synchronize them with the value, and so on (if needed). + template + void invokeCallback(T& callback, bool uses_future) { + static_assert( + std::is_invocable_r_v, + "The callback must have signature void(Future&)"); + + // The synchronization performed below shouldn't be needed when the future + // is not used by the callback. + if (uses_future) { + c10::OptionalDeviceGuard deviceGuard(currentDevice_); + + std::vector streams; + streams.reserve(devices_.size()); + for (const c10::Device& device : devices_) { + streams.push_back(impl_.getStreamFromGlobalPool(device)); + } + c10::MultiStreamGuard streamGuard(streams); + synchronizeWithCurrentStreams(); + callback(*this); + } else { + callback(*this); + } + } + + // This method should be called before this future's value is used, as it + // ensures that the CUDA streams that are "current" at the callsite properly + // synchronize with the value. + void synchronizeWithCurrentStreams() { + for (c10::Event& event : events_) { + event.block(impl_.getStream(event.device())); + } + + for (const WeakStorage& weak_storage : storages_) { + c10::intrusive_ptr storage = weak_storage.lock(); + if (!storage) { + continue; + } + if (!storage->device().is_cpu()) { + impl_.recordDataPtrOnStream( + storage->data_ptr(), impl_.getStream(storage->device())); + } + } + } + + void setErrorInternal( + std::exception_ptr eptr, + std::unique_lock& lock) { + TORCH_CHECK( + !eptr_, + "Error already set on this Future: ", + tryRetrieveErrorMessageInternal(eptr_), + ", trying to set error: ", + tryRetrieveErrorMessageInternal(eptr)); + TORCH_INTERNAL_ASSERT(!completed(), "Future is already marked completed"); + completed_ = true; + eptr_ = std::move(eptr); + + std::vector cbs; + cbs.swap(callbacks_); + lock.unlock(); + + finished_cv_.notify_all(); + for (const auto& callback : cbs) { + invokeCallback(callback.callback, callback.uses_future); + } + } + + // Tries to retrieve the error message from std::exception_ptr. + std::string tryRetrieveErrorMessageInternal(std::exception_ptr eptr) const { + try { + std::rethrow_exception(std::move(eptr)); + } catch (const std::exception& e) { + return e.what(); + } catch (...) { + return "Unknown Exception Type"; + } + } + + // Defined in ivalue.cpp. + static std::vector extractStorages( + const at::IValue& value); + + static std::vector getDevicesOfStorages( + const c10::impl::VirtualGuardImpl& impl, + const std::vector& storages) { + c10::DeviceIndex deviceCount = impl.deviceCount(); + std::vector isDeviceUsed(deviceCount, false); + for (const WeakStorage& weak_storage : storages) { + c10::intrusive_ptr storage = weak_storage.lock(); + if (!storage) { + continue; + } + c10::Device device = storage->device(); + if (!device.is_cpu()) { + TORCH_CHECK_VALUE( + device.type() == impl.type(), + "Expected all data ptrs to be on a device of type ", + impl.type(), + ", got one on device ", + device); + isDeviceUsed[device.index()] = true; + } + } + std::vector devices; + for (c10::DeviceIndex idx = 0; idx < deviceCount; idx++) { + if (isDeviceUsed[idx]) { + devices.emplace_back(impl.type(), idx); + } + } + return devices; + } + + static std::string formatSetOfDevices( + const std::vector& devices) { + if (devices.empty()) { + return "(none)"; + } + std::ostringstream oss; + oss << devices[0]; + for (const auto idx : c10::irange(1, devices.size())) { + if (idx == devices.size() - 1) { + oss << " and "; + } else { + oss << ", "; + } + oss << devices[idx]; + } + return oss.str(); + } + + static c10::DeviceType getTypeOfDevices( + const std::vector& devices) { + if (devices.empty()) { + return c10::kCPU; + } + c10::DeviceType deviceType = devices[0].type(); + for (const auto idx : c10::irange(1, devices.size())) { + TORCH_CHECK_VALUE( + devices[idx].type() == deviceType, + "Expected all devices to be of the same type, but got a mismatch between ", + devices[0], + " and ", + devices[idx]); + } + return deviceType; + } + + // We need devices to be sorted in order to use ensureIsSubsetOfDevices. + static std::vector sortAndDeduplicateDevices( + const c10::impl::VirtualGuardImpl& /*impl*/, + std::vector devices) { + std::sort( + devices.begin(), devices.end(), + [](const c10::Device& a, const c10::Device& b) { return a.index() < b.index(); }); + // Deduplicate by compacting. + size_t targetIdx = 0; + for (const auto sourceIdx : c10::irange(devices.size())) { + TORCH_CHECK_VALUE( + devices[sourceIdx].has_index(), + "Expected devices to have indices, got ", devices[sourceIdx]); + if (targetIdx > 0 && devices[targetIdx - 1].index() == devices[sourceIdx].index()) { + // It's a duplicate, skip it. + continue; + } + if (sourceIdx != targetIdx) { + devices[targetIdx] = devices[sourceIdx]; + } + targetIdx++; + } + // If there were duplicates there's now a gap at the end: trim it. Resizing + // requires the item type to be default-constructible (which c10::Device is + // not) because in principle it could be required to create new items. Since + // we know we'll shrink the vector, we provide a custom dummy value instead. + devices.resize(targetIdx, c10::Device(c10::kCPU)); + return devices; + } + + static void ensureIsSubsetOfDevices( + const std::vector& subset, + const std::vector& superset) { + // We assume the devices in both vectors have the same consistent type, and + // their indices are unique and sorted. + std::vector excessDevices; + std::set_difference( + subset.begin(), + subset.end(), + superset.begin(), + superset.end(), + std::back_inserter(excessDevices), + [](const c10::Device& a, const c10::Device& b) { return a.index() < b.index(); }); + TORCH_CHECK_VALUE( + excessDevices.empty(), + "The result contained tensors residing on device(s) ", + formatSetOfDevices(excessDevices), + " which are not among the expected device(s) ", + formatSetOfDevices(superset)); + } + + mutable std::mutex mutex_; + std::atomic_bool completed_ = {false}; // is this future complete + std::condition_variable finished_cv_; + + IValue value_; // when finished the value + TypePtr type_; + std::vector callbacks_; + std::exception_ptr eptr_; + + // An upcast pointer to a virtual class which allows us to manipulate events, + // streams, ... in a generic way, without an explicit dependency on CUDA. + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const c10::impl::VirtualGuardImpl impl_; + + // The device that was current when markCompleted was called, which we'll + // restore when invoking callbacks. It's optional because we'll only store it + // if the future completes successfully. + std::optional currentDevice_; + + // The events that correspond to the completion of the async I/O kernels. They + // are recorded on the appropriate streams when the future is marked completed + // and can then be queried/waited/blocked on. There is one event for each + // distinct device on which the value's tensors reside. + std::vector events_; + + // A cached version of the storages extracted from the value when the future + // is first marked completed. + std::vector storages_; + + // The bounding set of devices that this future, and any of its children, is + // allowed to use. This is a superset of the set of devices used by the events + // above. We need this to know what streams (for which devices) to set as + // current when invoking a callback, thus allowing the callback to use devices + // that the parent future didn't use. This field is set to the value provided + // in the constructor and will be "inherited" by all child futures. + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const std::vector devices_; +}; + +struct C10_EXPORT ivalue::Await final : c10::intrusive_ptr_target { + private: + explicit Await(TypePtr elType, std::function fn) + : elType_(std::move(elType)), type_(AwaitType::create(elType_)), fn_(std::move(fn)) {} + + explicit Await(TypePtr elType) : elType_(std::move(elType)), type_(AwaitType::create(elType_)) { } + + friend c10::intrusive_ptr; + + public: + Await(const Await&) = delete; + Await(Await&&) = delete; + Await& operator=(const Await&) = delete; + Await& operator=(Await&&) = delete; + ~Await() override = default; + + IValue wait() { + if (!completed_) { + TORCH_CHECK(fn_, "Incompleted Await: fn can't be None"); + value_ = fn_(); + completed_ = true; + args_ = {}; + } + return value_; + } + + IValue value() { + TORCH_CHECK(completed_, "Await must be completed"); + return value_; + } + + void setFn(std::function fn) { + fn_ = std::move(fn); + } + + bool completed() { + return completed_; + } + + void markCompleted(IValue value) { + value_ = std::move(value); + completed_ = true; + } + + TORCH_API friend std::ostream& operator<<( + std::ostream& out, + const Await& v); + + const TypePtr& elementType() const { + return elType_; + } + + const TypePtr& type() const { + return type_; + } + + void setArgs(std::vector args) { + args_ = std::move(args); + } + + std::vector& args() { + return args_; + } + + private: + TypePtr elType_; + TypePtr type_; + std::vector args_; + std::function fn_; + IValue value_; + bool completed_{}; +}; + +// Input is a list of Futures with the same target type. +// Output is a Future to the List of completed Futures. +TORCH_API intrusive_ptr collectAll( + const c10::List>& srcs); +// Input is a List of Futures with the same target type. +// Output is a Future that will be updated with a seen value. +TORCH_API intrusive_ptr collectAny( + const c10::List>& srcs); + +// User-defined object. +struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target { + public: + // In general, class types hold a shared_ptr to its owning CompilationUnit, + // so that its type and methods do not get deallocated while the class exists. + // However, the CompilationUnit holds ownership of the type's graphs, so + // inserting a constant object into a Graph would create a reference cycle if + // that constant object held a shared_ptr to its CU. For these objects we + // instatiate them with non-owning references to its CU + Object(WeakOrStrongTypePtr type, size_t numSlots) : type_(std::move(type)) { + slots_.resize(numSlots); + } + + Object(StrongTypePtr type, size_t numSlots) + : type_(WeakOrStrongTypePtr(std::move(type))) { + slots_.resize(numSlots); + } + + static c10::intrusive_ptr create( + WeakOrStrongTypePtr type, + size_t numSlots) { + return c10::make_intrusive(std::move(type), numSlots); + } + + static c10::intrusive_ptr create( + StrongTypePtr type, + size_t numSlots) { + return c10::make_intrusive(std::move(type), numSlots); + } + + static c10::intrusive_ptr create(ClassTypePtr classType, size_t numSlots); + + /** + * Slot API. + * + * Attributes are stored as a simple vector so that lookups are fast at + * runtime. A "slot" is just an index into that vector, which can be computed + * statically if you have access to the class type. Use this API if you are + * writing compiler stuff. + */ + void setSlot(size_t slot, IValue v) { + if (slot >= slots_.size()) { + // for module types, it is possible that the members of the class have + // expanded after the object was created. In this case, we expand + // the slots to the right size + resizeObject(slot); + } + slots_[slot] = std::move(v); + } + + const IValue& getSlot(size_t slot) const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(slot < slots_.size()); + // NOTE: This lookup is fairly hot, so we use unchecked access to the + // vector. Errors should still be detectable with ASan. + return slots_[slot]; + } + + void unsafeRemoveSlot(size_t slot) { + TORCH_CHECK(slot < slots_.size()); + slots_.erase(slots_.begin() + static_cast(slot)); + } + + /** + * Attribute API. + * + * Wrappers around the slot stuff so that users can access attributes + * directly. Use this API if you are a user. + * + * Note: Unlike in Python, TorchScript must make a distinction between + * attributes (which are IValues) and methods (which are Methods). If you + * want a method, use `obj.type()->getMethod()` + */ + IValue getAttr(const std::string& name) const; + void setAttr(const std::string& name, IValue v); + // Remove attribute by name, caller is responsible for + // the safety of this operation + // We didn't remove the attribute in the type because the type + // might be shared by multiple objects. + // Therefore after removing attribute, the object is in an inconsistent + // state where it has more attribute types in its Type than + // the attribute slots it has, user needs to make sure the object + // has consistent by removing the attribute in type as well + void unsafeRemoveAttr(const std::string& name); + + std::string name() const; + + const std::vector& slots() const { + return slots_; + } + std::shared_ptr type() const; + + std::shared_ptr compilation_unit() { + if (type_.holds_strong_ref()) { + return type_.cu_.getStrongRefOrThrow(); + } else { + auto weak_ptr = type_.cu_.getWeakRefOrThrow(); + return std::shared_ptr(weak_ptr); + } + } + + c10::intrusive_ptr copy_to_weak_compilation_ref() const; + + void unsafe_make_weak_compilation_ref() { + type_ = WeakOrStrongTypePtr(type_.asWeakTypePtr()); + } + + c10::intrusive_ptr copy() const; + + c10::intrusive_ptr deepcopy( + std::optional device = std::nullopt) const; + + c10::intrusive_ptr deepcopy( + IValue::HashIdentityIValueMap& memo, + std::optional device = std::nullopt) const; + + bool is_weak_compilation_ref() const { + return !type_.holds_strong_ref(); + } + + bool is_empty_strong_compilation_ref() const { + return type_.holds_empty_strong_ref(); + } + + private: + void resizeObject(size_t slot); + WeakOrStrongTypePtr type_; + std::vector slots_; +}; + +// virtual ivalue PyObjectHolder that hold a py::object, we make this virtual +// because the py::object and refcounting logic should happen in libtorch_python +// see concrete implementation in python_ivalue.h +struct ivalue::PyObjectHolder : c10::intrusive_ptr_target { + public: + virtual PyObject* getPyObject() = 0; + virtual c10::InferredType tryToInferType() = 0; + virtual IValue toIValue(const TypePtr& type, std::optional N = std::nullopt) = 0; + virtual std::string toStr() = 0; + virtual std::vector extractTensors() = 0; + + ~PyObjectHolder() override = default; +}; + +struct ivalue::EnumHolder : c10::intrusive_ptr_target { + public: + EnumHolder(std::shared_ptr type, std::string name, IValue value) + : type_(std::move(type)), + name_(std::move(name)), + value_(std::move(value)) {} + + bool is(const ivalue::EnumHolder& rhs) { + return *this == rhs; + } + + friend bool operator==( + const ivalue::EnumHolder& lhs, + const ivalue::EnumHolder& rhs); + + TORCH_API friend std::ostream& operator<<( + std::ostream& out, + const ivalue::EnumHolder& v); + + TORCH_API const std::string& qualifiedClassName() const; + + const std::string& unqualifiedClassName() const; + + const std::string& name() const { + return name_; + } + + const IValue& value() const { + return value_; + } + + std::shared_ptr type() const { + return type_; + } + + private: + std::shared_ptr type_; + std::string name_; + IValue value_; +}; + +#undef TORCH_FORALL_TAGS + +namespace detail { + +struct _guarded_unsigned_long_unique_dummy final { + _guarded_unsigned_long_unique_dummy(int64_t){} +}; +using _guarded_unsigned_long = std::conditional_t< + std::is_same_v || + std::is_same_v, + _guarded_unsigned_long_unique_dummy, + unsigned long>; + +} // namespace detail + +inline ivalue::Object& IValue::toObjectRef() const { + AT_ASSERT(isObject(), "Expected Object but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), "Attempted to create null reference"); + return *static_cast(payload.u.as_intrusive_ptr); +} + +// note: when adding a DEFINE_TO case here you should also add a +// toX method to IValue. These named methods are much more discoverable +// than the to templated function. + +#define DEFINE_TO(T, method_name) \ + template <> \ + inline T IValue::to()&& { \ + return static_cast(std::move(*this).method_name()); \ + } \ + template <> \ + inline c10::detail::ivalue_to_const_ref_overload_return::type IValue::to() const& { \ + typedef c10::detail::ivalue_to_const_ref_overload_return::type return_type; \ + return static_cast(this->method_name()); \ + } + +DEFINE_TO(at::Tensor, toTensor) +DEFINE_TO(at::Storage, toStorage) +DEFINE_TO(c10::Stream, toStream) +DEFINE_TO(float, toDouble) +DEFINE_TO(double, toDouble) +DEFINE_TO(c10::complex, toComplexDouble) +DEFINE_TO(unsigned char, toInt) +DEFINE_TO(signed char, toInt) +DEFINE_TO(unsigned short, toInt) +DEFINE_TO(short, toInt) +DEFINE_TO(int, toInt) +DEFINE_TO(uint32_t, toInt) +DEFINE_TO(uint64_t, toInt) +DEFINE_TO(detail::_guarded_unsigned_long, toInt) +DEFINE_TO(int64_t, toInt) +DEFINE_TO(bool, toBool) +DEFINE_TO(c10::intrusive_ptr, toBlob) +DEFINE_TO(c10::intrusive_ptr, toString) +DEFINE_TO(c10::intrusive_ptr, toObject) +DEFINE_TO(at::Scalar, toScalar) +DEFINE_TO(c10::List, toIntList) +DEFINE_TO(c10::List, toSymIntList) +DEFINE_TO(c10::List, toDoubleList) +DEFINE_TO(c10::List>, toComplexDoubleList) +DEFINE_TO(c10::List, toBoolList) +DEFINE_TO(c10::List, toTensorList) +DEFINE_TO(c10::impl::GenericList, toList) +DEFINE_TO(c10::impl::GenericDict, toGenericDict) +DEFINE_TO(c10::intrusive_ptr, toTuple) +DEFINE_TO(std::string, toStringRef) +DEFINE_TO(std::string_view, toStringView) +DEFINE_TO(c10::intrusive_ptr, toFuture) +DEFINE_TO(c10::intrusive_ptr, toAwait) +DEFINE_TO(c10::intrusive_ptr, toRRef) +DEFINE_TO(c10::intrusive_ptr, toQuantizer) +DEFINE_TO(IValue, toIValue) +DEFINE_TO(c10::Device, toDevice) +DEFINE_TO(at::ScalarType, toScalarType) +DEFINE_TO(at::Layout, toLayout) +DEFINE_TO(at::MemoryFormat, toMemoryFormat) +DEFINE_TO(at::QScheme, toQScheme) +DEFINE_TO(at::Dimname, toDimname) +DEFINE_TO(at::Generator, toGenerator) +DEFINE_TO(c10::SymInt, toSymInt) +DEFINE_TO(c10::SymFloat, toSymFloat) +DEFINE_TO(c10::SymBool, toSymBool) + +template +struct _fake_type {}; + +// generic_to converts an IValue from a generic list or generic dict +// to a concrete list/dict type likelike List, Dict<...> or std::optional. +// Note that in the case of lists, this only works for IValue-based lists, +// i.e. not for int64_t, double, ... +// generic_to is an implementation detail of IValue::to and not +// supposed to be called directly. +// The _fake_type parameter allows us to overload +// based on the return type. +template +// TODO this is deprecated but we don't throw a warning because a lot of ops in +// native_functions.yaml still return std::vector. +// C10_DEPRECATED_MESSAGE("IValues based on std::vector are potentially slow +// and deprecated. Please use torch::List instead.") +std::vector generic_to(IValue ivalue, _fake_type>) { + // We need to do a deep copy of the vector because there might be other + // references to this same IValue that also use the list. We can't just + // move the elements out. + auto list = std::move(ivalue).template to>(); + std::vector result; + result.reserve(list.size()); + for (Elem v : list) { + result.push_back(std::move(v)); + } + return result; +} + +template +c10::intrusive_ptr IValue::toCustomClass() && { + static_assert( + std::is_base_of_v == true, + "toCustomClass requires that template parameter T must inherit " + "from torch::CustomClassHolder"); + auto obj = toObject(); + TORCH_CHECK( + obj->slots().size() == 1, + "Tried to cast IValue to custom class but it did " + "not contain a custom class!"); + const auto* expected_type = c10::getCustomClassType>().get(); + ivalue::checkCustomClassType(expected_type, type().get()); + auto userObj = + c10::static_intrusive_pointer_cast(obj->getSlot(0).toCapsule()); + return userObj; +} + +template +c10::intrusive_ptr IValue::toCustomClass() const& { + static_assert( + std::is_base_of_v == true, + "toCustomClass requires that template parameter T must inherit " + "from torch::CustomClassHolder"); + auto obj = toObject(); + TORCH_CHECK( + obj->slots().size() == 1, + "Tried to cast IValue to custom class but it did " + "not contain a custom class!"); + const auto* expected_type = c10::getCustomClassType>().get(); + ivalue::checkCustomClassType(expected_type, type().get()); + auto userObj = + c10::static_intrusive_pointer_cast(obj->getSlot(0).toCapsule()); + return userObj; +} + +template +T generic_to(IValue ivalue, _fake_type) { + using ElemType = typename std::remove_pointer::type::element_type; + return std::move(ivalue).template toCustomClass(); +} + +template +tagged_capsule generic_to(IValue ivalue, _fake_type>) { + return tagged_capsule{std::move(ivalue)}; +} + +template +c10::List generic_to(IValue ivalue, _fake_type>) { + return impl::toTypedList(std::move(ivalue).toList()); +} + +template +static T createVectorLikeFromList(const c10::detail::ListImpl* impl) { + T result; + result.reserve(impl->list.size()); + for (const auto & i : impl->list) { + result.push_back(i.to()); + } + return result; +} + +template +static std::vector createVectorFromList(const c10::detail::ListImpl* impl) { + return createVectorLikeFromList>(impl); +} + +template +std::vector createVectorFromList(const c10::List& impl) { + std::vector result; + result.reserve(impl.size()); + for (size_t i = 0, N = impl.size(); i < N; ++i) { + result.push_back(impl[i]); + } + return result; +} + +template +OptionalArray generic_to(IValue ivalue, _fake_type>) { + if (ivalue.isNone()) { + return {}; + } + return createVectorFromList( + std::move(ivalue).template to>() + ); +} + +namespace detail { +template +std::array generic_to_array( + IValue ivalue, + _fake_type>, + std::index_sequence) { + // We need to do a deep copy of the array because there might be other + // references to this same IValue that also use the list. We can't just + // move the elements out. + auto list = std::move(ivalue).template to>(); + TORCH_CHECK( + list.size() == sizeof...(I), + "Tried to convert a List with ", + list.size(), + " elements to a fixed-size array of size ", + sizeof...(I)); + return {list[I]...}; +} +} // namespace detail + +template +std::array generic_to( + IValue ivalue, + _fake_type> ft) { + return detail::generic_to_array(ivalue, ft, std::make_index_sequence()); +} + +template +c10::Dict generic_to( + IValue ivalue, + _fake_type>) { + return impl::toTypedDict(std::move(ivalue).toGenericDict()); +} + +template +C10_DEPRECATED_MESSAGE( + "IValues based on std::unordered_map are slow and deprecated. Please use c10::Dict instead.") +std::unordered_map generic_to( + IValue ivalue, + _fake_type>) { + std::unordered_map specialized_dict; + + for (const auto& item : std::move(ivalue).toGenericDict()) { + specialized_dict[item.key().template to()] = item.value().template to(); + } + + return specialized_dict; +} + +template +std::optional generic_to(IValue ivalue, _fake_type>) { + if (ivalue.isNone()) { + return std::nullopt; + } + return std::move(ivalue).template to(); +} + +namespace detail { +template +Tuple generic_to_tuple_impl( + const ivalue::TupleElements& t, + std::index_sequence) { + return std::make_tuple( + t[INDEX].to::type>()...); +} +} // namespace detail + +template < + typename... Args, + typename Indices = std::make_index_sequence, + std::enable_if_t< + !std::disjunction_v< + std::is_lvalue_reference..., + std::negation>...>, + std::nullptr_t> = nullptr> +std::tuple generic_to(const IValue& ivalue, _fake_type>) { + const auto& vals = ivalue.toTupleRef().elements(); + TORCH_CHECK(vals.size() == sizeof...(Args)); + return detail::generic_to_tuple_impl>(vals, Indices{}); +} + +template +inline T IValue::to() && { + return generic_to(std::move(*this), _fake_type{}); +} + +template <> +inline std::optional IValue::to() && { + // In the default implementation, the IValue is destroyed with std::move. + // But if the unboxed type is std::optional we cannot destroy + // the IValue. + return generic_to(*this, _fake_type>{}); +} + +template +inline typename c10::detail::ivalue_to_const_ref_overload_return::type IValue::to() const& { + return generic_to(*this, _fake_type{}); +} + +inline c10::List IValue::toIntList() && { + AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind()); + return c10::List(moveToIntrusivePtr()); +} +inline c10::List IValue::toIntList() const& { + AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind()); + return c10::List(toIntrusivePtr()); +} +inline std::vector IValue::toIntVector() const { + AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toIntVector on null intrusive_ptr IValue"); + return createVectorFromList( + static_cast(payload.u.as_intrusive_ptr)); +} +inline c10::List IValue::toSymIntList() && { + AT_ASSERT( + isSymIntList() || isIntList(), + "Expected SymIntList or IntList but got ", + tagKind()); + return c10::List(moveToIntrusivePtr()); +} +inline c10::List IValue::toSymIntList() const& { + AT_ASSERT( + isSymIntList() || isIntList(), + "Expected SymIntList or IntList but got ", + tagKind()); + return c10::List(toIntrusivePtr()); +} +inline std::vector IValue::toSymIntVector() const { + AT_ASSERT(isSymIntList() || isIntList(), "Expected SymIntList or IntList but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toSymIntVector on null intrusive_ptr IValue"); + return createVectorFromList( + static_cast(payload.u.as_intrusive_ptr)); +} +inline at::DimVector IValue::toDimVector() const { + AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toDimVector on null intrusive_ptr IValue"); + return createVectorLikeFromList( + static_cast(payload.u.as_intrusive_ptr)); +} +inline c10::List IValue::toDoubleList() && { + AT_ASSERT(isDoubleList(), "Expected DoubleList but got ", tagKind()); + return c10::List(moveToIntrusivePtr()); +} +inline c10::List IValue::toDoubleList() const& { + AT_ASSERT(isDoubleList(), "Expected DoubleList but got ", tagKind()); + return c10::List(toIntrusivePtr()); +} +inline std::vector IValue::toDoubleVector() const { + AT_ASSERT(isDoubleList(), "Expected DoubleList but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toDoubleVector on null intrusive_ptr IValue"); + return createVectorFromList( + static_cast(payload.u.as_intrusive_ptr)); +} +inline c10::List> IValue::toComplexDoubleList() && { + AT_ASSERT(isComplexDoubleList(), "Expected ComplexDoubleList but got ", tagKind()); + return c10::List>(moveToIntrusivePtr()); +} +inline c10::List> IValue::toComplexDoubleList() const& { + AT_ASSERT(isComplexDoubleList(), "Expected ComplexDoubleList but got ", tagKind()); + return c10::List>(toIntrusivePtr()); +} +inline std::vector> IValue::toComplexDoubleVector() const { + AT_ASSERT(isComplexDoubleList(), "Expected ComplexDoubleList but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toComplexDoubleVector on null intrusive_ptr IValue"); + return createVectorFromList>( + static_cast(payload.u.as_intrusive_ptr)); +} +inline c10::List IValue::toBoolList() && { + AT_ASSERT(isBoolList(), "Expected BoolList but got ", tagKind()); + return c10::List(moveToIntrusivePtr()); +} +inline c10::List IValue::toBoolList() const& { + AT_ASSERT(isBoolList(), "Expected BoolList but got ", tagKind()); + return c10::List(toIntrusivePtr()); +} +inline c10::List IValue::toTensorList() && { + AT_ASSERT(isTensorList(), "Expected TensorList but got ", tagKind()); + return c10::List(moveToIntrusivePtr()); +} +inline c10::List IValue::toTensorList() const& { + AT_ASSERT(isTensorList(), "Expected TensorList but got ", tagKind()); + return c10::List(toIntrusivePtr()); +} +inline std::vector IValue::toTensorVector() const { + AT_ASSERT(isTensorList(), "Expected TensorList but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toTensorVector on null intrusive_ptr IValue"); + return createVectorFromList( + static_cast(payload.u.as_intrusive_ptr)); +} +inline c10::List> IValue::toOptionalTensorList() && { + AT_ASSERT(isOptionalTensorList(), "Expected OptionalTensorList but got ", tagKind()); + return c10::List>(moveToIntrusivePtr()); +} +inline c10::List> IValue::toOptionalTensorList() const& { + AT_ASSERT(isOptionalTensorList(), "Expected OptionalTensorList but got ", tagKind()); + return c10::List>(toIntrusivePtr()); +} +inline std::vector> IValue::toOptionalTensorVector() const { + AT_ASSERT(isOptionalTensorList(), "Expected OptionalTensorList but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toOptionalTensorVector on null intrusive_ptr IValue"); + return createVectorFromList>( + static_cast(payload.u.as_intrusive_ptr)); +} +inline c10::List IValue::toList() && { + AT_ASSERT(isList(), "Expected GenericList but got ", tagKind()); + return c10::List(moveToIntrusivePtr()); +} +inline c10::List IValue::toList() const& { + AT_ASSERT(isList(), "Expected GenericList but got ", tagKind()); + return c10::List(toIntrusivePtr()); +} +inline c10::ArrayRef IValue::toListRef() const { + AT_ASSERT(isList(), "Expected GenericList but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toListRef on null intrusive_ptr IValue"); + return static_cast(payload.u.as_intrusive_ptr) + ->list; +} +inline c10::Dict IValue::toGenericDict() && { + AT_ASSERT(isGenericDict(), "Expected GenericDict but got ", tagKind()); + return c10::Dict(moveToIntrusivePtr()); +} +inline c10::Dict IValue::toGenericDict() const& { + AT_ASSERT(isGenericDict(), "Expected GenericDict but got ", tagKind()); + return c10::Dict(toIntrusivePtr()); +} +inline c10::intrusive_ptr IValue::toTuple() && { + AT_ASSERT(isTuple(), "Expected Tuple but got ", tagKind()); + return moveToIntrusivePtr(); +} +inline c10::intrusive_ptr IValue::toTuple() const& { + AT_ASSERT(isTuple(), "Expected Tuple but got ", tagKind()); + return toIntrusivePtr(); +} +inline ivalue::Tuple& IValue::toTupleRef() const { + AT_ASSERT(isTuple(), "Expected Tuple but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toTupleRef on null intrusive_ptr IValue"); + return *static_cast( + payload.u.as_intrusive_ptr); +} + +inline IValue::IValue(c10::intrusive_ptr v) + : tag(Tag::Tuple) { + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); +} +template < + typename... Args, + std::enable_if_t< + !std::disjunction_v< + std::is_lvalue_reference..., + std::negation>...>, + std::nullptr_t>> +inline IValue::IValue(const std::tuple& t) + : IValue(std::apply(c10::ivalue::Tuple::create, t)) { +} + +template < + typename... Args, + std::enable_if_t< + !std::disjunction_v< + std::is_lvalue_reference..., + std::negation>...>, + std::nullptr_t>> +inline IValue::IValue(std::tuple&& t) + : IValue(std::apply(c10::ivalue::Tuple::create, std::move(t))) { +} + +inline IValue::IValue(c10::intrusive_ptr v) + : tag(Tag::String) { + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); +} +inline IValue::IValue(std::string v) + : IValue(ivalue::ConstantString::create(std::move(v))) {} + +inline IValue::IValue(c10::impl::GenericList v) + : tag(Tag::GenericList) { + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.impl_.release()); +} + +template > +inline IValue::IValue(c10::List&& v) : IValue(impl::toList(std::move(v))) {} +template > +inline IValue::IValue(const c10::List& v) : IValue(impl::toList(v)) {} +template > +inline IValue::IValue(at::ArrayRef v) : IValue(c10::List()) { + auto list = to>(); + list.reserve(v.size()); + for (const auto& e : v) { + list.push_back(e); + } +} +template > +inline IValue::IValue(at::ArrayRef v) : IValue() { + auto vi = c10::asIntArrayRefSlowOpt(v); + if (vi.has_value()) { + // This list is entirely integers; ensure it is typed as + // an IntList so toIntList works + *this = IValue(*vi); + } else { + // This list has SymInts; type it as a SymInt + *this = IValue(impl::toList(c10::List())); + auto list = to>(); + list.reserve(v.size()); + for (const auto& e : v) { + list.push_back(e); + } + } +} +template > +inline IValue::IValue(at::OptionalArrayRef mb_v) : IValue() { + if (!mb_v.has_value()) return; + *this = IValue(*mb_v); +} +template > +inline IValue::IValue(const std::vector& v) : IValue() { + *this = IValue(at::ArrayRef(v)); +} +template > +inline IValue::IValue(std::vector&& v) : IValue() { + auto vi = c10::asIntArrayRefSlowOpt(v); + if (vi.has_value()) { + // This list is entirely integers; ensure it is typed as + // an IntList so toIntList works + *this = IValue(*vi); + } else { + // This list has SymInts; type it as a SymInt + *this = IValue(impl::toList(c10::List())); + auto list = to>(); + list.reserve(v.size()); + for (auto&& e : std::move(v)) { + list.push_back(std::move(e)); + } + } +} +template > +inline IValue::IValue(const std::vector& v) : IValue(c10::List()) { + auto list = to>(); + list.reserve(v.size()); + for (const auto& e : v) { + list.push_back(e); + } +} + +template > +inline IValue::IValue(std::vector&& v) : IValue(c10::List()) { + auto list = to>(); + list.reserve(v.size()); + if constexpr (std::is_same_v) { + for (auto e : v) { + list.push_back(e); + } + } else { + for (auto&& e : std::move(v)) { + list.push_back(std::move(e)); + } + } +} + +template > +inline IValue::IValue(c10::OptionalArrayRef v) : IValue() { + if (v.has_value()) { + *this = IValue(std::move(*v)); + } +} + +template +inline IValue::IValue(std::array v) : IValue(c10::List()) { + auto list = to>(); + list.reserve(v.size()); + for (auto& e : v) { + list.push_back(std::move(e)); + } +} + +template > +inline IValue::IValue(c10::IListRef v) : IValue() { + constexpr bool boxed_type_constructs_ivalue = + std::is_constructible_v::boxed_type>; + // First, we try to use the boxed value. + // If we fail (either it's not in the boxed state, or its boxed type + // can not construct an IValue), we fallback to copying the list. + if (boxed_type_constructs_ivalue && v.isBoxed()) { + *this = IValue(impl::toList(v.toBoxed())); + } else { + c10::List list; + list.reserve(v.size()); + for (const auto& t : v) { + list.push_back(t); + } + *this = IValue(impl::toList(std::move(list))); + } +} + +inline IValue::IValue(c10::impl::GenericDict v) + : tag(Tag::GenericDict) { + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.impl_.release()); +} +template +inline IValue::IValue(c10::Dict v) + : IValue(impl::toGenericDict(std::move(v))) {} + +template +inline IValue::IValue(std::unordered_map v) + : IValue(Dict()) { + auto dict = to>(); + dict.reserve(v.size()); + for (auto& e : v) { + dict.insert(std::move(e.first), std::move(e.second)); + } +} + +template > +inline IValue::IValue(std::optional v) : IValue() { + if (v.has_value()) { + *this = IValue(std::move(*v)); + } +} + +inline IValue::IValue(std::nullopt_t) : IValue() {} + +inline IValue::IValue(c10::intrusive_ptr v) + : tag(Tag::Object) { + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); +} + +inline IValue::IValue(c10::intrusive_ptr v) + : tag(Tag::PyObject) { + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); +} + +inline IValue::IValue(c10::intrusive_ptr v) + : tag(Tag::Enum) { + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); +} + +inline IValue IValue::make_capsule( + intrusive_ptr blob) { + IValue iv; + iv.tag = Tag::Capsule; + iv.payload.u.as_intrusive_ptr = null_to_undefined_tensor(blob.release()); + return iv; +} + +template < + typename T, + std::enable_if_t, int>> +IValue::IValue(c10::intrusive_ptr custom_class) : tag(Tag::Object) { + auto classType = []() { + try { + return c10::getCustomClassType>(); + } catch (const c10::Error&) { + throw c10::Error( + "Trying to instantiate a class that isn't a registered custom class: " + + std::string(c10::util::get_fully_qualified_type_name())); + } + }(); + auto ivalue_obj = c10::ivalue::Object::create(std::move(classType), /* numSlots */1); + ivalue_obj->setSlot(0, IValue::make_capsule(std::move(custom_class))); + payload.u.as_intrusive_ptr = null_to_undefined_tensor(ivalue_obj.release()); + +} + +inline IValue::IValue(c10::intrusive_ptr v) + : tag(Tag::Future) { + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); +} + +inline IValue::IValue(c10::intrusive_ptr v) + : tag(Tag::Await) { + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); +} + +inline IValue::IValue(c10::intrusive_ptr v) + : tag(Tag::RRef) { + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); +} + +inline IValue::IValue(c10::intrusive_ptr v) + : tag(Tag::Quantizer) { + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); +} + +template +inline IValue::IValue(c10::complex c) + : tag(Tag::ComplexDouble) { + auto v = c10::make_intrusive(c); + payload.u.as_intrusive_ptr = v.release(); +} + +inline const std::string& IValue::toStringRef() const { + AT_ASSERT(isString(), "Expected String but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toStringRef on null intrusive_ptr IValue"); + return static_cast( + payload.u.as_intrusive_ptr) + ->string(); +} +inline std::optional> IValue:: + toOptionalStringRef() const { + if (isNone()) { + return std::nullopt; + } + AT_ASSERT(isString(), "Expected std::optional but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toOptionalStringRef on null intrusive_ptr IValue"); + return std::reference_wrapper( + static_cast(payload.u.as_intrusive_ptr) + ->string()); +} + +inline std::string_view IValue::toStringView() const { + AT_ASSERT(isString(), "Expected String but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toStringView on null intrusive_ptr IValue"); + return static_cast( + payload.u.as_intrusive_ptr) + ->string_view(); +} + +inline PyObject* IValue::toPyObject() const { + return toPyObjectHolder()->getPyObject(); +} + +template +inline std::optional IValue::toOptional() { + if (this->isNone()) { + return std::nullopt; + } + return this->to(); +} + +template +inline std::optional IValue::toOptional() const { + if (this->isNone()) { + return std::nullopt; + } + return this->to(); +} + +inline bool IValue::isCustomClass() const { + return torch::isCustomClass(*this); +} + +inline bool IValue::isSameIdentity(const IValue& rhs) const { + // We choose to not use memcmp for payload check due to potential random + // padding characters on union type + + // Semantics: + // 1. Immutable primitive values of the same type (Int, Double, None, Bool, + // Str) return value equality + // 2. If it is a tensor type, we need to take undefined tensor into account + // 3. Undefined_tensor is None and vice versa should be true + // 4. If it is a reference type (i.e. isIntrusivePtr()), then is True when + // the pointed-to object is the same. + // 5. False for all other comparisons. + if (this->isNone() && rhs.isNone()) { + return true; + } else if (this->isBool() && rhs.isBool()) { + // for bool type, do equality check + return this->toBool() == rhs.toBool(); + } else if (this->isTensor() && rhs.isTensor()) { + return this->payload.as_tensor.is_same(rhs.payload.as_tensor); + } else if (this->isTensor() && rhs.isNone()) { + // special case: undefined tensor and None are the same identity + return !this->payload.as_tensor.defined(); + } else if (this->isNone() && rhs.isTensor()) { + // special case: undefined tensor and None are the same identity + return !rhs.payload.as_tensor.defined(); + } else if (this->isInt() && rhs.isInt()) { + return this->toInt() == rhs.toInt(); + } else if (this->isDouble() && rhs.isDouble()) { + return this->toDouble() == rhs.toDouble(); + } else if (this->isString() && rhs.isString()) { + return this->toStringRef() == rhs.toStringRef(); + } else { + // for objects holding in IValue, do shallow compare on pointer address to + // testify the identity + return this->isIntrusivePtr() && rhs.isIntrusivePtr() && + this->payload.u.as_intrusive_ptr == rhs.payload.u.as_intrusive_ptr; + } +} + +namespace ivalue { +namespace detail { + +template +IValue from_(T&& x, std::true_type) { + return IValue(std::forward(x)); +} +template +IValue from_(c10::intrusive_ptr x, std::false_type) { + return IValue(std::move(x)); +} +template +IValue from_(T&& /*x*/, std::false_type) { + static_assert( + guts::false_t::value, + "You are calling from with a type that it doesn't support, and isn't a potential custom class (ie: is an intrusive_ptr)"); + return IValue(); +} +} // namespace detail + +template +IValue from(T&& x) { + return detail::from_( + std::forward(x), typename std::is_constructible::type{}); +} + +} // namespace ivalue + + +template <> +struct MaybeOwnedTraits { + using owned_type = IValue; + using borrow_type = IValue; + + static borrow_type createBorrow(const owned_type& from) { + if (!from.isPtrType()) { + return from; + } + if (from.isTensor()) { + return IValue(MaybeOwnedTraits::createBorrow(from.toTensor())); + } else { + return IValue(from.payload, from.tag); + } + } + + static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) { + lhs.clearToNone(); + if (!rhs.isPtrType()) { + lhs = rhs; + } else if (rhs.isTensor()) { + lhs = IValue(MaybeOwnedTraits::createBorrow(rhs.toTensor())); + } else { + lhs = IValue(rhs.payload, rhs.tag); + } + } + + static void destroyBorrow(borrow_type& toDestroy) { + toDestroy.clearToNone(); + } + + static const owned_type& referenceFromBorrow(const borrow_type& borrow) { + return borrow; + } + + static const owned_type* pointerFromBorrow(const borrow_type& borrow) { + return &borrow; + } + + static bool debugBorrowIsValid(const borrow_type&) { + return true; + } +}; + +template <> +struct IValue::TagType { + static TORCH_API c10::TypePtr get(const IValue&); +}; + +template <> +struct IValue::TagType { + static TORCH_API c10::TypePtr get(const IValue&); +}; + +template +TypePtr IValue::type() const { + return IValue::TagType::get(*this); +} + +} // namespace c10 diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/ivalue_to.h b/phivenv/Lib/site-packages/torch/include/ATen/core/ivalue_to.h new file mode 100644 index 0000000000000000000000000000000000000000..f750de76cfa9dc1ae0b1ef975526b38d70eb8bb0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/ivalue_to.h @@ -0,0 +1,36 @@ +#pragma once + +#include + +namespace at { +class Tensor; +} // namespace at + +namespace c10 { +struct IValue; +namespace detail { +// Determine the return type of `IValue::to() const &`. It's a const +// reference when possible and a copy otherwise. It is in this +// separate header so that List can use it as well. +template +struct ivalue_to_const_ref_overload_return { + using type = T; +}; + +template<> +struct ivalue_to_const_ref_overload_return { + using type = const at::Tensor&; +}; + +template<> +struct ivalue_to_const_ref_overload_return { + using type = const std::string&; +}; + +template<> +struct ivalue_to_const_ref_overload_return { + using type = const IValue&; +}; + +} // namespace detail +} // namespace c10 diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/jit_type.h b/phivenv/Lib/site-packages/torch/include/ATen/core/jit_type.h new file mode 100644 index 0000000000000000000000000000000000000000..d8553560342b5298d458aa23c2b361d701d613f1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/jit_type.h @@ -0,0 +1,2437 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + + +namespace torch::jit { +struct Function; +} // namespace torch::jit + + +namespace c10 { + +template +class Dict; +struct IValue; +struct FunctionSchema; +struct NamedType; +using OptNameList = std::optional>; + +void standardizeVectorForUnion(std::vector& reference, std::vector* to_fill); +void standardizeVectorForUnion(std::vector* to_flatten); + +inline bool is_contiguous_strides( + const IntArrayRef sizes, + const IntArrayRef strides) { + int n_dim = static_cast(sizes.size()); + if (n_dim == 0) { + return true; + } + + if (strides[n_dim - 1] != 1) { + return false; + } + + for (int i = n_dim - 2; i >= 0; i--) { + if (strides[i] != strides[i + 1] * sizes[i + 1]) { + return false; + } + } + return true; +} + +struct AnyType; +using AnyTypePtr = SingletonTypePtr; +// Any is the top of the type hierarchy, all other types are subtypes +// T <: Any, forall T +struct TORCH_API AnyType : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "Any"; + } + static const TypeKind Kind = TypeKind::AnyType; + // global singleton + static AnyTypePtr get(); + + private: + AnyType() : Type(TypeKind::AnyType) {} +}; + +inline std::string toString(const Type& type) { + return type.str(); +} + +// Shim for compatibility with code that uses TypePtr. +inline std::string toString(const TypePtr& typePtr) { + return toString(*typePtr); +} + +inline bool operator!=(const Type& lhs, const Type& rhs) { + return !(lhs == rhs); +} + +// common base for all types that have a single sub element +// e.g. Future[T], Optional[T], List[T] +template +struct SingleElementType : public SharedType { + static const TypeKind Kind = K; + + const TypePtr& getElementType() const { + return elem; + } + + bool hasFreeVariables() const override { + return getElementType()->hasFreeVariables(); + } + + at::ArrayRef containedTypes() const override { + return elem; + } + + bool equals(const Type& rhs) const override { + if (auto rhs_ = rhs.cast()) { + return *getElementType() == *rhs_->getElementType(); + } + return false; + } + + protected: + SingleElementType(TypePtr elem) : SharedType(Kind), elem(std::move(elem)) { + if (!this->elem) { + throw std::runtime_error(c10::str( + "Can not create ", typeKindToString(Kind), " with None type")); + } + } + + private: + TypePtr elem; +}; + +struct UnionType; +using UnionTypePtr = std::shared_ptr; +struct TORCH_API UnionType : public SharedType { + friend struct Type; + + static const TypeKind Kind = TypeKind::UnionType; + + bool isSubtypeOfExt(const Type& rhs_, std::ostream* why_not) const override; + + std::string str() const override; + + static UnionTypePtr create(std::vector reference); + + bool equals(const Type& rhs) const override; + + bool isUnionType() const override { + return true; + } + + at::ArrayRef containedTypes() const override { + return types_; + } + + // For testing purposes only + at::ArrayRef getTypes() const { + return types_; + } + + TypePtr createWithContained(std::vector contained_types) const override { + return create(std::move(contained_types)); + } + + bool canHoldType(const Type& type) const; + + bool hasFreeVariables() const override { + return has_free_variables_; + } + + std::optional toOptional() const; + + std::optional subtractTypeSet(std::vector& to_subtract) const; + + protected: + explicit UnionType(std::vector types, TypeKind kind=TypeKind::UnionType); + std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override; + std::string unionStr( + const TypePrinter& printer = nullptr, + bool is_annotation_str = false) const; + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + bool has_free_variables_; + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + std::vector types_; + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + bool can_hold_none_; + +}; + +struct OptionalType; +using OptionalTypePtr = std::shared_ptr; +// This type represents an optional type. There is one `Optional` for +// each element type. `Optional[T]` can accept both `T` and +// `None`(`std::nullopt` in C++) +// Subtype hierarchy for Optional: +// - Optional[T] <: Optional[R] iff T <: R +// - T <: Optional[R] if T <: R +// - None <: Optional[T] for all T +// - Optional[T] == Union[T, None] for all T +struct TORCH_API OptionalType : public UnionType { + static OptionalTypePtr create(const TypePtr& contained); + + static const TypeKind Kind = TypeKind::OptionalType; + + friend struct Type; + + bool equals(const Type& rhs) const override; + + const TypePtr& getElementType() const { + return contained_; + } + + at::ArrayRef containedTypes() const override { + return contained_; + } + + std::string str() const override { + std::stringstream ss; + ss << getElementType()->str() << "?"; + return ss.str(); + } + + TypePtr createWithContained( + std::vector contained_types) const override { + AT_ASSERT(contained_types.size() == 1); + return create(contained_types[0]); + } + + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override; + + bool isUnionType() const override { + return true; + } + + // common cast Optional[Tensor] for undefined tensor type + static TypePtr ofTensor(); + // + // global singleton + static TypePtr get(TypePtr inner); + + private: + explicit OptionalType(const TypePtr& contained); + + TypePtr contained_; + + std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override { + std::stringstream ss; + ss << "Optional[" << getElementType()->annotation_str(printer) << "]"; + return ss.str(); + } +}; + +template +inline std::optional merge_primitive( + const std::optional& a, + const std::optional& b) { + if (a.has_value() && b.has_value() && a.value() == b.value()) { + return a; + } + return std::optional{}; +} + +// If we see `a + b + c` and know that a, b, and c are the same size and have +// two dimensions (WxH), then we can generate a fused kernel for them. That +// fused kernel would likely have indexing math to handling both the W and H +// dimensions. However, if we knew the WxH dimensions were contiguous, we can +// pretend like we only have a single dimension, simplifying the indexing logic. +// This can be performed even if the dimensions are transposed, +// as long as a, b, and c are transposed in the same way. +// We'd like to have the compiler be able to do this dimensionality reduction, +// but simply knowing sizes is not enough. +// We can extend profiling to also record stride information. +// Rather than recording specific strides, +// we can simply order the strides from smallest to largest with +// `stride_indices` A contiguity marker on the smallest stride (c0) indicates +// the stride is precisely 1, otherwise a contiguity marker means that $stride_n +// = size_{n-1}*stride_{n-1}$ +struct TORCH_API Stride { + Stride() = default; + Stride( + const std::optional& stride_index, + std::optional contiguous, + const std::optional& stride) + : stride_index_(stride_index), contiguous_(contiguous), stride_(stride) {} + + bool operator==(const Stride& b) const { + return stride_index_ == b.stride_index_ && contiguous_ == b.contiguous_ && + stride_ == b.stride_; + } + + bool isComplete() const { + return stride_index_ && contiguous_ && stride_; + } + + std::optional stride_index_; + std::optional contiguous_; + std::optional stride_; +}; + +template <> +inline std::optional merge_primitive( + const std::optional& a, + const std::optional& b) { + std::optional left = a; + std::optional right = b; + if (!left.has_value()) { + left = {Stride()}; + } + if (!right.has_value()) { + right = {Stride()}; + } + + auto merged_index = + merge_primitive(left->stride_index_, right->stride_index_); + auto merged_cont = merge_primitive(left->contiguous_, right->contiguous_); + auto merged_stride = merge_primitive(left->stride_, right->stride_); + auto r = Stride(merged_index, merged_cont, merged_stride); + // normalize + if (!r.stride_index_.has_value() && !r.contiguous_.has_value() && + !r.stride_.has_value()) { + return std::optional{}; + } + + return r; +} + +struct TORCH_API ShapeSymbol { + // needed for use in `std::map` + ShapeSymbol() : value_(-1) {} + // is this symbol a fixed/static dimension + bool is_static() const { + return value_ >= 0; + } + bool operator==(const ShapeSymbol& b) const { + return value_ == b.value_; + } + bool operator<(const ShapeSymbol& b) const { + return value_ < b.value_; + } + + static ShapeSymbol fromStaticSize(int64_t val) { + return ShapeSymbol(val); + } + int64_t static_size() const { + TORCH_CHECK(is_static()); + return value_; + } + + int64_t value() const { + return value_; + } + + static ShapeSymbol newSymbol() { + return fromStaticSize(-static_cast(++num_symbols)); + } + friend TORCH_API std::ostream& operator<<( + std::ostream& os, + const ShapeSymbol& s); + + private: + ShapeSymbol(int64_t val) : value_(val) {} + int64_t value_; + static std::atomic num_symbols; +}; + +inline ShapeSymbol merge_primitive( + const ShapeSymbol& a, + const ShapeSymbol& b) { + if (a.is_static() && b.is_static() && a == b) { + return a; + } + return ShapeSymbol::newSymbol(); +} + +// Shape of a Tensor represented with ShapeSymbol's. Unranked, ranked unknown +// dims, partially known and fully known shapes are all supported. +struct TORCH_API SymbolicShape { + // Unranked shape constructor. + SymbolicShape() : dims_(std::nullopt) {} + + // Known rank but unknown dimentions. + SymbolicShape(std::optional rank) : dims_(std::nullopt) { + if(!rank) { + return; + } + + std::vector shape_symbols; + shape_symbols.reserve(*rank); + for(size_t i = 0; i < *rank; ++i) { + shape_symbols.push_back(ShapeSymbol::newSymbol()); + } + dims_ = shape_symbols; + } + + // Mix of known and unknown ranks + SymbolicShape(const std::vector>& dims) { + std::vector shape_symbols; + shape_symbols.reserve(dims.size()); + for(std::optional dim: dims) { + if(!dim) { + shape_symbols.push_back(ShapeSymbol::newSymbol()); + } else { + shape_symbols.push_back(ShapeSymbol::fromStaticSize(*dim)); + } + } + dims_ = shape_symbols; + } + + void dump() const; + + SymbolicShape(std::vector dims) : dims_(std::move(dims)) {} + + SymbolicShape(c10::IntArrayRef dims) { + std::vector shape_symbols; + shape_symbols.reserve(dims.size()); + for(int64_t dim : dims) { + shape_symbols.push_back(ShapeSymbol::fromStaticSize(dim)); + } + dims_ = shape_symbols; + } + + ShapeSymbol operator[](size_t i) const { + if (!dims_) { + throw std::runtime_error("Rank isn't fixed"); + } + return (*dims_).at(i); + } + + ShapeSymbol at(size_t i) const { + if (!dims_) { + throw std::runtime_error("Rank isn't fixed"); + } + return (*dims_).at(i); + } + + // Returns rank or nullopt in case of unranked shape. + std::optional rank() const { + if(!dims_) { + return std::nullopt; + } + return dims_->size(); + } + + std::optional> sizes() const { + return dims_; + } + + std::optional> symbolicDims() const { + if (!dims_) { + return std::nullopt; + } + auto symbolic_dims = std::vector(); + for (const ShapeSymbol& s : *dims_) { + symbolic_dims.push_back(!s.is_static()); + } + return symbolic_dims; + } + + // Checks whether the shape is fully defined/complete, ie. rank and sizes + // of every dimension are known. + bool isComplete() const { + if(!dims_) { + return false; + } + for(auto d : *dims_) { + if(!d.is_static()) { + return false; + } + } + return true; + } + + // Create new SymbolicShape that is result of merging self and another + // SymbolicShape. Only dimensions that are static and equal will be + // preserved. + // If either of two shapes are of unknown rank or they have unmatching rank, + // result will be unranked. + SymbolicShape merge(const SymbolicShape& other) const; + + friend bool operator==(const SymbolicShape& lhs, const SymbolicShape& rhs) { + return lhs.dims_ == rhs.dims_; + } + + friend bool operator!=(const SymbolicShape& lhs, const SymbolicShape& rhs) { + return !(lhs == rhs); + } + + private: + std::optional> dims_; +}; + +namespace detail { +inline bool isComplete(const Stride& s) { + return s.isComplete(); +} + +template +inline bool isComplete(const T& /*t*/) { + return true; +} +} + +template +struct VaryingShape { + using ListOfOptionalElements = std::vector>; + VaryingShape(const std::vector& vec) + : VaryingShape(ListOfOptionalElements(vec.begin(), vec.end())) {} + + VaryingShape(c10::ArrayRef vec) + : VaryingShape(ListOfOptionalElements(vec.begin(), vec.end())) {} + + VaryingShape(std::optional size = std::nullopt) : dims_(std::nullopt) { + if (size) { + dims_ = ListOfOptionalElements(*size); + } + } + + VaryingShape(ListOfOptionalElements dims) : dims_(std::move(dims)) {} + + VaryingShape(size_t size) : VaryingShape(std::optional(size)) {} + + bool operator==(const VaryingShape& other) const { + return dims_ == other.dims_; + } + + const std::optional &operator[](size_t i) const { + if (!dims_) { + throw std::runtime_error("Rank isn't fixed"); + } + return (*dims_).at(i); + } + + std::optional size() const { + if (!dims_) { + return std::nullopt; + } + const auto& dims = dims_.value(); + return dims.size(); + } + + const std::optional& sizes() const { + return dims_; + } + + TORCH_API VaryingShape merge(const VaryingShape& other) const; + + std::optional> concrete_sizes() const { + if (!dims_) { + return std::nullopt; + } + std::vector sizes; + sizes.reserve(dims_.value().size()); + for (auto d : *dims_) { + if (!d) { + return std::nullopt; + } + sizes.push_back(d.value()); + } + return sizes; + } + + bool isComplete() const { + if (!dims_) { + return false; + } + for (auto d : *dims_) { + if (!d || !detail::isComplete(*d)) { + return false; + } + } + return true; + } + + private: + std::optional dims_; +}; + +struct TensorType; +// TODO: investigate making this SingletonOrSharedTypePtr +using TensorTypePtr = std::shared_ptr; +// This type represents a single Tensor with a specific size +struct TORCH_API TensorType : public SharedType { + static TensorTypePtr create(const at::Tensor& t); + + // used by TensorType::create(size_t dim) which in turn used by + // shape_analysis.cpp + static TensorTypePtr create( + std::optional scalar_type, + std::optional device, + const VaryingShape& sizes, + const VaryingShape& strides, + std::optional requires_grad, + std::optional undefined = false, + bool tensor_contiguity = false); + + static TensorTypePtr create( + std::optional scalar_type, + std::optional device, + SymbolicShape sizes, + VaryingShape stride_, + std::optional requires_grad, + std::optional undefined = false); + + static TensorTypePtr create( + std::optional scalar_type, + std::optional device, + std::optional dim, + std::optional requires_grad); + + // overloaded create variadic template argument as it could not distinguish + // initializer list + static TensorTypePtr createContiguous( + at::ScalarType scalar_type, + at::Device device, + at::IntArrayRef sizes); + + static TypePtr fromNumberType(const Type& typ); + static TypePtr fromBoolType(); + + std::optional dim() const { + return sizes().size(); + } + + VaryingShape sizes() const; + + VaryingShape strides() const; + + const VaryingShape& stride_properties() const { + return strides_; + } + + const std::optional& device() const { + return device_; + } + const std::optional& scalarType() const { + return scalar_type_; + } + const std::optional& requiresGrad() const { + return requires_grad_; + } + bool requires_grad() const override { + return requires_grad_ ? *requires_grad_ : true; + } + + bool equals(const Type& rhs) const override; + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override; + + std::string str() const override; + + std::string repr_str() const override { + if (isInferredType()) { + return str() + " (inferred)"; + } else { + return str(); + } + } + + std::optional numel() const { + size_t prod = 1; + const auto& shape = sizes(); + + for (size_t i = 0; i < shape.size(); i++) { + auto const &s = shape[i]; + if (!s.has_value()) { + return std::optional{}; + } + prod *= s.value(); + } + return prod; + } + + TensorTypePtr withRequiresGrad(std::optional s) { + auto copy = clone(); + copy->requires_grad_ = s; + return copy; + } + + TensorTypePtr withScalarType(std::optional st) { + auto copy = clone(); + copy->scalar_type_ = st; + return copy; + } + + TensorTypePtr withDim(std::optional d) { + auto copy = clone(); + // withDim is only used by the legacy executor + // that only cares about the rank, so create dummy symbols)) : + copy->sizes_ = SymbolicShape(d); + copy->strides_ = VaryingShape(d); + return copy; + } + + TensorTypePtr withStrides(VaryingShape sstrides) const { + auto cloned = clone(); + cloned->strides_ = std::move(sstrides); + return cloned; + } + + TensorTypePtr withSizesStrides( + at::IntArrayRef sizes, + at::IntArrayRef strides) const { + auto cloned = clone(); + auto ssizes = SymbolicShape(sizes); + cloned->sizes_ = ssizes; + cloned->strides_ = computeStrideProps(sizes, strides); + return cloned; + } + + TensorTypePtr withSymbolicShapes(SymbolicShape ssizes) const { + auto cloned = clone(); + cloned->sizes_ = std::move(ssizes); + return cloned; + } + + TensorTypePtr withSizes(at::IntArrayRef sizes) const { + return withSizesStrides( + sizes, contiguousStridesOf(sizes)); + } + + TensorTypePtr withDevice(const std::optional device) const { + auto copy = clone(); + copy->device_ = device; + return copy; + } + + TensorTypePtr dimensionedOnly() const { + auto copy = clone(); + copy->sizes_ = SymbolicShape(sizes().size()); + copy->strides_ = VaryingShape(sizes().size()); + return copy; + } + + TensorTypePtr contiguous() const { + auto cloned = clone(); + auto concrete_sizes = sizes().concrete_sizes(); + TORCH_INTERNAL_ASSERT(concrete_sizes.has_value()); + auto strides = computeStrideProps( + *concrete_sizes, + contiguousStridesOf(*concrete_sizes)); + cloned->strides_ = strides; + return cloned; + } + + const SymbolicShape& symbolic_sizes() const; + + TensorTypePtr merge(const TensorType& other, bool merge_sizes = true) const; + + bool matchTensor(const at::Tensor& t); + + // is all information about the type specified except for autograd? + // This replaces the notion of a 'CompleteTensorType' that used to exist + // in the type-hierarchy. Excluding require_grad and undefined allows + // this to match the old behavior. + bool isComplete() const { + return scalar_type_ && device_ && sizes_.isComplete() && strides_.isComplete(); + } + + bool isInferredType() const { + return is_inferred_; + } + + static TensorTypePtr getInferred() { + static auto valueInferred = TensorType::create( + /*scalar_type=*/{}, + /*device=*/{}, + /*sizes=*/SymbolicShape(), + /*stride=*/VaryingShape{}, + /*requires_grad=*/{}, + /*undefined=*/false); + valueInferred->is_inferred_ = true; + return valueInferred; + } + + // this property is used by GuardElimination + // please see `checkInputs` for more details + bool isSummarized() const { + return !(isComplete() && requiresGrad().has_value() && + undefined().has_value()); + } + + TensorTypePtr withUndefined() { + auto r = clone(); + r->undefined_ = true; + return r; + } + + TensorTypePtr withPossiblyUndefined() { + auto r = clone(); + r->undefined_ = std::nullopt; + return r; + } + + std::optional undefined() const { return undefined_; } + + static const TensorTypePtr& get(); + + static const TypeKind Kind = TypeKind::TensorType; + + static std::vector contiguousStridesOf( + at::IntArrayRef in_sizes, + at::MemoryFormat memory_format = MemoryFormat::Contiguous) { + auto contiguous_fn = [](const at::IntArrayRef& sizes, + const std::vector& dim_order) { + std::vector strides(sizes.size()); + if (sizes.empty()) // zero-dim case + return strides; + + strides[dim_order[0]] = 1; + for (size_t i = 1; i < dim_order.size(); i++) { + auto cur_dim = dim_order[i]; + auto pre_dim = dim_order[i - 1]; + strides[cur_dim] = strides[pre_dim] * sizes[pre_dim]; + } + return strides; + }; + + std::vector dim_order(in_sizes.size()); + if (memory_format == MemoryFormat::ChannelsLast) { + dim_order = {1, 3, 2, 0}; + } else if (memory_format == MemoryFormat::ChannelsLast3d) { + dim_order = {1, 4, 3, 2, 0}; + } else { + auto ndims = in_sizes.size(); + for (size_t i = 0; i < ndims; i++) { + dim_order[i] = static_cast(ndims - i - 1); // Reverse + } + } + return contiguous_fn(in_sizes, dim_order); + } + + private: + TensorType( + std::optional scalar_type, + std::optional device, + SymbolicShape sizes, + VaryingShape strides, + std::optional requires_grad, + std::optional undefined = false); + + TensorTypePtr clone() const { + return TensorTypePtr(new TensorType( + scalar_type_, device_, sizes_, strides_, requires_grad_, undefined_)); + } + + static VaryingShape computeStrideProps( + at::IntArrayRef sizes, + at::IntArrayRef strides, + bool tensor_contiguity = false); + + std::optional scalar_type_; + std::optional device_; + SymbolicShape sizes_; + VaryingShape strides_; + std::optional requires_grad_; + // we exploit the fact certain tensors must be zero in the autograd to + // optimize gradient computation. Such zero tensors are currently implemented + // with `UndefinedTensorImpl.` They can be handled only by special operators + // (e.g. `AutogradAdd`) and their `Tensor::defined()` property returns false. + // Normally, `undefined_` is set to false, unless a type was created + // with `withUndefined` + // This will also mean that `undefined` tensors will fail + // `subtypeOf(TensorType::get())` check + // undefined_ may become `std::nullopt` if the tensor was observed to be both + // defined and undefined. However, no tensor type starts out with + // `undefined_` set to `std::nullopt` + std::optional undefined_; + // Represents whether or not this type was inferred. + bool is_inferred_ = false; +}; + +struct ListType; +using ListTypePtr = std::shared_ptr; +struct TORCH_API ListType + : public SingleElementType { + // It's not exactly a singleton, but there should be exactly one instance of + // List[T] for every T + friend struct Type; + template + static ListTypePtr create(T&&... all) { + return ListTypePtr( + new ListType(std::forward(all)...)); // NOLINT(modernize-make-shared) + } + + std::string str() const override { + std::stringstream ss; + ss << getElementType()->str() << "[]"; + return ss.str(); + } + TypePtr createWithContained( + std::vector contained_types) const override { + return create(std::move(contained_types.at(0))); + } + + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override; + + // global singleton + // Given an inner type T and an identifier, + // this function wil return the global singleton type pointer + // the type List. + // The extra "identifier" argument is needed beccause we have multiple container types + // that all re-use this function (List, array, etc.) + static TypePtr get(const std::string& identifier, TypePtr inner); + + // common cast List[Tensor] + static ListTypePtr ofTensors(); + static ListTypePtr ofOptionalTensors(); + static ListTypePtr ofInts(); + static ListTypePtr ofSymInts(); + static ListTypePtr ofFloats(); + static ListTypePtr ofComplexDoubles(); + static ListTypePtr ofBools(); + static ListTypePtr ofStrings(); + static ListTypePtr ofNumbers(); + + private: + ListType(TypePtr elem) : SingleElementType(std::move(elem)) {} + + std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override { + std::stringstream ss; + ss << "List[" << getElementType()->annotation_str(printer) << "]"; + return ss.str(); + } +}; + +struct DictType; +using DictTypePtr = std::shared_ptr; +struct TORCH_API DictType : public SharedType { + friend struct Type; + static const TypeKind Kind = TypeKind::DictType; + + static DictTypePtr create(TypePtr key, TypePtr value) { + auto kind = key->kind(); + if (auto dyn = key->castRaw()) { + kind = dyn->dynamicKind(); + } + switch (kind) { + case TypeKind::AnyType: + case TypeKind::IntType: + case TypeKind::BoolType: + case TypeKind::FloatType: + case TypeKind::ComplexType: + case TypeKind::StringType: + case TypeKind::TensorType: + case TypeKind::DeviceObjType: + return DictTypePtr(new DictType(std::move(key), std::move(value))); + default: + TORCH_CHECK(false, + "Cannot create dict for key type '", + key->str(), + "', only int, float, complex, Tensor, device and string keys are supported"); + } + } + + // aligned with the format in FunctionSchema + std::string str() const override { + std::stringstream ss; + ss << "Dict(" << getKeyType()->str() << ", " << getValueType()->str() + << ")"; + return ss.str(); + } + + TypePtr createWithContained( + std::vector contained_types) const override { + if (contained_types.size() != 2) { + throw std::runtime_error("Expected 2 contained types"); + } + return create(std::move(contained_types.at(0)), std::move(contained_types.at(1))); + } + + const TypePtr& getKeyType() const { + return types.at(0); + } + + const TypePtr& getValueType() const { + return types.at(1); + } + + bool hasFreeVariables() const override { + return has_free_variables; + } + + at::ArrayRef containedTypes() const override { + return types; + } + + bool equals(const Type& rhs) const override { + if (auto* dict_rhs = rhs.castRaw()) { + return *getKeyType() == *(dict_rhs->getKeyType()) && + *getValueType() == *(dict_rhs->getValueType()); + } + return false; + } + + // global singleton + // Given an inner type T and an identifier, + // this function will return the global singleton type pointer + // the type List. + // The extra "identifier" argument is needed because we have multiple container types + // that all re-use this function (Dict and unordered_map) + static TypePtr get(const std::string& identifier, TypePtr key, TypePtr val); + + private: + DictType(TypePtr key, TypePtr value) + : SharedType(TypeKind::DictType), + has_free_variables( + key->hasFreeVariables() || value->hasFreeVariables()) { + types.reserve(2); + types.push_back(std::move(key)); + types.push_back(std::move(value)); + } + + std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override; + + std::vector types; + bool has_free_variables; +}; + +struct FutureType; +using FutureTypePtr = std::shared_ptr; + +struct TORCH_API FutureType + : public SingleElementType { + friend struct Type; + template + static FutureTypePtr create(TypePtr elem) { + return FutureTypePtr( + new FutureType(std::move(elem))); // NOLINT(modernize-make-shared) + } + + std::string str() const override { + std::stringstream ss; + ss << "Future(" << getElementType()->str() << ")"; + return ss.str(); + } + TypePtr createWithContained( + std::vector contained_types) const override { + return create(std::move(contained_types.at(0))); + } + + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override { + if (Type::isSubtypeOfExt(rhs, why_not)) { + return true; + } + if (auto rhs_ = rhs.castRaw()) { + return getElementType()->isSubtypeOfExt(*rhs_->getElementType(), why_not); + } + return false; + } + + private: + FutureType(TypePtr elem) : SingleElementType(std::move(elem)) {} + + std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override { + std::stringstream ss; + ss << "Future[" << getElementType()->annotation_str(printer) << "]"; + return ss.str(); + } +}; + +struct AwaitType; +using AwaitTypePtr = std::shared_ptr; + +struct TORCH_API AwaitType + : public SingleElementType { + friend struct Type; + template + static AwaitTypePtr create(TypePtr elem) { + return AwaitTypePtr( + new AwaitType(std::move(elem))); // NOLINT(modernize-make-shared) + } + + std::string str() const override { + std::stringstream ss; + ss << "Await(" << getElementType()->str() << ")"; + return ss.str(); + } + TypePtr createWithContained( + std::vector contained_types) const override { + return create(std::move(contained_types.at(0))); + } + + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override { + if (Type::isSubtypeOfExt(rhs, why_not)) { + return true; + } + if (auto rhs_ = rhs.castRaw()) { + return getElementType()->isSubtypeOfExt(*rhs_->getElementType(), why_not); + } + return false; + } + + private: + AwaitType(TypePtr elem) : SingleElementType(std::move(elem)) {} + + std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override { + std::stringstream ss; + ss << "Await[" << getElementType()->annotation_str(printer) << "]"; + return ss.str(); + } +}; + +struct RRefType; +using RRefTypePtr = std::shared_ptr; + +struct TORCH_API RRefType + : public SingleElementType { + friend struct Type; + template + static RRefTypePtr create(TypePtr elem) { + return RRefTypePtr( + new RRefType(std::move(elem))); // NOLINT(modernize-make-shared) + } + + std::string str() const override { + std::stringstream ss; + ss << "RRef(" << getElementType()->str() << ")"; + return ss.str(); + } + TypePtr createWithContained( + std::vector contained_types) const override { + return create(std::move(contained_types.at(0))); + } + + private: + RRefType(TypePtr elem) : SingleElementType(std::move(elem)) {} + + std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override { + std::stringstream ss; + ss << "RRef[" << getElementType()->annotation_str(printer) << "]"; + return ss.str(); + } +}; + +// Any should never appear in a named type like a class, namedtuple or +// interface. If it does, then dynamic type information will be lost in the +// Pickler, leading to hard-to-track-down bugs that will only occur +// after saving or loading a model. This is because we rely on the +// static types in named types to reconstruct type tags of loaded +// values. Lifting this restriction requires solving the serialization +// problem first. +TORCH_API void checkNoAny( + const Type& base, + const char* what, + const std::string& attrname, + const TypePtr& attrtype); + +struct TupleType; +using TupleTypePtr = std::shared_ptr; +using NameList = std::vector; +// This type represents a Tuple +struct TORCH_API TupleType : public NamedType { + + static TupleTypePtr createNamed(const std::optional& name, + const std::vector& field_names, + const std::vector& field_types, + std::vector& field_defaults); + + static TupleTypePtr createNamed(const std::optional& name, + const std::vector& field_names, + const std::vector& field_types); + + static TupleTypePtr createNamed(const std::optional& name, + const std::vector& field_names, + const std::vector& field_types); + + static TupleTypePtr create( + std::vector types) { + return TupleTypePtr(new TupleType( + std::move(types), + std::nullopt, + nullptr)); // NOLINT(modernize-make-shared) + } + static TupleTypePtr create() { + return create({}); + } + + at::ArrayRef elements() const { + return elements_; + } + + bool equals(const Type& rhs) const override; + bool isSubtypeOfExt(const Type& rhs_, std::ostream* why_not) const override; + + std::string str() const override; + bool hasFreeVariables() const override { + return has_free_variables_; + } + at::ArrayRef containedTypes() const override { + return elements_; + } + TypePtr createWithContained( + std::vector contained_types) const override { + return std::shared_ptr( + new TupleType(std::move(contained_types), name(), schema())); + } + const std::shared_ptr& schema() const { + return schema_; + } + std::optional> names() const; + + static const TypeKind Kind = TypeKind::TupleType; + + private: + template + static TupleTypePtr createWithSpec( + const std::optional& name, + const std::vector& field_names, + const std::vector& field_types, + std::vector& field_defaults); + + TupleType( + std::vector elements_, + std::optional name, + std::shared_ptr schema); + + bool compare( + const Type& rhs, + const std::function& fn) const { + if (rhs.kind() != kind()) { + return false; + } + + const auto& l_elements = elements(); + const auto& r_elements = rhs.castRaw()->elements(); + if (l_elements.size() != r_elements.size()) + return false; + for (size_t i = 0; i < l_elements.size(); ++i) { + if (!fn(*l_elements[i], *r_elements[i])) + return false; + } + return true; + } + + std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override; + + std::vector elements_; + bool has_free_variables_; + std::shared_ptr schema_; +}; + +// the common supertype of all Enums, only used in operator registraion. +// EnumType <: AnyEnumType for all Enums +struct AnyEnumType; +using AnyEnumTypePtr = SingletonTypePtr; +struct TORCH_API AnyEnumType final : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "AnyEnumType"; + } + static const TypeKind Kind = TypeKind::AnyEnumType; + // global singleton + static AnyEnumTypePtr get(); +private: + AnyEnumType() + : Type(TypeKind::AnyEnumType) {} +}; + +struct NumberType; +using NumberTypePtr = SingletonTypePtr; +// This type represents a Python number +// Subtype hierarchy for Number Types (NumberType as the base type): +// IntType <: NumberType +// FloatType <: NumberType +// ComplexType <:NumberType +// +// WARNING: if you add a new subtype of NumberType that is not +// represented by a global singleton, you need to change NumberTypePtr +// to a SingletonOrSharedTypePtr and deal with NumberType needing to +// both inherit and not inherit from SharedType! +struct TORCH_API NumberType : public Type { + bool equals(const Type& rhs) const override; + + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override; + + std::string str() const override { + return "Scalar"; // match what PythonArgParser says for clarity + } + static const TypeKind Kind = TypeKind::NumberType; + // global singleton + static NumberTypePtr get(); + + protected: + NumberType(TypeKind kind = TypeKind::NumberType) : Type(kind) {} + + std::string annotation_str_impl( + [[maybe_unused]] const TypePrinter& printer = nullptr) const override { + return "number"; // technically not a valid python type, but + // we need to use it when parsing back in annotations + // for implicit conversions + } +}; + +struct FloatType; +using FloatTypePtr = SingletonTypePtr; +// This type represents a Python float number +struct TORCH_API FloatType : public NumberType { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "float"; + } + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override { + // NOLINTNEXTLINE(bugprone-parent-virtual-call) + return rhs.kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not); + } + static const TypeKind Kind = TypeKind::FloatType; + // global singleton + static FloatTypePtr get(); + + private: + FloatType() : NumberType(TypeKind::FloatType) {} + std::string annotation_str_impl( + [[maybe_unused]] const TypePrinter& printer = nullptr) const override { + return "float"; + } +}; + +struct ComplexType; +using ComplexTypePtr = SingletonTypePtr; +// This type represents a Python float number +struct TORCH_API ComplexType : public NumberType { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "complex"; + } + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override { + // NOLINTNEXTLINE(bugprone-parent-virtual-call) + return rhs.kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not); + } + static const TypeKind Kind = TypeKind::ComplexType; + // global singleton + static ComplexTypePtr get(); + + private: + ComplexType() : NumberType(TypeKind::ComplexType) {} + std::string annotation_str_impl( + [[maybe_unused]] const TypePrinter& printer = nullptr) const override { + return "complex"; + } +}; + +// We need to introduce `SymIntType` to represent the `SymInt` type +// used in function schemas e.g. `aten::narrow_copy(... SymInt length) +// `SymInt` will be used to enable tracing arithmetic operations on +// dimension values. Please see [SymInt.h] for more information +struct SymIntType; +using SymIntTypePtr = SingletonTypePtr; +struct TORCH_API SymIntType : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "SymInt"; + } + std::string annotation_str_impl(const TypePrinter& printer [[maybe_unused]] = nullptr) const override { + return "int"; + } + static const TypeKind Kind = TypeKind::SymIntType; + // global singleton + static SymIntTypePtr get(); + + private: + SymIntType() : Type(TypeKind::SymIntType) {} +}; + +struct SymFloatType; +using SymFloatTypePtr = SingletonTypePtr; +struct TORCH_API SymFloatType : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "SymFloat"; + } + std::string annotation_str_impl(const TypePrinter& printer [[maybe_unused]] = nullptr) const override { + return "float"; + } + static const TypeKind Kind = TypeKind::SymFloatType; + // global singleton + static SymFloatTypePtr get(); + + private: + SymFloatType() : Type(TypeKind::SymFloatType) {} +}; + +struct SymBoolType; +using SymBoolTypePtr = SingletonTypePtr; +struct TORCH_API SymBoolType : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "SymBool"; + } + std::string annotation_str_impl(const TypePrinter& printer [[maybe_unused]] = nullptr) const override { + return "bool"; + } + static const TypeKind Kind = TypeKind::SymBoolType; + // global singleton + static SymBoolTypePtr get(); + + private: + SymBoolType() : Type(TypeKind::SymBoolType) {} +}; + +struct IntType; +using IntTypePtr = SingletonTypePtr; +// This type represents a Python int number +struct TORCH_API IntType : public NumberType { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "int"; + } + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override { + // NOLINTNEXTLINE(bugprone-parent-virtual-call) + return rhs.kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not); + } + static const TypeKind Kind = TypeKind::IntType; + // global singleton + static IntTypePtr get(); + + private: + IntType() : NumberType(TypeKind::IntType) {} + std::string annotation_str_impl( + [[maybe_unused]] const TypePrinter& printer = nullptr) const override { + return "int"; + } +}; + +struct BoolType; +using BoolTypePtr = SingletonTypePtr; +// This node represents a Python bool value +struct TORCH_API BoolType : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "bool"; + } + static const TypeKind Kind = TypeKind::BoolType; + // global singleton + static BoolTypePtr get(); + + private: + BoolType() : Type(TypeKind::BoolType) {} +}; + +struct StringType; +using StringTypePtr = SingletonTypePtr; +// This type represents a Python string +struct TORCH_API StringType : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + // we only use "str" (not "string") in both FunctionSchema and script + return annotation_str(); + } + std::string annotation_str_impl( + [[maybe_unused]] const TypePrinter& printer = nullptr) const override { + return "str"; + } + static const TypeKind Kind = TypeKind::StringType; + // global singleton + static StringTypePtr get(); + + private: + StringType() : Type(TypeKind::StringType) {} +}; + +struct StorageType; +using StorageTypePtr = SingletonTypePtr; +struct TORCH_API StorageType : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return annotation_str(); + } + std::string annotation_str_impl( + [[maybe_unused]] const TypePrinter& printer = nullptr) const override { + return "Storage"; + } + static const TypeKind Kind = TypeKind::StorageType; + // global singleton + static StorageTypePtr get(); + + private: + StorageType() : Type(TypeKind::StorageType) {} +}; + +struct FunctionType; +using FunctionTypePtr = std::shared_ptr; +struct TORCH_API FunctionType : public NamedType { + static FunctionTypePtr create(torch::jit::Function* function) { + return FunctionTypePtr( + new FunctionType(function)); // NOLINT(modernize-make-shared) + } + bool equals(const Type& rhs) const override { + if (auto func_type = rhs.cast()) { + return func_type->function_ == function_; + } + + return false; + } + std::string str() const override { + return "Function"; + } + torch::jit::Function* function() const { + return function_; + } + static const TypeKind Kind = TypeKind::FunctionType; + + private: + FunctionType(torch::jit::Function* function); + std::string annotation_str_impl( + [[maybe_unused]] const TypePrinter& printer = nullptr) const override { + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + return name()->qualifiedName(); + } + torch::jit::Function* function_; +}; + +struct NoneType; +using NoneTypePtr = SingletonTypePtr; +// This type represents a Python None +struct TORCH_API NoneType : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "NoneType"; + } + bool isSubtypeOfExt(const Type& rhs, std::ostream *why_not) const override; + + static const TypeKind Kind = TypeKind::NoneType; + // global singleton + static NoneTypePtr get(); + + private: + NoneType() : Type(TypeKind::NoneType) {} +}; + +struct GeneratorType; +using GeneratorTypePtr = SingletonTypePtr; +// This type represents a Generator +struct TORCH_API GeneratorType : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "Generator"; + } + static const TypeKind Kind = TypeKind::GeneratorType; + // global singleton + static GeneratorTypePtr get(); + + private: + GeneratorType() : Type(TypeKind::GeneratorType) {} +}; + +struct QuantizerType; +using QuantizerTypePtr = SingletonTypePtr; +// This type represents a Quantizer +struct TORCH_API QuantizerType : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "Quantizer"; + } + static const TypeKind Kind = TypeKind::QuantizerType; + // global singleton + static QuantizerTypePtr get(); + + private: + QuantizerType() : Type(TypeKind::QuantizerType) {} +}; + +struct QSchemeType; +using QSchemeTypePtr = SingletonTypePtr; +// This type represents a QScheme +struct TORCH_API QSchemeType : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "QScheme"; + } + static const TypeKind Kind = TypeKind::QSchemeType; + // global singleton + static QSchemeTypePtr get(); + + private: + QSchemeType() : Type(TypeKind::QSchemeType) {} +}; + +struct DeviceObjType; +using DeviceObjTypePtr = SingletonTypePtr; +// This type represents a Device +struct TORCH_API DeviceObjType : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "Device"; + } + static const TypeKind Kind = TypeKind::DeviceObjType; + // global singleton + static DeviceObjTypePtr get(); + + private: + DeviceObjType() : Type(TypeKind::DeviceObjType) {} +}; + +struct StreamObjType; +using StreamObjTypePtr = SingletonTypePtr; +// This type represents a Generator +struct TORCH_API StreamObjType : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "Stream"; + } + static const TypeKind Kind = TypeKind::StreamObjType; + // global singleton + static StreamObjTypePtr get(); + +private: + StreamObjType() : Type(TypeKind::StreamObjType) {} +}; + +struct VarType; +using VarTypePtr = std::shared_ptr; +// This type represents a type variable, used in FunctionSchema +struct VarType : public SharedType { + static VarTypePtr create(std::string name_) { + return VarTypePtr(new VarType(std::move(name_))); + } + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return name(); + } + const std::string& name() const { + return name_; + } + bool hasFreeVariables() const override { + return true; + } + static const TypeKind Kind = TypeKind::VarType; + + private: + VarType(std::string name_) + : SharedType(TypeKind::VarType), name_(std::move(name_)) {} + std::string name_; +}; + +struct CapsuleType; +using CapsuleTypePtr = SingletonTypePtr; +// This type represents a Python Capsule. +// It does not appear in the IR and is only used during runtime +struct TORCH_API CapsuleType : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "Capsule"; + } + static const TypeKind Kind = TypeKind::CapsuleType; + // global singleton + static CapsuleTypePtr get(); +private: + CapsuleType() + : Type(TypeKind::CapsuleType) {} +}; + +struct PyObjectType; +using PyObjectTypePtr = SingletonTypePtr; +// This type represents a PyObject Type +struct TORCH_API PyObjectType : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "PyObject"; + } + static const TypeKind Kind = TypeKind::PyObjectType; + // global singleton + static PyObjectTypePtr get(); +private: + PyObjectType() + : Type(TypeKind::PyObjectType) {} +}; + +enum class TypeVerbosity { + None, + Type, + TypeAndStride, + Full, + Symbolic, + Default = Full, +}; + +TORCH_API TypeVerbosity type_verbosity(); + +TORCH_API std::ostream& operator<<(std::ostream& out, const Type& t); +template +TORCH_API std::ostream& operator<<( + std::ostream& out, + const VaryingShape& t); +TORCH_API std::ostream& operator<<(std::ostream& os, const SymbolicShape& s); +TORCH_API std::ostream& operator<<(std::ostream& os, const ShapeSymbol& s); +TORCH_API std::ostream& operator<<(std::ostream& os, const Stride& s); +// what is the type, ignoring extra size/shape information? +// e.g. Tensor(2x3) -> Dynamic, and Tuple(Tensor(2x3),...) -> Tuple(Dynamic,...) + +// `unshapedType` is used to remove Tensor subtypes. We treat all Tensor +// subtypes as simply "Tensor"; we also create a new version of any +// container types in which internal Tensors have undergone the same +// operation. This is used for type comparisons between two Tensor types +// (`unshapedType` means that we don't falsely return `false` for e.g. +// Tensors of different dimensions). It's also used in the alias +// analysis pass. +// Be careful with calls because this can be very slow. If calling this +// on a graph, use `EraseShapeInformation` in shape_analysis.h +inline TypePtr unshapedType(const TypePtr& type) { + if (type->isSubtypeOf(*TensorType::get())) { + return TensorType::get(); + } + at::ArrayRef contained = type->containedTypes(); + if (contained.empty()) { + return type; + } + return type->withContained(fmap(type->containedTypes(), unshapedType)); +} + +inline TypePtr TensorType::fromNumberType(const Type& typ) { + if (typ.isSubtypeOf(*IntType::get())) { + return TensorType::createContiguous(at::kLong, at::kCPU, {}); + } else if (typ.isSubtypeOf(*FloatType::get())) { + return TensorType::createContiguous(at::kDouble, at::kCPU, {}); + } else if (typ.isSubtypeOf(*BoolType::get())) { + return TensorType::createContiguous(at::kBool, at::kCPU, {}); + } else if (typ.kind() == NumberType::Kind) { + return TensorType::create(std::nullopt, at::kCPU, {}, std::nullopt); + } + TORCH_CHECK(false, "Unknown number type: ", typ.str()); +} +inline TypePtr TensorType::fromBoolType() { + return TensorType::createContiguous(at::kBool, at::kCPU, {}); +} + +inline std::optional tryScalarTypeFromJitType(const Type& type) { + if (type == *FloatType::get()) { + return at::typeMetaToScalarType(c10::get_default_dtype()); + } else if (type == *IntType::get()) { + return at::ScalarType::Long; + } else if (type == *BoolType::get()) { + return at::ScalarType::Bool; + } + return std::nullopt; +} + +inline at::ScalarType scalarTypeFromJitType(const Type& type) { + auto result = tryScalarTypeFromJitType(type); + TORCH_CHECK( + result, + "Add new condition, expected Float, Complex, Int, or Bool but got", + type.str()); + return *result; +} + +// Attempt to find the correct supertype of the two types `t1` and `t2`. +// If no supertype is found, then nullopt will be returned if +// `default_to_union` is false, and `Union[t1, t2]` will be returned +// if it is true. If `t1 == t2`, or `t1` is a type refinement of `t2`, +// then `t2` will be returned (and vice versa). +// +// Two different tensortypes will return dynamic. +// +// Currently we chose not to support returning a NumberType for +// two types from the set of {FloatType, IntType, ComplexType}, because +// there is a lack of operator support for NumberType. +// +// If `type_hint` is an `InterfaceType`, then we can use that as a +// potential supertype for `ClassType`s in the list. Otherwise, we have +// no way to find and use some common interface type +TORCH_API std::optional unifyTypes( + const TypePtr& t1, + const TypePtr& t2, + bool default_to_union = false, + const TypePtr& type_hint = nullptr); + +TORCH_API std::optional unifyTypeList( + at::ArrayRef elements, + std::ostream& why_not, + bool default_to_union = false, + const TypePtr& type_hint = nullptr); + +namespace detail { +template +struct getTypePtr_ final { + static decltype(auto) call() { + return ([]() { + try { + return getCustomClassType(); + } catch(const c10::Error&) { + TORCH_CHECK( + false, + "Type ", + c10::util::get_fully_qualified_type_name(), + " could not be converted to any of the known types." + ); + } + }()); + } +}; + +template +struct getMaybeFakeTypePtr_ final { + static decltype(auto) call() { + return getTypePtr_::call(); + } +}; + +template <> +struct getTypePtr_ final { + static decltype(auto) call() { + return AnyType::get(); + } +}; + +template <> +struct getTypePtr_ final { + static decltype(auto) call() { + return TensorType::get(); + } +}; +template <> +struct getTypePtr_ final { + static decltype(auto) call() { + return StorageType::get(); + } +}; +template <> +struct getTypePtr_ final { + static decltype(auto) call() { + return StreamObjType::get(); + } +}; +template <> +struct getTypePtr_ final { + static decltype(auto) call() { + return FloatType::get(); + } +}; +template <> +struct getTypePtr_> final { + static decltype(auto) call() { + return ComplexType::get(); + } +}; +template <> +struct getTypePtr_ final { + static decltype(auto) call() { + return IntType::get(); + } +}; + +template <> +struct getTypePtr_ final { + static decltype(auto) call() { + return IntType::get(); + } +}; + +template <> +struct getMaybeFakeTypePtr_ final { + static decltype(auto) call() { + return SymIntType::get(); + } +}; +template <> +struct getMaybeFakeTypePtr_ final { + static decltype(auto) call() { + return IntType::get(); + } +}; + +template <> +struct getMaybeFakeTypePtr_ final { + static decltype(auto) call() { + return SymFloatType::get(); + } +}; +template <> +struct getMaybeFakeTypePtr_ final { + static decltype(auto) call() { + return FloatType::get(); + } +}; + +template <> +struct getMaybeFakeTypePtr_ final { + static decltype(auto) call() { + return SymBoolType::get(); + } +}; +template <> +struct getMaybeFakeTypePtr_ final { + static decltype(auto) call() { + return BoolType::get(); + } +}; + +template <> +struct getTypePtr_ final { + static decltype(auto) call() { + return DeviceObjType::get(); + } +}; +template <> +struct getTypePtr_ final { + static decltype(auto) call() { + return BoolType::get(); + } +}; +template <> +struct getTypePtr_ final { + static decltype(auto) call() { + return NumberType::get(); + } +}; +template <> +struct getTypePtr_ final { + static decltype(auto) call() { + return QSchemeType::get(); + } +}; +template <> +struct getTypePtr_ final { + static decltype(auto) call() { + return TypeFactory::create( + TypeFactory::get()); + } +}; +template <> +struct getTypePtr_ final { + static decltype(auto) call() { + return StringType::get(); + } +}; +template <> +struct getTypePtr_ final { + static decltype(auto) call() { + return StringType::get(); + } +}; +template <> +struct getTypePtr_ final { + static decltype(auto) call() { + return StringType::get(); + } +}; +template +struct getMaybeFakeTypePtr_, fake> final { + static const auto& call() { + static auto inner_type = getMaybeFakeTypePtr_::call(); + // The "per vector" static singleton needs to live in a .cpp file, + // otherwise we'll end up with one singleton instance per shared library. + static auto type = ListType::get("vector", inner_type); + return type; + } +}; +template +struct getMaybeFakeTypePtr_, fake> final { + static const auto& call() { + static auto inner_type = getMaybeFakeTypePtr_::call(); + // The "per ArrayRef" static singleton needs to live in a .cpp file, + // otherwise we'll end up with one singleton instance per shared library. + static auto type = ListType::get("ArrayRef", inner_type); + return type; + } +}; +template +struct getMaybeFakeTypePtr_ final { + static const auto& call() { + static auto type = ListType::create(getMaybeFakeTypePtr_::call()); + return type; + } +}; +template +struct getMaybeFakeTypePtr_, fake> final { + static const auto& call() { + static auto inner_type = getMaybeFakeTypePtr_::call(); + // The "per List" static singleton needs to live in a .cpp file, + // otherwise we'll end up with one singleton instance per shared library. + static auto type = ListType::get("List", inner_type); + return type; + } +}; +template +struct getMaybeFakeTypePtr_, fake> final { + static const auto& call() { + static auto inner_type = getMaybeFakeTypePtr_::call(); + static auto type = ListType::get("List", inner_type); + return type; + } +}; +template +struct getMaybeFakeTypePtr_, fake> final { + static const auto& call() { + static auto inner_type = getMaybeFakeTypePtr_::call(); + // The "per array" static singleton needs to live in a .cpp file, + // otherwise we'll end up with one singleton instance per shared library. + // (Concatenating the length onto the end of the string because we want a unique + // type_ptr created for every std::array type). + static auto type = ListType::get(std::string("array") + std::to_string(N), inner_type); + return type; + } +}; +template +struct getMaybeFakeTypePtr_, fake> final { + static const auto& call() { + static auto inner_key_type = getMaybeFakeTypePtr_::call(); + static auto inner_val_type = getMaybeFakeTypePtr_::call(); + // The "per unordered_map" static singleton needs to live in a .cpp file, + // otherwise we'll end up with one singleton instance per shared library. + static auto type = DictType::get("unordered_map", inner_key_type, inner_val_type); + return type; + } +}; +template +struct getMaybeFakeTypePtr_, fake> final { + static const auto& call() { + static auto inner_key_type = getMaybeFakeTypePtr_::call(); + static auto inner_val_type = getMaybeFakeTypePtr_::call(); + // The "per Dict" static singleton needs to live in a .cpp file, + // otherwise we'll end up with one singleton instance per shared library. + static auto type = DictType::get("Dict", inner_key_type, inner_val_type); + return type; + } +}; + +template +struct getMaybeFakeTypePtr_, fake> final { + static const auto& call() { + static auto inner_type = getMaybeFakeTypePtr_::call(); + // The "per std::optional" static singleton needs to live in a .cpp file, + // otherwise we'll end up with one singleton instance per shared library. + static auto type = OptionalType::get(inner_type); + return type; + } +}; + + +template<> +struct getTypePtr_ final { + static const auto& call() { + static auto inner_type = getMaybeFakeTypePtr_::call(); + // The "per std::optional" static singleton needs to live in a .cpp file, + // otherwise we'll end up with one singleton instance per shared library. + static auto type = OptionalType::get(inner_type); + return type; + } +}; + +template +struct getMaybeFakeTypePtr_ final { + static const auto& call() { + // The "per std::optional" static singleton needs to live in a .cpp file, + // otherwise we'll end up with one singleton instance per shared library. + static auto inner_type = getMaybeFakeTypePtr_::call(); + static auto type = OptionalType::get(inner_type); + return type; + } +}; + +template +struct getMaybeFakeTypePtr_, fake> final { + static const auto& call() { + static auto type = ([]() { + std::vector contained_types = { + (getMaybeFakeTypePtr_::call())... + }; + return TupleType::create(std::move(contained_types)); + })(); + return type; + } +}; +template <> +struct getTypePtr_ final { + static decltype(auto) call() { + return NoneType::get(); + } +}; +} // namespace detail +template +inline decltype(auto) getTypePtr() { + // TODO: static_assert that a templated function exists, and throw a friendly + // error message if not + return detail::getMaybeFakeTypePtr_::call(); +} + +template +inline TypePtr getTypePtrCopy() { + // TODO: static_assert that a templated function exists, and throw a friendly + // error message if not + return getTypePtr(); +} + +template +inline decltype(auto) getFakeTypePtr() { + return detail::getMaybeFakeTypePtr_::call(); +} + +template +inline TypePtr getFakeTypePtrCopy() { + return getFakeTypePtr(); +} + +using TypeEnv = std::unordered_map; +struct MatchTypeReturn { + MatchTypeReturn(std::string reason) : reason_(std::move(reason)) {} + static MatchTypeReturn Success() { + return MatchTypeReturn(); + } + bool success() const { + return !reason_.has_value(); + } + const std::string& reason() const { + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + return reason_.value(); + } + + private: + MatchTypeReturn() + : reason_(std::nullopt) {} + std::optional reason_; // is there is no match, this contains the reason +}; + +// attempt to match the type variables in formal to actual, adding them to type_env. +// If no match is possible this returns a MatchTypeReturn with r.success() == false +// and a r.reason() that describes why it could not match. +// note: It is possible to successfully match a formal, but for type variables +// in the formal to still not be defined. In particular, None matches Optional[T] +// but does not define the value of T. +TORCH_API MatchTypeReturn +matchTypeVariables(const TypePtr& formal, const TypePtr& actual, TypeEnv& type_env); + +// replace type variables appearing in `type` with the values in +// `type_env`. Returns nullptr if a variable used in `type` +// does not appear in `type_env` +TORCH_API TypePtr tryEvalTypeVariables(const TypePtr& type, TypeEnv& type_env); + +TORCH_API bool elementTypeCanBeInferredFromMembers(const TypePtr& elem_type); + +struct InterfaceType; +using InterfaceTypePtr = std::shared_ptr; + +// Interfaces are a list of abstract methods that a class might meet. +// If a class provides those methods, it implicitly meets the interface. + +// Subtype relations for Interface with ClassType: +// lhs (ClassType or InterfaceType) is a subtype of rhs if: +// 1. lhs methods are a superset of rhs methods +// 2. if rhs is module interface, the lhs must be module interface or module itself +struct TORCH_API InterfaceType : public NamedType { + static InterfaceTypePtr create( + QualifiedName qualifiedName, bool is_module=false); + + bool equals(const Type& rhs) const override { + if (auto user_rhs = rhs.castRaw()) { + return isSubTypeImpl(*this, *user_rhs, nullptr) && + isSubTypeImpl(*user_rhs, *this, nullptr); + } + return false; + } + + std::string str() const override { + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + return std::string("InterfaceType<") + name()->name() + ">"; + } + + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override; + + // try to find a method of this interface, + // returns nullptr if not found. + const FunctionSchema* getMethod(const std::string& name) const; + void addMethod(FunctionSchema schema); + const std::vector& methods() const { + return *methods_; + } + + bool is_module() const override{ + return is_module_; + } + static const TypeKind Kind = TypeKind::InterfaceType; + ~InterfaceType() override = default; + private: + InterfaceType(QualifiedName name, bool is_module); + static bool isSubTypeImpl( + const InterfaceType& lhs, + const InterfaceType& rhs, + std::ostream* why_not); + + std::string annotation_str_impl( + [[maybe_unused]] const TypePrinter& printer = nullptr) const override { + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + return name()->qualifiedName(); + } + + // shared_ptr so that this header does not have to depend on + // FunctionSchema.h + std::shared_ptr> methods_; + // flag to distinguish if it's an interface type from a module or not + bool is_module_; +}; + +template +struct EnumerationType : public Type { +static const TypeKind Kind = K; + +bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); +} + +protected: +EnumerationType() : Type(Kind) {} +}; + +// WARNING: These enumeration types below DO NOT actually get parsed out +// from the logical schema strings, instead they are mapped as ints. To +// observe these types, use real_type() instead of type() on Argument + +struct ScalarTypeType; +using ScalarTypeTypePtr = SingletonTypePtr; +struct TORCH_API ScalarTypeType : public EnumerationType { +std::string str() const override { +return "ScalarType"; +} +static const TypeKind Kind = TypeKind::ScalarTypeType; +// global singleton +static ScalarTypeTypePtr get(); + +private: +ScalarTypeType() {} +}; + +struct MemoryFormatType; +using MemoryFormatTypePtr = SingletonTypePtr; +struct TORCH_API MemoryFormatType : public EnumerationType { +std::string str() const override { +return "MemoryFormat"; +} +static const TypeKind Kind = TypeKind::MemoryFormatType; +// global singleton +static MemoryFormatTypePtr get(); + +private: +MemoryFormatType() {} +}; + +struct LayoutType; +using LayoutTypePtr = SingletonTypePtr; +struct TORCH_API LayoutType : public EnumerationType { +std::string str() const override { +return "Layout"; +} +static const TypeKind Kind = TypeKind::LayoutType; +// global singleton +static LayoutTypePtr get(); + +private: +LayoutType() {} +}; + +namespace detail { +template <> +struct getMaybeFakeTypePtr_ final { + static decltype(auto) call() { + return ScalarTypeType::get(); + } +}; +template <> +struct getMaybeFakeTypePtr_ final { + static decltype(auto) call() { + return LayoutType::get(); + } +}; +template <> +struct getMaybeFakeTypePtr_ final { + static decltype(auto) call() { + return MemoryFormatType::get(); + } +}; +template <> +struct getMaybeFakeTypePtr_ final { + static decltype(auto) call() { + return IntType::get(); + } +}; +template <> +struct getMaybeFakeTypePtr_ final { + static decltype(auto) call() { + return IntType::get(); + } +}; +template <> +struct getMaybeFakeTypePtr_ final { + static decltype(auto) call() { + return IntType::get(); + } +}; +} // namespace detail + +// the common supertype of all lists, +// List[T] <: AnyList for all T +struct AnyListType; +using AnyListTypePtr = SingletonTypePtr; +struct TORCH_API AnyListType : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "list"; + } + static const TypeKind Kind = TypeKind::AnyListType; + // global singleton + static AnyListTypePtr get(); +private: + AnyListType() + : Type(TypeKind::AnyListType) {} +}; + +// the common supertype of all tuples, +// Tuple[T...] <: AnyTuple for all T +struct AnyTupleType; +using AnyTupleTypePtr = SingletonTypePtr; +struct TORCH_API AnyTupleType : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + + std::string str() const override { + return "tuple"; + } + static const TypeKind Kind = TypeKind::AnyTupleType; + + // global singleton + static AnyTupleTypePtr get(); +private: + AnyTupleType() + : Type(TypeKind::AnyTupleType) {} +}; + +// the common supertype of all classes, +// ClassType <: AnyClassType for all classes +struct AnyClassType; +using AnyClassTypePtr = SingletonTypePtr; +struct TORCH_API AnyClassType : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "AnyClassType"; + } + static const TypeKind Kind = TypeKind::AnyClassType; + // global singleton + static AnyClassTypePtr get(); +private: + AnyClassType() + : Type(TypeKind::AnyClassType) {} +}; + +template<> +inline typename detail::CastReturnType::type Type::cast() { + if (kind() == TypeKind::TupleType || kind() == TypeKind::FunctionType || + kind() == TypeKind::ClassType || kind() == TypeKind::InterfaceType) { + return std::static_pointer_cast(static_cast(this)->shared_from_this()); + } + return nullptr; +} + +template<> +inline typename detail::CastConstReturnType::type Type::cast() const { + if (kind() == TypeKind::TupleType || kind() == TypeKind::FunctionType || + kind() == TypeKind::ClassType || kind() == TypeKind::InterfaceType) { + return std::static_pointer_cast(static_cast(this)->shared_from_this()); + } + return nullptr; +} + +template<> +inline const NamedType* Type::castRaw() const { + if (kind() == TypeKind::TupleType || kind() == TypeKind::FunctionType || + kind() == TypeKind::ClassType || kind() == TypeKind::InterfaceType) { + return static_cast(this); + } + return nullptr; +} + +// Used as a return type when inferring the IValue type of a Python object. +struct InferredType { + /* implicit */ InferredType(TypePtr type) : type_(std::move(type)) {} + /* implicit */ InferredType(std::string reason) + : type_(nullptr), reason_(std::move(reason)) {} + TypePtr type() const { + TORCH_INTERNAL_ASSERT( + type_, + "Tried to get the type from an InferredType but the type is null. ", + "Reason: ", + reason_); + return type_; + } + bool success() const { + return type_ != nullptr; + } + const std::string& reason() const { + TORCH_INTERNAL_ASSERT(!type_); + return reason_; + } + +private: + TypePtr type_; + std::string reason_; +}; + +TORCH_API bool containsAnyType(const TypePtr& type); + +} // namespace c10 diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/jit_type_base.h b/phivenv/Lib/site-packages/torch/include/ATen/core/jit_type_base.h new file mode 100644 index 0000000000000000000000000000000000000000..0a64b3838098da6f352de9cdc4ca3bd9b0a07579 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/jit_type_base.h @@ -0,0 +1,721 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +#define C10_FORALL_TYPES(_) \ + _(AnyType) \ + _(EnumType) \ + _(AnyEnumType) \ + _(TensorType) \ + _(StorageType) \ + _(TupleType) \ + _(ListType) \ + _(DictType) \ + _(NumberType) \ + _(FloatType) \ + _(ComplexType) \ + _(FutureType) \ + _(AwaitType) \ + _(RRefType) \ + _(IntType) \ + _(NoneType) \ + _(StringType) \ + _(GeneratorType) \ + _(QuantizerType) \ + _(BoolType) \ + _(OptionalType) \ + _(VarType) \ + _(DeviceObjType) \ + _(StreamObjType) \ + _(FunctionType) \ + _(ClassType) \ + _(PyObjectType) \ + _(CapsuleType) \ + _(InterfaceType) \ + _(QSchemeType) \ + _(ScalarTypeType) \ + _(LayoutType) \ + _(MemoryFormatType) \ + _(AnyListType) \ + _(AnyTupleType) \ + _(AnyClassType) \ + _(SymIntType) \ + _(SymFloatType) \ + _(SymBoolType) \ + _(UnionType) \ + _(DynamicType) + +enum class TypeKind { +#define DEFINE_TYPE(T) T, + C10_FORALL_TYPES(DEFINE_TYPE) +#undef DEFINE_TYPE +}; + +TORCH_API const char* typeKindToString(TypeKind kind); + +struct Type; +struct SharedType; + +// Use this to customize how a Type is printed using `annotation_str()`. If +// std::nullopt is returned, `annotation_str()` falls through to its default +// implementation. +using TypePrinter = std::function(const Type&)>; + +namespace detail { +template +struct IsSingletonType : public std::integral_constant {}; +} // namespace detail +#define TORCH_DECLARE_SINGLETON(Type) \ + struct Type; \ + namespace detail { \ + template <> struct IsSingletonType : public std::integral_constant {}; \ + } + +TORCH_DECLARE_SINGLETON(AnyType) +TORCH_DECLARE_SINGLETON(AnyEnumType) +TORCH_DECLARE_SINGLETON(NumberType) +TORCH_DECLARE_SINGLETON(FloatType) +TORCH_DECLARE_SINGLETON(ComplexType) +TORCH_DECLARE_SINGLETON(IntType) +TORCH_DECLARE_SINGLETON(BoolType) +TORCH_DECLARE_SINGLETON(StringType) +TORCH_DECLARE_SINGLETON(StorageType) +TORCH_DECLARE_SINGLETON(NoneType) +TORCH_DECLARE_SINGLETON(GeneratorType) +TORCH_DECLARE_SINGLETON(QuantizerType) +TORCH_DECLARE_SINGLETON(QSchemeType) +TORCH_DECLARE_SINGLETON(DeviceObjType) +TORCH_DECLARE_SINGLETON(StreamObjType) +TORCH_DECLARE_SINGLETON(CapsuleType) +TORCH_DECLARE_SINGLETON(PyObjectType) +TORCH_DECLARE_SINGLETON(ScalarTypeType) +TORCH_DECLARE_SINGLETON(LayoutType) +TORCH_DECLARE_SINGLETON(MemoryFormatType) +TORCH_DECLARE_SINGLETON(AnyListType) +TORCH_DECLARE_SINGLETON(AnyTupleType) +TORCH_DECLARE_SINGLETON(AnyClassType) + +namespace detail { +template +struct CastReturnType { + using type = std::shared_ptr; +}; + +template +struct CastReturnType::value>> { + using type = SingletonTypePtr; +}; + +template +struct CastConstReturnType { + using type = std::shared_ptr; +}; + +template +struct CastConstReturnType::value>> { + using type = SingletonTypePtr; +}; + +template +struct as_shared_type { + using type = SharedType*; +}; + +template +struct as_shared_type { + using type = const SharedType *; +}; +} // namespace detail + +struct TORCH_API Type { + friend TORCH_API bool operator==(const Type& lhs, const Type& rhs); + private: + TypeKind kind_; + + protected: + Type(TypeKind kind) : kind_(kind) {} + + Type(const Type&) = default; + Type& operator=(const Type&) = default; + Type(Type&&) noexcept = default; + Type& operator=(Type&&) noexcept = default; + + virtual std::string annotation_str_impl(const TypePrinter& /*printer*/) const { + return str(); + } + // a == b + virtual bool equals(const Type& rhs) const = 0; + // a == b <=> b == a + virtual bool symmetric() const { + return true; + } + + public: + template + class SingletonOrSharedTypePtr { + public: + using element_type = typename std::shared_ptr::element_type; + + SingletonOrSharedTypePtr() = default; + + /* implicit */ SingletonOrSharedTypePtr(std::shared_ptr x) + : repr_(std::move(x)) {} + + template , bool> = true> + /* implicit */ SingletonOrSharedTypePtr(std::shared_ptr x) + : repr_(std::move(x)) {} + + /* implicit */ SingletonOrSharedTypePtr(std::nullptr_t) + : repr_(nullptr) {} + + /* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr p) + : repr_(p) {} + + template , bool> = true> + /* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr p) + : repr_(SingletonTypePtr(p.get())) {} + + + // We need to support construction from T* for pybind. The problem + // is that it's not clear if we are supposed to be taking shared + // ownership or not. + // + // Case 1: if T is known statically to derive from SharedType, we should use + // shared_from_this() and take shared_ownership. + // + // Case 2: if T is exactly Type, we need to do a dynamic_cast to + // check if it's a SharedType and do the right thing. + // + // Case 3: Otherwise, T is not a SharedType. (debug-check this + // assumption!) Use a singleton pointer. + + template , bool> = true> + /* implicit */ SingletonOrSharedTypePtr(T* p) : SingletonOrSharedTypePtr(static_cast::type>(p)->shared_from_this()) {} + + template , bool> = true> + /* implicit */ SingletonOrSharedTypePtr(T* p) { + if (auto* shared_p = dynamic_cast::type>(p)) { + repr_ = Repr(shared_p->shared_from_this()); + } else { + repr_ = Repr(p); + } + } + + template && !std::is_base_of_v, bool> = true> + /* implicit */ SingletonOrSharedTypePtr(T* p) + : repr_(p) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dynamic_cast::type>(p) == nullptr); + } + + SingletonOrSharedTypePtr(const SingletonOrSharedTypePtr&) = default; + SingletonOrSharedTypePtr(SingletonOrSharedTypePtr&&) noexcept = default; + SingletonOrSharedTypePtr& operator=(const SingletonOrSharedTypePtr&) = default; + SingletonOrSharedTypePtr& operator=(SingletonOrSharedTypePtr&&) noexcept = default; + ~SingletonOrSharedTypePtr() = default; + + T* get() const { + return repr_.isSharedAndNonNull() ? repr_.shared_.repr_.get() : static_cast(repr_.rawRepr().first); + } + + operator bool() const { + return repr_.isNonNull(); + } + + bool operator==(std::nullptr_t) const { + return !repr_.isNonNull(); + } + + bool operator!=(std::nullptr_t) const { + return repr_.isNonNull(); + } + + template , void>, bool> = true> + U& operator*() const { + return *get(); + } + + T* operator->() const { + return get(); + } + + private: + // NOTE: SharedPtrWrapper exists to work around a baffling bug in + // nvcc; see comment in destroy() below. + struct SharedPtrWrapper { + SharedPtrWrapper(std::shared_ptr &&x) + : repr_(std::move(x)) {} + std::shared_ptr repr_; + }; + union Repr { + Repr() : Repr(nullptr) {} + + explicit Repr(std::shared_ptr x) + : shared_(std::move(x)) {} + + explicit Repr(std::nullptr_t) + : singletonRepr_(nullptr) {} + + explicit Repr(SingletonTypePtr p) + : singletonRepr_(p.get()) {} + + ~Repr() { + destroy(); + } + + // NOTE: the only non-UB way to access our null state is through + // rawRepr(), because our copy operation doesn't preserve which + // union member is active for null pointers. + Repr(const Repr& rhs) { + if (rhs.isSharedAndNonNull()) { + new (&shared_) SharedPtrWrapper(rhs.shared_); + } else { + singletonRepr_.singleton_ = static_cast(rhs.rawRepr().first); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.singletonRepr_.unused_ == nullptr); + singletonRepr_.unused_ = nullptr; + } + } + + Repr(Repr&& rhs) noexcept { + if (rhs.isSharedAndNonNull()) { + new (&shared_) SharedPtrWrapper(std::move(rhs.shared_)); + } else { + singletonRepr_.singleton_ = static_cast(rhs.rawRepr().first); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.singletonRepr_.unused_ == nullptr); + singletonRepr_.unused_ = nullptr; + } + } + + Repr& operator=(const Repr& rhs) { + if (&rhs == this) { + return *this; + } + if (rhs.isSharedAndNonNull()) { + if (isSharedAndNonNull()) { + shared_ = rhs.shared_; + } else { + new (&shared_) SharedPtrWrapper(rhs.shared_); + } + } else { + if (isSharedAndNonNull()) { + destroy(); + } + singletonRepr_.singleton_ = static_cast(rhs.rawRepr().first); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.rawRepr().nullIfSingleton_ == nullptr); + singletonRepr_.unused_ = nullptr; + } + return *this; + } + + Repr& operator=(Repr&& rhs) noexcept { + if (&rhs == this) { + return *this; + } + if (rhs.isSharedAndNonNull()) { + if (isSharedAndNonNull()) { + shared_ = std::move(rhs.shared_); + } else { + new (&shared_) SharedPtrWrapper(std::move(rhs.shared_)); + } + } else { + if (isSharedAndNonNull()) { + destroy(); + } + singletonRepr_.singleton_ = static_cast(rhs.rawRepr().first); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.rawRepr().nullIfSingleton_ == nullptr); + singletonRepr_.unused_ = nullptr; + } + return *this; + } + + SharedPtrWrapper shared_; + + struct SingletonRepr { + explicit SingletonRepr(T* s) : singleton_(s) {} + T* singleton_; + void* unused_ = nullptr; + } singletonRepr_; + struct RawRepr { + void* first; + void* nullIfSingleton_; + }; + + // It is UB to read the singleton part of Repr if it was + // constructed as a shared_ptr and vice versa, but memcpying out + // the representation is always OK, so here's an accessor to obey + // the letter of the law. + RawRepr rawRepr() const { + RawRepr repr{}; + memcpy(&repr, reinterpret_cast(this), sizeof(RawRepr)); + return repr; + } + + bool isNonNull() const { + auto repr = rawRepr(); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(repr.nullIfSingleton_ == nullptr || repr.first != nullptr); + return repr.first != nullptr; + } + + bool isSharedAndNonNull() const { + return rawRepr().nullIfSingleton_ != nullptr; + } + + private: + void destroy() { + if (isSharedAndNonNull()) { + // Without SharedPtrWrapper, this line would read + // `shared_.~shared_ptr()` and nvcc would complain with + // "error: expected primary-expression before '>' token" + // referring to the "t" in "shared_ptr". SharedPtrWrapper + // exists to work around this compiler bug. + shared_.~SharedPtrWrapper(); + } + } + } repr_; + }; + + using TypePtr = SingletonOrSharedTypePtr; + using Ptr = TypePtr; + using ElementType = Type; + + // subtyping relation. By default, we return true for the case + // when the type is exactly equal or if this <: T where rhs = Optional[T] + + // if this returns false and the why_not stream is non-null, it contains + // additional details that describe why this is not a subtype of 'rhs'. + // This additional information should only contain details that are not + // obvious from the annotation_str() that describes the type. For instance it + // is clear that `int <: str` is false but not clear why `Foo <: InterfaceBar` + // might be false. + virtual bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const; + virtual bool is_module() const; + bool isSubtypeOf(const Type& rhs) const { + return isSubtypeOfExt(rhs, nullptr); + } + // Compatibility shims to accommodate existing code that passes shared_ptrs + // around. Ideally, we would just delete this, but it should be harmless. + template + std::enable_if_t, bool> + isSubtypeOf(const std::shared_ptr& rhs) const { + return isSubtypeOf(*rhs); + } + + template + std::enable_if_t, bool> + isSubtypeOf(const SingletonOrSharedTypePtr& rhs) const { + return isSubtypeOf(*rhs); + } + + template + std::enable_if_t, bool> + isSubtypeOf(SingletonTypePtr rhs) const { + return isSubtypeOf(*rhs); + } + + template + std::enable_if_t, bool> + isSubtypeOfExt(const SingletonOrSharedTypePtr& rhs, std::ostream* why_not) const { + return isSubtypeOfExt(*rhs, why_not); + } + + template + std::enable_if_t, bool> + isSubtypeOfExt(const std::shared_ptr& rhs, std::ostream* why_not) const { + return isSubtypeOfExt(*rhs, why_not); + } + + template + std::enable_if_t, bool> + isSubtypeOfExt(SingletonTypePtr rhs, std::ostream* why_not) const { + return isSubtypeOfExt(*rhs, why_not); + } + + // How this type will appear in FunctionSchema declarations + virtual std::string str() const = 0; + + // How this type will appear as if it were a type annotation in Python + // which is sometimes different than how it appears in declarations (e.g. + // int[] vs List[int]) + // + // Takes a custom printer that users can pass in to customize the output of + // this method. + std::string annotation_str(const TypePrinter& printer) const { + if (printer) { + // the printer can return std::nullopt to fall through to the default impl + if (auto renamed = printer(*this)) { + return *renamed; + } + } + return annotation_str_impl(printer); + } + std::string annotation_str() const { + // Overload instead of define a default value for `printer` to help + // debuggers out. + return annotation_str(nullptr); + } + + // Returns a human readable string that includes additional information like + // "type is inferred rather than explicitly defined" to help construct more + // user-friendly messages. + virtual std::string repr_str() const { + return annotation_str(); + } + + TypeKind kind() const { + return kind_; + } + + virtual bool isUnionType() const { + return false; + } + + virtual bool requires_grad() const { + for (const auto& ct : containedTypes()) { + if (ct->requires_grad()) { + return true; + } + } + return false; + } + + // Dynamically cast this object to the subclass indicated by the + // template variable, returning nullptr if the cast is invalid. + template ::value, bool> = true> + typename detail::CastReturnType::type cast() { + if (T::Kind == kind()) { + return std::static_pointer_cast(static_cast(this)->shared_from_this()); + } + return nullptr; + } + template ::value, bool> = true> + typename detail::CastReturnType::type cast() { + if (T::Kind == kind()) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(this == T::get().get()); + return typename detail::CastReturnType::type(static_cast(this)); + } + return nullptr; + } + template ::value, bool> = true> + typename detail::CastConstReturnType::type cast() const { + if (T::Kind == kind()) { + return std::static_pointer_cast(static_cast(this)->shared_from_this()); + } + return nullptr; + } + template ::value, bool> = true> + typename detail::CastConstReturnType::type cast() const { + if (T::Kind == kind()) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(this == T::get().get()); + return typename detail::CastConstReturnType::type(static_cast(this)); + } + return nullptr; + } + template + T* castRaw() { + if (T::Kind == kind()) { + return static_cast(this); + } + return nullptr; + } + template + const T* castRaw() const { + if (T::Kind == kind()) { + return static_cast(this); + } + return nullptr; + } + template + auto expect() { + auto r = cast(); + AT_ASSERT(r); + return r; + } + template + auto expect() const { + auto r = cast(); + AT_ASSERT(r); + return r; + } + template + T& expectRef() { + auto* r = castRaw(); + AT_ASSERT(r); + return *r; + } + template + const T& expectRef() const { + auto* r = castRaw(); + AT_ASSERT(r); + return *r; + } + virtual ~Type() = default; + virtual bool hasFreeVariables() const { + return false; + } + // list of types this type contains, e.g. for a List then element type of a + // list for a tuple, the types of the tuple elements + virtual at::ArrayRef containedTypes() const { + return {}; + } + virtual TypePtr containedType(size_t i) const { + return containedTypes().at(i); + } + virtual size_t containedTypeSize() const { + return containedTypes().size(); + } + // create a new version of this type, replacing its contained types with + // contained_types + TypePtr withContained(std::vector contained_types); + // per-type constructor, you only need to override this if the + // containedTypes() is not empty + virtual TypePtr createWithContained( + // NOLINTNEXTLINE(performance-unnecessary-value-param) + std::vector /*contained_types*/) const { + TORCH_CHECK(false, + "type with contained types did not overload createWithContained: ", + str()); + } + +}; + +template +using SingletonOrSharedTypePtr = Type::SingletonOrSharedTypePtr; + + +template +bool operator==(const SingletonOrSharedTypePtr& x, const SingletonOrSharedTypePtr& y) { + return (void*)x.get() == (void*)y.get(); +} + +template +bool operator==(const SingletonOrSharedTypePtr& x, const std::shared_ptr& y) { + return (void*)x.get() == (void*)y.get(); +} + +template +bool operator==(const std::shared_ptr& x, const SingletonOrSharedTypePtr& y) { + return (void*)x.get() == (void*)y.get(); +} + +template +bool operator==(const SingletonOrSharedTypePtr& x, const SingletonTypePtr& y) { + return (void*)x.get() == (void*)y.get(); +} + +template +bool operator==(const SingletonTypePtr& x, const SingletonOrSharedTypePtr& y) { + return (void*)x.get() == (void*)y.get(); +} + +template +bool operator!=(const SingletonOrSharedTypePtr& x, const SingletonOrSharedTypePtr& y) { + return !(x == y); +} + +template +bool operator!=(const SingletonOrSharedTypePtr& x, const std::shared_ptr& y) { + return !(x == y); +} + +template +bool operator!=(const std::shared_ptr& x, const SingletonOrSharedTypePtr& y) { + return !(x == y); +} + +template +bool operator!=(const SingletonOrSharedTypePtr& x, const SingletonTypePtr& y) { + return !(x == y); +} + +template +bool operator!=(const SingletonTypePtr& x, const SingletonOrSharedTypePtr& y) { + return !(x == y); +} + +using TypePtr = SingletonOrSharedTypePtr; +using ConstTypePtr = SingletonOrSharedTypePtr; + +// Explicitly enable MaybeOwned>, rather than allowing +// MaybeOwned to be used for any type right away. +template +struct MaybeOwnedTraits> + : public MaybeOwnedTraitsGenericImpl> {}; + +// Base class for Types that are guaranteed to be owned by std::shared_ptr. +struct TORCH_API SharedType : public Type, public std::enable_shared_from_this { + using Type::Type; +}; + +inline TypePtr Type::withContained(std::vector contained_types) { + auto current_contained = containedTypes(); + // Types with no contained_types don't need this call. Check before calling! + // + // (We can't support this efficiently because types without + // contained types may be singletons, in which case + // shared_from_this will crash; we would have to provide a virtual + // typeptr_from_this or isSingleton.) + TORCH_INTERNAL_ASSERT(!current_contained.empty() && current_contained.size() == contained_types.size()); + if (current_contained.equals(contained_types)) { + return std::static_pointer_cast(static_cast(this)->shared_from_this()); + } + return createWithContained(std::move(contained_types)); +} + + +TORCH_API inline bool operator==(const Type& lhs, const Type& rhs) { + if (C10_UNLIKELY(!rhs.symmetric())) { + return rhs.equals(lhs); + } + return lhs.equals(rhs); +} + +struct NamedType; +using NamedTypePtr = std::shared_ptr; +using ConstNamedTypePtr = std::shared_ptr; + +struct TORCH_API NamedType : public SharedType { + NamedType(TypeKind tk, std::optional name) + : SharedType(tk), name_(std::move(name)) { + TORCH_INTERNAL_ASSERT( + tk == TypeKind::TupleType || tk == TypeKind::FunctionType || + tk == TypeKind::ClassType || tk == TypeKind::InterfaceType || + tk == TypeKind::EnumType, + "If you add a new kind of NamedType, ", + "please update the cast specialization and this assert"); + } + + // Fully qualified name of type + // Looks like: "foo.bar.Baz". + const std::optional& name() const { + return name_; + } + + private: + std::optional name_; +}; + +} // namespace c10 + +namespace std { +template +struct hash> { + size_t operator()(const c10::SingletonOrSharedTypePtr& x) const { + return std::hash()(x.get()); + } +}; +} // namespace std diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/op_registration/adaption.h b/phivenv/Lib/site-packages/torch/include/ATen/core/op_registration/adaption.h new file mode 100644 index 0000000000000000000000000000000000000000..c009d5fd3fc68006830223923e4775a391603ecc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/op_registration/adaption.h @@ -0,0 +1,81 @@ +#pragma once + +#include +#include +#include +#include + +/* + * [Note: hacky wrapper removal for optional tensor] + * + * The kernel implementation takes an optional tensor marked in the schema as + * Tensor? but the C++ function takes Tensor instead of the std::optional + * expected by the dispatcher. + * + * To remove the hacky wrapper, the C++ function is changed to take + * std::optional and unwrap the Tensor value at the beginning of + * the function, e.g.: + * > c10::MaybeOwned weight_maybe_owned = + * > at::borrow_from_optional_tensor(weight_opt); + * > const Tensor& weight = *weight_maybe_owned; + * + * We may want to make the kernel handle optional directly without + * going through the creation of a default-constructed Tensor in + * at::borrow_from_optional_tensor. + */ + +/* + * [Note: hacky wrapper removal for TensorOptions] + * + * The kernel implementation takes a TensorOptions argument but the dispatcher + * expects separate arguments for dtype, layout, device, pin_memory. + * + * To remove the hacky wrapper, the kernel implementation is changed to take + * the 4 arguments (dtype, layout, device, pin_memory), and assemble the + * TensorOptions value at the beginning of the function, e.g.: + * > TensorOptions options = TensorOptions().dtype(dtype).layout(layout) + * > .device(device).pinned_memory(pin_memory); + * + * We may want make the kernel handle these parameters directly without going + * through the creation of a TensorOptions value. + */ + +namespace c10::impl { + +TORCH_API void common_device_check_failure(Device common_device, const at::Tensor& tensor, at::CheckedFrom methodName, at::CheckedFrom argName); + +inline void check_and_update_common_device(std::optional& common_device, const at::Tensor& tensor, at::CheckedFrom methodName, at::CheckedFrom argName) { + // TODO: Remove this once the following issue is addressed: + // https://github.com/pytorch/pytorch/issues/57380 + if (!tensor.defined()) { + return; + } + + if (!common_device.has_value()) { + common_device = tensor.device(); + return; + } + + if (C10_UNLIKELY(common_device != tensor.device())) { + common_device_check_failure(*common_device, tensor, methodName, argName); + } +} + +inline void check_and_update_common_device(std::optional& common_device, const std::optional& tensor, at::CheckedFrom methodName, at::CheckedFrom argName) { + if (tensor.has_value()) { + check_and_update_common_device(common_device, tensor.value(), methodName, argName); + } +} + +inline void check_and_update_common_device(std::optional& common_device, at::ITensorListRef tensors, at::CheckedFrom methodName, at::CheckedFrom argName) { + for (const auto& tensor : tensors) { + check_and_update_common_device(common_device, tensor, methodName, argName); + } +} + +inline void check_and_update_common_device(std::optional& common_device, const List>& tensors, at::CheckedFrom methodName, at::CheckedFrom argName) { + for (const auto& tensor : tensors) { + check_and_update_common_device(common_device, tensor, methodName, argName); + } +} +} // namespace c10::impl diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/op_registration/infer_schema.h b/phivenv/Lib/site-packages/torch/include/ATen/core/op_registration/infer_schema.h new file mode 100644 index 0000000000000000000000000000000000000000..1e32706f2869f35943945482bd8afab72e312a8a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/op_registration/infer_schema.h @@ -0,0 +1,157 @@ +#pragma once + +/** + * This file contains functionality to take a C++ function and infer its + * c10::FunctionSchema. + */ + +#include +#include + +namespace c10 { +namespace detail::infer_schema { + +/// The templated inference code creates `ArgumentDef` instead of `Argument`, +/// because that can be constructed at compile time and has a much smaller +/// binary size than having calls to `Argument` constructors in the template. +/// Creating `Argument` objects from `ArgumentDef` can then be done at +/// runtime in a non-templated way. +struct ArgumentDef final { + using GetTypeFn = TypePtr(); + GetTypeFn* getTypeFn; + GetTypeFn* getFakeTypeFn; + constexpr ArgumentDef(): getTypeFn(nullptr), getFakeTypeFn(nullptr) {} + explicit constexpr ArgumentDef(GetTypeFn *getTypeFn, GetTypeFn *getFakeTypeFn): getTypeFn(getTypeFn), getFakeTypeFn(getFakeTypeFn) {} +}; + +template +struct bool_t {}; +template<> struct bool_t : std::true_type {}; +template<> struct bool_t : std::false_type {}; + +/// Checks the static C++ types `Types` for correctness to catch common error cases. +template +constexpr int checkStaticTypes() { + // Give nice error messages for some of the common error cases. + // Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT + static_assert(std::conjunction_v< + bool_t || std::is_same_v || std::is_same_v || std::is_same_v>... + >, "INVALID TYPE: Only int8_t, int64_t and bool are supported as an integral argument type"); + static_assert(std::conjunction_v< + bool_t>... + >, "INVALID TYPE: float is not supported as an argument type, use double instead"); + return 0; +} + +template +constexpr std::array createArgumentVectorFromTypes(std::index_sequence) { + return ( + // Check types for common errors + checkStaticTypes(), + + // Create the return value + std::array{ + ArgumentDef(&getTypePtrCopy>, &getFakeTypePtrCopy>)...} + ); +} + +/// Creates a vector of `ArgumentDef` from a list of C++ types that are specified +/// as template arguments. +template struct createArguments final {}; +template +struct createArguments> final { + static constexpr std::array call() { + return createArgumentVectorFromTypes( + std::make_index_sequence() + ); + } +}; + +/// Creates a vector of `ArgumentDef` from a list of C++ types that are specified +/// as a tuple (i.e. in the way c10 kernels return values). +/// It can be a tuple if there's three output arguments with types A, B, C. +/// It can be an empty tuple<>, or void for kernels that don't return anything. +/// It can be a single type A (i.e. no tuple) for the case where a kernel just +/// returns one value. +template struct createReturns final {}; + +template +struct createReturns, void> final { + static constexpr std::array call() { + return createArgumentVectorFromTypes( + std::make_index_sequence() + ); + } +}; + +template +struct createReturns && !guts::is_instantiation_of::value>> final { + static constexpr std::array call() { + return createReturns>::call(); + } +}; + +template<> +struct createReturns final { + static constexpr std::array call() { + return createReturns>::call(); + } +}; + +template +struct createSingleReturn { + static constexpr std::array call() { + return createArgumentVectorFromTypes(std::make_index_sequence<1>()); + } +}; + +TORCH_API FunctionSchema make_function_schema(std::string&& name, std::string&& overload_name, c10::ArrayRef arguments, c10::ArrayRef returns); +TORCH_API FunctionSchema make_function_schema(c10::ArrayRef arguments, c10::ArrayRef returns); + +/// Creates a `FunctionSchema` object from a `FunctionTraits` type for a +/// function. Flattens std::tuple returns into multiple return types +template +FunctionSchema createFunctionSchemaFromTraitsFlattenedReturns() { + using ReturnType = typename FunctionTraits::return_type; + using ParameterTypes = typename FunctionTraits::parameter_types; + + // arguments and returns are computed into a std::array at compile time and embedded into the binary. + // The only code executed at runtime here is the one that creates a std::vector + // of the arguments/returns from the std::array. + constexpr auto arguments = createArguments::call(); + constexpr auto returns = createReturns::call(); + + return make_function_schema(arguments, returns); +} + +/// Creates a `FunctionSchema` object from a `FunctionTraits` type for a +/// function. Preserves std::tuple returns as a Tuple return type +template +FunctionSchema createFunctionSchemaFromTraitsSingleReturn(std::string&& name, std::string&& overload_name) { + using ReturnType = typename FunctionTraits::return_type; + using ParameterTypes = typename FunctionTraits::parameter_types; + + // arguments and returns are computed into a std::array at compile time and embedded into the binary. + // The only code executed at runtime here is the one that creates a std::vector + // of the arguments/returns from the std::array. + constexpr auto arguments = createArguments::call(); + constexpr auto returns = createSingleReturn::call(); + + return make_function_schema(std::move(name), std::move(overload_name), arguments, returns); +} + +} + +template +FunctionSchema inferFunctionSchemaFlattenedReturns() { + return detail::infer_schema::createFunctionSchemaFromTraitsFlattenedReturns>(); +} + +template +FunctionSchema inferFunctionSchemaSingleReturn(std::string&& name, std::string&& overload_name) { + return detail::infer_schema::createFunctionSchemaFromTraitsSingleReturn>(std::move(name), std::move(overload_name)); +} + +TORCH_API std::optional findSchemaDifferences(const FunctionSchema& inferred, const FunctionSchema& specified); + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/op_registration/op_allowlist.h b/phivenv/Lib/site-packages/torch/include/ATen/core/op_registration/op_allowlist.h new file mode 100644 index 0000000000000000000000000000000000000000..a46cd9c1775aab7d6293cb218b019fe1c94a18a1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/op_registration/op_allowlist.h @@ -0,0 +1,181 @@ +#pragma once + +// TODO: unify to C10_MOBILE. In theory this header could be used in OSS. +#ifdef TEMPLATE_SELECTIVE_BUILD +#include +#endif + +/** + * This header implements functionality to build PyTorch with only a certain + * set of operators (+ dependencies) included. + * + * - Build with -DTORCH_OPERATOR_WHITELIST="aten::add;aten::sub" and only these + * two ops will be included in your build. The allowlist records operators + * only, no overloads; if you include aten::add, all overloads of aten::add + * will be included. + * + * Internally, this is done by removing the operator registration calls + * using compile time programming, and the linker will then prune all + * operator functions that weren't registered. + * See Note [Selective build] for more details + * + * WARNING: The allowlist mechanism doesn't work for all ways you could go about + * registering an operator. If the dispatch key / operator name is not + * sufficiently obvious at compile time, then the allowlisting mechanism + * will fail (and the operator will be included in the binary anyway). + */ + +#include +#include +#include + + +#if defined(ENABLE_RECORD_KERNEL_FUNCTION_DTYPE) +#include +#endif + +namespace c10::impl { + +constexpr bool allowlist_contains(std::string_view allowlist, std::string_view item); // Forward Declare + +/** + * In selective build mode returns true/false depending on whether a build + * feature is available or not. + * + * In instrumenting mode (tracing mode), always returns true, and doesn't + * trigger any side effects. + */ +constexpr bool is_build_feature_available(const char* name) { +#if !defined(ENABLE_RECORD_KERNEL_FUNCTION_DTYPE) + // Selective Build mode. +#if !defined(TORCH_BUILD_FEATURE_ALLOWLIST) + (void)name; + return true; +#else + return allowlist_contains( + C10_STRINGIZE(TORCH_BUILD_FEATURE_ALLOWLIST), + name); +#endif + +#else + // Instrumenting mode. + (void)name; + return true; +#endif +} + +[[noreturn]] void build_feature_required_feature_not_available(const char* feature); + +/** + * Use BUILD_FEATURE_REQUIRED macro in user-code. + * + * In selective build mode becomes a no-op if the build feature passed + * in is available. If not available, throws an exception (c10::Error). + * The compiler is able to perform dead code elimination for code + * following this method if the build feature is not available. + * + * In instrumenting mode (tracing mode), registers (as a side effect) + * the presence of this specific build feature being triggered. + */ +#if !defined(ENABLE_RECORD_KERNEL_FUNCTION_DTYPE) // selective build mode + +#if defined(TORCH_BUILD_FEATURE_ALLOWLIST) +#define BUILD_FEATURE_REQUIRED(NAME) \ + if (!c10::impl::is_build_feature_available(NAME)) { \ + ::c10::impl::build_feature_required_feature_not_available(NAME); \ + } +#else // Everything trivially selected +#define BUILD_FEATURE_REQUIRED(NAME) + +#endif + +#else // trace mode +#define BUILD_FEATURE_REQUIRED(NAME) \ + RECORD_FUNCTION_WITH_SCOPE( \ + at::RecordScope::BUILD_FEATURE, \ + std::string(NAME), \ + {}); +#endif + +// Use this macro, and not is_build_feature_available +#define BUILD_FEATURE_AVAILABLE(NAME) ::c10::impl::is_build_feature_available(NAME) + +// returns true iff allowlist contains item +// allowlist_contains("a;bc;d", "bc") == true +constexpr bool allowlist_contains(std::string_view allowlist, std::string_view item) { + //Choose a really big value for next so that if something goes wrong + //this code will blow up in a hopefully detectable way. + size_t next = std::numeric_limits::max(); + for (size_t cur = 0; cur <= allowlist.size(); cur = next) { + next = allowlist.find(';', cur); + if (next != std::string_view::npos) { + if (allowlist.substr(cur, next - cur) == item) { + return true; + } + next++; + } else { + if (allowlist.substr(cur).compare(item) == 0) { + return true; + } + break; + } + } + return false; +} + +// Returns true iff the given op name is on the allowlist +// and should be registered +constexpr bool op_allowlist_check(std::string_view op_name [[maybe_unused]]) { + assert(op_name.find("::") != std::string_view::npos); + // Use assert() instead of throw() due to a gcc bug. See: + // https://stackoverflow.com/questions/34280729/throw-in-constexpr-function + // https://github.com/fmtlib/fmt/issues/682 + assert(op_name.find('(') == std::string_view::npos); +#if !defined(TORCH_OPERATOR_WHITELIST) + // If the TORCH_OPERATOR_WHITELIST parameter is not defined, + // all ops are to be registered + return true; +#else + return allowlist_contains( + C10_STRINGIZE(TORCH_OPERATOR_WHITELIST), + // This function is majorly used for mobile selective build with + // root operators, where the overload is included in the allowlist. + op_name); + // // Strip overload name (as allowlist doesn't contain overloads) + // // Another function based on this may be added when there's usage + // // on op names without overload. + // OperatorNameView::parse(op_name).name); +#endif +} + +// Returns true iff the given schema string is on the allowlist +// and should be registered +constexpr bool schema_allowlist_check(std::string_view schema) { +#if defined(TORCH_FORCE_SCHEMA_REGISTRATION) + return true; +#else + return op_allowlist_check(schema.substr(0, schema.find('('))); +#endif +} + +// Returns true iff the given custom class name is on the allowlist +// and should be registered +constexpr bool custom_class_allowlist_check(std::string_view custom_class_name [[maybe_unused]]) { +#if !defined(TORCH_CUSTOM_CLASS_ALLOWLIST) + // If the TORCH_CUSTOM_CLASS_ALLOWLIST parameter is not defined, + // all custom classes are to be registered + return true; +#else + return allowlist_contains( + C10_STRINGIZE(TORCH_CUSTOM_CLASS_ALLOWLIST), + custom_class_name); +#endif +} + +// schema_allowlist_check() implicitly depends on a macro, TORCH_OPERATOR_WHITELIST. +// Add this API to pass arbitrary allowlist. +constexpr bool op_allowlist_contains_name_in_schema(std::string_view allowlist, std::string_view schema) { + return allowlist_contains(allowlist, schema.substr(0, schema.find('('))); +} + +} // namespace c10::impl diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/op_registration/op_registration.h b/phivenv/Lib/site-packages/torch/include/ATen/core/op_registration/op_registration.h new file mode 100644 index 0000000000000000000000000000000000000000..f92d9712c26cafa05046c2a025692bca5c151b78 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/op_registration/op_registration.h @@ -0,0 +1,596 @@ +#pragma once + +/** + * Include this file if you want to register operators. It includes all + * functionality needed to do so for you. + */ + +#include +#include +#include +#include +#include +#include +#include +#if defined(EXPOSE_C2_OPS) || !defined(CAFFE2_IS_XPLAT_BUILD) +#include +#endif +#include + +namespace c10 { + +namespace detail { +// The first argument of the schema might be of type DispatchKeySet, in which case we remove it. +// We do this because every argument in a function schema is expected to be convertable +// to an ivalue, but DispatchKeySet is not a type we want the jit to be aware of. +// See Note [Plumbing Keys Through The Dispatcher] +template +std::unique_ptr inferFunctionSchemaFromFunctor() { + using func_type = typename c10::remove_DispatchKeySet_arg_from_func::func_type; + return std::make_unique(inferFunctionSchemaFlattenedReturns()); +} +} + +/** + * An instance of this class handles the registration for one or more operators. + * Make sure you keep the RegisterOperators instance around since it will + * deregister the operator it's responsible for in its destructor. + * + * Example: + * + * > namespace { + * > class my_kernel_cpu final : public c10::OperatorKernel { + * > public: + * > Tensor operator()(Tensor a, Tensor b) {...} + * > }; + * > } + * > + * > static auto registry = c10::RegisterOperators() + * > .op(c10::RegisterOperators::options() + * > .schema("my_op") + * > .kernel(DispatchKey::CPU)); + */ +class TORCH_API RegisterOperators final { +public: + RegisterOperators() = default; + ~RegisterOperators() = default; + + RegisterOperators(const RegisterOperators&) = delete; + RegisterOperators& operator=(const RegisterOperators&) = delete; + RegisterOperators(RegisterOperators&&) noexcept = default; + RegisterOperators& operator=(RegisterOperators&&) noexcept = default; + + class TORCH_API Options final { + public: + Options(const Options&) = delete; + Options(Options&&) noexcept = delete; + Options& operator=(const Options&) = delete; + Options& operator=(Options&&) noexcept = delete; + + // internal-only for registering stack based kernels + template + Options&& kernel(DispatchKey dispatch_key) && { + return std::move(*this).kernel(dispatch_key, KernelFunction::makeFromBoxedFunction(), std::nullopt, nullptr); + } + + // internal-only for registering stack based catch-all kernels + template + Options&& catchAllKernel() && { + return std::move(*this).kernel(std::nullopt, KernelFunction::makeFromBoxedFunction(), std::nullopt, nullptr); + } + + // internal only for registering caffe2 ops + Options&& schema(FunctionSchema&& schema) { + TORCH_CHECK(!schemaOrName_.has_value(), "You can only specify the schema once per operator registration."); + schemaOrName_ = FunctionSchema(std::move(schema)); + return std::move(*this); + } + + /** + * Use this to specify the schema for an operator. You can also specify + * the operator name only to have the function signature part of the + * schema be inferred from the kernel function. + * + * Example: + * + * > // Infer function signature from my_kernel_cpu + * > static auto registry = c10::RegisterOperators() + * > .op(c10::RegisterOperators::options() + * > .schema("my_op") + * > .kernel(DispatchKey::CPU)); + * > + * > + * > // Explicitly specify full schema + * > static auto registry = c10::RegisterOperators() + * > .op(c10::RegisterOperators::options() + * > .schema("my_op(Tensor a) -> Tensor") + * > .kernel(DispatchKey::CPU)); + */ + Options&& schema(const std::string& schemaOrName) { + TORCH_CHECK(!schemaOrName_.has_value(), "Tried to register operator ", schemaOrName," but specified schema multiple times. You can only specify the schema once per operator registration."); + + #if !defined(EXPOSE_C2_OPS) && defined(CAFFE2_IS_XPLAT_BUILD) + throw std::logic_error("Tried to register operator " + schemaOrName + ". We don't support registering c10 ops on mobile yet because the function schema parser isn't present in the mobile build."); + #else + schemaOrName_ = torch::jit::parseSchemaOrName(schemaOrName); + #endif + + return std::move(*this); + } + + /** + * Use this to register an operator whose kernel is implemented as a functor. + * The kernel is only called for inputs matching the given dispatch key. + * You can register multiple kernels for different dispatch keys. + * + * Example: + * + * > namespace { + * > class my_kernel_cpu final : public c10::OperatorKernel { + * > public: + * > Tensor operator()(Tensor a, Tensor b) {...} + * > }; + * > } + * > + * > static auto registry = c10::RegisterOperators() + * > .op(c10::RegisterOperators::options() + * > .schema("my_op") + * > .kernel(DispatchKey::CPU)); + * + * The functor constructor can take arguments to configure the kernel. + * The arguments are defined in the kernel registration. + * Example: + * + * > namespace { + * > class my_kernel_cpu final : public c10::OperatorKernel { + * > public: + * > explicit my_kernel_cpu(std::string some_configuration, int a, bool b) + * > : ... {...} + * > + * > Tensor operator()(Tensor a, Tensor b) {...} + * > }; + * > } + * > + * > static auto registry = c10::RegisterOperators() + * > .op(c10::RegisterOperators::options() + * > .schema("my_op") + * > .kernel(DispatchKey::CPU, "some_configuration", 3, true)); + */ + template + // enable_if: only enable it if KernelFunctor is actually a functor + std::enable_if_t::value, Options&&> kernel(DispatchKey dispatch_key, ConstructorParameters&&... constructorParameters) && { + static_assert(std::is_base_of_v, "Tried to register a kernel functor using the kernel() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it."); + static_assert(std::is_constructible_v, "Wrong argument list for constructor of kernel functor. The arguments to kernel(arguments...) must match one of the constructors of Functor."); + + return std::move(*this).kernel( + dispatch_key, + KernelFunction::makeFromUnboxedFunctor(std::make_unique(std::forward(constructorParameters)...)), + impl::CppSignature::make(), + detail::inferFunctionSchemaFromFunctor() + ); + } + + /** + * Use this to register an operator whose kernel is implemented as a functor. + * The kernel is a catch-all kernel, meaning it's called independent from + * the input. Dispatch is disabled for this operator. + * + * Example: + * + * > namespace { + * > class my_kernel_cpu final : public c10::OperatorKernel { + * > public: + * > Tensor operator()(Tensor a, Tensor b) {...} + * > }; + * > } + * > + * > static auto registry = c10::RegisterOperators() + * > .op(c10::RegisterOperators::options() + * > .schema("my_op") + * > .catchAllKernel()); + * + * The functor constructor can take arguments to configure the kernel. + * The arguments are defined in the kernel registration. + * Example: + * + * > namespace { + * > class my_kernel_cpu final : public c10::OperatorKernel { + * > public: + * > explicit my_kernel_cpu(std::string some_configuration, int a, bool b) + * > : ... {...} + * > + * > Tensor operator()(Tensor a, Tensor b) {...} + * > }; + * > } + * > + * > static auto registry = c10::RegisterOperators() + * > .op(c10::RegisterOperators::options() + * > .schema("my_op") + * > .catchAllKernel("some_configuration", 3, true)); + */ + template + // enable_if: only enable it if KernelFunctor is actually a functor + std::enable_if_t::value, Options&&> catchAllKernel(ConstructorParameters&&... constructorParameters) && { + static_assert(std::is_base_of_v, "Tried to register a kernel functor using the kernel() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it."); + static_assert(std::is_constructible_v, "Wrong argument list for constructor of kernel functor. The arguments to kernel(arguments...) must match one of the constructors of Functor."); + + return std::move(*this).kernel( + std::nullopt, + KernelFunction::makeFromUnboxedFunctor(std::make_unique(std::forward(constructorParameters)...)), + impl::CppSignature::make(), + detail::inferFunctionSchemaFromFunctor() + ); + } + + /** + * Use this to register an operator whose kernel is implemented by a function. + * The kernel is only called for inputs matching the given dispatch key. + * You can register multiple kernels for different dispatch keys. + * + * Example: + * + * > namespace { Tensor my_kernel_cpu(Tensor a, Tensor b) {...} } + * > + * > static auto registry = c10::RegisterOperators() + * > .op(c10::RegisterOperators::options() + * > .schema("my_op") + * > .kernel(DispatchKey::CPU)); + */ + template + // enable_if: only enable it if FuncType is actually a function + std::enable_if_t::value, Options&&> kernel(DispatchKey dispatch_key) && { + static_assert(!std::is_same_v, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API."); + static_assert(kernel_func != nullptr, "Kernel function cannot be nullptr"); + + return std::move(*this).kernel( + dispatch_key, + KernelFunction::makeFromUnboxedFunction(TORCH_FN(kernel_func)), + impl::CppSignature::make(), + // TODO Do schema inference without relying on WrapFunctionIntoFunctor + detail::inferFunctionSchemaFromFunctor>::type>() + ); + } + + /** + * Use this to register an operator whose kernel is implemented by a function. + * The kernel is a catch-all kernel, meaning it's called independent from + * the input. Dispatch is disabled for this operator. + * + * Example: + * + * > namespace { Tensor my_kernel_cpu(Tensor a, Tensor b) {...} } + * > + * > static auto registry = c10::RegisterOperators() + * > .op(c10::RegisterOperators::options() + * > .schema("my_op") + * > .catchAllKernel()); + */ + template + // enable_if: only enable it if FuncType is actually a function + std::enable_if_t::value, Options&&> catchAllKernel() && { + static_assert(!std::is_same_v, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API."); + static_assert(kernel_func != nullptr, "Kernel function cannot be nullptr"); + + return std::move(*this).kernel( + std::nullopt, + KernelFunction::makeFromUnboxedFunction(TORCH_FN(kernel_func)), + impl::CppSignature::make(), + // TODO Do schema inference without relying on WrapFunctionIntoFunctor + detail::inferFunctionSchemaFromFunctor>::type>() + ); + } + + template + // enable_if: only enable it if FuncType is actually a function + std::enable_if_t::value, Options&&> kernel(DispatchKey dispatch_key, FuncType* kernel_func) && { + static_assert(!std::is_same_v, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API."); + TORCH_INTERNAL_ASSERT(kernel_func != nullptr, "Kernel function cannot be nullptr"); + + return std::move(*this).kernel( + dispatch_key, + KernelFunction::makeFromUnboxedRuntimeFunction(kernel_func), + impl::CppSignature::make(), + // TODO Do schema inference without relying on WrapFunctionIntoFunctor + detail::inferFunctionSchemaFromFunctor>>() + ); + } + + template + // enable_if: only enable it if FuncType is actually a function + std::enable_if_t::value, Options&&> catchAllKernel(FuncType* kernel_func) && { + static_assert(!std::is_same_v, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API."); + TORCH_INTERNAL_ASSERT(kernel_func != nullptr, "Kernel function cannot be nullptr"); + + return std::move(*this).kernel( + std::nullopt, + KernelFunction::makeFromUnboxedRuntimeFunction(kernel_func), + impl::CppSignature::make(), + // TODO Do schema inference without relying on WrapFunctionIntoFunctor + detail::inferFunctionSchemaFromFunctor>>() + ); + } + + /** + * Use this to register an operator whose kernel is implemented as a lambda. + * The kernel is only called for inputs matching the given dispatch key. + * You can register multiple kernels for different dispatch keys. + * + * The lambda must be stateless, i.e. not have a capture. If your kernel + * needs to store some configuration parameters, write the kernel as a + * functor instead. + * + * Example: + * + * > static auto registry = c10::RegisterOperators() + * > .op(c10::RegisterOperators::options() + * > .schema("my_op") + * > .kernel(DispatchKey::CPU, [] (Tensor a) -> Tensor {...})); + */ + template + // enable_if: only enable it if Lambda is a functor (note: lambdas are functors) + std::enable_if_t< + guts::is_functor>::value + && !std::is_same_v>::func_type, KernelFunction::BoxedKernelFunction>, + Options&&> kernel(DispatchKey dispatch_key, Lambda&& functor) && { + static_assert(!std::is_base_of_v>, "The kernel(x) API for registering a kernel is only meant to be used with lambdas. Your kernel is a functor. Please use the kernel() API instead."); + + // We don't support stateful lambdas (i.e. lambdas with a capture), because their + // behavior would be nonobvious. A functor kernel with cache gets a new instance of + // its cache each time the kernel is looked up from the dispatch table. + // A lambda with a capture would be global and share its capture between all kernel lookups. + // So, instead of making users having to think about it (including the thread-safety + // issues this causes), let's just forbid stateful lambdas altogether. + static_assert(guts::is_stateless_lambda>::value, "The kernel(x) API for registering a kernel only works for stateless lambdas (i.e. lambdas without captures). If you need a cache, please use the functor based API kernel() instead."); + + return std::move(*this).kernel( + dispatch_key, + KernelFunction::makeFromUnboxedLambda(std::forward(functor)), + impl::CppSignature::make(), + // TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor + detail::inferFunctionSchemaFromFunctor>>() + ); + } + + /** + * Use this to register an operator whose kernel is implemented as a lambda. + * The kernel is a catch-all kernel, meaning it's called independent from + * the input. Dispatch is disabled for this operator. + * + * The lambda must be stateless, i.e. not have a capture. If your kernel + * needs to store some configuration parameters, write the kernel as a + * functor instead. + * + * Example: + * + * > static auto registry = c10::RegisterOperators() + * > .op(c10::RegisterOperators::options() + * > .schema("my_op") + * > .catchAllKernel([] (Tensor a) -> Tensor {...})); + */ + template + // enable_if: only enable it if Lambda is a functor (note: lambdas are functors) + std::enable_if_t< + guts::is_functor>::value + && !std::is_same_v>::func_type, KernelFunction::BoxedKernelFunction>, + Options&&> catchAllKernel(Lambda&& lambda) && { + static_assert(!std::is_base_of_v>, "The kernel(x) API for registering a kernel is only meant to be used with lambdas. Your kernel is a functor. Please use the kernel() API instead."); + + // We don't support stateful lambdas (i.e. lambdas with a capture), because their + // behavior would be nonobvious. + // A lambda with a capture would be global and share its capture between all kernel lookups. + // This would be a likely source for unexpected race conditions, so we forbid it. + // If a kernel really needs global state, they can just have regular global state + // in their .cpp file next to the kernel lambda. + static_assert(guts::is_stateless_lambda>::value, "The kernel(x) API for registering a kernel only works for stateless lambdas (i.e. lambdas without captures). If you need a cache, please use the functor based API kernel() instead."); + + return std::move(*this).kernel( + std::nullopt, + KernelFunction::makeFromUnboxedLambda(std::forward(lambda)), + impl::CppSignature::make(), + // TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor + detail::inferFunctionSchemaFromFunctor>>() + ); + } + + Options&& aliasAnalysis(AliasAnalysisKind aliasAnalysisKind) && { + TORCH_CHECK(!aliasAnalysisKind_.has_value(), "You can only call aliasAnalysis() once per operator registration."); + aliasAnalysisKind_ = aliasAnalysisKind; + return std::move(*this); + } + + private: + Options&& kernel(std::optional dispatch_key, KernelFunction&& func, std::optional cpp_signature, std::unique_ptr&& inferred_function_schema) && { + KernelRegistrationConfig config; + config.dispatch_key = dispatch_key; + config.func = std::move(func); + config.cpp_signature = cpp_signature; + config.inferred_function_schema = std::move(inferred_function_schema); + kernels.push_back(std::move(config)); + return std::move(*this); + } + + Options() + : schemaOrName_(std::nullopt) + , kernels() + , aliasAnalysisKind_(std::nullopt) + {} + + // KernelRegistrationConfig accumulates all information from the config + // parameters passed to a RegisterOperators::op() call into one object. + struct KernelRegistrationConfig final { + KernelRegistrationConfig() + : dispatch_key(std::nullopt) + , func() + , cpp_signature(std::nullopt) + , inferred_function_schema(nullptr) + {} + + std::optional dispatch_key; + KernelFunction func; + std::optional cpp_signature; + std::unique_ptr inferred_function_schema; + }; + + std::optional> schemaOrName_; + + std::vector kernels; + std::optional aliasAnalysisKind_; + friend class RegisterOperators; + friend class Library; + }; + + /** + * Call this to get an instance of registration options, which + * can be passed to a call to RegisterOperators::op() to specify + * these options for the operator registration. + * See class doc comment for examples. + */ + static Options options() { + return {}; + } + + /** + * Call this to register an operator. See class doc comment for examples. + */ + RegisterOperators&& op(Options&& options) && { + checkSchemaAndRegisterOp_(std::move(options)); + return std::move(*this); + } + + // Regular mutator version of the && version above + RegisterOperators& op(Options&& options) & { + checkSchemaAndRegisterOp_(std::move(options)); + return *this; + } + + /** + * This is a shorthand for RegisterOperators::op(Options) where you can + * specify the operator schema outside of the options parameter. + * See class doc comment for examples. + */ + RegisterOperators&& op(const std::string& schemaOrName, Options&& options = RegisterOperators::options()) && { + return std::move(*this).op(std::move(options).schema(schemaOrName)); + } + + // internal only for registering caffe2 ops + RegisterOperators&& op(FunctionSchema schema, Options&& options) && { + return std::move(*this).op(std::move(options).schema(std::move(schema))); + } + + template + explicit RegisterOperators(const std::string& schemaOrName, FuncType&& func, Options&& options = RegisterOperators::options()) + : RegisterOperators() { + std::move(*this).op(schemaOrName, std::forward(func), std::move(options)); + } + + /** + * This API registers an operator based on a kernel function pointer. + * + * Given a kernel + * + * > namespace { Tensor my_kernel_cpu(Tensor a, Tensor b) {...} } + * + * This API looks like: + * + * > static auto registry = c10::RegisterOperators() + * > .op("my_op", &my_kernel_cpu); + * + * If your kernel is small and the overhead of calling it matters, + * then this API might be the wrong choice since the following API + * has a slightly lower overhead for calling into the kernel: + * + * > static auto registry = c10::RegisterOperators() + * > .op("my_op", c10::RegisterOperators::options() + * > .kernel()); + * + * Or, alternatively, write your kernel as a functor: + * + * > namespace { + * > class my_kernel_cpu final : public c10::OperatorKernel { + * > public: + * > Tensor operator()(Tensor a, Tensor b) {...} + * > }; + * > } + * > + * > static auto registry = c10::RegisterOperators() + * > .op("my_op", c10::RegisterOperators::options() + * > .kernel()); + */ + template + // enable_if: only enable it if FuncType is actually a function, but not a stack based BoxedKernelFunction. + std::enable_if_t::value && !std::is_same_v, RegisterOperators&&> + op(const std::string& schemaOrName, FuncType* func, Options&& options = RegisterOperators::options()) && { + constexpr bool AllowLegacyTypes = true; + return std::move(*this).op(std::move(options).schema(schemaOrName).kernel( + std::nullopt, + KernelFunction::makeFromUnboxedRuntimeFunction(func), + impl::CppSignature::make(), + // TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor + detail::inferFunctionSchemaFromFunctor>>() + )); + } + + /** + * This API registers an operator based on a kernel lambda. + * + * This API looks like: + * + * > static auto registry = c10::RegisterOperators() + * > .op("my_op", [] (Tensor a, Tensor b) {...}); + * + * This is equivalent to: + * + * > static auto registry = c10::RegisterOperators() + * > .op("my_op", c10::RegisterOperators::options() + * > .catchAllKernel([] (Tensor a, Tensor b) {...})); + * + */ + template + // enable_if: only enable it if Lambda is actually a stateless lambda + std::enable_if_t::value && guts::is_stateless_lambda>::value, RegisterOperators&&> + op(const std::string& schemaOrName, Lambda&& lambda, Options&& options = RegisterOperators::options()) && { + static_assert(!std::is_base_of_v, "c10::OperatorKernel is part of the new kernel registration API and shouldn't be used together with the deprecated registration API. Please use the new RegisterOperators::options().kernel() based API instead."); + + constexpr bool AllowLegacyTypes = true; + return std::move(*this).op(std::move(options).schema(schemaOrName).kernel( + std::nullopt, + KernelFunction::makeFromUnboxedLambda(std::forward(lambda)), + impl::CppSignature::make(), + // TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor + detail::inferFunctionSchemaFromFunctor>>() + )); + } + + template + C10_DEPRECATED_MESSAGE("Registering operator kernels with stateful lambdas (i.e. lambdas with a capture) has non-obvious behavior. This is deprecated. Please use a lambda without a capture or a functor class instead.") + // enable_if: only enable it if Lambda is actually a functor but not a stateless lambda + std::enable_if_t::value && !guts::is_stateless_lambda>::value, RegisterOperators&&> + op(const std::string& schemaOrName, Lambda&& lambda, Options&& options = RegisterOperators::options()) && { + static_assert(!std::is_base_of_v, "c10::OperatorKernel is part of the new kernel registration API and shouldn't be used together with the deprecated registration API. Please use the new RegisterOperators::options().kernel() based API instead."); + + constexpr bool AllowLegacyTypes = true; + return std::move(*this).op(std::move(options).schema(schemaOrName).kernel( + std::nullopt, + KernelFunction::makeFromUnboxedLambda(std::forward(lambda)), + impl::CppSignature::make(), + // TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor + detail::inferFunctionSchemaFromFunctor>>() + )); + } + +private: + void checkSchemaAndRegisterOp_(Options&& config); + + static c10::FunctionSchema inferSchemaFromKernels_(const OperatorName& opNameStr, const Options& options); + void checkNoDuplicateKernels_(const Options& options); + void registerOp_(Options&& options); + + std::vector registrars_; +}; + +} // namespace c10 + +namespace torch { + // Old-style API + using RegisterOperators = c10::RegisterOperators; +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/operator_name.h b/phivenv/Lib/site-packages/torch/include/ATen/core/operator_name.h new file mode 100644 index 0000000000000000000000000000000000000000..63d85c4c6dba3d6b29a8880df39b173b6e4b3cb8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/operator_name.h @@ -0,0 +1,98 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace c10 { + +// TODO: consider storing namespace separately too +struct OperatorName final { + std::string name; + std::string overload_name; + OperatorName(std::string name, std::string overload_name) + : name(std::move(name)), overload_name(std::move(overload_name)) {} + + // TODO: These two functions below are slow! Fix internal data structures so + // I don't have to manually reconstruct the namespaces! + + // Return the namespace of this OperatorName, if it exists. The + // returned string_view is only live as long as the OperatorName + // exists and name is not mutated + std::optional getNamespace() const { + auto pos = name.find("::"); + if (pos == std::string::npos) { + return std::nullopt; + } else { + return std::string_view(name.data(), pos); + } + } + + // Returns true if we successfully set the namespace + bool setNamespaceIfNotSet(const char* ns) { + if (!getNamespace().has_value()) { + const auto ns_len = strlen(ns); + const auto old_name_size = name.size(); + name.resize(ns_len + 2 + old_name_size); + // Shift current value of name to the end of the new space. + name.replace( + name.size() - old_name_size, old_name_size, name, 0, old_name_size); + name.replace(0, ns_len, ns, ns_len); + name[ns_len] = ':'; + name[ns_len + 1] = ':'; + return true; + } else { + return false; + } + } +}; + +// Non-owning view of an OperatorName. Unlike OperatorName, most of +// its functions are constexpr, so it can be used for compile time +// computations +struct OperatorNameView final { + std::string_view name; + std::string_view overload_name; + constexpr OperatorNameView( + std::string_view name, + std::string_view overload_name) + : name(name), overload_name(overload_name) {} + // Parses strings like "foo.overload" and also "foo" + constexpr static OperatorNameView parse(std::string_view full_name) { + auto i = full_name.find('.'); + if (i == std::string_view::npos) { + return OperatorNameView(full_name, std::string_view()); + } else { + return OperatorNameView(full_name.substr(0, i), full_name.substr(i + 1)); + } + } +}; + +inline bool operator==(const OperatorName& lhs, const OperatorName& rhs) { + return lhs.name == rhs.name && lhs.overload_name == rhs.overload_name; +} + +inline bool operator!=(const OperatorName& lhs, const OperatorName& rhs) { + return !operator==(lhs, rhs); +} + +TORCH_API std::string toString(const OperatorName& opName); +TORCH_API std::ostream& operator<<(std::ostream&, const OperatorName&); + +} // namespace c10 + +namespace std { +template <> +struct hash<::c10::OperatorName> { + size_t operator()(const ::c10::OperatorName& x) const { + return std::hash()(x.name) ^ + (~std::hash()(x.overload_name)); + } +}; +} // namespace std diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/qualified_name.h b/phivenv/Lib/site-packages/torch/include/ATen/core/qualified_name.h new file mode 100644 index 0000000000000000000000000000000000000000..fcc5bdada9b276b2aa7745d87d47c850c95c6894 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/qualified_name.h @@ -0,0 +1,161 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace c10 { + +// Represents a name of the form "foo.bar.baz" +struct QualifiedName { + QualifiedName() = default; + + // `name` can be a dotted string, like "foo.bar.baz", or just a bare name. + /* implicit */ QualifiedName(const std::string& name) { + TORCH_CHECK(!name.empty()); + // split the string into its atoms. + size_t startSearchFrom = 0; + size_t pos = name.find(delimiter_, startSearchFrom); + + while (pos != std::string::npos) { + auto atom = name.substr(startSearchFrom, pos - startSearchFrom); + TORCH_INTERNAL_ASSERT( + !atom.empty(), "Invalid name for qualified name: '", name, "'"); + atoms_.push_back(std::move(atom)); + startSearchFrom = pos + 1; + pos = name.find(delimiter_, startSearchFrom); + } + + auto finalAtom = name.substr(startSearchFrom); + TORCH_INTERNAL_ASSERT( + !finalAtom.empty(), "Invalid name for qualified name: '", name, "'"); + atoms_.emplace_back(std::move(finalAtom)); + + cacheAccessors(); + } + + explicit QualifiedName(std::vector atoms) : atoms_(std::move(atoms)) { + for (const auto& atom : atoms_) { + TORCH_CHECK(!atom.empty(), "Atom cannot be empty"); + TORCH_CHECK( + atom.find(delimiter_) == std::string::npos, + "Delimiter not allowed in atom"); + } + + cacheAccessors(); + } + // Unnecessary copy. Ideally we'd use something like std::string_view. + /* implicit */ QualifiedName(const char* name) + : QualifiedName(std::string(name)) {} + + // `name` must be a bare name (no dots!) + explicit QualifiedName(const QualifiedName& prefix, std::string name) { + TORCH_INTERNAL_ASSERT(!name.empty()); + TORCH_INTERNAL_ASSERT(name.find(delimiter_) == std::string::npos); + atoms_.insert(atoms_.begin(), prefix.atoms_.begin(), prefix.atoms_.end()); + atoms_.push_back(std::move(name)); + + cacheAccessors(); + } + + // Is `this` a prefix of `other`? + // For example, "foo.bar" is a prefix of "foo.bar.baz" + bool isPrefixOf(const QualifiedName& other) const { + const auto& thisAtoms = atoms_; + const auto& otherAtoms = other.atoms_; + + if (thisAtoms.size() > otherAtoms.size()) { + // Can't be a prefix if it's bigger + return false; + } + for (const auto i : c10::irange(thisAtoms.size())) { + if (thisAtoms[i] != otherAtoms[i]) { + return false; + } + } + return true; + } + + // The fully qualified name, like "foo.bar.baz" + const std::string& qualifiedName() const { + return qualifiedName_; + } + + // The leading qualifier, like "foo.bar" + const std::string& prefix() const { + return prefix_; + } + + // The base name, like "baz" + const std::string& name() const { + return name_; + } + + const std::vector& atoms() const { + return atoms_; + } + + bool operator==(const QualifiedName& other) const { + return this->qualifiedName_ == other.qualifiedName_; + } + + bool operator!=(const QualifiedName& other) const { + return !(*this == other); + } + + private: + static constexpr char delimiter_ = '.'; + + // Helper for cacheAccessors() below. + template + std::string join(char delimiter, const T& v) { + std::string out; + size_t reserve = 0; + for (const auto& e : v) { + reserve += e.size() + 1; + } + out.reserve(reserve); + for (const auto i : c10::irange(v.size())) { + if (i != 0) { + out.push_back(delimiter); + } + out.append(v[i]); + } + return out; + } + + void cacheAccessors() { + qualifiedName_ = join(delimiter_, atoms_); + if (atoms_.size() > 1) { + ArrayRef view(atoms_); + const auto prefixView = view.slice(0, view.size() - 1); + prefix_ = join(delimiter_, prefixView); + } + + if (!atoms_.empty()) { + name_ = atoms_.back(); + } + } + + // The actual list of names, like "{foo, bar, baz}" + std::vector atoms_; + + /* + * Cached accessors, derived from `atoms_`. + */ + std::string qualifiedName_; + std::string prefix_; + std::string name_; +}; +} // namespace c10 + +namespace std { +template <> +struct hash { + size_t operator()(const c10::QualifiedName& n) const noexcept { + return std::hash()(n.qualifiedName()); + } +}; +} // namespace std diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/rref_interface.h b/phivenv/Lib/site-packages/torch/include/ATen/core/rref_interface.h new file mode 100644 index 0000000000000000000000000000000000000000..563e9c0d102d239bf9c1e210159df4db37ac76a9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/rref_interface.h @@ -0,0 +1,41 @@ +#pragma once + +#include +#include + +namespace c10 { + +struct Type; +using worker_id_t = int16_t; + +// This abstract class contains only user-facing APIs, and will be shared +// between jit and distributed to implement TorchScript support. +class C10_EXPORT RRefInterface : public c10::intrusive_ptr_target { + public: + RRefInterface() = default; + // RRef is made NOT copyable NOT movable to prevent messing up reference + // counting. + RRefInterface(const RRefInterface& other) = delete; + RRefInterface(RRefInterface&& other) = delete; + RRefInterface& operator=(const RRefInterface& other) = delete; + RRefInterface& operator=(RRefInterface&& other) = delete; + + ~RRefInterface() override = default; + + // returns the worker id of the owner + virtual worker_id_t owner() const = 0; + + // returns the worker name of the owner + virtual std::string ownerName() const = 0; + + // Returns true if this is the ``OwnerRRef`` + virtual bool isOwner() const = 0; + + // Returns true if this is an ``OwnerRRef`` or if this ``UserRRef`` has been + // confirmed by its owner. + virtual bool confirmedByOwner() const = 0; + + virtual const TypePtr type() const = 0; +}; + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/stack.h b/phivenv/Lib/site-packages/torch/include/ATen/core/stack.h new file mode 100644 index 0000000000000000000000000000000000000000..a63fa72c0f1697badc9f3ce20552d6b267b6d575 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/stack.h @@ -0,0 +1,204 @@ +#pragma once + +#include + +#include +#include +#include + +// TODO move this to c10 namespace + + +namespace torch::jit { + +using c10::IValue; +using Stack = std::vector; + +class Operation { + template + using accepts = std::is_constructible, F&&>; + + public: + template ::value, int> = 0> + C10_DEPRECATED_MESSAGE("Please use void(Stack&) to register operator instead.") + Operation(F&& raw): op_([raw = std::forward(raw)](Stack& stack) { + raw(&stack); + }) {} + + template ::value && + !std::is_same_v, Operation>, int> = 0> + Operation(F&& op): op_(std::forward(op)) {} + + Operation(std::nullptr_t) noexcept {} + + explicit operator bool() const noexcept { + return op_ ? true : false; + } + + void operator()(Stack& stack) { + op_(stack); + } + + template + T* target() noexcept { + return op_.target(); + } + + private: + std::function op_; +}; + +// An operation with N inputs and M outputs pops the last N inputs off +// the stack and pushes its M inputs onto the stack +// before: I0, I1, ... IN <- stack.back() +// after: O0, O1, ... OM +// operations are defined this way so that ownership of inputs can be +// transferred to the operation and it can incrementally drop ownership of +// tensors when they become unneeded. For large operations, like 'run an entire +// subgraph', this functionality is very important for minimizing gpu memory +// usage return value is the relative 'offset' to jump to for the next +// operation: +// pc += 1 + offset +// so a return value of 0 goes to the next instruction + +// treat the last N elements of the stack as a list, looking up +// element i +inline IValue& peek(Stack& stack, size_t i, size_t N) { + // NOLINTNEXTLINE(*-narrowing-conversions) + return *(stack.end() - N + i); +} +inline IValue& peek(Stack* stack, size_t i, size_t N) { + return peek(*stack, i, N); +} +inline const IValue& peek(const Stack& stack, size_t i, size_t N) { + // NOLINTNEXTLINE(*-narrowing-conversions) + return *(stack.end() - N + i); +} +inline const IValue& peek(const Stack* stack, size_t i, size_t N) { + return peek(*stack, i, N); +} +// treat the last N elements of the stack as a list, looking up the +// slice starting at index i and having length len +inline at::ArrayRef peekSlice( + const Stack& stack, + size_t i, + size_t len, + size_t N) { + return at::ArrayRef(stack).slice(stack.size() - N + i, len); +} +inline at::ArrayRef last(const Stack& stack, size_t N) { + return peekSlice(stack, 0, N, N); +} +inline at::ArrayRef last(const Stack* stack, size_t N) { + return last(*stack, N); +} +inline void drop(Stack& stack, size_t n) { + // NOLINTNEXTLINE(*-narrowing-conversions) + stack.erase(stack.end() - n, stack.end()); +} +inline void drop(Stack* stack, size_t n) { + drop(*stack, n); +} +inline IValue pop(Stack& stack) { + TORCH_CHECK(!stack.empty(), "pop() called on empty stack"); + auto r = std::move(stack.back()); + stack.pop_back(); + return r; +} +inline IValue pop(Stack* stack) { + return pop(*stack); +} +inline std::vector pop(Stack& stack, size_t n) { + std::vector result; + result.reserve(n); + for (const auto i : c10::irange(n)) { + result.push_back(std::move(peek(stack, i, n))); + } + drop(stack, n); + return result; +} + +// variadic pop: +// int64_t a; at::Tensor b; +// pop(stack, a, b); +// equivalent to: +// b = pop(stack).toTensor(); +// a = pop(stack).toInt(); +template +inline void pop(Stack& stack, Types&... args) { + size_t i = 0; + constexpr size_t N = sizeof...(args); + (void)std::initializer_list{ + (args = std::move(peek(stack, i++, N)).template to(), 0)...}; + drop(stack, N); +} +template +inline void pop(Stack* stack, Types&... args) { + pop(*stack, args...); +} +template +inline void push_one(Stack& stack, Type&& arg) { + stack.emplace_back(std::forward(arg)); +} + +inline void push_one(Stack& stack, c10::TensorOptions options) { + stack.emplace_back(c10::typeMetaToScalarType(options.dtype())); + stack.emplace_back(options.layout()); + stack.emplace_back(options.device()); + stack.emplace_back(options.pinned_memory()); +} + +template +inline void push(Stack& stack, Types&&... args) { + (void)std::initializer_list{(push_one(stack, std::forward(args)), 0)...}; +} +template +inline void push(Stack* stack, Types&&... args) { + return push(*stack, std::forward(args)...); +} +template +inline void push_list_elements(Stack& stack, const c10::List& elements) { + for (T elem : elements) { + stack.push_back(std::move(elem)); + } +} + +// The packer here is carefully written not to make any unnecessary +// copies. + +// pack takes the return values of aten functions pushes them onto the stack +template +inline void pack(Stack& stack, T&& v) { + stack.emplace_back(std::forward(v)); +} +template +inline void pack(Stack* stack, T&& v) { + pack(*stack, std::forward(v)); +} + +template +struct TuplePacker { + // NB: *Not* a universal reference. + static void execute(Stack& stack, std::tuple&& t) { + // NB: The move here does not "destroy" the entire tuple, that is + // not what std::move does; only the particular tuple index + // processed here gets stolen. + pack(stack, std::get(std::move(t))); + TuplePacker::execute(stack, std::move(t)); + } +}; + +template +struct TuplePacker<0, Args...> { + // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) + static void execute(Stack& /*stack*/, std::tuple&& /*t*/){} +}; + +template +inline void pack(Stack& stack, std::tuple&& t) { + TuplePacker::execute(stack, std::move(t)); +} + +} // namespace torch::jit diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/symbol.h b/phivenv/Lib/site-packages/torch/include/ATen/core/symbol.h new file mode 100644 index 0000000000000000000000000000000000000000..a96fbe55d0a54028d4cdbaa020f412566407c0fd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/symbol.h @@ -0,0 +1,147 @@ +#pragma once +#include +#include +#include // For std::hash +#include + + +namespace c10 { + +// 'prim' symbols are synthetic operators that occur only in the IR +// and don't have corresponding implementations in ATen. + +// 'onnx' symbols correspond to ONNX operators. Their semantics +// are defined in https://github.com/onnx/onnx/blob/master/docs/Operators.md +// The particular version we are targeting is specified by '_onnx_opset_version' +// in torch.onnx.symbolic_helper +// +// In general, most ONNX operators won't get an entry here, because they +// are handled from the Python end. However, you may occasionally need +// to intern an ONNX symbol here so that you can conveniently write an +// optimization on ONNX operations. + +// 'attr' symbols are attribute keys. They are shared between both ONNX and ATen +// operators (you disambiguate their meaning by looking at the operator itself). +// In general, you only need to define attribute keys that are used by +// onnx or prim; ATen attributes are automatically generated in FORALL_ATTR_BASE_SYMBOLS. + +// Note [Symbol allocation] +// ~~~~~~~~~~~~~~~~~~~~~~~~ +// +// 1. Symbol namespace is split up into namespaces. +// +// 2. The intended access pattern for built-in symbols is onnx::MatMul +// in the c10 namespace (this is a Symbol). +// + +// Built-in constant definition strategy: +// - Enum is the most convenient way to generate a contiguous sequence +// of numbers for an identifier. +// - However, an enum gives you a fresh type. We want onnx::MatMul to +// be type Symbol, not some random enum type! +// - Therefore, after using enums to generate the sequence of integers, +// we then declare constexpr Symbols to get everything the actual Symbol +// type we want. Symbols must be constexpr to be valid to be "case"ed on. + +using unique_t = uint32_t; + +const std::string& domain_prefix(); + +// A Symbol is like an interned string, but with a little extra +// structure; it is namespaced via SymbolNamespace and the resulting +// intern pointers support efficient namespace testing. +struct TORCH_API Symbol { + explicit constexpr Symbol() : value(0) {} + explicit constexpr Symbol(unique_t uniq) + : value(uniq) {} + + // Get a Symbol for a qualified string like "attr::bar" + static Symbol fromQualString(const std::string & s); + + // Get a Symbol from a domain and an unqualified string like "org.pytorch.attr" and "bar" + static Symbol fromDomainAndUnqualString(const std::string & d, const std::string & s); + + // Constructors for our various namespaced strings. This will construct + // the appropriate namespaced string, e.g., "attr::foo" for the + // argument "foo", and then attempt to intern it. DO NOT USE THIS + // with a string literal; attr::foo should be available in that case + // (and if it's not, you should add it to the built-ins list above.) + static Symbol attr(const std::string & s); + static Symbol aten(const std::string & s); + static Symbol cuda(const std::string & s); + static Symbol onnx(const std::string & s); + static Symbol prim(const std::string & s); + static Symbol user(const std::string & s); + static Symbol caffe2(const std::string & s); + static Symbol dimname(const std::string & s); + // TODO: eliminate me + static Symbol scope(const std::string & s); + + bool is_attr() const; + bool is_aten() const; + bool is_cuda() const; + bool is_prim() const; + bool is_prims() const; + bool is_nvprims() const; + bool is_onnx() const; + bool is_user() const; + bool is_caffe2() const; + bool is_dimname() const; + + // So we can switch on this + constexpr operator unique_t() const { + return value; + } + + Symbol ns() const; + + // Give a string corresponding to the unqualified version of this name, e.g., + // "mm". Use this in a context where the intended namespace of the string is + // obvious; this is a *lossy* conversion. + const char * toUnqualString() const; + + // Give a string corresponding to the qualified version of this name, + // e.g., "aten::mm". This string format is made available to Python bindings + // (so we know how to parse it.) + const char * toQualString() const; + + // This describes a symbol in a case where humans read it. At the moment it's + // the same as toQualString. This has to be a const char* returned because + // a lot of printf style macros use it. + const char * toDisplayString() const; + + // Give a string corresponding to the domain name for the symbol, + // e.g., "org.pytorch.aten". + std::string domainString() const; + +private: + + explicit Symbol(Symbol ns, const std::string & s); + unique_t value; +}; + +static inline bool operator==(Symbol lhs, Symbol rhs) { + return static_cast(lhs) == static_cast(rhs); +} + +inline Symbol Symbol::attr(const std::string & s) { return Symbol::fromQualString("attr::" + s); } +inline Symbol Symbol::aten(const std::string & s) { return Symbol::fromQualString("aten::" + s); } +inline Symbol Symbol::cuda(const std::string & s) { return Symbol::fromQualString("cuda::" + s); } +inline Symbol Symbol::onnx(const std::string & s) { return Symbol::fromQualString("onnx::" + s); } +inline Symbol Symbol::prim(const std::string & s) { return Symbol::fromQualString("prim::" + s); } +inline Symbol Symbol::scope(const std::string & s) { return Symbol::fromQualString("scope::" + s); } +inline Symbol Symbol::user(const std::string & s) { return Symbol::fromQualString("user::" + s); } +inline Symbol Symbol::caffe2(const std::string & s) { return Symbol::fromQualString("_caffe2::" + s); } +inline Symbol Symbol::dimname(const std::string & s) { return Symbol::fromQualString("dimname::" + s); } + +} // namespace c10 + +// make symbol behave like an integer in hash tables +namespace std { +template <> +struct hash { + size_t operator()(c10::Symbol s) const { + return std::hash()(static_cast(s)); + } +}; +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/type_factory.h b/phivenv/Lib/site-packages/torch/include/ATen/core/type_factory.h new file mode 100644 index 0000000000000000000000000000000000000000..096069684d1c98efb55e939c224c352846df9e95 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/type_factory.h @@ -0,0 +1,108 @@ +#pragma once + +#include +#include + +#include +#include +#include + +namespace c10 { + +template +struct TORCH_API TypeFactoryBase {}; + +template <> +struct TORCH_API TypeFactoryBase { + template + static c10::DynamicTypePtr create(TypePtr ty, Args&&... args) { + return std::make_shared( + c10::DynamicTypeTrait::tagValue(), + c10::DynamicType::Arguments(c10::ArrayRef( + {std::move(ty), std::forward(args)...}))); + } + template + static c10::DynamicTypePtr create(const std::vector& types) { + return std::make_shared( + c10::DynamicTypeTrait::tagValue(), + c10::DynamicType::Arguments(types)); + } + static c10::DynamicTypePtr createNamedTuple( + const std::string& name, + const std::vector& fields, + const std::vector& types) { + return std::make_shared( + c10::DynamicType::Tag::Tuple, + name, + c10::DynamicType::Arguments(fields, types)); + } + template + C10_ERASE static c10::DynamicTypePtr createNamed(const std::string& name) { + return std::make_shared( + c10::DynamicTypeTrait::tagValue(), + name, + c10::DynamicType::Arguments{}); + } + template + C10_ERASE static decltype(auto) get() { + return DynamicTypeTrait::getBaseType(); + } + static const std::unordered_map& basePythonTypes(); +}; + +using DynamicTypeFactory = TypeFactoryBase; + +// Helper functions for constructing DynamicTypes inline. +template < + typename T, + std::enable_if_t::isBaseType, int> = 0> +C10_ERASE DynamicTypePtr dynT() { + return DynamicTypeFactory::get(); +} + +template < + typename T, + typename... Args, + std::enable_if_t::isBaseType, int> = 0> +C10_ERASE DynamicTypePtr dynT(Args&&... args) { + return DynamicTypeFactory::create(std::forward(args)...); +} + +template <> +struct TORCH_API TypeFactoryBase { + template + static c10::TypePtr create(TypePtr ty, Args&&... args) { + return T::create(std::move(ty), std::forward(args)...); + } + template + static c10::TypePtr create(std::vector types) { + return T::create(std::move(types)); + } + static c10::TypePtr createNamedTuple( + const std::string& name, + const std::vector& fields, + const std::vector& types); + template + C10_ERASE static c10::TypePtr createNamed(const std::string& name) { + return T::create(name); + } + static const std::unordered_map& basePythonTypes(); + template + C10_ERASE static c10::TypePtr get() { + return T::get(); + } +}; + +using DefaultTypeFactory = TypeFactoryBase; + +using PlatformType = +#ifdef C10_MOBILE + c10::DynamicType +#else + c10::Type +#endif + ; + +using TypeFactory = TypeFactoryBase; + +} // namespace c10 diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/type_ptr.h b/phivenv/Lib/site-packages/torch/include/ATen/core/type_ptr.h new file mode 100644 index 0000000000000000000000000000000000000000..cf8796e278e034cbf39c20f8e7560903189c4d36 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/type_ptr.h @@ -0,0 +1,54 @@ +#pragma once + +#include +#include + +#include +#include + +namespace c10 { + +// Compatibility wrapper around a raw pointer so that existing code +// written to deal with a shared_ptr can keep working. +template +class SingletonTypePtr { + public: + /* implicit */ SingletonTypePtr(T* p) : repr_(p) {} + + // We need this to satisfy Pybind11, but it shouldn't be hit. + explicit SingletonTypePtr(std::shared_ptr) { TORCH_CHECK(false); } + + using element_type = typename std::shared_ptr::element_type; + + template , void>, bool> = true> + T& operator*() const { + return *repr_; + } + + T* get() const { + return repr_; + } + + T* operator->() const { + return repr_; + } + + operator bool() const { + return repr_ != nullptr; + } + + private: + T* repr_{nullptr}; +}; + +template +bool operator==(SingletonTypePtr lhs, SingletonTypePtr rhs) { + return (void*)lhs.get() == (void*)rhs.get(); +} + +template +bool operator!=(SingletonTypePtr lhs, SingletonTypePtr rhs) { + return !(lhs == rhs); +} + +} // namespace c10 diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/typeid.h b/phivenv/Lib/site-packages/torch/include/ATen/core/typeid.h new file mode 100644 index 0000000000000000000000000000000000000000..d69eba920abb0059a113405faf0264cd5a9b7bab --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/typeid.h @@ -0,0 +1 @@ +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/FlushDenormal.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/FlushDenormal.h new file mode 100644 index 0000000000000000000000000000000000000000..0d7b4b9cc679c93d48f3b1be053a7ff9fb004128 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/FlushDenormal.h @@ -0,0 +1,14 @@ +/// Flush-To-Zero and Denormals-Are-Zero mode +/// +/// Flush-To-Zero (FTZ) and Denormals-Are-Zero (DAZ) are modes that bypass +/// IEEE 754 methods of dealing with denormal floating-point numbers on x86-64 +/// and some x86 CPUs. They result in reduced precision for values near zero, +/// but increased performance. +/// +/// See https://software.intel.com/en-us/articles/x87-and-sse-floating-point-assists-in-ia-32-flush-to-zero-ftz-and-denormals-are-zero-daz + +namespace at::cpu { + +bool set_flush_denormal(bool on); + +} // namespace at::cpu diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/Utils.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/Utils.h new file mode 100644 index 0000000000000000000000000000000000000000..90f730c8711516d01cab15cc0d6714c598d2f08b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/Utils.h @@ -0,0 +1,33 @@ +#pragma once + +#include + +#include + +namespace at::cpu { + +TORCH_API bool is_avx2_supported(); +TORCH_API bool is_avx512_supported(); + +// Detect if CPU support Vector Neural Network Instruction. +TORCH_API bool is_avx512_vnni_supported(); + +// Detect if CPU supports AVX512_BF16 ISA +TORCH_API bool is_avx512_bf16_supported(); + +// Detect if CPU support Advanced Matrix Extension. +TORCH_API bool is_amx_tile_supported(); + +// Detect if CPU support Advanced Matrix Extension for fp16. +TORCH_API bool is_amx_fp16_supported(); + +// Enable the system to use AMX instructions. +TORCH_API bool init_amx(); + +// Get the L1 cache size per core in Byte +TORCH_API uint32_t L1d_cache_size(); + +// Get the L2 cache size per core in Byte +TORCH_API uint32_t L2_cache_size(); + +} // namespace at::cpu diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/functional.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/functional.h new file mode 100644 index 0000000000000000000000000000000000000000..032e9bfa471391b3a38e56dedd04c7a881a241f2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/functional.h @@ -0,0 +1,4 @@ +#pragma once + +#include +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/functional_base.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/functional_base.h new file mode 100644 index 0000000000000000000000000000000000000000..4c9c43c4272f3c2b1c824a2d81c5ac7b0650e49a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/functional_base.h @@ -0,0 +1,475 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include + +namespace at { +namespace detail { +// We prefer to convert through float for reduced-precision floating +// point types if we have a Vectorized specialization for float and we +// don't have one for the actual type in question. +template +struct should_prefer_converting_through_float + : std::bool_constant< + is_reduced_floating_point_v && + vec::is_vec_specialized_for_v && + !vec::is_vec_specialized_for_v> {}; + +template +constexpr auto should_prefer_converting_through_float_v = + should_prefer_converting_through_float::value; +} // namespace detail + +namespace vec { +// slow path +template +inline scalar_t vec_reduce_all( + const Op& vec_fun, + vec::Vectorized acc_vec, + int64_t size) { + using Vec = vec::Vectorized; + scalar_t acc_arr[Vec::size()]; + acc_vec.store(acc_arr); + for (const auto i : c10::irange(1, size)) { + std::array acc_arr_next = {0}; + acc_arr_next[0] = acc_arr[i]; + Vec acc_vec_next = Vec::loadu(acc_arr_next.data()); + acc_vec = vec_fun(acc_vec, acc_vec_next); + } + acc_vec.store(acc_arr); + return acc_arr[0]; +} + +template +struct VecReduceAllSIMD { + static inline scalar_t apply( + const Op& vec_fun, + const Vectorized& acc_vec) { + return vec_reduce_all(vec_fun, acc_vec, Vectorized::size()); + } +}; + +#if defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && \ + !defined(C10_MOBILE) +#if defined(CPU_CAPABILITY_AVX2) +template +struct VecReduceAllSIMD { + static inline float apply( + const Op& vec_fun, + const Vectorized& acc_vec) { + using Vec = Vectorized; + Vec v = acc_vec; + // 128-bit shuffle + Vec v1 = _mm256_permute2f128_ps(v, v, 0x1); + v = vec_fun(v, v1); + // 64-bit shuffle + v1 = _mm256_shuffle_ps(v, v, 0x4E); + v = vec_fun(v, v1); + // 32-bit shuffle + v1 = _mm256_shuffle_ps(v, v, 0xB1); + v = vec_fun(v, v1); + return _mm256_cvtss_f32(v); + } +}; +#endif // defined(CPU_CAPABILITY_AVX2) +#if defined(CPU_CAPABILITY_AVX512) +template +struct VecReduceAllSIMD { + static inline float apply( + const Op& vec_fun, + const Vectorized& acc_vec) { + using Vec = Vectorized; + Vec v = acc_vec; + // 256-bit shuffle + Vec v1 = _mm512_shuffle_f32x4(v, v, 0x4E); + v = vec_fun(v, v1); + // 128-bit shuffle + v1 = _mm512_shuffle_f32x4(v, v, 0xB1); + v = vec_fun(v, v1); + // 64-bit shuffle + v1 = _mm512_shuffle_ps(v, v, 0x4E); + v = vec_fun(v, v1); + // 32-bit shuffle + v1 = _mm512_shuffle_ps(v, v, 0xB1); + v = vec_fun(v, v1); + return _mm512_cvtss_f32(v); + } +}; +#endif // defined(CPU_CAPABILITY_AVX512) +#endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && + // !defined(C10_MOBILE) + +#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \ + !defined(CPU_CAPABILITY_SVE) +template +struct VecReduceAllSIMD { + static inline float apply( + const Op& vec_fun, + const Vectorized& acc_vec) { + using Vec = Vectorized; + Vec v = acc_vec; + + // 64-bit shuffle: [a1+a5, a2+a6, a3+a7, a4+a8, -, -, -, -] -> [a3+a7, + // a4+a8, a1+a5, a2+a6, -, -, -, -] + float32x4_t v1_1 = vextq_f32(v, v, 2); + Vec v1 = v1_1; + // [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, -, -, -] + v = vec_fun(v, v1); + + // 32-bit shuffle: [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, + // -, -, -] -> [a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, -, -, -, + // -] + v1_1 = vrev64q_f32(v); + v1 = v1_1; + // [a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, + // a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, -, -, -, -] + v = vec_fun(v, v1); + + return v[0]; + } +}; + +template <> +struct VecReduceAllSIMD>> { + static inline float apply( + const std::plus>& vec_fun, + const Vectorized& acc_vec) { + return vaddvq_f32(acc_vec); + } +}; +#endif // defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) + // && !defined(CPU_CAPABILITY_SVE) + +#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \ + defined(CPU_CAPABILITY_SVE256) +template +struct VecReduceAllSIMD { + static inline float apply( + const Op& vec_fun, + const Vectorized& acc_vec) { + using Vec = Vectorized; + Vec v = acc_vec; + // 128-bit shuffle + svuint32_t ind = svdupq_n_u32(4, 5, 6, 7); + Vec v1 = svtbl_f32(v, ind); + v = vec_fun(v, v1); + // 64-bit shuffle + ind = svdupq_n_u32(2, 3, 0, 1); + v1 = svtbl_f32(v, ind); + v = vec_fun(v, v1); + // 32-bit shuffle + ind = svdupq_n_u32(1, 0, 2, 3); + v1 = svtbl_f32(v, ind); + v = vec_fun(v, v1); + return svlasta(svpfalse(), v); + } +}; +#endif // defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) + // && defined(CPU_CAPABILITY_SVE256) + +template +inline scalar_t vec_reduce_all( + const Op& vec_fun, + const Vectorized& acc_vec) { + return VecReduceAllSIMD::apply(vec_fun, acc_vec); +} + +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> +inline scalar_t reduce_all( + const Op& vec_fun, + const scalar_t* data, + int64_t size) { + using Vec = vec::Vectorized; + if (size < Vec::size()) + return vec_reduce_all(vec_fun, Vec::loadu(data, size), size); + int64_t d = Vec::size(); + Vec acc_vec = Vec::loadu(data); + for (; d < size - (size % Vec::size()); d += Vec::size()) { + Vec data_vec = Vec::loadu(data + d); + acc_vec = vec_fun(acc_vec, data_vec); + } + if (size - d > 0) { + Vec data_vec = Vec::loadu(data + d, size - d); + acc_vec = Vec::set(acc_vec, vec_fun(acc_vec, data_vec), size - d); + } + return vec_reduce_all(vec_fun, acc_vec); +} + +// similar to reduce_all, but reduces into two outputs +template < + typename scalar_t, + typename Op1, + typename Op2, + typename std::enable_if_t, int> = 0> +inline std::pair reduce2_all( + const Op1& vec_fun1, + const Op2& vec_fun2, + const scalar_t* data, + int64_t size) { + using Vec = vec::Vectorized; + if (size < Vec::size()) { + auto loaded_data = Vec::loadu(data, size); + return std::pair( + vec_reduce_all(vec_fun1, loaded_data, size), + vec_reduce_all(vec_fun2, loaded_data, size)); + } + int64_t d = Vec::size(); + Vec acc_vec1 = Vec::loadu(data); + Vec acc_vec2 = Vec::loadu(data); + for (; d < size - (size % Vec::size()); d += Vec::size()) { + Vec data_vec = Vec::loadu(data + d); + acc_vec1 = vec_fun1(acc_vec1, data_vec); + acc_vec2 = vec_fun2(acc_vec2, data_vec); + } + if (size - d > 0) { + Vec data_vec = Vec::loadu(data + d, size - d); + acc_vec1 = Vec::set(acc_vec1, vec_fun1(acc_vec1, data_vec), size - d); + acc_vec2 = Vec::set(acc_vec2, vec_fun2(acc_vec2, data_vec), size - d); + } + return std::pair( + vec_reduce_all(vec_fun1, acc_vec1), vec_reduce_all(vec_fun2, acc_vec2)); +} + +template < + typename scalar_t, + typename MapOp, + typename ReduceOp, + typename std::enable_if_t, int> = 0> +inline scalar_t map_reduce_all( + const MapOp& map_fun, + const ReduceOp& red_fun, + const scalar_t* data, + int64_t size) { + using Vec = vec::Vectorized; + if (size < Vec::size()) + return vec_reduce_all(red_fun, map_fun(Vec::loadu(data, size)), size); + int64_t d = Vec::size(); + Vec acc_vec = map_fun(Vec::loadu(data)); + for (; d < size - (size % Vec::size()); d += Vec::size()) { + Vec data_vec = Vec::loadu(data + d); + data_vec = map_fun(data_vec); + acc_vec = red_fun(acc_vec, data_vec); + } + if (size - d > 0) { + Vec data_vec = Vec::loadu(data + d, size - d); + data_vec = map_fun(data_vec); + acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d); + } + return vec_reduce_all(red_fun, acc_vec); +} + +template < + typename scalar_t, + typename MapOp, + typename ReduceOp, + typename std::enable_if_t, int> = 0> +inline scalar_t map2_reduce_all( + const MapOp& map_fun, + const ReduceOp& red_fun, + const scalar_t* data, + const scalar_t* data2, + int64_t size) { + using Vec = vec::Vectorized; + if (size < Vec::size()) { + Vec data_vec = Vec::loadu(data, size); + Vec data2_vec = Vec::loadu(data2, size); + data_vec = map_fun(data_vec, data2_vec); + return vec_reduce_all(red_fun, data_vec, size); + } + int64_t d = Vec::size(); + Vec acc_vec = map_fun(Vec::loadu(data), Vec::loadu(data2)); + for (; d < size - (size % Vec::size()); d += Vec::size()) { + Vec data_vec = Vec::loadu(data + d); + Vec data2_vec = Vec::loadu(data2 + d); + data_vec = map_fun(data_vec, data2_vec); + acc_vec = red_fun(acc_vec, data_vec); + } + if (size - d > 0) { + Vec data_vec = Vec::loadu(data + d, size - d); + Vec data2_vec = Vec::loadu(data2 + d, size - d); + data_vec = map_fun(data_vec, data2_vec); + acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d); + } + return vec_reduce_all(red_fun, acc_vec); +} + +template < + typename scalar_t, + typename MapOp, + typename ReduceOp, + typename std::enable_if_t, int> = 0> +inline scalar_t map3_reduce_all( + const MapOp& map_fun, + const ReduceOp& red_fun, + const scalar_t* data, + const scalar_t* data2, + const scalar_t* data3, + int64_t size) { + using Vec = vec::Vectorized; + if (size < Vec::size()) { + Vec data_vec = Vec::loadu(data, size); + Vec data2_vec = Vec::loadu(data2, size); + Vec data3_vec = Vec::loadu(data3, size); + data_vec = map_fun(data_vec, data2_vec, data3_vec); + return vec_reduce_all(red_fun, data_vec, size); + } + + int64_t d = Vec::size(); + Vec acc_vec = map_fun(Vec::loadu(data), Vec::loadu(data2), Vec::loadu(data3)); + for (; d < size - (size % Vec::size()); d += Vec::size()) { + Vec data_vec = Vec::loadu(data + d); + Vec data2_vec = Vec::loadu(data2 + d); + Vec data3_vec = Vec::loadu(data3 + d); + data_vec = map_fun(data_vec, data2_vec, data3_vec); + acc_vec = red_fun(acc_vec, data_vec); + } + if (size - d > 0) { + Vec data_vec = Vec::loadu(data + d, size - d); + Vec data2_vec = Vec::loadu(data2 + d, size - d); + Vec data3_vec = Vec::loadu(data3 + d, size - d); + data_vec = map_fun(data_vec, data2_vec, data3_vec); + acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d); + } + return vec_reduce_all(red_fun, acc_vec); +} + +template < + typename scalar_t, + typename Op, + typename std::enable_if_t< + !detail::should_prefer_converting_through_float_v && + std::is_invocable_v>, + int> = 0> +inline void map( + const Op& vec_fun, + scalar_t* output_data, + const scalar_t* input_data, + int64_t size) { + using Vec = vec::Vectorized; + int64_t d = 0; + for (; d < size - (size % Vec::size()); d += Vec::size()) { + Vec output_vec = vec_fun(Vec::loadu(input_data + d)); + output_vec.store(output_data + d); + } + if (size - d > 0) { + Vec output_vec = vec_fun(Vec::loadu(input_data + d, size - d)); + output_vec.store(output_data + d, size - d); + } +} + +template < + typename scalar_t, + typename Op, + typename std::enable_if_t< + !detail::should_prefer_converting_through_float_v && + std::is_invocable_v< + Op, + vec::Vectorized, + vec::Vectorized>, + int> = 0> +inline void map2( + const Op& vec_fun, + scalar_t* output_data, + const scalar_t* input_data, + const scalar_t* input_data2, + int64_t size) { + using Vec = vec::Vectorized; + int64_t d = 0; + for (; d < size - (size % Vec::size()); d += Vec::size()) { + Vec data_vec = Vec::loadu(input_data + d); + Vec data_vec2 = Vec::loadu(input_data2 + d); + Vec output_vec = vec_fun(data_vec, data_vec2); + output_vec.store(output_data + d); + } + if (size - d > 0) { + Vec data_vec = Vec::loadu(input_data + d, size - d); + Vec data_vec2 = Vec::loadu(input_data2 + d, size - d); + Vec output_vec = vec_fun(data_vec, data_vec2); + output_vec.store(output_data + d, size - d); + } +} + +template < + typename scalar_t, + typename Op, + typename std::enable_if_t< + !detail::should_prefer_converting_through_float_v && + std::is_invocable_v< + Op, + vec::Vectorized, + vec::Vectorized, + vec::Vectorized>, + int> = 0> +inline void map3( + const Op& vec_fun, + scalar_t* output_data, + const scalar_t* input_data1, + const scalar_t* input_data2, + const scalar_t* input_data3, + int64_t size) { + using Vec = vec::Vectorized; + int64_t d = 0; + for (; d < size - (size % Vec::size()); d += Vec::size()) { + Vec data_vec1 = Vec::loadu(input_data1 + d); + Vec data_vec2 = Vec::loadu(input_data2 + d); + Vec data_vec3 = Vec::loadu(input_data3 + d); + Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3); + output_vec.store(output_data + d); + } + if (size - d > 0) { + Vec data_vec1 = Vec::loadu(input_data1 + d, size - d); + Vec data_vec2 = Vec::loadu(input_data2 + d, size - d); + Vec data_vec3 = Vec::loadu(input_data3 + d, size - d); + Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3); + output_vec.store(output_data + d, size - d); + } +} + +template < + typename scalar_t, + typename Op, + typename std::enable_if_t< + !detail::should_prefer_converting_through_float_v && + std::is_invocable_v< + Op, + vec::Vectorized, + vec::Vectorized, + vec::Vectorized, + vec::Vectorized>, + int> = 0> +inline void map4( + const Op& vec_fun, + scalar_t* output_data, + const scalar_t* input_data1, + const scalar_t* input_data2, + const scalar_t* input_data3, + const scalar_t* input_data4, + int64_t size) { + using Vec = vec::Vectorized; + int64_t d = 0; + for (; d < size - (size % Vec::size()); d += Vec::size()) { + Vec data_vec1 = Vec::loadu(input_data1 + d); + Vec data_vec2 = Vec::loadu(input_data2 + d); + Vec data_vec3 = Vec::loadu(input_data3 + d); + Vec data_vec4 = Vec::loadu(input_data4 + d); + Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3, data_vec4); + output_vec.store(output_data + d); + } + if (size - d > 0) { + Vec data_vec1 = Vec::loadu(input_data1 + d, size - d); + Vec data_vec2 = Vec::loadu(input_data2 + d, size - d); + Vec data_vec3 = Vec::loadu(input_data3 + d, size - d); + Vec data_vec4 = Vec::loadu(input_data4 + d, size - d); + Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3, data_vec4); + output_vec.store(output_data + d, size - d); + } +} + +} // namespace vec +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/functional_bfloat16.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/functional_bfloat16.h new file mode 100644 index 0000000000000000000000000000000000000000..70364cccee54b1445761ad99203f9b5fb9bfb594 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/functional_bfloat16.h @@ -0,0 +1,647 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include + +namespace at::vec { +// BFloat16 specification +template +struct VecScalarType { + using type = scalar_t; +}; +template <> +struct VecScalarType { + using type = float; +}; +template <> +struct VecScalarType { + using type = float; +}; + +// This is different from at::acc_type since we only need to specialize BFloat16 +template +using vec_scalar_t = typename VecScalarType::type; + +// Vector conversion between float and bfloat16/half +template <> +inline std::tuple, Vectorized> convert_to_float< + BFloat16>(const Vectorized& a) { + return convert_bfloat16_float(a); +} + +template <> +inline std::tuple, Vectorized> convert_to_float( + const Vectorized& a) { + return convert_half_float(a); +} + +template <> +inline Vectorized convert_from_float( + const Vectorized& a, + const Vectorized& b) { + return convert_float_bfloat16(a, b); +} + +template <> +inline Vectorized convert_from_float( + const Vectorized& a, + const Vectorized& b) { + return convert_float_half(a, b); +} + +template < + typename scalar_t, + typename std::enable_if_t, int> = 0> +inline void load_to_float( + const scalar_t* data, + Vectorized& out1, + Vectorized& out2); + +template <> +inline void load_to_float( + const BFloat16* data, + Vectorized& out1, + Vectorized& out2) { + load_fp32_from_bf16(data, out1, out2); +} + +template <> +inline void load_to_float( + const Half* data, + Vectorized& out1, + Vectorized& out2) { + load_fp32_from_fp16(data, out1, out2); +} + +template < + typename scalar_t, + typename std::enable_if_t, int> = 0> +inline void load_to_float(const scalar_t* data, Vectorized& out); + +template <> +inline void load_to_float( + const BFloat16* data, + Vectorized& out) { + load_fp32_from_bf16(data, out); +} + +template <> +inline void load_to_float(const Half* data, Vectorized& out) { + load_fp32_from_fp16(data, out); +} + +// Note that we already have specialized member of Vectorized for +// BFloat16 so the following functions would run smoothly: +// using Vec = Vectorized; +// Vec one = Vec(BFloat16(1)); +// vec::map([](Vec x) { return one / (one + x.exp()); }, y_ptr, x_ptr, N); +// +// Then why we still need to specialize "functional"? +// If we do specialization at Vectorized<> level, the above example would need +// 3 pairs of conversion of bf16->fp32/fp32->bf16, each for ".exp()", "+" and +// "/". If we do specialization at vec::map<>() level, we have only 1 pair of +// conversion of bf16->fp32/fp32->bf16, for the input and output BFloat16 +// vector only. +// +// The following BFloat16 functionality will only do data type conversion for +// input and output vector (reduce functionality will only convert the final +// scalar back to bf16). Compared to Vectorized<> specialization, +// 1. better performance since we have less data type conversion; +// 2. less rounding error since immediate results are kept in fp32; +// 3. accumulation done on data type of fp32. +// +// If you plan to extend this file, please ensure adding unit tests at +// aten/src/ATen/test/vec_test_all_types.cpp +// +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> +inline float reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) { + using bVec = vec::Vectorized; + using fVec = vec::Vectorized; + if (size < bVec::size()) { + bVec data_bvec = bVec::loadu(data, size); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + if (size > fVec::size()) { + data_fvec0 = fVec::set( + data_fvec0, vec_fun(data_fvec0, data_fvec1), size - fVec::size()); + return vec_reduce_all(vec_fun, data_fvec0, fVec::size()); + } else { + return vec_reduce_all(vec_fun, data_fvec0, size); + } + } + int64_t d = bVec::size(); + bVec acc_bvec = bVec::loadu(data); + auto [acc_fvec0, acc_fvec1] = convert_to_float(acc_bvec); + for (; d < size - (size % bVec::size()); d += bVec::size()) { + bVec data_bvec = bVec::loadu(data + d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + acc_fvec0 = vec_fun(acc_fvec0, data_fvec0); + acc_fvec1 = vec_fun(acc_fvec1, data_fvec1); + } + if (size - d > 0) { + bVec data_bvec = bVec::loadu(data + d, size - d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + if (size - d > fVec::size()) { + acc_fvec0 = vec_fun(acc_fvec0, data_fvec0); + acc_fvec1 = fVec::set( + acc_fvec1, vec_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); + } else { + acc_fvec0 = + fVec::set(acc_fvec0, vec_fun(acc_fvec0, data_fvec0), size - d); + } + } + acc_fvec0 = vec_fun(acc_fvec0, acc_fvec1); + return vec_reduce_all(vec_fun, acc_fvec0); +} + +template < + typename scalar_t, + typename Op1, + typename Op2, + typename std::enable_if_t, int> = 0> +inline std::pair reduce2_all( + const Op1& vec_fun1, + const Op2& vec_fun2, + const scalar_t* data, + int64_t size) { + using bVec = vec::Vectorized; + using fVec = vec::Vectorized; + if (size < bVec::size()) { + bVec data_bvec = bVec::loadu(data, size); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + if (size > fVec::size()) { + fVec acc1_fvec = fVec::set( + data_fvec0, vec_fun1(data_fvec0, data_fvec1), size - fVec::size()); + fVec acc2_fvec = fVec::set( + data_fvec0, vec_fun2(data_fvec0, data_fvec1), size - fVec::size()); + return std::pair( + vec_reduce_all(vec_fun1, acc1_fvec, fVec::size()), + vec_reduce_all(vec_fun2, acc2_fvec, fVec::size())); + } else { + return std::pair( + vec_reduce_all(vec_fun1, data_fvec0, size), + vec_reduce_all(vec_fun2, data_fvec0, size)); + } + } + int64_t d = bVec::size(); + bVec acc_bvec = bVec::loadu(data); + auto [acc1_fvec0, acc1_fvec1] = convert_to_float(acc_bvec); + auto [acc2_fvec0, acc2_fvec1] = convert_to_float(acc_bvec); + for (; d < size - (size % bVec::size()); d += bVec::size()) { + bVec data_bvec = bVec::loadu(data + d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + acc1_fvec0 = vec_fun1(acc1_fvec0, data_fvec0); + acc1_fvec1 = vec_fun1(acc1_fvec1, data_fvec1); + acc2_fvec0 = vec_fun2(acc2_fvec0, data_fvec0); + acc2_fvec1 = vec_fun2(acc2_fvec1, data_fvec1); + } + if (size - d > 0) { + bVec data_bvec = bVec::loadu(data + d, size - d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + if (size - d > fVec::size()) { + acc1_fvec0 = vec_fun1(acc1_fvec0, data_fvec0); + acc1_fvec1 = fVec::set( + acc1_fvec1, + vec_fun1(acc1_fvec1, data_fvec1), + size - d - fVec::size()); + acc2_fvec0 = vec_fun2(acc2_fvec0, data_fvec0); + acc2_fvec1 = fVec::set( + acc2_fvec1, + vec_fun2(acc2_fvec1, data_fvec1), + size - d - fVec::size()); + } else { + acc1_fvec0 = + fVec::set(acc1_fvec0, vec_fun1(acc1_fvec0, data_fvec0), size - d); + acc2_fvec0 = + fVec::set(acc2_fvec0, vec_fun2(acc2_fvec0, data_fvec0), size - d); + } + } + acc1_fvec0 = vec_fun1(acc1_fvec0, acc1_fvec1); + acc2_fvec0 = vec_fun2(acc2_fvec0, acc2_fvec1); + return std::pair( + vec_reduce_all(vec_fun1, acc1_fvec0), + vec_reduce_all(vec_fun2, acc2_fvec0)); +} + +template < + typename scalar_t, + typename MapOp, + typename ReduceOp, + typename std::enable_if_t, int> = 0> +inline float map_reduce_all( + const MapOp& map_fun, + const ReduceOp& red_fun, + const scalar_t* data, + int64_t size) { + using bVec = vec::Vectorized; + using fVec = vec::Vectorized; + if (size < bVec::size()) { + bVec data_bvec = bVec::loadu(data, size); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + if (size > fVec::size()) { + data_fvec0 = map_fun(data_fvec0); + data_fvec1 = map_fun(data_fvec1); + data_fvec0 = fVec::set( + data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size()); + return vec_reduce_all(red_fun, data_fvec0, fVec::size()); + } else { + data_fvec0 = map_fun(data_fvec0); + return vec_reduce_all(red_fun, data_fvec0, size); + } + } + int64_t d = bVec::size(); + bVec acc_bvec = bVec::loadu(data); + auto [acc_fvec0, acc_fvec1] = convert_to_float(acc_bvec); + acc_fvec0 = map_fun(acc_fvec0); + acc_fvec1 = map_fun(acc_fvec1); + for (; d < size - (size % bVec::size()); d += bVec::size()) { + bVec data_bvec = bVec::loadu(data + d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + data_fvec0 = map_fun(data_fvec0); + data_fvec1 = map_fun(data_fvec1); + acc_fvec0 = red_fun(acc_fvec0, data_fvec0); + acc_fvec1 = red_fun(acc_fvec1, data_fvec1); + } + if (size - d > 0) { + bVec data_bvec = bVec::loadu(data + d, size - d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + if (size - d > fVec::size()) { + data_fvec0 = map_fun(data_fvec0); + data_fvec1 = map_fun(data_fvec1); + acc_fvec0 = red_fun(acc_fvec0, data_fvec0); + acc_fvec1 = fVec::set( + acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); + } else { + data_fvec0 = map_fun(data_fvec0); + acc_fvec0 = + fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d); + } + } + acc_fvec0 = red_fun(acc_fvec0, acc_fvec1); + return vec_reduce_all(red_fun, acc_fvec0); +} + +template < + typename scalar_t, + typename MapOp, + typename ReduceOp, + typename std::enable_if_t, int> = 0> +inline float map2_reduce_all( + const MapOp& map_fun, + const ReduceOp& red_fun, + const scalar_t* data, + const scalar_t* data2, + int64_t size) { + using bVec = vec::Vectorized; + using fVec = vec::Vectorized; + if (size < bVec::size()) { + bVec data_bvec = bVec::loadu(data, size); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + bVec data2_bvec = bVec::loadu(data2, size); + auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec); + if (size > fVec::size()) { + data_fvec0 = map_fun(data_fvec0, data2_fvec0); + data_fvec1 = map_fun(data_fvec1, data2_fvec1); + data_fvec0 = fVec::set( + data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size()); + return vec_reduce_all(red_fun, data_fvec0, fVec::size()); + } else { + data_fvec0 = map_fun(data_fvec0, data2_fvec0); + return vec_reduce_all(red_fun, data_fvec0, size); + } + } + int64_t d = bVec::size(); + bVec acc_bvec = bVec::loadu(data); + auto [acc_fvec0, acc_fvec1] = convert_to_float(acc_bvec); + bVec acc2_bvec = bVec::loadu(data2); + auto [acc2_fvec0, acc2_fvec1] = convert_to_float(acc2_bvec); + acc_fvec0 = map_fun(acc_fvec0, acc2_fvec0); + acc_fvec1 = map_fun(acc_fvec1, acc2_fvec1); + for (; d < size - (size % bVec::size()); d += bVec::size()) { + bVec data_bvec = bVec::loadu(data + d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + bVec data2_bvec = bVec::loadu(data2 + d); + auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec); + data_fvec0 = map_fun(data_fvec0, data2_fvec0); + data_fvec1 = map_fun(data_fvec1, data2_fvec1); + acc_fvec0 = red_fun(acc_fvec0, data_fvec0); + acc_fvec1 = red_fun(acc_fvec1, data_fvec1); + } + if (size - d > 0) { + bVec data_bvec = bVec::loadu(data + d, size - d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + bVec data2_bvec = bVec::loadu(data2 + d, size - d); + auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec); + if (size - d > fVec::size()) { + data_fvec0 = map_fun(data_fvec0, data2_fvec0); + data_fvec1 = map_fun(data_fvec1, data2_fvec1); + acc_fvec0 = red_fun(acc_fvec0, data_fvec0); + acc_fvec1 = fVec::set( + acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); + } else { + data_fvec0 = map_fun(data_fvec0, data2_fvec0); + acc_fvec0 = + fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d); + } + } + acc_fvec0 = red_fun(acc_fvec0, acc_fvec1); + return vec_reduce_all(red_fun, acc_fvec0); +} + +template < + typename scalar_t, + typename MapOp, + typename ReduceOp, + typename std::enable_if_t, int> = 0> +inline float map3_reduce_all( + const MapOp& map_fun, + const ReduceOp& red_fun, + const scalar_t* data, + const scalar_t* data2, + const scalar_t* data3, + int64_t size) { + using bVec = vec::Vectorized; + using fVec = vec::Vectorized; + if (size < bVec::size()) { + bVec data_bvec = bVec::loadu(data, size); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + bVec data2_bvec = bVec::loadu(data2, size); + auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec); + bVec data3_bvec = bVec::loadu(data3, size); + auto [data3_fvec0, data3_fvec1] = convert_to_float(data3_bvec); + if (size > fVec::size()) { + data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0); + data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1); + data_fvec0 = fVec::set( + data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size()); + return vec_reduce_all(red_fun, data_fvec0, fVec::size()); + } else { + data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0); + return vec_reduce_all(red_fun, data_fvec0, size); + } + } + int64_t d = bVec::size(); + bVec acc_bvec = bVec::loadu(data); + auto [acc_fvec0, acc_fvec1] = convert_to_float(acc_bvec); + bVec acc2_bvec = bVec::loadu(data2); + auto [acc2_fvec0, acc2_fvec1] = convert_to_float(acc2_bvec); + bVec acc3_bvec = bVec::loadu(data3); + auto [acc3_fvec0, acc3_fvec1] = convert_to_float(acc3_bvec); + acc_fvec0 = map_fun(acc_fvec0, acc2_fvec0, acc3_fvec0); + acc_fvec1 = map_fun(acc_fvec1, acc2_fvec1, acc3_fvec1); + for (; d < size - (size % bVec::size()); d += bVec::size()) { + bVec data_bvec = bVec::loadu(data + d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + bVec data2_bvec = bVec::loadu(data2 + d); + auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec); + bVec data3_bvec = bVec::loadu(data3 + d); + auto [data3_fvec0, data3_fvec1] = convert_to_float(data3_bvec); + data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0); + data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1); + acc_fvec0 = red_fun(acc_fvec0, data_fvec0); + acc_fvec1 = red_fun(acc_fvec1, data_fvec1); + } + if (size - d > 0) { + bVec data_bvec = bVec::loadu(data + d, size - d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + bVec data2_bvec = bVec::loadu(data2 + d, size - d); + auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec); + bVec data3_bvec = bVec::loadu(data3 + d, size - d); + auto [data3_fvec0, data3_fvec1] = convert_to_float(data3_bvec); + if (size - d > fVec::size()) { + data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0); + data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1); + acc_fvec0 = red_fun(acc_fvec0, data_fvec0); + acc_fvec1 = fVec::set( + acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); + } else { + data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0); + acc_fvec0 = + fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d); + } + } + acc_fvec0 = red_fun(acc_fvec0, acc_fvec1); + return vec_reduce_all(red_fun, acc_fvec0); +} + +template < + typename scalar_t, + typename Op, + typename std::enable_if_t< + !(!detail::should_prefer_converting_through_float_v && + std::is_invocable_v>), + int> = 0> +inline void map( + const Op& vec_fun, + scalar_t* output_data, + const scalar_t* input_data, + int64_t size) { + using bVec = vec::Vectorized; + using fVec = vec::Vectorized; + int64_t d = 0; + for (; d < size - (size % bVec::size()); d += bVec::size()) { + bVec data_bvec = bVec::loadu(input_data + d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + fVec output_fvec0 = vec_fun(data_fvec0); + fVec output_fvec1 = vec_fun(data_fvec1); + bVec output_bvec = convert_from_float(output_fvec0, output_fvec1); + output_bvec.store(output_data + d); + } + if (size - d > 0) { + bVec data_bvec = bVec::loadu(input_data + d, size - d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + fVec output_fvec0 = vec_fun(data_fvec0); + fVec output_fvec1 = vec_fun(data_fvec1); + bVec output_bvec = convert_from_float(output_fvec0, output_fvec1); + output_bvec.store(output_data + d, size - d); + } +} + +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> +inline void map( + const Op& vec_fun, + scalar_t* output_data, + const float* input_data, + int64_t size) { + using bVec = vec::Vectorized; + using fVec = vec::Vectorized; + int64_t d = 0; + for (; d < size - (size % bVec::size()); d += bVec::size()) { + fVec data_fvec0 = fVec::loadu(input_data + d); + fVec data_fvec1 = fVec::loadu(input_data + d + fVec::size()); + fVec output_fvec0 = vec_fun(data_fvec0); + fVec output_fvec1 = vec_fun(data_fvec1); + bVec output_bvec = convert_from_float(output_fvec0, output_fvec1); + output_bvec.store(output_data + d); + } + if (size - d > 0) { + fVec data_fvec0, data_fvec1; + if (size - d > fVec::size()) { + data_fvec0 = fVec::loadu(input_data + d); + data_fvec1 = + fVec::loadu(input_data + d + fVec::size(), size - d - fVec::size()); + } else { + // choose to align with behaviour of bVec::loadu(ptr, size), + // which leaves data_fvec1 uninitialized + data_fvec0 = fVec::loadu(input_data + d, size - d); + } + fVec output_fvec0 = vec_fun(data_fvec0); + fVec output_fvec1 = vec_fun(data_fvec1); + bVec output_bvec = convert_from_float(output_fvec0, output_fvec1); + output_bvec.store(output_data + d, size - d); + } +} + +template < + typename scalar_t, + typename Op, + typename std::enable_if_t< + !(!detail::should_prefer_converting_through_float_v && + std::is_invocable_v< + Op, + vec::Vectorized, + vec::Vectorized>), + int> = 0> +inline void map2( + const Op& vec_fun, + scalar_t* output_data, + const scalar_t* input_data, + const scalar_t* input_data2, + int64_t size) { + using bVec = vec::Vectorized; + using fVec = vec::Vectorized; + int64_t d = 0; + for (; d < size - (size % bVec::size()); d += bVec::size()) { + bVec data_bvec = bVec::loadu(input_data + d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + bVec data2_bvec = bVec::loadu(input_data2 + d); + auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec); + fVec output_fvec0 = vec_fun(data_fvec0, data2_fvec0); + fVec output_fvec1 = vec_fun(data_fvec1, data2_fvec1); + bVec output_bvec = convert_from_float(output_fvec0, output_fvec1); + output_bvec.store(output_data + d); + } + if (size - d > 0) { + bVec data_bvec = bVec::loadu(input_data + d, size - d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + bVec data2_bvec = bVec::loadu(input_data2 + d, size - d); + auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec); + fVec output_fvec0 = vec_fun(data_fvec0, data2_fvec0); + fVec output_fvec1 = vec_fun(data_fvec1, data2_fvec1); + bVec output_bvec = convert_from_float(output_fvec0, output_fvec1); + output_bvec.store(output_data + d, size - d); + } +} + +template < + typename scalar_t, + typename Op, + typename std::enable_if_t< + !(!detail::should_prefer_converting_through_float_v && + std::is_invocable_v< + Op, + vec::Vectorized, + vec::Vectorized, + vec::Vectorized>), + int> = 0> +inline void map3( + const Op& vec_fun, + scalar_t* output_data, + const scalar_t* input_data1, + const scalar_t* input_data2, + const scalar_t* input_data3, + int64_t size) { + using bVec = vec::Vectorized; + using fVec = vec::Vectorized; + int64_t d = 0; + for (; d < size - (size % bVec::size()); d += bVec::size()) { + bVec data1_bvec = bVec::loadu(input_data1 + d); + auto [data1_fvec0, data1_fvec1] = convert_to_float(data1_bvec); + bVec data2_bvec = bVec::loadu(input_data2 + d); + auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec); + bVec data3_bvec = bVec::loadu(input_data3 + d); + auto [data3_fvec0, data3_fvec1] = convert_to_float(data3_bvec); + fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0); + fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1); + bVec output_bvec = convert_from_float(output_fvec0, output_fvec1); + output_bvec.store(output_data + d); + } + if (size - d > 0) { + bVec data1_bvec = bVec::loadu(input_data1 + d, size - d); + auto [data1_fvec0, data1_fvec1] = convert_to_float(data1_bvec); + bVec data2_bvec = bVec::loadu(input_data2 + d, size - d); + auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec); + bVec data3_bvec = bVec::loadu(input_data3 + d, size - d); + auto [data3_fvec0, data3_fvec1] = convert_to_float(data3_bvec); + fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0); + fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1); + bVec output_bvec = convert_from_float(output_fvec0, output_fvec1); + output_bvec.store(output_data + d, size - d); + } +} + +template < + typename scalar_t, + typename Op, + typename std::enable_if_t< + !(!detail::should_prefer_converting_through_float_v && + std::is_invocable_v< + Op, + vec::Vectorized, + vec::Vectorized, + vec::Vectorized, + vec::Vectorized>), + int> = 0> +inline void map4( + const Op& vec_fun, + scalar_t* output_data, + const scalar_t* input_data1, + const scalar_t* input_data2, + const scalar_t* input_data3, + const scalar_t* input_data4, + int64_t size) { + using bVec = vec::Vectorized; + using fVec = vec::Vectorized; + int64_t d = 0; + for (; d < size - (size % bVec::size()); d += bVec::size()) { + bVec data1_bvec = bVec::loadu(input_data1 + d); + auto [data1_fvec0, data1_fvec1] = convert_to_float(data1_bvec); + bVec data2_bvec = bVec::loadu(input_data2 + d); + auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec); + bVec data3_bvec = bVec::loadu(input_data3 + d); + auto [data3_fvec0, data3_fvec1] = convert_to_float(data3_bvec); + bVec data4_bvec = bVec::loadu(input_data4 + d); + auto [data4_fvec0, data4_fvec1] = convert_to_float(data4_bvec); + fVec output_fvec0 = + vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0); + fVec output_fvec1 = + vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1); + bVec output_bvec = convert_from_float(output_fvec0, output_fvec1); + output_bvec.store(output_data + d); + } + if (size - d > 0) { + bVec data1_bvec = bVec::loadu(input_data1 + d, size - d); + auto [data1_fvec0, data1_fvec1] = convert_to_float(data1_bvec); + bVec data2_bvec = bVec::loadu(input_data2 + d, size - d); + auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec); + bVec data3_bvec = bVec::loadu(input_data3 + d, size - d); + auto [data3_fvec0, data3_fvec1] = convert_to_float(data3_bvec); + bVec data4_bvec = bVec::loadu(input_data4 + d, size - d); + auto [data4_fvec0, data4_fvec1] = convert_to_float(data4_bvec); + fVec output_fvec0 = + vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0); + fVec output_fvec1 = + vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1); + bVec output_bvec = convert_from_float(output_fvec0, output_fvec1); + output_bvec.store(output_data + d, size - d); + } +} + +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/intrinsics.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/intrinsics.h new file mode 100644 index 0000000000000000000000000000000000000000..a58d637c94d8dee905764121f82c419c1dd7c2e4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/intrinsics.h @@ -0,0 +1,55 @@ +#pragma once +#if defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__)) +/* GCC or clang-compatible compiler, targeting x86/x86-64 */ +#include +#elif defined(__clang__) && (defined(__ARM_NEON__) || defined(__aarch64__)) +/* Clang-compatible compiler, targeting arm neon */ +#include +#if defined(__ARM_FEATURE_SVE) +/* CLANG-compatible compiler, targeting ARM with SVE */ +#include +#endif +#elif defined(_MSC_VER) +/* Microsoft C/C++-compatible compiler */ +#include +#if _MSC_VER <= 1900 +#define _mm256_extract_epi64(X, Y) \ + (_mm_extract_epi64(_mm256_extractf128_si256(X, Y >> 1), Y % 2)) +#define _mm256_extract_epi32(X, Y) \ + (_mm_extract_epi32(_mm256_extractf128_si256(X, Y >> 2), Y % 4)) +#define _mm256_extract_epi16(X, Y) \ + (_mm_extract_epi16(_mm256_extractf128_si256(X, Y >> 3), Y % 8)) +#define _mm256_extract_epi8(X, Y) \ + (_mm_extract_epi8(_mm256_extractf128_si256(X, Y >> 4), Y % 16)) +#endif +#elif defined(__GNUC__) && (defined(__ARM_NEON__) || defined(__aarch64__)) +/* GCC-compatible compiler, targeting ARM with NEON */ +#include +#if defined(__ARM_FEATURE_SVE) +/* GCC-compatible compiler, targeting ARM with SVE */ +#include +#endif +#if defined(MISSING_ARM_VLD1) +#include +#elif defined(MISSING_ARM_VST1) +#include +#endif +#elif defined(__GNUC__) && defined(__IWMMXT__) +/* GCC-compatible compiler, targeting ARM with WMMX */ +#include +#elif defined(__s390x__) +// targets Z/architecture +// we will include vecintrin later +#elif (defined(__GNUC__) || defined(__xlC__)) && \ + (defined(__VEC__) || defined(__ALTIVEC__)) +/* XLC or GCC-compatible compiler, targeting PowerPC with VMX/VSX */ +#include +/* We need to undef those tokens defined by to avoid conflicts + with the C++ types. => Can still use __bool/__vector */ +#undef bool +#undef vector +#undef pixel +#elif defined(__GNUC__) && defined(__SPE__) +/* GCC-compatible compiler, targeting PowerPC with SPE */ +#include +#endif diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/sve/sve_helper.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/sve/sve_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..37dca8c78a425a9a0dd107a2f9e343fff537c659 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/sve/sve_helper.h @@ -0,0 +1,80 @@ +#pragma once + +#include + +#include + +#if defined(CPU_CAPABILITY_SVE) + +// Define the data type of VLS(vector-length specific). +typedef svbool_t vls_pred_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svint8_t vls_int8_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svint16_t vls_int16_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svint32_t vls_int32_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svint64_t vls_int64_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svuint8_t vls_uint8_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svuint16_t vls_uint16_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svuint32_t vls_uint32_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svuint64_t vls_uint64_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svfloat16_t vls_float16_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svbfloat16_t vls_bfloat16_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svfloat32_t vls_float32_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svfloat64_t vls_float64_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); + +#define ptrue svptrue_b8() +#define ZERO_S8 svdup_n_s8(0) +#define ZERO_S16 svdup_n_s16(0) +#define ZERO_S32 svdup_n_s32(0) +#define ZERO_S64 svdup_n_s64(0) +#define ZERO_U8 svdup_n_u8(0) +#define ZERO_U16 svdup_n_u16(0) +#define ZERO_U32 svdup_n_u32(0) +#define ZERO_U64 svdup_n_u64(0) +#define ZERO_F16 svdup_n_f16(0.f) +#define ZERO_F32 svdup_n_f32(0.f) +#define ZERO_F64 svdup_n_f64(0.0) +#define ONE_S8 svdup_n_s8(1) +#define ONE_S16 svdup_n_s16(1) +#define ONE_S32 svdup_n_s32(1) +#define ONE_S64 svdup_n_s64(1) +#define ONE_U8 svdup_n_u8(1) +#define ONE_U16 svdup_n_u16(1) +#define ONE_U32 svdup_n_u32(1) +#define ONE_U64 svdup_n_u64(1) +#define ONE_F16 svdup_n_f16(1.f) +#define ONE_BF16 svdup_n_bf16(1.f) +#define ONE_F32 svdup_n_f32(1.f) +#define ONE_F64 svdup_n_f64(1.0) +#define ALL_S8_TRUE_MASK svdup_n_s8(0xff) +#define ALL_S8_FALSE_MASK svdup_n_s8(0x0) +#define ALL_S16_TRUE_MASK svdup_n_s16(0xffff) +#define ALL_S16_FALSE_MASK svdup_n_s16(0x0) +#define ALL_S32_TRUE_MASK svdup_n_s32(0xffffffff) +#define ALL_S32_FALSE_MASK svdup_n_s32(0x0) +#define ALL_S64_TRUE_MASK svdup_n_s64(0xffffffffffffffff) +#define ALL_S64_FALSE_MASK svdup_n_s64(0x0) +#define ALL_U8_TRUE_MASK svdup_n_u8(0x01) +#define ALL_U8_FALSE_MASK svdup_n_u8(0x00) +#define ALL_F16_TRUE_MASK svreinterpret_f16_s16(ALL_S16_TRUE_MASK) +#define ALL_F16_FALSE_MASK svreinterpret_f16_s16(ALL_S16_FALSE_MASK) +#define ALL_BF16_TRUE_MASK svreinterpret_bf16_s16(ALL_S16_TRUE_MASK) +#define ALL_BF16_FALSE_MASK svreinterpret_bf16_s16(ALL_S16_FALSE_MASK) +#define ALL_F32_TRUE_MASK svreinterpret_f32_s32(ALL_S32_TRUE_MASK) +#define ALL_F32_FALSE_MASK svreinterpret_f32_s32(ALL_S32_FALSE_MASK) +#define ALL_F64_TRUE_MASK svreinterpret_f64_s64(ALL_S64_TRUE_MASK) +#define ALL_F64_FALSE_MASK svreinterpret_f64_s64(ALL_S64_FALSE_MASK) + +#endif // defined(CPU_CAPABILITY_SVE) diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/sve/vec_bfloat16.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/sve/vec_bfloat16.h new file mode 100644 index 0000000000000000000000000000000000000000..8a720e1a367ce1f07de7a01ccad6a12f40de7bf3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/sve/vec_bfloat16.h @@ -0,0 +1,580 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +namespace at { +namespace vec { +// Note [CPU_CAPABILITY namespace] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This header, and all of its subheaders, will be compiled with +// different architecture flags for each supported set of vector +// intrinsics. So we need to make sure they aren't inadvertently +// linked together. We do this by declaring objects in an `inline +// namespace` which changes the name mangling, but can still be +// accessed as `at::vec`. +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_SVE256) && defined(__ARM_FEATURE_BF16) + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + vls_bfloat16_t values; + + public: + using value_type = BFloat16; + using size_type = int; + + static constexpr size_type size() { + return VECTOR_WIDTH / sizeof(BFloat16); + } + + Vectorized() {} + Vectorized(svbfloat16_t v) : values(v) {} + Vectorized(int val); + Vectorized(BFloat16 val); + + template < + typename... Args, + typename = std::enable_if_t<(sizeof...(Args) == size())>> + Vectorized(Args... vals) { + __at_align__ BFloat16 buffer[size()] = {vals...}; + values = svld1_bf16(ptrue, reinterpret_cast(buffer)); + } + + operator svbfloat16_t() const { + return values; + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask_) { + svbool_t mask = + svcmpeq_s16(ptrue, svreinterpret_s16_bf16(mask_), ALL_S16_TRUE_MASK); + return svsel_bf16(mask, b, a); + } + template + static Vectorized arange( + BFloat16 base = 0.f, + step_t step = static_cast(1)) { + __at_align__ BFloat16 buffer[size()]; + for (int64_t i = 0; i < size(); i++) { + buffer[i] = base + i * step; + } + return svld1_bf16(ptrue, reinterpret_cast(buffer)); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + if (count == 0) { + return a; + } else if (count < size()) { + return svsel_bf16(svwhilelt_b16(0ull, count), b, a); + } + return b; + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return svld1_bf16(ptrue, reinterpret_cast(ptr)); + svbool_t pg = svwhilelt_b16(0ull, count); + return svld1_bf16(pg, reinterpret_cast(ptr)); + } + void store(void* ptr, int64_t count = size()) const { + __at_align__ bfloat16_t tmp[size()]; + std::memset(tmp, 0, sizeof(tmp)); + if (count == size()) { + svst1_bf16(ptrue, reinterpret_cast(tmp), values); + } else { + svbool_t pg = svwhilelt_b16(0ull, count); + svst1_bf16(pg, reinterpret_cast(tmp), values); + } + std::memcpy( + reinterpret_cast(ptr), + reinterpret_cast(tmp), + count * sizeof(bfloat16_t)); + } + const BFloat16& operator[](int idx) const = delete; + BFloat16& operator[](int idx) = delete; + int64_t zero_mask() const { + int64_t mask = 0; + // returns an integer mask where all zero elements are translated to + // 1-bit and others are translated to 0-bit int64_t mask = 0; + __at_align__ int16_t mask_array[size()]; + + svbool_t svbool_mask = + svcmpeq_f16(ptrue, svreinterpret_f16_bf16(values), ZERO_F16); + svst1_s16( + ptrue, + mask_array, + svsel_s16(svbool_mask, ALL_S16_TRUE_MASK, ALL_S16_FALSE_MASK)); + for (int64_t i = 0; i < size(); ++i) { + if (mask_array[i]) + mask |= (1ull << i); + } + return mask; + } + Vectorized isnan() const; + bool has_inf_nan() const; + Vectorized map(BFloat16 (*f)(BFloat16)) const { + __at_align__ BFloat16 tmp[size()]; + store(tmp); + for (int64_t i = 0; i < size(); ++i) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + Vectorized abs() const { + auto mask = svdup_n_u16(0x7FFF); + auto vals = svreinterpret_u16_bf16(values); + vals = svand_u16_x(ptrue, vals, mask); + return svreinterpret_bf16_u16(vals); + } + Vectorized angle() const; + Vectorized real() const { + return values; + } + Vectorized imag() const { + return Vectorized(0.f); + } + Vectorized conj() const { + return values; + } + Vectorized acos() const; + Vectorized acosh() const; + Vectorized asin() const; + Vectorized atan() const; + Vectorized atanh() const; + Vectorized atan2(const Vectorized& b) const; + Vectorized copysign(const Vectorized& sign) const; + Vectorized erf() const; + Vectorized erfc() const; + Vectorized erfinv() const; + Vectorized exp() const; + Vectorized exp2() const; + Vectorized expm1() const; + Vectorized exp_u20() const { + return exp(); + } + Vectorized fmod(const Vectorized& q) const; + Vectorized hypot(const Vectorized& b) const; + Vectorized i0() const; + Vectorized i0e() const; + Vectorized digamma() const; + Vectorized igamma(const Vectorized& x) const; + Vectorized igammac(const Vectorized& x) const; + Vectorized nextafter(const Vectorized& b) const; + Vectorized log() const; + Vectorized log2() const; + Vectorized log10() const; + Vectorized log1p() const; + Vectorized frac() const; + Vectorized sin() const; + Vectorized sinh() const; + Vectorized cos() const; + Vectorized cosh() const; + Vectorized ceil() const; + Vectorized floor() const; + Vectorized neg() const { + auto mask = svdup_n_u16(0x8000); + auto vals = svreinterpret_u16_bf16(values); + vals = sveor_u16_x(ptrue, vals, mask); + return svreinterpret_bf16_u16(vals); + }; + Vectorized round() const; + Vectorized tan() const; + Vectorized tanh() const; + Vectorized trunc() const; + Vectorized lgamma() const; + Vectorized sqrt() const; + Vectorized reciprocal() const; + Vectorized rsqrt() const; + Vectorized pow(const Vectorized& b) const; + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized operator==(const Vectorized& other) const; + + Vectorized operator!=(const Vectorized& other) const; + + Vectorized operator<(const Vectorized& other) const; + + Vectorized operator<=(const Vectorized& other) const; + + Vectorized operator>(const Vectorized& other) const; + + Vectorized operator>=(const Vectorized& other) const; + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +inline std::tuple, Vectorized> convert_bfloat16_float( + const Vectorized& a) { + static_assert( + Vectorized::size() == 2 * Vectorized::size()); + auto zero = svreinterpret_bf16_f32(svdup_n_f32(0.0f)); + auto bf16_vec1 = svzip1_bf16(zero, a); + auto bf16_vec2 = svzip2_bf16(zero, a); + auto x1 = svreinterpret_f32_bf16(bf16_vec1); + auto x2 = svreinterpret_f32_bf16(bf16_vec2); + return {Vectorized(x1), Vectorized(x2)}; +} + +inline Vectorized convert_float_bfloat16( + const Vectorized& a, + const Vectorized& b) { + static_assert( + Vectorized::size() == 2 * Vectorized::size()); + svbfloat16_t x1 = svcvt_bf16_f32_z(ptrue, a); + svbfloat16_t x2 = svcvt_bf16_f32_z(ptrue, b); + return Vectorized(svuzp1_bf16(x1, x2)); +} + +inline void load_fp32_from_bf16(const BFloat16* data, Vectorized& out) { + __at_align__ float values[Vectorized::size()]; + for (const auto k : c10::irange(Vectorized::size())) { + values[k] = data[k]; + } + out = Vectorized::loadu(values); +} + +inline void load_fp32_from_bf16( + const BFloat16* data, + Vectorized& out1, + Vectorized& out2) { + Vectorized bf16_vec = Vectorized::loadu(data); + auto floats = convert_bfloat16_float(bf16_vec); + out1 = std::get<0>(floats); + out2 = std::get<1>(floats); +} + +template +Vectorized binary_operator_via_float( + Op op, + const Vectorized& a, + const Vectorized& b) { + const auto [a_float_low, a_float_high] = convert_bfloat16_float(a); + const auto [b_float_low, b_float_high] = convert_bfloat16_float(b); + return convert_float_bfloat16( + op(a_float_low, b_float_low), op(a_float_high, b_float_high)); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return binary_operator_via_float(std::plus>(), a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return binary_operator_via_float(std::minus>(), a, b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return binary_operator_via_float(std::multiplies>(), a, b); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return binary_operator_via_float(std::divides>(), a, b); +} + +inline Vectorized::Vectorized(int val) { + auto vals_f = svdup_n_f32(val); + values = convert_float_bfloat16(vals_f, vals_f); +} + +inline Vectorized::Vectorized(BFloat16 val) { + auto vals_f = svdup_n_f32((float)val); + values = convert_float_bfloat16(vals_f, vals_f); +} + +bool inline Vectorized::has_inf_nan() const { + auto [v1, v2] = convert_bfloat16_float(values); + return v1.has_inf_nan() || v2.has_inf_nan(); +} +// frac. Implement this here so we can use subtraction +Vectorized inline Vectorized::frac() const { + return *this - this->trunc(); +} + +#define DEFINE_BF16_FUNC_VIA_FLOAT(func_name) \ + Vectorized inline Vectorized::func_name() const { \ + auto [v1, v2] = convert_bfloat16_float(*this); \ + v1 = v1.func_name(); \ + v2 = v2.func_name(); \ + return convert_float_bfloat16(v1, v2); \ + } + +#define DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(func_name) \ + Vectorized inline Vectorized::func_name( \ + const Vectorized& a) const { \ + auto [v1, v2] = convert_bfloat16_float(*this); \ + auto [v3, v4] = convert_bfloat16_float(a); \ + v1 = v1.func_name(v3); \ + v2 = v2.func_name(v4); \ + return convert_float_bfloat16(v1, v2); \ + } + +DEFINE_BF16_FUNC_VIA_FLOAT(isnan); +DEFINE_BF16_FUNC_VIA_FLOAT(angle); +DEFINE_BF16_FUNC_VIA_FLOAT(acos); +DEFINE_BF16_FUNC_VIA_FLOAT(acosh); +DEFINE_BF16_FUNC_VIA_FLOAT(asin); +DEFINE_BF16_FUNC_VIA_FLOAT(atan); +DEFINE_BF16_FUNC_VIA_FLOAT(atanh); +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(atan2); +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(copysign); +DEFINE_BF16_FUNC_VIA_FLOAT(erf); +DEFINE_BF16_FUNC_VIA_FLOAT(erfc); +DEFINE_BF16_FUNC_VIA_FLOAT(exp); +DEFINE_BF16_FUNC_VIA_FLOAT(exp2); +DEFINE_BF16_FUNC_VIA_FLOAT(expm1); +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(fmod); +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(hypot); +DEFINE_BF16_FUNC_VIA_FLOAT(i0); +DEFINE_BF16_FUNC_VIA_FLOAT(i0e); +DEFINE_BF16_FUNC_VIA_FLOAT(digamma); +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igamma); +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igammac); +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(nextafter); +DEFINE_BF16_FUNC_VIA_FLOAT(log); +DEFINE_BF16_FUNC_VIA_FLOAT(log2); +DEFINE_BF16_FUNC_VIA_FLOAT(log10); +DEFINE_BF16_FUNC_VIA_FLOAT(log1p); +DEFINE_BF16_FUNC_VIA_FLOAT(sin); +DEFINE_BF16_FUNC_VIA_FLOAT(sinh); +DEFINE_BF16_FUNC_VIA_FLOAT(cos); +DEFINE_BF16_FUNC_VIA_FLOAT(cosh); +DEFINE_BF16_FUNC_VIA_FLOAT(ceil); +DEFINE_BF16_FUNC_VIA_FLOAT(floor); +DEFINE_BF16_FUNC_VIA_FLOAT(round); +DEFINE_BF16_FUNC_VIA_FLOAT(tan); +DEFINE_BF16_FUNC_VIA_FLOAT(tanh); +DEFINE_BF16_FUNC_VIA_FLOAT(trunc); +DEFINE_BF16_FUNC_VIA_FLOAT(lgamma); +DEFINE_BF16_FUNC_VIA_FLOAT(sqrt); +DEFINE_BF16_FUNC_VIA_FLOAT(reciprocal); +DEFINE_BF16_FUNC_VIA_FLOAT(rsqrt); +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(pow); + +Vectorized inline Vectorized::operator==( + const Vectorized& other) const { + auto [f1, f2] = convert_bfloat16_float(values); + auto [f3, f4] = convert_bfloat16_float(other); + svbool_t mask1 = svcmpeq_f32(ptrue, f1, f3); + svbool_t mask2 = svcmpeq_f32(ptrue, f2, f4); + auto res1 = svsel_f32(mask1, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + auto res2 = svsel_f32(mask2, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + + auto bf16_1 = svreinterpret_bf16_f32(res1); + auto bf16_2 = svreinterpret_bf16_f32(res2); + return svuzp1_bf16(bf16_1, bf16_2); +} +Vectorized inline Vectorized::operator!=( + const Vectorized& other) const { + auto [f1, f2] = convert_bfloat16_float(values); + auto [f3, f4] = convert_bfloat16_float(other); + svbool_t mask1 = svcmpne_f32(ptrue, f1, f3); + svbool_t mask2 = svcmpne_f32(ptrue, f2, f4); + auto res1 = svsel_f32(mask1, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + auto res2 = svsel_f32(mask2, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + + auto bf16_1 = svreinterpret_bf16_f32(res1); + auto bf16_2 = svreinterpret_bf16_f32(res2); + return svuzp1_bf16(bf16_1, bf16_2); +} +Vectorized inline Vectorized::operator>( + const Vectorized& other) const { + auto [v1, v2] = convert_bfloat16_float(*this); + auto [v3, v4] = convert_bfloat16_float(other); + return convert_float_bfloat16(v1 > v3, v2 > v4); +} +Vectorized inline Vectorized::operator>=( + const Vectorized& other) const { + auto [v1, v2] = convert_bfloat16_float(*this); + auto [v3, v4] = convert_bfloat16_float(other); + return convert_float_bfloat16(v1 >= v3, v2 >= v4); +} +Vectorized inline Vectorized::operator<( + const Vectorized& other) const { + auto [v1, v2] = convert_bfloat16_float(*this); + auto [v3, v4] = convert_bfloat16_float(other); + return convert_float_bfloat16(v1 < v3, v2 < v4); +} +Vectorized inline Vectorized::operator<=( + const Vectorized& other) const { + auto [v1, v2] = convert_bfloat16_float(*this); + auto [v3, v4] = convert_bfloat16_float(other); + return convert_float_bfloat16(v1 <= v3, v2 <= v4); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return binary_operator_via_float( + static_cast (*)( + const Vectorized&, const Vectorized&)>(&maximum), + a, + b); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return binary_operator_via_float( + static_cast (*)( + const Vectorized&, const Vectorized&)>(&minimum), + a, + b); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + return binary_operator_via_float( + static_cast (*)( + const Vectorized&, const Vectorized&)>(&clamp_max), + a, + max); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + return binary_operator_via_float( + static_cast (*)( + const Vectorized&, const Vectorized&)>(&clamp_min), + a, + min); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + return clamp_min(clamp_max(a, max), min); +} + +template <> +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return svreinterpret_bf16_u16( + svand_u16_x(ptrue, svreinterpret_u16_bf16(a), svreinterpret_u16_bf16(b))); +} + +template <> +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return svreinterpret_bf16_u16( + svorr_u16_x(ptrue, svreinterpret_u16_bf16(a), svreinterpret_u16_bf16(b))); +} + +template <> +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return svreinterpret_bf16_u16( + sveor_u16_x(ptrue, svreinterpret_u16_bf16(a), svreinterpret_u16_bf16(b))); +} + +Vectorized inline Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +template <> +inline void convert(const BFloat16* src, BFloat16* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svst1_bf16( + ptrue, + const_cast(reinterpret_cast(dst)) + i, + svldnt1_bf16( + ptrue, + const_cast(reinterpret_cast(src)) + + i)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + svbool_t pg = svwhilelt_b16(i, n); + svst1_bf16( + pg, + const_cast(reinterpret_cast(dst)) + i, + svldnt1_bf16( + pg, + const_cast(reinterpret_cast(src)) + + i)); + } +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return a * b + c; +} + +#endif // defined(CPU_CAPABILITY_SVE) && defined(__ARM_FEATURE_BF16) + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/sve/vec_common_sve.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/sve/vec_common_sve.h new file mode 100644 index 0000000000000000000000000000000000000000..f97dab3a33ebe9b4cddf118b1b22573a6f8a434e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/sve/vec_common_sve.h @@ -0,0 +1,236 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with SVE] + +#include + +#include +#include + +#if defined(CPU_CAPABILITY_SVE) +#include +#include +#include +#include +#include +#endif + +namespace at::vec { +// Note [CPU_CAPABILITY namespace] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This header, and all of its subheaders, will be compiled with +// different architecture flags for each supported set of vector +// intrinsics. So we need to make sure they aren't inadvertently +// linked together. We do this by declaring objects in an `inline +// namespace` which changes the name mangling, but can still be +// accessed as `at::vec`. +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_SVE) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +#define DEFINE_SVE_CAST(t1_t, t1_prefix, t2_t, t2_prefix) \ + template <> \ + inline Vectorized cast(const Vectorized& src) { \ + return svreinterpret_##t1_prefix##_##t2_prefix(src); \ + } \ + template <> \ + inline Vectorized cast(const Vectorized& src) { \ + return svreinterpret_##t2_prefix##_##t1_prefix(src); \ + } + +DEFINE_SVE_CAST(int64_t, s64, double, f64) +DEFINE_SVE_CAST(int32_t, s32, double, f64) +DEFINE_SVE_CAST(int16_t, s16, double, f64) +DEFINE_SVE_CAST(int64_t, s64, float, f32) +DEFINE_SVE_CAST(int32_t, s32, float, f32) +DEFINE_SVE_CAST(int16_t, s16, float, f32) +DEFINE_SVE_CAST(float, f32, double, f64) + +#ifdef __ARM_FEATURE_BF16 +DEFINE_SVE_CAST(int64_t, s64, c10::BFloat16, bf16) +DEFINE_SVE_CAST(int32_t, s32, c10::BFloat16, bf16) +DEFINE_SVE_CAST(int16_t, s16, c10::BFloat16, bf16) +#endif // __ARM_FEATURE_BF16 + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template +std::enable_if_t< + scale == 1 || scale == 2 || scale == 4 || scale == 8, + Vectorized< + double>> inline gather(const double* base_addr, const Vectorized& vindex_) { + svint64_t vindex = + svasrd_n_s64_x(ptrue, svmul_s64_x(ptrue, vindex_, svdup_n_s64(scale)), 3); + return svld1_gather_s64index_f64(ptrue, base_addr, vindex); +} + +template +std::enable_if_t< + scale == 1 || scale == 2 || scale == 4 || scale == 8, + Vectorized< + float>> inline gather(const float* base_addr, const Vectorized& vindex_) { + svint32_t vindex = + svasrd_n_s32_x(ptrue, svmul_s32_x(ptrue, vindex_, svdup_n_s32(scale)), 2); + return svld1_gather_s32index_f32(ptrue, base_addr, vindex); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template +std:: + enable_if_t> inline mask_gather( + const Vectorized& src, + const double* base_addr, + const Vectorized& vindex_, + const Vectorized& mask_) { + svbool_t mask = + svcmpeq_s64(ptrue, svreinterpret_s64_f64(mask_), ALL_S64_TRUE_MASK); + svint64_t vindex = + svasrd_n_s64_x(ptrue, svmul_s64_x(ptrue, vindex_, svdup_n_s64(scale)), 3); + return svsel_f64( + mask, svld1_gather_s64index_f64(mask, base_addr, vindex), src); +} + +template +std:: + enable_if_t> inline mask_gather( + const Vectorized& src, + const float* base_addr, + const Vectorized& vindex_, + const Vectorized& mask_) { + svbool_t mask = + svcmpeq_s32(ptrue, svreinterpret_s32_f32(mask_), ALL_S32_TRUE_MASK); + svint32_t vindex = + svasrd_n_s32_x(ptrue, svmul_s32_x(ptrue, vindex_, svdup_n_s32(scale)), 2); + return svsel_f32( + mask, svld1_gather_s32index_f32(mask, base_addr, vindex), src); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Only works for inputs in the range: [-2^51, 2^51] +// From: https://stackoverflow.com/a/41148578 +template <> +Vectorized inline convert_to_int_of_same_size( + const Vectorized& src) { + svfloat64_t x = svadd_f64_x(ptrue, src, svdup_n_f64(0x0018000000000000)); + return svsub_s64_x( + ptrue, + svreinterpret_s64_f64(x), + svreinterpret_s64_f64(svdup_n_f64(0x0018000000000000))); +} + +template <> +Vectorized inline convert_to_int_of_same_size( + const Vectorized& src) { + return svcvt_s32_f32_x(ptrue, src); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +std::pair, Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, a1, a3, a3} + // b = {b0, b1, b2, b3} + // group cols crossing lanes: + // return {a0, b0, a1, b1} + // {a2, b2, a3, b3} + return std::make_pair( + Vectorized(svzip1_f64(a, b)), + Vectorized(svzip2_f64(a, b))); +} + +template <> +std::pair, Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, a1, a2, a3, a4, a5, a6, a7} + // b = {b0, b1, b2, b3, b4, b5, b6, b7} + // group cols crossing lanes: + // return {a0, b0, a1, b1, a2, b2, a3, b3} + // {a4, b4, a5, b5, a6, b6, a7, b7} + return std::make_pair( + Vectorized(svzip1_f32(a, b)), Vectorized(svzip2_f32(a, b))); +} + +#ifdef __ARM_FEATURE_BF16 +template <> +std::pair< + Vectorized, + Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, a1, a2, a3, a4, a5, a6, a7} + // b = {b0, b1, b2, b3, b4, b5, b6, b7} + // group cols crossing lanes: + // return {a0, b0, a1, b1, a2, b2, a3, b3} + // {a4, b4, a5, b5, a6, b6, a7, b7} + return std::make_pair( + Vectorized(svzip1_bf16(a, b)), + Vectorized(svzip2_bf16(a, b))); +} +#endif // __ARM_FEATURE_BF16 + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DEINTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +std::pair, Vectorized> inline deinterleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1} + // b = {a2, b2, a3, b3} + // swap lanes: + // return {a0, a1, a2, a3} + // {b0, b1, b2, b3} + return std::make_pair( + Vectorized(svuzp1_f64(a, b)), + Vectorized(svuzp2_f64(a, b))); +} + +template <> +std::pair, Vectorized> inline deinterleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1, a2, b2, a3, b3} + // b = {a4, b4, a5, b5, a6, b6, a7, b7} + // swap lanes: + // return {a0, a1, a2, a3, a4, a5, a6, a7} + // {b0, b1, b2, b3, b4, b5, b6, b7} + return std::make_pair( + Vectorized(svuzp1_f32(a, b)), Vectorized(svuzp2_f32(a, b))); +} + +#ifdef __ARM_FEATURE_BF16 +template <> +std::pair< + Vectorized, + Vectorized> inline deinterleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1, a2, b2, a3, b3} + // b = {a4, b4, a5, b5, a6, b6, a7, b7} + // swap lanes: + // return {a0, a1, a2, a3, a4, a5, a6, a7} + // {b0, b1, b2, b3, b4, b5, b6, b7} + return std::make_pair( + Vectorized(svuzp1_bf16((svbfloat16_t)a, (svbfloat16_t)b)), + Vectorized(svuzp2_bf16((svbfloat16_t)a, (svbfloat16_t)b))); +} +#endif // __ARM_FEATURE_BF16 + +#endif // defined(CPU_CAPABILITY_SVE) + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/sve/vec_double.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/sve/vec_double.h new file mode 100644 index 0000000000000000000000000000000000000000..1c385a9d7fb1c1773318d5c65fb5e81a9d10fa2b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/sve/vec_double.h @@ -0,0 +1,588 @@ +#pragma once + +#include +#include +#include +#include +#if defined(__aarch64__) && defined(AT_BUILD_ARM_VEC256_WITH_SLEEF) +#include +#define USE_SLEEF(sleef_code, non_sleef_code) sleef_code +#else +#define USE_SLEEF(sleef_code, non_sleef_code) non_sleef_code +#endif + +namespace at::vec { +// Note [CPU_CAPABILITY namespace] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This header, and all of its subheaders, will be compiled with +// different architecture flags for each supported set of vector +// intrinsics. So we need to make sure they aren't inadvertently +// linked together. We do this by declaring objects in an `inline +// namespace` which changes the name mangling, but can still be +// accessed as `at::vec`. +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_SVE) + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + vls_float64_t values; + + public: + using value_type = double; + using size_type = int; + static constexpr size_type size() { + return VECTOR_WIDTH / sizeof(double); + } + Vectorized() {} + Vectorized(svfloat64_t v) : values(v) {} + Vectorized(double val) { + values = svdup_n_f64(val); + } + template < + typename... Args, + typename = std::enable_if_t<(sizeof...(Args) == size())>> + Vectorized(Args... vals) { + __at_align__ double buffer[size()] = {vals...}; + values = svld1_f64(ptrue, buffer); + } + operator svfloat64_t() const { + return values; + } + template + static Vectorized blend( + const Vectorized& a, + const Vectorized& b) { + // Build an array of flags: each element is 1 if the corresponding bit in + // 'mask' is set, 0 otherwise. + __at_align__ int64_t flag_arr[size()]; + for (int i = 0; i < size(); i++) { + flag_arr[i] = (mask & (1ULL << i)) ? 1 : 0; + } + // Load the flag array into an SVE int64 vector. + svint64_t int_mask = svld1_s64(svptrue_b64(), flag_arr); + // Compare each lane of int_mask to 0; returns an svbool_t predicate where + // true indicates a nonzero flag. + svbool_t blend_mask = svcmpne_n_s64(svptrue_b64(), int_mask, 0); + + // Use svsel to select elements from b where the predicate is true, else + // from a. + svfloat64_t result = svsel(blend_mask, b.values, a.values); + return Vectorized(result); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask_) { + svbool_t mask = + svcmpeq_s64(ptrue, svreinterpret_s64_f64(mask_), ALL_S64_TRUE_MASK); + return svsel_f64(mask, b, a); + } + template + static Vectorized arange( + double base = 0., + step_t step = static_cast(1)) { + __at_align__ double buffer[size()]; + for (int64_t i = 0; i < size(); i++) { + buffer[i] = base + i * step; + } + return svld1_f64(ptrue, buffer); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + if (count == 0) { + return a; + } else if (count < size()) { + return svsel_f64(svwhilelt_b64(0ull, count), b, a); + } + return b; + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return svld1_f64(ptrue, reinterpret_cast(ptr)); + svbool_t pg = svwhilelt_b64(0ull, count); + return svld1_f64(pg, reinterpret_cast(ptr)); + } + void store(void* ptr, int64_t count = size()) const { + if (count == size()) { + svst1_f64(ptrue, reinterpret_cast(ptr), values); + } else { + svbool_t pg = svwhilelt_b64(0ull, count); + svst1_f64(pg, reinterpret_cast(ptr), values); + } + } + const double& operator[](int idx) const = delete; + double& operator[](int idx) = delete; + int64_t zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit + // and others are translated to 0-bit + int64_t mask = 0; + __at_align__ int64_t mask_array[size()]; + + svbool_t svbool_mask = svcmpeq_f64(ptrue, values, ZERO_F64); + svst1_s64( + ptrue, + mask_array, + svsel_s64(svbool_mask, ALL_S64_TRUE_MASK, ALL_S64_FALSE_MASK)); + for (int64_t i = 0; i < size(); ++i) { + if (mask_array[i]) + mask |= (1ull << i); + } + return mask; + } + Vectorized isnan() const { + // NaN check + svbool_t mask = svcmpuo_f64(ptrue, values, ZERO_F64); + return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); + } + bool has_inf_nan() const { + return svptest_any( + ptrue, + svcmpuo_f64(ptrue, svsub_f64_x(ptrue, values, values), ZERO_F64)); + } + Vectorized map(double (*f)(double)) const { + __at_align__ double tmp[size()]; + store(tmp); + for (int64_t i = 0; i < size(); ++i) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + Vectorized abs() const { + return svabs_f64_x(ptrue, values); + } + Vectorized angle() const { + const auto nan_vec = svdup_n_f64(NAN); + const auto nan_mask = svcmpuo_f64(ptrue, values, ZERO_F64); + const auto pi = svdup_n_f64(c10::pi); + + const auto neg_mask = svcmplt_f64(ptrue, values, ZERO_F64); + auto angle = svsel_f64(neg_mask, pi, ZERO_F64); + angle = svsel_f64(nan_mask, nan_vec, angle); + return angle; + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return Vectorized(0.0); + } + Vectorized conj() const { + return *this; + } + Vectorized acos() const { + return USE_SLEEF( + Vectorized(Sleef_acosdx_u10sve(values)), map(std::acos)); + } + Vectorized acosh() const { + return USE_SLEEF( + Vectorized(Sleef_acoshdx_u10sve(values)), map(std::acosh)); + } + Vectorized asin() const { + return USE_SLEEF( + Vectorized(Sleef_asindx_u10sve(values)), map(std::asin)); + } + Vectorized asinh() const { + return USE_SLEEF( + Vectorized(Sleef_asinhdx_u10sve(values)), map(std::asinh)); + } + Vectorized atan() const { + return USE_SLEEF( + Vectorized(Sleef_atandx_u10sve(values)), map(std::atan)); + } + Vectorized atanh() const { + return USE_SLEEF( + Vectorized(Sleef_atanhdx_u10sve(values)), map(std::atanh)); + } + Vectorized atan2(const Vectorized& b) const {USE_SLEEF( + { return Vectorized(Sleef_atan2dx_u10sve(values, b)); }, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::atan2(tmp[i], tmp_b[i]); + } + return loadu(tmp); + })} Vectorized copysign(const Vectorized& sign) const { + USE_SLEEF( + { return Vectorized(Sleef_copysigndx_sve(values, sign)); }, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_sign[size()]; + store(tmp); + sign.store(tmp_sign); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::copysign(tmp[i], tmp_sign[i]); + } + return loadu(tmp); + })} Vectorized erf() const { + return USE_SLEEF( + Vectorized(Sleef_erfdx_u10sve(values)), map(std::erf)); + } + Vectorized erfc() const { + return USE_SLEEF( + Vectorized(Sleef_erfcdx_u15sve(values)), map(std::erfc)); + } + Vectorized erfinv() const { + return map(calc_erfinv); + } + Vectorized exp() const { + return USE_SLEEF( + Vectorized(Sleef_expdx_u10sve(values)), map(std::exp)); + } + Vectorized exp2() const { + return USE_SLEEF( + Vectorized(Sleef_exp2dx_u10sve(values)), map(std::exp2)); + } + Vectorized expm1() const { + return USE_SLEEF( + Vectorized(Sleef_expm1dx_u10sve(values)), map(std::expm1)); + } + Vectorized exp_u20() const { + return exp(); + } + Vectorized fmod(const Vectorized& q) const {USE_SLEEF( + { return Vectorized(Sleef_fmoddx_sve(values, q)); }, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_q[size()]; + store(tmp); + q.store(tmp_q); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::fmod(tmp[i], tmp_q[i]); + } + return loadu(tmp); + })} Vectorized hypot(const Vectorized& b) const { + USE_SLEEF( + { return Vectorized(Sleef_hypotdx_u05sve(values, b)); }, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::hypot(tmp[i], tmp_b[i]); + } + return loadu(tmp); + })} Vectorized i0() const { + return map(calc_i0); + } + Vectorized i0e() const { + return map(calc_i0e); + } + Vectorized digamma() const { + return map(calc_digamma); + } + Vectorized igamma(const Vectorized& x) const { + __at_align__ double tmp[size()]; + __at_align__ double tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized igammac(const Vectorized& x) const { + __at_align__ double tmp[size()]; + __at_align__ double tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igammac(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized nextafter(const Vectorized& b) const {USE_SLEEF( + { return Vectorized(Sleef_nextafterdx_sve(values, b)); }, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); ++i) { + tmp[i] = std::nextafter(tmp[i], tmp_b[i]); + } + return loadu(tmp); + })} Vectorized log() const { + return USE_SLEEF( + Vectorized(Sleef_logdx_u10sve(values)), map(std::log)); + } + Vectorized log2() const { + return USE_SLEEF( + Vectorized(Sleef_log2dx_u10sve(values)), map(std::log2)); + } + Vectorized log10() const { + return USE_SLEEF( + Vectorized(Sleef_log10dx_u10sve(values)), map(std::log10)); + } + Vectorized log1p() const { + return USE_SLEEF( + Vectorized(Sleef_log1pdx_u10sve(values)), map(std::log1p)); + } + Vectorized frac() const; + Vectorized sin() const { + return USE_SLEEF( + Vectorized(Sleef_sindx_u10sve(values)), map(std::sin)); + } + Vectorized sinh() const { + return USE_SLEEF( + Vectorized(Sleef_sinhdx_u10sve(values)), map(std::sinh)); + } + Vectorized cos() const { + return USE_SLEEF( + Vectorized(Sleef_cosdx_u10sve(values)), map(std::cos)); + } + Vectorized cosh() const { + return USE_SLEEF( + Vectorized(Sleef_coshdx_u10sve(values)), map(std::cosh)); + } + Vectorized ceil() const { + return svrintp_f64_x(ptrue, values); + } + Vectorized floor() const { + return svrintm_f64_x(ptrue, values); + } + Vectorized neg() const { + return svneg_f64_x(ptrue, values); + } + Vectorized round() const { + return svrinti_f64_x(ptrue, values); + } + Vectorized tan() const { + return USE_SLEEF( + Vectorized(Sleef_tandx_u10sve(values)), map(std::tan)); + } + Vectorized tanh() const { + return USE_SLEEF( + Vectorized(Sleef_tanhdx_u10sve(values)), map(std::tanh)); + } + Vectorized trunc() const { + return svrintz_f64_x(ptrue, values); + } + Vectorized lgamma() const { + return USE_SLEEF( + Vectorized(Sleef_lgammadx_u10sve(values)), map(std::lgamma)); + } + Vectorized sqrt() const { + return svsqrt_f64_x(ptrue, values); + } + Vectorized reciprocal() const { + return svdivr_f64_x(ptrue, values, ONE_F64); + } + Vectorized rsqrt() const { + return svdivr_f64_x(ptrue, svsqrt_f64_x(ptrue, values), ONE_F64); + } + Vectorized pow(const Vectorized& b) const {USE_SLEEF( + { return Vectorized(Sleef_powdx_u10sve(values, b)); }, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::pow(tmp[i], tmp_b[i]); + } + return loadu(tmp); + })} // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized operator==(const Vectorized& other) const { + svbool_t mask = svcmpeq_f64(ptrue, values, other); + return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); + } + + Vectorized operator!=(const Vectorized& other) const { + svbool_t mask = svcmpne_f64(ptrue, values, other); + return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); + } + + Vectorized operator<(const Vectorized& other) const { + svbool_t mask = svcmplt_f64(ptrue, values, other); + return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); + } + + Vectorized operator<=(const Vectorized& other) const { + svbool_t mask = svcmple_f64(ptrue, values, other); + return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); + } + + Vectorized operator>(const Vectorized& other) const { + svbool_t mask = svcmpgt_f64(ptrue, values, other); + return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); + } + + Vectorized operator>=(const Vectorized& other) const { + svbool_t mask = svcmpge_f64(ptrue, values, other); + return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return svadd_f64_x(ptrue, a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return svsub_f64_x(ptrue, a, b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return svmul_f64_x(ptrue, a, b); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return svdiv_f64_x(ptrue, a, b); +} + +// frac. Implement this here so we can use subtraction +Vectorized inline Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return svmax_f64_x(ptrue, a, b); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return svmin_f64_x(ptrue, a, b); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + return svmin_f64_x(ptrue, max, svmax_f64_x(ptrue, min, a)); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + return svmin_f64_x(ptrue, max, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + return svmax_f64_x(ptrue, min, a); +} + +template <> +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return svreinterpret_f64_s64( + svand_s64_x(ptrue, svreinterpret_s64_f64(a), svreinterpret_s64_f64(b))); +} + +template <> +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return svreinterpret_f64_s64( + svorr_s64_x(ptrue, svreinterpret_s64_f64(a), svreinterpret_s64_f64(b))); +} + +template <> +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return svreinterpret_f64_s64( + sveor_s64_x(ptrue, svreinterpret_s64_f64(a), svreinterpret_s64_f64(b))); +} + +Vectorized inline Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0); +} + +Vectorized inline Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0); +} + +Vectorized inline Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0); +} + +Vectorized inline Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0); +} + +Vectorized inline Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0); +} + +Vectorized inline Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0); +} + +template <> +inline void convert(const double* src, double* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svst1_f64(ptrue, dst + i, svldnt1_f64(ptrue, src + i)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + svbool_t pg = svwhilelt_b64(i, n); + svst1_f64(pg, dst + i, svldnt1_f64(pg, src + i)); + } +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return svmad_f64_x(ptrue, a, b, c); +} + +#endif // defined(CPU_CAPABILITY_SVE) + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/sve/vec_float.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/sve/vec_float.h new file mode 100644 index 0000000000000000000000000000000000000000..546b6649c0ef03676f3ab01301054af2ca296089 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/sve/vec_float.h @@ -0,0 +1,759 @@ +#pragma once + +#include +#include +#include +#include +#if defined(__aarch64__) && defined(AT_BUILD_ARM_VEC256_WITH_SLEEF) +#include +#define USE_SLEEF(sleef_code, non_sleef_code) sleef_code +#else +#define USE_SLEEF(sleef_code, non_sleef_code) non_sleef_code +#endif + +namespace at::vec { +// Note [CPU_CAPABILITY namespace] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This header, and all of its subheaders, will be compiled with +// different architecture flags for each supported set of vector +// intrinsics. So we need to make sure they aren't inadvertently +// linked together. We do this by declaring objects in an `inline +// namespace` which changes the name mangling, but can still be +// accessed as `at::vec`. +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_SVE) + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + vls_float32_t values; + + public: + using value_type = float; + using size_type = int; + static constexpr size_type size() { + return VECTOR_WIDTH / sizeof(float); + } + Vectorized() {} + Vectorized(svfloat32_t v) : values(v) {} + Vectorized(float val) { + values = svdup_n_f32(val); + } + template < + typename... Args, + typename = std::enable_if_t<(sizeof...(Args) == size())>> + Vectorized(Args... vals) { + __at_align__ float buffer[size()] = {vals...}; + values = svld1_f32(ptrue, buffer); + } + operator svfloat32_t() const { + return values; + } + template + static Vectorized blend( + const Vectorized& a, + const Vectorized& b) { + // Build an array of flags: each element is 1 if the corresponding bit in + // 'mask' is set, 0 otherwise. + __at_align__ int32_t flag_arr[size()]; + for (int i = 0; i < size(); i++) { + flag_arr[i] = (mask & (1ULL << i)) ? 1 : 0; + } + // Load the flag array into an SVE int32 vector. + svint32_t int_mask = svld1_s32(svptrue_b32(), flag_arr); + // Compare each lane of int_mask to 0; returns an svbool_t predicate where + // true indicates a nonzero flag. + svbool_t blend_mask = svcmpne_n_s32(svptrue_b32(), int_mask, 0); + // Use svsel to select elements from b where the predicate is true, else + // from a. + svfloat32_t result = svsel_f32(blend_mask, b.values, a.values); + return Vectorized(result); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask_) { + svbool_t mask = + svcmpeq_s32(ptrue, svreinterpret_s32_f32(mask_), ALL_S32_TRUE_MASK); + return svsel_f32(mask, b, a); + } + template + static Vectorized arange( + float base = 0.f, + step_t step = static_cast(1)) { + __at_align__ float buffer[size()]; + for (int64_t i = 0; i < size(); i++) { + buffer[i] = base + i * step; + } + return svld1_f32(ptrue, buffer); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + if (count == 0) { + return a; + } else if (count < size()) { + return svsel_f32(svwhilelt_b32(0ull, count), b, a); + } + return b; + } + // Implementation is picked from + // https://github.com/ARM-software/ComputeLibrary/blob/v25.01/src/core/NEON/SVEMath.inl#L105 + inline svfloat32_t svexp_f32_z(svbool_t pg, svfloat32_t x) const { + const auto c1 = + svreinterpret_f32_u32(svdup_n_u32(0x3f7ffff6)); // x^1: 0x1.ffffecp-1f + const auto c2 = + svreinterpret_f32_u32(svdup_n_u32(0x3efffedb)); // x^2: 0x1.fffdb6p-2f + const auto c3 = + svreinterpret_f32_u32(svdup_n_u32(0x3e2aaf33)); // x^3: 0x1.555e66p-3f + const auto c4 = + svreinterpret_f32_u32(svdup_n_u32(0x3d2b9f17)); // x^4: 0x1.573e2ep-5f + const auto c5 = + svreinterpret_f32_u32(svdup_n_u32(0x3c072010)); // x^5: 0x1.0e4020p-7f + const auto shift = svreinterpret_f32_u32( + svdup_n_u32(0x4b00007f)); // 2^23 + 127 = 0x1.0000fep23f + const auto inv_ln2 = svreinterpret_f32_u32( + svdup_n_u32(0x3fb8aa3b)); // 1 / ln(2) = 0x1.715476p+0f + const auto neg_ln2_hi = svreinterpret_f32_u32(svdup_n_u32( + 0xbf317200)); // -ln(2) from bits -1 to -19: -0x1.62e400p-1f + const auto neg_ln2_lo = svreinterpret_f32_u32(svdup_n_u32( + 0xb5bfbe8e)); // -ln(2) from bits -20 to -42: -0x1.7f7d1cp-20f + const auto inf = svdup_n_f32(std::numeric_limits::infinity()); + const auto max_input = svdup_n_f32(88.37f); // Approximately ln(2^127.5) + const auto zero = svdup_n_f32(0.f); + const auto min_input = svdup_n_f32(-86.64f); // Approximately ln(2^-125) + // Range reduction: + // e^x = 2^n * e^r + // where: + // n = floor(x / ln(2)) + // r = x - n * ln(2) + // + // By adding x / ln(2) with 2^23 + 127 (shift): + // * As FP32 fraction part only has 23-bits, the addition of 2^23 + 127 + // forces decimal part + // of x / ln(2) out of the result. The integer part of x / ln(2) (i.e. + // n) + 127 will occupy the whole fraction part of z in FP32 format. + // Subtracting 2^23 + 127 (shift) from z will result in the integer part + // of x / ln(2) (i.e. n) because the decimal part has been pushed out + // and lost. + // * The addition of 127 makes the FP32 fraction part of z ready to be + // used as the exponent + // in FP32 format. Left shifting z by 23 bits will result in 2^n. + const auto z = svmla_f32_z(pg, shift, x, inv_ln2); + const auto n = svsub_f32_z(pg, z, shift); + const auto scale = svreinterpret_f32_u32( + svlsl_n_u32_z(pg, svreinterpret_u32_f32(z), 23)); // 2^n + // The calculation of n * ln(2) is done using 2 steps to achieve accuracy + // beyond FP32. This outperforms longer Taylor series (3-4 tabs) both in + // term of accuracy and performance. + const auto r_hi = svmla_f32_z(pg, x, n, neg_ln2_hi); + const auto r = svmla_f32_z(pg, r_hi, n, neg_ln2_lo); + // Compute the truncated Taylor series of e^r. + // poly = scale * (1 + c1 * r + c2 * r^2 + c3 * r^3 + c4 * r^4 + c5 * r^5) + const auto r2 = svmul_f32_z(pg, r, r); + const auto p1 = svmul_f32_z(pg, c1, r); + const auto p23 = svmla_f32_z(pg, c2, c3, r); + const auto p45 = svmla_f32_z(pg, c4, c5, r); + const auto p2345 = svmla_f32_z(pg, p23, p45, r2); + const auto p12345 = svmla_f32_z(pg, p1, p2345, r2); + auto poly = svmla_f32_z(pg, scale, p12345, scale); + // Handle underflow and overflow. + poly = svsel_f32(svcmplt_f32(pg, x, min_input), zero, poly); + poly = svsel_f32(svcmpgt_f32(pg, x, max_input), inf, poly); + return poly; + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return svld1_f32(ptrue, reinterpret_cast(ptr)); + svbool_t pg = svwhilelt_b32(0ull, count); + return svld1_f32(pg, reinterpret_cast(ptr)); + } + void store(void* ptr, int64_t count = size()) const { + if (count == size()) { + svst1_f32(ptrue, reinterpret_cast(ptr), values); + } else { + svbool_t pg = svwhilelt_b32(0ull, count); + svst1_f32(pg, reinterpret_cast(ptr), values); + } + } + const float& operator[](int idx) const = delete; + float& operator[](int idx) = delete; + int64_t zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit + // and others are translated to 0-bit + int64_t mask = 0; + __at_align__ int32_t mask_array[size()]; + + svbool_t svbool_mask = svcmpeq_f32(ptrue, values, ZERO_F32); + svst1_s32( + ptrue, + mask_array, + svsel_s32(svbool_mask, ALL_S32_TRUE_MASK, ALL_S32_FALSE_MASK)); + for (int64_t i = 0; i < size(); ++i) { + if (mask_array[i]) + mask |= (1ull << i); + } + return mask; + } + Vectorized isnan() const { + // NaN check + svbool_t mask = svcmpuo_f32(ptrue, values, ZERO_F32); + return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + } + bool has_inf_nan() const { + return svptest_any( + ptrue, + svcmpuo_f32(ptrue, svsub_f32_x(ptrue, values, values), ZERO_F32)); + } + Vectorized map(float (*f)(float)) const { + __at_align__ float tmp[size()]; + store(tmp); + for (int64_t i = 0; i < size(); ++i) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + Vectorized abs() const { + return svabs_f32_x(ptrue, values); + } + Vectorized angle() const { + const auto nan_vec = svdup_n_f32(NAN); + const auto nan_mask = svcmpuo_f32(ptrue, values, ZERO_F32); + const auto pi = svdup_n_f32(c10::pi); + + const auto neg_mask = svcmplt_f32(ptrue, values, ZERO_F32); + auto angle = svsel_f32(neg_mask, pi, ZERO_F32); + angle = svsel_f32(nan_mask, nan_vec, angle); + return angle; + } + Vectorized real() const { + return values; + } + Vectorized imag() const { + return Vectorized(0.f); + } + Vectorized conj() const { + return values; + } + Vectorized acos() const { + return USE_SLEEF( + Vectorized(Sleef_acosfx_u10sve(values)), map(std::acos)); + } + Vectorized acosh() const { + return USE_SLEEF( + Vectorized(Sleef_acoshfx_u10sve(values)), map(std::acosh)); + } + Vectorized asin() const { + return USE_SLEEF( + Vectorized(Sleef_asinfx_u10sve(values)), map(std::asin)); + } + Vectorized asinh() const { + return USE_SLEEF( + Vectorized(Sleef_asinhfx_u10sve(values)), map(std::asinh)); + } + Vectorized atan() const { + return USE_SLEEF( + Vectorized(Sleef_atanfx_u10sve(values)), map(std::atan)); + } + Vectorized atanh() const { + return USE_SLEEF( + Vectorized(Sleef_atanhfx_u10sve(values)), map(std::atanh)); + } + Vectorized atan2(const Vectorized& b) const {USE_SLEEF( + { return Vectorized(Sleef_atan2fx_u10sve(values, b)); }, + { + __at_align__ float tmp[size()]; + __at_align__ float tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::atan2(tmp[i], tmp_b[i]); + } + return loadu(tmp); + })} Vectorized copysign(const Vectorized& sign) const { + + USE_SLEEF( + { return Vectorized(Sleef_copysignfx_sve(values, sign)); }, + { + __at_align__ float tmp[size()]; + __at_align__ float tmp_sign[size()]; + store(tmp); + sign.store(tmp_sign); + for (int64_t i = 0; i < size(); ++i) { + tmp[i] = std::copysign(tmp[i], tmp_sign[i]); + } + return loadu(tmp); + })} Vectorized erf() const { + return USE_SLEEF( + Vectorized(Sleef_erffx_u10sve(values)), map(std::erf)); + } + Vectorized erfc() const { + return USE_SLEEF( + Vectorized(Sleef_erfcfx_u15sve(values)), map(std::erfc)); + } + Vectorized erfinv() const { + return map(calc_erfinv); + } + Vectorized exp() const { + return USE_SLEEF( + Vectorized(Sleef_expfx_u10sve(values)), map(std::exp)); + } + Vectorized exp2() const { + return USE_SLEEF( + Vectorized(Sleef_exp2fx_u10sve(values)), map(std::exp2)); + } + Vectorized expm1() const { + return USE_SLEEF( + Vectorized(Sleef_expm1fx_u10sve(values)), map(std::expm1)); + } + Vectorized exp_u20() const { + return exp(); + } + Vectorized fmod(const Vectorized& q) const {USE_SLEEF( + { return Vectorized(Sleef_fmodfx_sve(values, q)); }, + { + __at_align__ float tmp[size()]; + __at_align__ float tmp_q[size()]; + store(tmp); + q.store(tmp_q); + for (int64_t i = 0; i < size(); ++i) { + tmp[i] = std::fmod(tmp[i], tmp_q[i]); + } + return loadu(tmp); + })} Vectorized hypot(const Vectorized& b) const { + USE_SLEEF( + { return Vectorized(Sleef_hypotfx_u05sve(values, b)); }, + { + __at_align__ float tmp[size()]; + __at_align__ float tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::hypot(tmp[i], tmp_b[i]); + } + return loadu(tmp); + })} Vectorized i0() const { + return map(calc_i0); + } + Vectorized i0e() const { + return map(calc_i0e); + } + Vectorized digamma() const { + return map(calc_digamma); + } + Vectorized igamma(const Vectorized& x) const { + __at_align__ float tmp[size()]; + __at_align__ float tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized igammac(const Vectorized& x) const { + __at_align__ float tmp[size()]; + __at_align__ float tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igammac(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized nextafter(const Vectorized& b) const {USE_SLEEF( + { return Vectorized(Sleef_nextafterfx_sve(values, b)); }, + { + __at_align__ float tmp[size()]; + __at_align__ float tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); ++i) { + tmp[i] = std::nextafter(tmp[i], tmp_b[i]); + } + return loadu(tmp); + })} Vectorized log() const { + return USE_SLEEF( + Vectorized(Sleef_logfx_u10sve(values)), map(std::log)); + } + Vectorized log2() const { + return USE_SLEEF( + Vectorized(Sleef_log2fx_u10sve(values)), map(std::log2)); + } + Vectorized log10() const { + return USE_SLEEF( + Vectorized(Sleef_log10fx_u10sve(values)), map(std::log10)); + } + Vectorized log1p() const { + return USE_SLEEF( + Vectorized(Sleef_log1pfx_u10sve(values)), map(std::log1p)); + } + Vectorized frac() const; + Vectorized sin() const { + return USE_SLEEF( + Vectorized(Sleef_sinfx_u10sve(values)), map(std::sin)); + } + Vectorized sinh() const { + return USE_SLEEF( + Vectorized(Sleef_sinhfx_u10sve(values)), map(std::sinh)); + } + Vectorized cos() const { + return USE_SLEEF( + Vectorized(Sleef_cosfx_u10sve(values)), map(std::cos)); + } + Vectorized cosh() const { + return USE_SLEEF( + Vectorized(Sleef_coshfx_u10sve(values)), map(std::cosh)); + } + Vectorized ceil() const { + return svrintp_f32_x(ptrue, values); + } + Vectorized floor() const { + return svrintm_f32_x(ptrue, values); + } + Vectorized neg() const { + return svneg_f32_x(ptrue, values); + } + Vectorized round() const { + return svrinti_f32_x(ptrue, values); + } + Vectorized tan() const { + return USE_SLEEF( + Vectorized(Sleef_tanfx_u10sve(values)), map(std::tan)); + } + // Implementation is picked from + // https://github.com/ARM-software/ComputeLibrary/blob/v25.01/src/core/NEON/SVEMath.inl#L179 + Vectorized tanh() const { + // Constants used for the tanh calculation. + const svfloat32_t CONST_1 = + svdup_n_f32(1.f); // Constant 1.0f for the tanh formula. + const svfloat32_t CONST_2 = svdup_n_f32( + 2.f); // Constant 2.0f for the tanh formula (used in exp(2x)). + const svfloat32_t CONST_MIN_TANH = svdup_n_f32( + -10.f); // Minimum threshold for input values to prevent overflow. + const svfloat32_t CONST_MAX_TANH = svdup_n_f32( + 10.f); // Maximum threshold for input values to prevent overflow. + + // Step 1: Clamp the values within the range [-10, 10] to prevent overflow + // during exponentiation. The tanh function approaches ±1 rapidly as the + // input grows large, so we limit the input range to avoid numerical + // instability. svmax_f32_z ensures values are greater than -10, and + // svmin_f32_z ensures they are less than 10. + svfloat32_t x = svmin_f32_z( + ptrue, svmax_f32_z(ptrue, values, CONST_MIN_TANH), CONST_MAX_TANH); + + // Step 2: Calculate exp(2 * x), where x is the clamped value. + // svmul_f32_z computes 2 * x, and svexp_f32_z computes the exponential of + // the result. + svfloat32_t exp2x = svexp_f32_z(ptrue, svmul_f32_z(ptrue, CONST_2, x)); + + // Step 3: Calculate the numerator of the tanh function, which is exp(2x) + // - 1. + svfloat32_t num = svsub_f32_z(ptrue, exp2x, CONST_1); + + // Step 4: Calculate the denominator of the tanh function, which is exp(2x) + // + 1. + svfloat32_t den = svadd_f32_z(ptrue, exp2x, CONST_1); + + // Step 5: Calculate the tanh function as the ratio of the numerator and + // denominator: num / den. + svfloat32_t tanh = svdiv_f32_z(ptrue, num, den); + + // Return the calculated tanh values. + return tanh; + } + Vectorized trunc() const { + return svrintz_f32_x(ptrue, values); + } + Vectorized lgamma() const { + return USE_SLEEF( + Vectorized(Sleef_lgammafx_u10sve(values)), map(std::lgamma)); + } + Vectorized sqrt() const { + return svsqrt_f32_x(ptrue, values); + } + Vectorized reciprocal() const { + return svdivr_f32_x(ptrue, values, ONE_F32); + } + Vectorized rsqrt() const { + return svdivr_f32_x(ptrue, svsqrt_f32_x(ptrue, values), ONE_F32); + } + Vectorized pow(const Vectorized& b) const {USE_SLEEF( + { return Vectorized(Sleef_powfx_u10sve(values, b)); }, + { + __at_align__ float tmp[size()]; + __at_align__ float tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::pow(tmp[i], tmp_b[i]); + } + return loadu(tmp); + })} // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized operator==(const Vectorized& other) const { + svbool_t mask = svcmpeq_f32(ptrue, values, other); + return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + } + + Vectorized operator!=(const Vectorized& other) const { + svbool_t mask = svcmpne_f32(ptrue, values, other); + return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + } + + Vectorized operator<(const Vectorized& other) const { + svbool_t mask = svcmplt_f32(ptrue, values, other); + return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + } + + Vectorized operator<=(const Vectorized& other) const { + svbool_t mask = svcmple_f32(ptrue, values, other); + return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + } + + Vectorized operator>(const Vectorized& other) const { + svbool_t mask = svcmpgt_f32(ptrue, values, other); + return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + } + + Vectorized operator>=(const Vectorized& other) const { + svbool_t mask = svcmpge_f32(ptrue, values, other); + return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return svadd_f32_x(ptrue, a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return svsub_f32_x(ptrue, a, b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return svmul_f32_x(ptrue, a, b); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return svdiv_f32_x(ptrue, a, b); +} + +// frac. Implement this here so we can use subtraction +Vectorized inline Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return svmax_f32_x(ptrue, a, b); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return svmin_f32_x(ptrue, a, b); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + return svmin_f32_x(ptrue, max, svmax_f32_x(ptrue, min, a)); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + return svmin_f32_x(ptrue, max, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + return svmax_f32_x(ptrue, min, a); +} + +template <> +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return svreinterpret_f32_s32( + svand_s32_x(ptrue, svreinterpret_s32_f32(a), svreinterpret_s32_f32(b))); +} + +template <> +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return svreinterpret_f32_s32( + svorr_s32_x(ptrue, svreinterpret_s32_f32(a), svreinterpret_s32_f32(b))); +} + +template <> +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return svreinterpret_f32_s32( + sveor_s32_x(ptrue, svreinterpret_s32_f32(a), svreinterpret_s32_f32(b))); +} + +Vectorized inline Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +template <> +inline void convert(const float* src, float* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svst1_f32(ptrue, dst + i, svldnt1_f32(ptrue, src + i)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + svbool_t pg = svwhilelt_b32(i, n); + svst1_f32(pg, dst + i, svldnt1_f32(pg, src + i)); + } +} + +template <> +inline void convert(const float* src, at::Half* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg_16 = svwhilelt_b16(0ull, Vectorized::size()); + svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svfloat16_t src_vec = svuzp1_f16( + svcvt_f16_f32_x(ptrue, svldnt1_f32(pg_32, src + i)), ZERO_F16); + svst1_f16(pg_16, reinterpret_cast(dst) + i, src_vec); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg_16 = svwhilelt_b16(i, n); + pg_32 = svwhilelt_b32(i, n); + svfloat16_t src_vec = svuzp1_f16( + svcvt_f16_f32_x(ptrue, svldnt1_f32(pg_32, src + i)), ZERO_F16); + svst1_f16(pg_16, reinterpret_cast(dst) + i, src_vec); + } +} + +template <> +inline void convert(const at::Half* src, float* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg_16 = svwhilelt_b16(0ull, Vectorized::size()); + svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svfloat16_t src_vec = svzip1_f16( + svldnt1_f16(pg_16, reinterpret_cast(src) + i), + ZERO_F16); + svst1_f32(pg_32, dst + i, svcvt_f32_f16_x(ptrue, src_vec)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg_16 = svwhilelt_b16(i, n); + pg_32 = svwhilelt_b32(i, n); + svfloat16_t src_vec = svzip1_f16( + svldnt1_f16(pg_16, reinterpret_cast(src) + i), + ZERO_F16); + svst1_f32(pg_32, dst + i, svcvt_f32_f16_x(ptrue, src_vec)); + } +} + +template <> +inline void convert(const bool* src, float* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg_8 = svwhilelt_b8(0ull, Vectorized::size()); + svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svuint8_t src_vec_u8 = + svldnt1_u8(pg_8, reinterpret_cast(src) + i); + svuint32_t src_vec_u32 = svunpklo_u32(svunpklo_u16(src_vec_u8)); + svbool_t mask = svcmpne_u32(pg_32, src_vec_u32, ZERO_U32); + svst1_f32(pg_32, dst + i, svsel_f32(mask, ONE_F32, ZERO_F32)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg_8 = svwhilelt_b8(i, n); + pg_32 = svwhilelt_b32(i, n); + svuint8_t src_vec_u8 = + svldnt1_u8(pg_8, reinterpret_cast(src) + i); + svuint32_t src_vec_u32 = svunpklo_u32(svunpklo_u16(src_vec_u8)); + svbool_t mask = svcmpne_u32(pg_32, src_vec_u32, ZERO_U32); + svst1_f32(pg_32, dst + i, svsel_f32(mask, ONE_F32, ZERO_F32)); + } +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return svmad_f32_x(ptrue, a, b, c); +} + +#endif // defined(CPU_CAPABILITY_SVE) + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/sve/vec_int.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/sve/vec_int.h new file mode 100644 index 0000000000000000000000000000000000000000..dfea7d2ca4f597ef4661e25eda57a52ff7ac9e22 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/sve/vec_int.h @@ -0,0 +1,497 @@ +#pragma once + +#include +#include +#include + +namespace at::vec { +// Note [CPU_CAPABILITY namespace] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This header, and all of its subheaders, will be compiled with +// different architecture flags for each supported set of vector +// intrinsics. So we need to make sure they aren't inadvertently +// linked together. We do this by declaring objects in an `inline +// namespace` which changes the name mangling, but can still be +// accessed as `at::vec`. +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_SVE) + +#define VEC_INT_SVE_TEMPLATE(vl, bit) \ + template <> \ + struct is_vec_specialized_for : std::bool_constant {}; \ + \ + template <> \ + class Vectorized { \ + private: \ + vls_int##bit##_t values; \ + \ + public: \ + using value_type = int##bit##_t; \ + using size_type = int; \ + static constexpr size_type size() { \ + return vl; \ + } \ + Vectorized() {} \ + Vectorized(svint##bit##_t v) : values(v) {} \ + Vectorized(int##bit##_t val) { \ + values = svdup_n_s##bit(val); \ + } \ + template < \ + typename... Args, \ + typename = std::enable_if_t<(sizeof...(Args) == size())>> \ + Vectorized(Args... vals) { \ + __at_align__ int##bit##_t buffer[size()] = {vals...}; \ + values = svld1_s##bit(ptrue, buffer); \ + } \ + operator svint##bit##_t() const { \ + return values; \ + } \ + template \ + static Vectorized blend( \ + const Vectorized& a, \ + const Vectorized& b) { \ + __at_align__ int##bit##_t flag_arr[size()]; \ + for (int i = 0; i < size(); ++i) { \ + flag_arr[i] = (i < 64 && (mask & (1ULL << i))) ? 1 : 0; \ + } \ + svbool_t blend_mask = svcmpne_n_s##bit( \ + svptrue_b##bit(), svld1_s##bit(svptrue_b##bit(), flag_arr), 0); \ + return Vectorized( \ + svsel_s##bit(blend_mask, b.values, a.values)); \ + } \ + static Vectorized blendv( \ + const Vectorized& a, \ + const Vectorized& b, \ + const Vectorized& mask_) { \ + svbool_t mask = svcmpeq_s##bit(ptrue, mask_, ALL_S##bit##_TRUE_MASK); \ + return svsel_s##bit(mask, b, a); \ + } \ + /* step sometimes requires a higher precision type (e.g., T=int, \ + * step_t=double) */ \ + template \ + static Vectorized arange( \ + int##bit##_t base = 0, \ + step_t step = static_cast(1)) { \ + __at_align__ int##bit##_t buffer[size()]; \ + for (int64_t i = 0; i < size(); i++) { \ + buffer[i] = base + i * step; \ + } \ + return svld1_s##bit(ptrue, buffer); \ + } \ + static Vectorized set( \ + const Vectorized& a, \ + const Vectorized& b, \ + int##bit##_t count = size()) { \ + if (count == 0) { \ + return a; \ + } else if (count < size()) { \ + return svsel_s##bit(svwhilelt_b##bit(0ull, count), b, a); \ + } \ + return b; \ + } \ + static Vectorized loadu( \ + const void* ptr, \ + int64_t count = size()) { \ + if (count == size()) \ + return svld1_s##bit( \ + ptrue, reinterpret_cast(ptr)); \ + svbool_t pg = svwhilelt_b##bit(0ull, count); \ + return svld1_s##bit(pg, reinterpret_cast(ptr)); \ + } \ + void store(void* ptr, int64_t count = size()) const { \ + if (count == size()) { \ + svst1_s##bit(ptrue, reinterpret_cast(ptr), values); \ + } else { \ + svbool_t pg = svwhilelt_b##bit(0ull, count); \ + svst1_s##bit(pg, reinterpret_cast(ptr), values); \ + } \ + } \ + const int##bit##_t& operator[](int idx) const = delete; \ + int##bit##_t& operator[](int idx) = delete; \ + Vectorized abs() const { \ + return svabs_s##bit##_x(ptrue, values); \ + } \ + Vectorized real() const { \ + return values; \ + } \ + Vectorized imag() const { \ + return svdup_n_s##bit(0); \ + } \ + Vectorized conj() const { \ + return values; \ + } \ + Vectorized frac() const; \ + Vectorized neg() const { \ + return svneg_s##bit##_x(ptrue, values); \ + } \ + Vectorized operator==( \ + const Vectorized& other) const { \ + svbool_t mask = svcmpeq_s##bit(ptrue, values, other); \ + return svsel_s##bit( \ + mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \ + } \ + Vectorized operator!=( \ + const Vectorized& other) const { \ + svbool_t mask = svcmpne_s##bit(ptrue, values, other); \ + return svsel_s##bit( \ + mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \ + } \ + Vectorized operator<( \ + const Vectorized& other) const { \ + svbool_t mask = svcmplt_s##bit(ptrue, values, other); \ + return svsel_s##bit( \ + mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \ + } \ + Vectorized operator<=( \ + const Vectorized& other) const { \ + svbool_t mask = svcmple_s##bit(ptrue, values, other); \ + return svsel_s##bit( \ + mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \ + } \ + Vectorized operator>( \ + const Vectorized& other) const { \ + svbool_t mask = svcmpgt_s##bit(ptrue, values, other); \ + return svsel_s##bit( \ + mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \ + } \ + Vectorized operator>=( \ + const Vectorized& other) const { \ + svbool_t mask = svcmpge_s##bit(ptrue, values, other); \ + return svsel_s##bit( \ + mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \ + } \ + Vectorized eq(const Vectorized& other) const; \ + Vectorized ne(const Vectorized& other) const; \ + Vectorized gt(const Vectorized& other) const; \ + Vectorized ge(const Vectorized& other) const; \ + Vectorized lt(const Vectorized& other) const; \ + Vectorized le(const Vectorized& other) const; \ + }; \ + template <> \ + Vectorized inline operator+( \ + const Vectorized& a, const Vectorized& b) { \ + return svadd_s##bit##_x(ptrue, a, b); \ + } \ + template <> \ + Vectorized inline operator-( \ + const Vectorized& a, const Vectorized& b) { \ + return svsub_s##bit##_x(ptrue, a, b); \ + } \ + template <> \ + Vectorized inline operator*( \ + const Vectorized& a, const Vectorized& b) { \ + return svmul_s##bit##_x(ptrue, a, b); \ + } \ + template <> \ + Vectorized inline maximum( \ + const Vectorized& a, const Vectorized& b) { \ + return svmax_s##bit##_x(ptrue, a, b); \ + } \ + template <> \ + Vectorized inline minimum( \ + const Vectorized& a, const Vectorized& b) { \ + return svmin_s##bit##_x(ptrue, a, b); \ + } \ + template <> \ + Vectorized inline clamp( \ + const Vectorized& a, \ + const Vectorized& min, \ + const Vectorized& max) { \ + return svmin_s##bit##_x(ptrue, max, svmax_s##bit##_x(ptrue, min, a)); \ + } \ + template <> \ + Vectorized inline clamp_max( \ + const Vectorized& a, \ + const Vectorized& max) { \ + return svmin_s##bit##_x(ptrue, max, a); \ + } \ + template <> \ + Vectorized inline clamp_min( \ + const Vectorized& a, \ + const Vectorized& min) { \ + return svmax_s##bit##_x(ptrue, min, a); \ + } \ + template <> \ + Vectorized inline operator&( \ + const Vectorized& a, const Vectorized& b) { \ + return svand_s##bit##_x(ptrue, a, b); \ + } \ + template <> \ + Vectorized inline operator|( \ + const Vectorized& a, const Vectorized& b) { \ + return svorr_s##bit##_x(ptrue, a, b); \ + } \ + template <> \ + Vectorized inline operator^( \ + const Vectorized& a, const Vectorized& b) { \ + return sveor_s##bit##_x(ptrue, a, b); \ + } \ + template <> \ + inline Vectorized operator~( \ + const Vectorized& a) { \ + return sveor_s##bit##_x(ptrue, a, svdup_n_s##bit(-1)); \ + } \ + Vectorized inline Vectorized::eq( \ + const Vectorized& other) const { \ + return (*this == other) & Vectorized(1); \ + } \ + Vectorized inline Vectorized::ne( \ + const Vectorized& other) const { \ + return (*this != other) & Vectorized(1); \ + } \ + Vectorized inline Vectorized::gt( \ + const Vectorized& other) const { \ + return (*this > other) & Vectorized(1); \ + } \ + Vectorized inline Vectorized::ge( \ + const Vectorized& other) const { \ + return (*this >= other) & Vectorized(1); \ + } \ + Vectorized inline Vectorized::lt( \ + const Vectorized& other) const { \ + return (*this < other) & Vectorized(1); \ + } \ + Vectorized inline Vectorized::le( \ + const Vectorized& other) const { \ + return (*this <= other) & Vectorized(1); \ + } + +VEC_INT_SVE_TEMPLATE(VECTOR_WIDTH / sizeof(int64_t), 64) +VEC_INT_SVE_TEMPLATE(VECTOR_WIDTH / sizeof(int32_t), 32) +VEC_INT_SVE_TEMPLATE(VECTOR_WIDTH / sizeof(int16_t), 16) +VEC_INT_SVE_TEMPLATE(VECTOR_WIDTH / sizeof(int8_t), 8) + +template +Vectorized inline intdiv_nosve( + const Vectorized& a, + const Vectorized& b) { + T values_a[Vectorized::size()]; + T values_b[Vectorized::size()]; + a.store(values_a); + b.store(values_b); + for (int i = 0; i != Vectorized::size(); i++) { + values_a[i] /= values_b[i]; + } + return Vectorized::loadu(values_a); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return svdiv_s64_x(ptrue, a, b); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return svdiv_s32_x(ptrue, a, b); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return intdiv_nosve(a, b); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return intdiv_nosve(a, b); +} + +template <> +inline void convert(const int32_t* src, int64_t* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); + svbool_t pg_64 = svwhilelt_b64(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) + svst1_s64(pg_64, dst + i, svunpklo_s64(svldnt1_s32(pg_32, src + i))); +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg_32 = svwhilelt_b32(i, n); + pg_64 = svwhilelt_b64(i, n); + svst1_s64(pg_64, dst + i, svunpklo_s64(svldnt1_s32(pg_32, src + i))); + } +} + +template <> +inline void convert(const int64_t* src, float* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); + svbool_t pg_64 = svwhilelt_b64(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svint64_t src_vec_s64 = svldnt1_s64(pg_64, src + i); + svfloat32_t src_vec_f32 = + svuzp1_f32(svcvt_f32_s64_x(pg_64, src_vec_s64), ZERO_F32); + svst1_f32(pg_32, dst + i, src_vec_f32); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg_32 = svwhilelt_b32(i, n); + pg_64 = svwhilelt_b64(i, n); + svint64_t src_vec_s64 = svldnt1_s64(pg_64, src + i); + svfloat32_t src_vec_f32 = + svuzp1_f32(svcvt_f32_s64_x(pg_64, src_vec_s64), ZERO_F32); + svst1_f32(pg_32, dst + i, src_vec_f32); + } +} + +template <> +inline void convert(const int32_t* src, float* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg = svwhilelt_b32(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svint32_t src_vec = svldnt1_s32(pg, src + i); + svst1_f32(pg, dst + i, svcvt_f32_s32_x(pg, src_vec)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg = svwhilelt_b32(i, n); + svint32_t src_vec = svldnt1_s32(pg, src + i); + svst1_f32(pg, dst + i, svcvt_f32_s32_x(pg, src_vec)); + } +} + +template <> +inline void convert(const bool* src, int64_t* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg_8 = svwhilelt_b8(0ull, Vectorized::size()); + svbool_t pg_64 = svwhilelt_b64(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svuint8_t src_vec_u8 = + svldnt1_u8(pg_8, reinterpret_cast(src) + i); + svuint64_t src_vec_u64 = + svunpklo_u64(svunpklo_u32(svunpklo_u16(src_vec_u8))); + svbool_t mask = svcmpne_u64(pg_64, src_vec_u64, ZERO_U64); + svst1_s64(pg_64, dst + i, svsel_s64(mask, ONE_S64, ZERO_S64)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg_8 = svwhilelt_b8(i, n); + pg_64 = svwhilelt_b64(i, n); + svuint8_t src_vec_u8 = + svldnt1_u8(pg_8, reinterpret_cast(src) + i); + svuint64_t src_vec_u64 = + svunpklo_u64(svunpklo_u32(svunpklo_u16(src_vec_u8))); + svbool_t mask = svcmpne_u64(pg_64, src_vec_u64, ZERO_U64); + svst1_s64(pg_64, dst + i, svsel_s64(mask, ONE_S64, ZERO_S64)); + } +} + +template <> +inline void convert(const bool* src, int32_t* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg_8 = svwhilelt_b8(0ull, Vectorized::size()); + svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svuint8_t src_vec_u8 = + svldnt1_u8(pg_8, reinterpret_cast(src) + i); + svuint32_t src_vec_u32 = svunpklo_u32(svunpklo_u16(src_vec_u8)); + svbool_t mask = svcmpne_u32(pg_32, src_vec_u32, ZERO_U32); + svst1_s32(pg_32, dst + i, svsel_s32(mask, ONE_S32, ZERO_S32)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg_8 = svwhilelt_b8(i, n); + pg_32 = svwhilelt_b32(i, n); + svuint8_t src_vec_u8 = + svldnt1_u8(pg_8, reinterpret_cast(src) + i); + svuint32_t src_vec_u32 = svunpklo_u32(svunpklo_u16(src_vec_u8)); + svbool_t mask = svcmpne_u32(pg_32, src_vec_u32, ZERO_U32); + svst1_s32(pg_32, dst + i, svsel_s32(mask, ONE_S32, ZERO_S32)); + } +} + +template <> +inline void convert(const uint8_t* src, bool* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg = svwhilelt_b8(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svbool_t mask = svcmpne_u8(pg, svldnt1_u8(pg, src + i), ZERO_U8); + svst1_u8( + pg, + reinterpret_cast(dst) + i, + svsel_u8(mask, ALL_U8_TRUE_MASK, ALL_U8_FALSE_MASK)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg = svwhilelt_b8(i, n); + svbool_t mask = svcmpne_u8(pg, svldnt1_u8(pg, src + i), ZERO_U8); + svst1_u8( + pg, + reinterpret_cast(dst) + i, + svsel_u8(mask, ALL_U8_TRUE_MASK, ALL_U8_FALSE_MASK)); + } +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return svlsl_s64_x(ptrue, a, svreinterpret_u64_s64(b)); +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return svlsl_s32_x(ptrue, a, svreinterpret_u32_s32(b)); +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return svlsl_s16_x(ptrue, a, svreinterpret_u16_s16(b)); +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return svlsl_s8_x(ptrue, a, svreinterpret_u8_s8(b)); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return svasr_s64_x(ptrue, a, svreinterpret_u64_s64(b)); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return svasr_s32_x(ptrue, a, svreinterpret_u32_s32(b)); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return svasr_s16_x(ptrue, a, svreinterpret_u16_s16(b)); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return svasr_s8_x(ptrue, a, svreinterpret_u8_s8(b)); +} + +#endif // defined(CPU_CAPABILITY_SVE) + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/sve/vec_qint.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/sve/vec_qint.h new file mode 100644 index 0000000000000000000000000000000000000000..92fa9ba89c0644f114e5d84e518d01c3894ac44f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/sve/vec_qint.h @@ -0,0 +1,606 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with SVE] + +#include +#include +#include +#include +#include +#include + +#include + +// This file defines Vectorized<> for the quantized types. +// +// +// Currently, we simply use these classes as efficient converters between +// the quantized types and Vectorized, usually in bandwidth-bound cases +// where doing the arithmetic in full-precision is acceptable (e.g. +// elementwise operators). +// +// +// Conversions are as follows: +// Vectorized -> 4x Vectorized +// Vectorized -> 4x Vectorized +// Vectorized -> 1x Vectorized +// +// The size of the returned float vector is specified by the special +// constexpr function float_num_vecs. The type of the value returned +// from dequantize (and expected as an argument to quantize) is +// specified by float_vec_return_type. +// +// When writing kernels with these vectors, it is expected that floating- +// point operations will be carried out in a loop over +// Vectorized::float_num_vecs iterations. + +namespace at::vec { +// Note [CPU_CAPABILITY namespace] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This header, and all of its subheaders, will be compiled with +// different architecture flags for each supported set of vector +// intrinsics. So we need to make sure they aren't inadvertently +// linked together. We do this by declaring objects in an `inline +// namespace` which changes the name mangling, but can still be +// accessed as `at::vec`. +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_SVE) + +// NOTE: These are low-performance implementations that we fall back on +// if we are not building with SVE. This may not be an issue, because +// currently for quantization we assume the user has at least SVE +// installed, so these can simply act as a reference implementation. +// +// If in the future we relax this requirement (SVE+), we should probably +// revisit these implementations + +template < + typename T, + typename float_vec_return_type_, + typename int_vec_return_type_, + int size_> +struct VectorizedQuantizedConverter { + using size_type = int; + static constexpr size_type size() { + return size_; + } + + static constexpr int float_num_vecs() { + return size() / Vectorized::size(); + } + + static constexpr int int_num_vecs() { + return size() / Vectorized::size(); + } + + using float_vec_return_type = float_vec_return_type_; + using int_vec_return_type = int_vec_return_type_; + + using value_type = typename T::underlying; + std::array vals; + + VectorizedQuantizedConverter(T val) { + for (size_t i = 0; i < size(); ++i) { + vals[i] = val.val_; + } + } + + VectorizedQuantizedConverter(const void* ptr) { + memcpy(vals.data(), ptr, sizeof(value_type) * size()); + } + + void store(void* ptr, int count = size()) const { + memcpy(ptr, vals.data(), count * sizeof(value_type)); + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_zp_premul) const { + float_vec_return_type rv; + float tmp_scale[Vectorized::size()]; + float tmp_zero_point[Vectorized::size()]; + scale.store(tmp_scale); + zero_point.store(tmp_zero_point); + for (int i = 0; i < float_num_vecs(); ++i) { + float tmp_vals[Vectorized::size()]; + for (int j = 0; j < Vectorized::size(); ++j) { + tmp_vals[j] = at::native::dequantize_val( + tmp_scale[j], + tmp_zero_point[j], + T(vals[Vectorized::size() * i + j])); + } + rv[i] = Vectorized::loadu(tmp_vals); + } + return rv; + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + float_vec_return_type rv; + float tmp_scale[Vectorized::size()]; + float tmp_zero_point[Vectorized::size()]; + scale.store(tmp_scale); + zero_point.store(tmp_zero_point); + for (int i = 0; i < float_num_vecs(); ++i) { + float tmp_vals[Vectorized::size()]; + for (int j = 0; j < Vectorized::size(); ++j) { + tmp_vals[j] = at::native::dequantize_val( + tmp_scale[j], + tmp_zero_point[j], + T(vals[Vectorized::size() * i + j])); + } + rv[i] = Vectorized::loadu(tmp_vals); + } + return rv; + } + + protected: + VectorizedQuantizedConverter() {} +}; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + VECTOR_WIDTH / 4> { + Vectorized() + : VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + VECTOR_WIDTH / 4>() {} + Vectorized(c10::qint32 val) + : VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + VECTOR_WIDTH / 4>(val) {} + Vectorized(const void* ptr) + : VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + VECTOR_WIDTH / 4>(ptr) {} +#if 1 + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return loadu(tmp_values); + } +#else + static Vectorized loadu( + const void* ptr, + int64_t count = size()) { + if (count == size()) + return svld1_s32(ptrue, reinterpret_cast(ptr)); + svbool_t pg = svwhilelt_b32(0ull, count); + return svld1_s32(pg, reinterpret_cast(ptr)); + } +#endif + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + std::array qvals; + std::array::size()> float_vals; + + for (int i = 0; i < float_num_vecs(); ++i) { + rhs[i].store( + &float_vals[i * Vectorized::size()], + Vectorized::size()); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::qint32*)qvals.data(), + Vectorized::size() * float_num_vecs()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + for (size_t i = 0; i < size(); ++i) { + retval[0].vals[i] = vals[i] - b.vals[i]; + } + return retval; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = + nearbyint(static_cast(inp[0].vals[i]) * multiplier) + + zero_point; + } + return retval; + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + Vectorized retval; + for (size_t i = 0; i < std::decay_t::size(); ++i) { + retval.vals[i] = a.vals[i] * b.vals[i]; + } + return retval; +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + Vectorized retval; + for (size_t i = 0; i < std::decay_t::size(); ++i) { + retval.vals[i] = a.vals[i] + b.vals[i]; + } + return retval; +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH> { + Vectorized() + : VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH>() {} + Vectorized(c10::qint8 val) + : VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH>(val) {} + Vectorized(const void* ptr) + : VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH>(ptr) {} + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return loadu(tmp_values); + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + std::array qvals; + std::array::size()> float_vals; + + for (int i = 0; i < float_num_vecs(); ++i) { + rhs[i].store( + &float_vals[i * Vectorized::size()], + Vectorized::size()); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::qint8*)qvals.data(), + Vectorized::size() * float_num_vecs()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + constexpr int elem_per_int_vec = size() / int_num_vecs(); + for (size_t i = 0; i < int_num_vecs(); ++i) { + for (size_t j = 0; j < elem_per_int_vec; ++j) { + retval[i].vals[j] = + static_cast(vals[i * elem_per_int_vec + j]) - + static_cast(b.vals[i * elem_per_int_vec + j]); + } + } + return retval; + } + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + constexpr int elem_per_int_vec = size() / int_num_vecs(); + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + Vectorized retval; + for (size_t i = 0; i < int_num_vecs(); ++i) { + for (size_t j = 0; j < elem_per_int_vec; ++j) { + int32_t rounded = + nearbyint(static_cast(inp[i].vals[j]) * multiplier) + + zero_point; + retval.vals[i * elem_per_int_vec + j] = + std::min(std::max(rounded, min_val), max_val); + } + } + return retval; + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH> { + Vectorized() + : VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH>() {} + Vectorized(c10::quint8 val) + : VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH>(val) {} + Vectorized(const void* ptr) + : VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH>(ptr) {} +#if 1 + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return loadu(tmp_values); + } +#else + static Vectorized loadu( + const void* ptr, + int64_t count = size()) { + if (count == size()) + return svld1_u8(ptrue, reinterpret_cast(ptr)); + svbool_t pg = svwhilelt_b8(0ull, count); + return svld1_u8(pg, reinterpret_cast(ptr)); + } +#endif + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + std::array qvals; + std::array::size()> float_vals; + + for (int i = 0; i < float_num_vecs(); ++i) { + rhs[i].store( + &float_vals[i * Vectorized::size()], + Vectorized::size()); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::quint8*)qvals.data(), + Vectorized::size() * float_num_vecs()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + constexpr int elem_per_int_vec = size() / int_num_vecs(); + for (size_t i = 0; i < int_num_vecs(); ++i) { + for (size_t j = 0; j < elem_per_int_vec; ++j) { + retval[i].vals[j] = + static_cast(vals[i * elem_per_int_vec + j]) - + static_cast(b.vals[i * elem_per_int_vec + j]); + } + } + return retval; + } + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + constexpr int elem_per_int_vec = size() / int_num_vecs(); + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + Vectorized retval; + for (size_t i = 0; i < int_num_vecs(); ++i) { + for (size_t j = 0; j < elem_per_int_vec; ++j) { + int32_t rounded = + nearbyint(static_cast(inp[i].vals[j]) * multiplier) + + zero_point; + retval.vals[i * elem_per_int_vec + j] = + std::min(std::max(rounded, min_val), max_val); + } + } + return retval; + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +#endif // defined(CPU_CAPABILITY_SVE) + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec.h new file mode 100644 index 0000000000000000000000000000000000000000..dff77adfdaed051652d535f31c49a335b4340e82 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec.h @@ -0,0 +1,57 @@ +#pragma once + +#if defined(CPU_CAPABILITY_AVX512) +#include +#else +#include +#include +#endif + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +inline Vectorized convert_to_bool(Vectorized x) { + __at_align__ bool buffer[x.size()]; + x.ne(Vectorized(0)).store(buffer); + + Vectorized ret; + static_assert(x.size() == ret.size()); + std::memcpy(ret, buffer, ret.size() * sizeof(bool)); + return ret; +} + +template <> +inline Vectorized Vectorized::loadu(const void* ptr) { + // See NOTE [Loading boolean values] + return convert_to_bool(Vectorized::loadu(ptr)); +} + +template <> +inline Vectorized Vectorized::loadu( + const void* ptr, + int64_t count) { + // See NOTE [Loading boolean values] + return convert_to_bool(Vectorized::loadu(ptr, count)); +} + +template +struct VecHoldType { + using hold_type = typename VT::value_type; +}; + +template <> +struct VecHoldType> { + using hold_type = BFloat16; +}; + +template <> +struct VecHoldType> { + using hold_type = Half; +}; + +template +using vechold_type = typename VecHoldType::hold_type; + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec128/vec128.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec128/vec128.h new file mode 100644 index 0000000000000000000000000000000000000000..af11498d706c7546c2b82ba6b17c4de18109ba0e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec128/vec128.h @@ -0,0 +1,14 @@ +#pragma once +// ARM NEON uses 128-bit vector registers. + +#include + +#ifdef __aarch64__ +#if !defined(CPU_CAPABILITY_SVE) +#include +#include +#include +#endif + +#include +#endif diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_bfloat16_neon.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_bfloat16_neon.h new file mode 100644 index 0000000000000000000000000000000000000000..9b80f7ad9ba3ed6be3d91b42af4c9ccc3bc57616 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_bfloat16_neon.h @@ -0,0 +1,567 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] +#include +#include +#include +#include +#include +#include + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +// Following vec128_half_neon.h, we only support aarch64. +#if !defined(C10_MOBILE) && defined(__aarch64__) +#ifdef __BIG_ENDIAN__ +#error "Big endian is not supported." +#endif + +// Unlike the float16_t family of types, bfloat16_t is not available +// when we're not targeting bfloat16 hardware support on some +// platforms (but not Mac, so we have to be careful not to shadow the +// definitions in case they are actually there!). (See +// https://godbolt.org/z/orv6e94n4 ) So, we need to handle it as +// uint16_t in that case. +#define IMPLEMENT_AT_BF16_SHIM(vec_suffix) \ + inline at_bfloat16x4_t at_vget_low_bf16(at_bfloat16x8_t a) { \ + return vget_low_##vec_suffix(a); \ + } \ + \ + inline at_bfloat16x4_t at_vget_high_bf16(at_bfloat16x8_t a) { \ + return vget_high_##vec_suffix(a); \ + } \ + \ + inline at_bfloat16x8_t at_vcombine_bf16( \ + at_bfloat16x4_t low, at_bfloat16x4_t high) { \ + return vcombine_##vec_suffix(low, high); \ + } \ + \ + inline at_bfloat16x8_t at_vdupq_n_bf16(at_bfloat16_t value) { \ + return vdupq_n_##vec_suffix(value); \ + } \ + \ + inline at_bfloat16x8_t at_vld1q_bf16(const at_bfloat16_t* ptr) { \ + return vld1q_##vec_suffix(ptr); \ + } \ + \ + inline void at_vst1q_bf16(at_bfloat16_t* ptr, at_bfloat16x8_t value) { \ + vst1q_##vec_suffix(ptr, value); \ + } \ + \ + template \ + inline at_bfloat16x8_t at_vreinterpretq_bf16_u16(T val) { \ + if constexpr (std::is_same_v) { \ + return val; \ + } else { \ + return vreinterpretq_bf16_u16(val); \ + } \ + } \ + template \ + inline at_bfloat16x4_t at_vreinterpret_bf16_u16(T val) { \ + if constexpr (std::is_same_v) { \ + return val; \ + } else { \ + return vreinterpret_bf16_u16(val); \ + } \ + } \ + template \ + inline uint16x8_t at_vreinterpretq_u16_bf16(T val) { \ + if constexpr (std::is_same_v) { \ + return val; \ + } else { \ + return vreinterpretq_u16_bf16(val); \ + } \ + } \ + template \ + inline uint16x4_t at_vreinterpret_u16_bf16(T val) { \ + if constexpr (std::is_same_v) { \ + return val; \ + } else { \ + return vreinterpret_u16_bf16(val); \ + } \ + } + +#ifdef __ARM_FEATURE_BF16 +using at_bfloat16x8_t = bfloat16x8_t; +using at_bfloat16x4_t = bfloat16x4_t; +using at_bfloat16_t = bfloat16_t; +IMPLEMENT_AT_BF16_SHIM(bf16) +#define at_vsetq_lane_bf16 vsetq_lane_bf16 +#define at_vgetq_lane_bf16 vgetq_lane_bf16 +#else +using at_bfloat16x8_t = uint16x8_t; +using at_bfloat16x4_t = uint16x4_t; +using at_bfloat16_t = uint16_t; +IMPLEMENT_AT_BF16_SHIM(u16) +#define at_vsetq_lane_bf16 vsetq_lane_u16 +#define at_vgetq_lane_bf16 vgetq_lane_u16 +#endif // __ARM_FEATURE_BF16 + +template +struct BlendBFloat16Regs { + static at_bfloat16x8_t impl( + const at_bfloat16x8_t& a, + const at_bfloat16x8_t& b, + at_bfloat16x8_t& res); +}; + +template +struct BlendBFloat16Regs { + static at_bfloat16x8_t impl( + const at_bfloat16x8_t& a, + const at_bfloat16x8_t& b, + at_bfloat16x8_t& res) { + return at_vsetq_lane_bf16(at_vgetq_lane_bf16(b, index), res, index); + } +}; + +template +struct BlendBFloat16Regs { + static at_bfloat16x8_t impl( + const at_bfloat16x8_t& a, + const at_bfloat16x8_t& b, + at_bfloat16x8_t& res) { + return at_vsetq_lane_bf16(at_vgetq_lane_bf16(a, index), res, index); + } +}; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorized16< + at_bfloat16x8_t, + c10::BFloat16, + BlendBFloat16Regs, + Vectorized> { + using Base = Vectorized16< + at_bfloat16x8_t, + c10::BFloat16, + BlendBFloat16Regs, + Vectorized>; + friend Base; + friend std::tuple, Vectorized> convert_bfloat16_float( + const Vectorized& a); + friend Vectorized convert_float_bfloat16( + const Vectorized& a, + const Vectorized& b); + + private: + Vectorized map2( + const Vectorized& second, + c10::BFloat16 (*const f)(c10::BFloat16, c10::BFloat16)) const { + __at_align__ c10::BFloat16 tmp_first[size()]; + __at_align__ c10::BFloat16 tmp_second[size()]; + store(tmp_first); // store this to tmp_first + second.store(tmp_second); + for (const auto i : c10::irange(size())) { + tmp_first[i] = f(tmp_first[i], tmp_second[i]); + } + return loadu(tmp_first); + } + + static float32x4_t convert_f32_bf16(at_bfloat16x4_t bf16) { +#ifdef __ARM_FEATURE_BF16 + return vcvt_f32_bf16(bf16); +#else + int32x4_t shift = vdupq_n_s32(16); + return vreinterpretq_f32_u32(vshlq_u32(vmovl_u16(bf16), shift)); +#endif // __ARM_FEATURE_BF16 + } + + static at_bfloat16x4_t convert_bf16_f32(const Vectorized& f32) { +#ifdef __ARM_FEATURE_BF16 + return vcvt_bf16_f32(f32); +#else + static_assert(std::is_same_v); + uint32x4_t as_uint32 = vreinterpretq_u32_f32(f32); + uint32x4_t rounding_bias = vaddq_u32( + vandq_u32(vshrq_n_u32(as_uint32, 16), vdupq_n_u32(1)), + vdupq_n_u32(0x7FFF)); + at_bfloat16x4_t rounded = + vshrn_n_u32(vaddq_u32(as_uint32, rounding_bias), 16); + const auto bf16_nan = vdup_n_u16(0x7FC0); + return vbsl_u16( + vmovn_u32(vreinterpretq_u32_f32(f32.isnan())), bf16_nan, rounded); +#endif // __ARM_FEATURE_BF16 + } + + Vectorized map_with_vec_float_method( + Vectorized (Vectorized::*m)() const) const { + float32x4_t v00 = convert_f32_bf16(at_vget_low_bf16(values)); + float32x4_t v01 = convert_f32_bf16(at_vget_high_bf16(values)); + Vectorized mv0 = (Vectorized(v00).*m)(); + Vectorized mv1 = (Vectorized(v01).*m)(); + at_bfloat16x4_t r00 = convert_bf16_f32(mv0); + at_bfloat16x4_t r01 = convert_bf16_f32(mv1); + return Vectorized(at_vcombine_bf16(r00, r01)); + } + + Vectorized map2_with_vec_float_method( + const Vectorized& second, + Vectorized (Vectorized::*m)(const Vectorized&) + const) const { + float32x4_t v00 = convert_f32_bf16(at_vget_low_bf16(values)); + float32x4_t v01 = convert_f32_bf16(at_vget_high_bf16(values)); + float32x4_t second_v00 = convert_f32_bf16(at_vget_low_bf16(second.values)); + float32x4_t second_v01 = convert_f32_bf16(at_vget_high_bf16(second.values)); + Vectorized mv0 = (Vectorized(v00).*m)(second_v00); + Vectorized mv1 = (Vectorized(v01).*m)(second_v01); + at_bfloat16x4_t r00 = convert_bf16_f32(mv0); + at_bfloat16x4_t r01 = convert_bf16_f32(mv1); + return Vectorized(at_vcombine_bf16(r00, r01)); + } + + Vectorized map2_bitmask_with_vec_float_method( + const Vectorized& second, + Vectorized (Vectorized::*m)(const Vectorized&) + const) const { + float32x4_t v00 = convert_f32_bf16(at_vget_low_bf16(values)); + float32x4_t v01 = convert_f32_bf16(at_vget_high_bf16(values)); + float32x4_t second_v00 = convert_f32_bf16(at_vget_low_bf16(second.values)); + float32x4_t second_v01 = convert_f32_bf16(at_vget_high_bf16(second.values)); + Vectorized mv0 = (Vectorized(v00).*m)(second_v00); + Vectorized mv1 = (Vectorized(v01).*m)(second_v01); + // Assume the operator returns a bitmask, not "real" floats, and + // just narrow the bits. All-ones is a NaN and will get mangled by + // conversion! + at_bfloat16x4_t r00 = + at_vreinterpret_bf16_u16(vmovn_u32(vreinterpretq_u32_f32(mv0))); + at_bfloat16x4_t r01 = + at_vreinterpret_bf16_u16(vmovn_u32(vreinterpretq_u32_f32(mv1))); + return Vectorized(at_vcombine_bf16(r00, r01)); + } + + public: + using Vectorized16::Vectorized16; + + Vectorized() = default; + + Vectorized(c10::BFloat16 val) + : Vectorized16(at_vdupq_n_bf16(c10::bit_cast(val.x))) {} + Vectorized(float val) : Vectorized(c10::BFloat16(val)) {} + Vectorized( + value_type val0, + value_type val1, + value_type val2, + value_type val3, + value_type val4, + value_type val5, + value_type val6, + value_type val7) + : Vectorized16(at_bfloat16x8_t{ + c10::bit_cast(val0.x), + c10::bit_cast(val1.x), + c10::bit_cast(val2.x), + c10::bit_cast(val3.x), + c10::bit_cast(val4.x), + c10::bit_cast(val5.x), + c10::bit_cast(val6.x), + c10::bit_cast(val7.x)}) {} + + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + // NOTE: blendv has the same problems as it does for Half; see comments in + // vec128_half_neon.h. + Vectorized vec(mask.values); + vec.values = at_vreinterpretq_bf16_u16(vbslq_u16( + at_vreinterpretq_u16_bf16(vec.values), + at_vreinterpretq_u16_bf16(b.values), + at_vreinterpretq_u16_bf16(a.values))); + return vec; + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + uint16_t pre_mask[size()] = {0}; + for (int i = 0; i < count; i++) { + pre_mask[i] = 0xFFFF; + } + uint16x8_t mask = vld1q_u16(pre_mask); + + Vectorized vec(at_vreinterpretq_bf16_u16(vbslq_u16( + mask, + at_vreinterpretq_u16_bf16(b.values), + at_vreinterpretq_u16_bf16(a.values)))); + + return vec; + } + static Vectorized loadu( + const void* ptr, + int64_t count = size()) { + if (count == size()) { + return at_vld1q_bf16(reinterpret_cast(ptr)); + } + __at_align__ at_bfloat16_t tmp_values[size()]; + std::memset(tmp_values, 0, sizeof(tmp_values)); + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(at_bfloat16_t)); + return at_vld1q_bf16(reinterpret_cast(tmp_values)); + } + void store(void* ptr, int64_t count = size()) const { + if (count == size()) { + at_vst1q_bf16(reinterpret_cast(ptr), values); + return; + } else { + at_bfloat16_t tmp_values[size()]; + at_vst1q_bf16(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(at_bfloat16_t)); + } + } + Vectorized isnan() const { + // NOTE: we could make this faster by doing vectorized checks of + // exponent/payload bits. + __at_align__ c10::BFloat16 tmp[size()]; + __at_align__ c10::BFloat16 res[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + if (_isnan(tmp[i])) { + std::memset(static_cast(&res[i]), 0xFF, sizeof(c10::BFloat16)); + } else { + std::memset(static_cast(&res[i]), 0, sizeof(c10::BFloat16)); + } + } + return loadu(res); + } + bool has_inf_nan() const { + __at_align__ c10::BFloat16 tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + if (_isnan(tmp[i]) || _isinf(tmp[i])) { + return true; + } + } + return false; + } +#define DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(name) \ + Vectorized name() const { \ + return map_with_vec_float_method(&Vectorized::name); \ + } + +#define DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(name) \ + Vectorized name(const Vectorized& other) const { \ + return map2_bitmask_with_vec_float_method( \ + other, &Vectorized::name); \ + } + + DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs) + Vectorized frac() const; + DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg) + DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(trunc) + DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(sqrt) + DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(reciprocal) + DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator==) + DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator!=) + DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<) + DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<=) + DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>) + DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>=) + +#undef DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD +#undef DEFINE_BINARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; // Vectorized + +inline std::tuple, Vectorized> convert_bfloat16_float( + const Vectorized& a) { + static_assert( + Vectorized::size() == 2 * Vectorized::size()); + at_bfloat16x8_t x = a; + float32x4_t x1 = + Vectorized::convert_f32_bf16(at_vget_low_bf16(x)); + float32x4_t x2 = + Vectorized::convert_f32_bf16(at_vget_high_bf16(x)); + return {Vectorized(x1), Vectorized(x2)}; +} +inline Vectorized convert_float_bfloat16( + const Vectorized& a, + const Vectorized& b) { + static_assert( + Vectorized::size() == 2 * Vectorized::size()); + at_bfloat16x4_t x1 = Vectorized::convert_bf16_f32(a); + at_bfloat16x4_t x2 = Vectorized::convert_bf16_f32(b); + return Vectorized(at_vcombine_bf16(x1, x2)); +} + +template +Vectorized binary_operator_via_float( + Op op, + const Vectorized& a, + const Vectorized& b) { + const auto [a_float_low, a_float_high] = convert_bfloat16_float(a); + const auto [b_float_low, b_float_high] = convert_bfloat16_float(b); + return convert_float_bfloat16( + op(a_float_low, b_float_low), op(a_float_high, b_float_high)); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return binary_operator_via_float(std::plus>(), a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return binary_operator_via_float(std::minus>(), a, b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return binary_operator_via_float(std::multiplies>(), a, b); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return binary_operator_via_float(std::divides>(), a, b); +} + +// frac. Implement this here so we can use subtraction +inline Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return binary_operator_via_float( + static_cast (*)( + const Vectorized&, const Vectorized&)>(&maximum), + a, + b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return binary_operator_via_float( + static_cast (*)( + const Vectorized&, const Vectorized&)>(&minimum), + a, + b); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + return minimum(max, maximum(min, a)); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + return minimum(max, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + return maximum(min, a); +} + +template <> +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(at_vreinterpretq_bf16_u16( + vandq_u16(at_vreinterpretq_u16_bf16(a), at_vreinterpretq_u16_bf16(b)))); +} + +template <> +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(at_vreinterpretq_bf16_u16( + vorrq_u16(at_vreinterpretq_u16_bf16(a), at_vreinterpretq_u16_bf16(b)))); +} + +template <> +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(at_vreinterpretq_bf16_u16( + veorq_u16(at_vreinterpretq_u16_bf16(a), at_vreinterpretq_u16_bf16(b)))); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + // NOTE [BF16 FMA]: There isn't an FMA that accumulates into BF16! Also, + // vbfmlalbq_f32 and vbfmlaltq_f32 take the even and odd-numbered + // elements, not the bottom and top half, so they don't seem + // particularly useful here. Ideally we would include dot product in + // the Vectorized interface... + return a * b + c; +} + +template <> +Vectorized inline fmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + // See NOTE [BF16 FMA] above. + return a * b - c; +} + +#endif // !defined(C10_MOBILE) && defined(__aarch64__) + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_convert.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_convert.h new file mode 100644 index 0000000000000000000000000000000000000000..bb6e4b697f230164b518e2166ec6dee6b87b5deb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_convert.h @@ -0,0 +1,65 @@ +#pragma once +#include +#include + +namespace at::vec { +inline namespace CPU_CAPABILITY { +#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256)) +template +struct VecConvert< + float, + 1, + src_t, + 1, + typename std::enable_if_t, void>> { + static inline VectorizedN apply(const VectorizedN& src) { + return convert_int8_half_register_to_float(src[0]); + } +}; +template +struct VecConvert< + float, + 2, + src_t, + 1, + typename std::enable_if_t, void>> { + static inline VectorizedN apply(const VectorizedN& src) { + const auto [v0, v1] = convert_int8_to_float(src[0]); + return VectorizedN(v0, v1); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + uint16x8_t u16_8 = vld1q_u16(reinterpret_cast(&src[0])); + auto u16_low1 = vget_low_u16(u16_8); + auto u16_high1 = vget_high_u16(u16_8); + float32x4_t f32x4_0 = + vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(u16_low1), 16)); + float32x4_t f32x4_1 = + vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(u16_high1), 16)); + result[0] = f32x4_0; + result[1] = f32x4_1; + return result; + } +}; +// Half register to full register. +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + uint16x4_t u16_8 = vld1_u16(reinterpret_cast(&src[0])); + float32x4_t f32x4_0 = + vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(u16_8), 16)); + result[0] = f32x4_0; + return result; + } +}; + +#endif // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256) +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_float_neon.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_float_neon.h new file mode 100644 index 0000000000000000000000000000000000000000..817b9931bbfb669cb4396034e2612dd1edf24326 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_float_neon.h @@ -0,0 +1,631 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include + +#if defined(__aarch64__) && defined(AT_BUILD_ARM_VEC256_WITH_SLEEF) +#include +#endif + +// Sleef offers vectorized versions of some transcedentals +// such as sin, cos, tan etc.. +// However for now opting for STL, since we are not building +// with Sleef for mobile yet. + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +// Right now contains only aarch64 implementation. +// Due to follow two reasons aarch32 is not currently supported. +// 1. Due to difference in ISA been aarch32 and aarch64, intrinsics +// that work for aarch64 dont work for aarch32. +// 2. Android NDK r21 has problems with compiling aarch32. +// Clang seg faults. +// https://github.com/android/ndk/issues/1248 +// https://bugs.llvm.org/show_bug.cgi?id=45824 +// Most likely we will do aarch32 support with inline asm. +#if defined(__aarch64__) + +#ifdef __BIG_ENDIAN__ +#error "Big endian is not supported." +#endif + +#if defined(AT_BUILD_ARM_VEC256_WITH_SLEEF) +#define USE_SLEEF(sleef_code, non_sleef_code) sleef_code +#else +#define USE_SLEEF(sleef_code, non_sleef_code) non_sleef_code +#endif + +template +struct BlendRegs { + static float32x4_t impl( + const float32x4_t& a, + const float32x4_t& b, + float32x4_t& res); +}; + +template +struct BlendRegs { + static float32x4_t impl( + const float32x4_t& a, + const float32x4_t& b, + float32x4_t& res) { + return vsetq_lane_f32(vgetq_lane_f32(b, index), res, index); + } +}; + +template +struct BlendRegs { + static float32x4_t impl( + const float32x4_t& a, + const float32x4_t& b, + float32x4_t& res) { + return vsetq_lane_f32(vgetq_lane_f32(a, index), res, index); + } +}; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + float32x4_t values; + + public: + using value_type = float; + using size_type = int; + static constexpr size_type size() { + return 4; + } + Vectorized() {} + Vectorized(float32x4_t v) : values(v) {} + Vectorized(float val) : values{vdupq_n_f32(val)} {} + Vectorized(float val0, float val1, float val2, float val3) + : values{val0, val1, val2, val3} {} + Vectorized(float (&arr)[4]) : Vectorized(arr[0], arr[1], arr[2], arr[3]) {} + operator float32x4_t() const { + return values; + } + template + static Vectorized blend( + const Vectorized& a, + const Vectorized& b) { + Vectorized vec; + vec.values = BlendRegs < 0, + (mask & 0x01) != 0 > ::impl(a.values, b.values, vec.values); + vec.values = BlendRegs < 1, + (mask & 0x02) != 0 > ::impl(a.values, b.values, vec.values); + vec.values = BlendRegs < 2, + (mask & 0x04) != 0 > ::impl(a.values, b.values, vec.values); + vec.values = BlendRegs < 3, + (mask & 0x08) != 0 > ::impl(a.values, b.values, vec.values); + return vec; + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + // TODO + // NB: This requires that each value, i.e., each uint value, + // of the mask either all be zeros or all be 1s. + // We perhaps need some kind of an assert? + // But that will affect performance. + Vectorized vec(mask.values); + vec.values = + vbslq_f32(vreinterpretq_u32_f32(vec.values), b.values, a.values); + return vec; + } + template + static Vectorized arange( + float base = 0.f, + step_t step = static_cast(1)) { + const Vectorized base_vec(base); + const Vectorized step_vec(step); + const Vectorized step_sizes(0, 1, 2, 3); + return fmadd(step_sizes, step_vec, base_vec); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: { + Vectorized vec; + static uint32x4_t mask_low = {0xFFFFFFFF, 0x0, 0x0, 0x0}; + vec.values = vreinterpretq_f32_u32(mask_low); + vec.values = + vbslq_f32(vreinterpretq_u32_f32(vec.values), b.values, a.values); + return vec; + } + case 2: { + Vectorized vec; + static uint32x4_t mask_low = {0xFFFFFFFF, 0xFFFFFFFF, 0x0, 0x0}; + vec.values = vreinterpretq_f32_u32(mask_low); + vec.values = + vbslq_f32(vreinterpretq_u32_f32(vec.values), b.values, a.values); + return vec; + } + case 3: { + Vectorized vec; + static uint32x4_t mask_low = {0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x0}; + vec.values = vreinterpretq_f32_u32(mask_low); + vec.values = + vbslq_f32(vreinterpretq_u32_f32(vec.values), b.values, a.values); + return vec; + } + } + return b; + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) { + return vld1q_f32(reinterpret_cast(ptr)); + } else { + __at_align__ float tmp_values[size()]; + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0.0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(float)); + return vld1q_f32(reinterpret_cast(tmp_values)); + } + } + void store(void* ptr, int64_t count = size()) const { + if (count == size()) { + vst1q_f32(reinterpret_cast(ptr), values); + } else { + float tmp_values[size()]; + vst1q_f32(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(float)); + } + } + // Very slow implementation of indexing. + // Only required because vec256_qint refers to this. + // Once we specialize that implementation for ARM + // this should be removed. TODO (kimishpatel) + float operator[](int idx) const { + __at_align__ float tmp[size()]; + store(tmp); + return tmp[idx]; + } + float operator[](int idx) { + __at_align__ float tmp[size()]; + store(tmp); + return tmp[idx]; + } + // For boolean version where we want to if any 1/all zero + // etc. can be done faster in a different way. + int zero_mask() const { + __at_align__ float tmp[size()]; + store(tmp); + int mask = 0; + for (int i = 0; i < size(); ++i) { + if (tmp[i] == 0.f) { + mask |= (1 << i); + } + } + return mask; + } + Vectorized isnan() const { + return vreinterpretq_f32_u32(vmvnq_u32(vceqq_f32(values, values))); + } + bool has_inf_nan() const { + __at_align__ float tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + if (_isnan(tmp[i]) || _isinf(tmp[i])) { + return true; + } + } + return false; + } + Vectorized map(float (*const f)(float)) const { + __at_align__ float tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + Vectorized map2( + const Vectorized& second, + float (*const f)(float, float)) const { + __at_align__ float tmp[size()]; + __at_align__ float tmp_second[size()]; + store(tmp); + second.store(tmp_second); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i], tmp_second[i]); + } + return loadu(tmp); + } + Vectorized abs() const { + return Vectorized(vabsq_f32(values)); + } + Vectorized angle() const { + auto zero = Vectorized(0); + auto pi = Vectorized(c10::pi); + auto tmp = blendv(zero, pi, *this < zero); + return blendv(tmp, *this, isnan()); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return Vectorized(0.f); + } + Vectorized conj() const { + return *this; + } +#define DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME( \ + name, sleef_name) \ + Vectorized name() const { \ + return USE_SLEEF(Vectorized(sleef_name(values)), map(std::name)); \ + } + +#define DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(name) \ + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME( \ + name, Sleef_##name##f4_u10) + + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(acos) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(acosh) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(asin) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(asinh) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(atan) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(atanh) + +#define DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME( \ + name, sleef_name) \ + Vectorized name(const Vectorized& arg) const { \ + return USE_SLEEF( \ + Vectorized(sleef_name(values, arg.values)), \ + map2(arg, std::name)); \ + } + +#define DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC(name) \ + DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME( \ + name, Sleef_##name##f4_u10) + + DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC(atan2) + DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME( + copysign, + Sleef_copysignf4) + Vectorized erf() const; + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME( + erfc, + Sleef_erfcf4_u15) + Vectorized erfinv() const { + return map(calc_erfinv); + } + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp2) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(expm1) + Vectorized exp_u20() const { + return exp(); + } + DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME( + fmod, + Sleef_fmodf4) + DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME( + hypot, + Sleef_hypotf4_u05) + Vectorized i0() const { + return map(calc_i0); + } + Vectorized i0e() const { + return map(calc_i0e); + } + Vectorized digamma() const { + return map(calc_digamma); + } + Vectorized igamma(const Vectorized& x) const { + return map2(x, calc_igamma); + } + Vectorized igammac(const Vectorized& x) const { + return map2(x, calc_igammac); + } + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(log) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(log10) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(log1p) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(log2) + DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME( + nextafter, + Sleef_nextafterf4) + Vectorized frac() const; + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(sin) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(sinh) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(cos) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(cosh) + Vectorized ceil() const { + return map(at::native::ceil_impl); + } + Vectorized floor() const { + return map(at::native::floor_impl); + } + Vectorized neg() const { + return Vectorized(vnegq_f32(values)); + } + Vectorized round() const { + // We do not use std::round because we would like to round midway numbers to + // the nearest even integer. + return map(at::native::round_impl); + } + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(tan) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(tanh) + Vectorized trunc() const { + return Vectorized(vrndq_f32(values)); + } + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(lgamma) + Vectorized sqrt() const { + return Vectorized(vsqrtq_f32(values)); + } + Vectorized reciprocal() const { + return Vectorized(vdivq_f32(vdupq_n_f32(1.0f), values)); + } + Vectorized rsqrt() const { + return this->sqrt().reciprocal(); + } + DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC(pow) + Vectorized operator==(const Vectorized& other) const { + return Vectorized( + vreinterpretq_f32_u32(vceqq_f32(values, other.values))); + } + + Vectorized operator!=(const Vectorized& other) const { + float32x4_t r0 = + vreinterpretq_f32_u32(vmvnq_u32(vceqq_f32(values, other.values))); + return Vectorized(r0); + } + + Vectorized operator<(const Vectorized& other) const { + return Vectorized( + vreinterpretq_f32_u32(vcltq_f32(values, other.values))); + } + + Vectorized operator<=(const Vectorized& other) const { + return Vectorized( + vreinterpretq_f32_u32(vcleq_f32(values, other.values))); + } + + Vectorized operator>(const Vectorized& other) const { + return Vectorized( + vreinterpretq_f32_u32(vcgtq_f32(values, other.values))); + } + + Vectorized operator>=(const Vectorized& other) const { + return Vectorized( + vreinterpretq_f32_u32(vcgeq_f32(values, other.values))); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(vaddq_f32(a, b)); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(vsubq_f32(a, b)); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(vmulq_f32(a, b)); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(vdivq_f32(a, b)); +} + +// frac. Implement this here so we can use subtraction +inline Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(vmaxq_f32(a, b)); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(vminq_f32(a, b)); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + return minimum(max, maximum(min, a)); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + return minimum(max, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + return maximum(min, a); +} + +template <> +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(vreinterpretq_f32_u32( + vandq_u32(vreinterpretq_u32_f32(a), vreinterpretq_u32_f32(b)))); +} + +template <> +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(vreinterpretq_f32_u32( + vorrq_u32(vreinterpretq_u32_f32(a), vreinterpretq_u32_f32(b)))); +} + +template <> +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(vreinterpretq_f32_u32( + veorq_u32(vreinterpretq_u32_f32(a), vreinterpretq_u32_f32(b)))); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +template <> +inline void convert(const float* src, int32_t* dst, int64_t n) { + int64_t i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + vst1q_s32(dst + i, vcvtq_s32_f32(vld1q_f32(src + i))); + } +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +template <> +inline void convert(const int32_t* src, float* dst, int64_t n) { + int64_t i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + vst1q_f32(dst + i, vcvtq_f32_s32(vld1q_s32(src + i))); + } +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return Vectorized(vfmaq_f32(c, a, b)); +} + +template <> +Vectorized inline fmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return Vectorized(vnegq_f32(vfmsq_f32(c, a, b))); +} + +inline Vectorized Vectorized::erf() const { + // constants + const Vectorized neg_zero_vec(-0.f); + const Vectorized one_vec(1.0f); + const Vectorized p(0.3275911f); + const Vectorized p1(0.254829592f); + const Vectorized p2(-0.284496736f); + const Vectorized p3(1.421413741f); + const Vectorized p4(-1.453152027f); + const Vectorized p5(1.061405429f); + // sign(x) + auto sign_mask = neg_zero_vec & *this; + auto abs_vec = this->abs(); + // t = 1 / (p * abs(x) + 1) + auto tmp0 = fmadd(p, abs_vec, one_vec); + auto t = one_vec / tmp0; + // r = p5 * t ^ 4 + p4 * t ^ 3 + p3 * t ^ 2 + p2 * t + p1 + auto tmp1 = fmadd(p5, t, p4); + auto tmp2 = fmadd(tmp1, t, p3); + auto tmp3 = fmadd(tmp2, t, p2); + auto r = fmadd(tmp3, t, p1); + // - exp(- x * x) + auto pow_2 = (*this) * (*this); + auto neg_pow_2 = pow_2 ^ neg_zero_vec; + auto tmp4 = neg_pow_2.map( + std::exp); // This can be swapped for a faster implementation of exp. + auto tmp5 = tmp4 ^ neg_zero_vec; + // erf(x) = sign(x) * (1 - r * t * exp(- x * x)) + auto tmp6 = t * tmp5; + auto tmp7 = fmadd(tmp6, r, one_vec); + return tmp7 ^ sign_mask; +} +#undef DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC +#undef DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC +#endif /* defined(aarch64) */ + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_half_neon.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_half_neon.h new file mode 100644 index 0000000000000000000000000000000000000000..b90cd0b9bf6bd2a7dfc66c74eb1d95d0c1284be5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_half_neon.h @@ -0,0 +1,614 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#include +#include +#include +#include + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +// Right now contains only aarch64 implementation. +// Due to follow two reasons aarch32 is not currently supported. +// 1. Due to difference in ISA been aarch32 and aarch64, intrinsics +// that work for aarch64 dont work for aarch32. +// 2. Android NDK r21 has problems with compiling aarch32. +// Clang seg faults. +// https://github.com/android/ndk/issues/1248 +// https://bugs.llvm.org/show_bug.cgi?id=45824 +// Most likely we will do aarch32 support with inline asm. +#if !defined(C10_MOBILE) && defined(__aarch64__) + +#ifdef __BIG_ENDIAN__ +#error "Big endian is not supported." +#endif + +template +struct BlendHalfRegs { + static float16x8_t impl( + const float16x8_t& a, + const float16x8_t& b, + float16x8_t& res); +}; + +template +struct BlendHalfRegs { + static float16x8_t impl( + const float16x8_t& a, + const float16x8_t& b, + float16x8_t& res) { + return vsetq_lane_f16(vgetq_lane_f16(b, index), res, index); + } +}; + +template +struct BlendHalfRegs { + static float16x8_t impl( + const float16x8_t& a, + const float16x8_t& b, + float16x8_t& res) { + return vsetq_lane_f16(vgetq_lane_f16(a, index), res, index); + } +}; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +// On ARM, Half type supports float16_t->Half constructor and Half->float16_t +// conversion +template <> +class Vectorized : public Vectorized16< + float16x8_t, + c10::Half, + BlendHalfRegs, + Vectorized> { + using Base = Vectorized16< + float16x8_t, + c10::Half, + BlendHalfRegs, + Vectorized>; + friend Base; + + private: + // We use these private map functions to implement various methods + Vectorized map_with_vec_float_method( + Vectorized (Vectorized::*m)() const) const { + float32x4_t v00 = vcvt_f32_f16(vget_low_f16(values)); + float32x4_t v01 = vcvt_f32_f16(vget_high_f16(values)); + Vectorized mv0 = (Vectorized(v00).*m)(); + Vectorized mv1 = (Vectorized(v01).*m)(); + float16x4_t r00 = vcvt_f16_f32(mv0); + float16x4_t r01 = vcvt_f16_f32(mv1); + return Vectorized(vcombine_f16(r00, r01)); + } + + Vectorized map2_with_vec_float_method( + const Vectorized& second, + Vectorized (Vectorized::*m)(const Vectorized&) + const) const { + float32x4_t v00 = vcvt_f32_f16(vget_low_f16(values)); + float32x4_t v01 = vcvt_f32_f16(vget_high_f16(values)); + float32x4_t second_v00 = vcvt_f32_f16(vget_low_f16(second.values)); + float32x4_t second_v01 = vcvt_f32_f16(vget_high_f16(second.values)); + Vectorized mv0 = + (Vectorized(v00).*m)(Vectorized(second_v00)); + Vectorized mv1 = + (Vectorized(v01).*m)(Vectorized(second_v01)); + float16x4_t r00 = vcvt_f16_f32(mv0); + float16x4_t r01 = vcvt_f16_f32(mv1); + + // Pack result into Vectorized + return Vectorized(vcombine_f16(r00, r01)); + } + + Vectorized map2_bitmask_with_vec_float_method( + const Vectorized& second, + Vectorized (Vectorized::*m)(const Vectorized&) + const) const { + float32x4_t v00 = vcvt_f32_f16(vget_low_f16(values)); + float32x4_t v01 = vcvt_f32_f16(vget_high_f16(values)); + float32x4_t second_v00 = vcvt_f32_f16(vget_low_f16(second.values)); + float32x4_t second_v01 = vcvt_f32_f16(vget_high_f16(second.values)); + Vectorized mv0 = + (Vectorized(v00).*m)(Vectorized(second_v00)); + Vectorized mv1 = + (Vectorized(v01).*m)(Vectorized(second_v01)); + // Assume the operator returns a bitmask, not "real" floats, and + // just narrow the bits. All-ones is a NaN and will get mangled by + // conversion! + float16x4_t r00 = + vreinterpret_f16_u16(vmovn_u32(vreinterpretq_u32_f32(mv0))); + float16x4_t r01 = + vreinterpret_f16_u16(vmovn_u32(vreinterpretq_u32_f32(mv1))); + + // Pack result into Vectorized + return Vectorized(vcombine_f16(r00, r01)); + } + + public: + using Vectorized16::Vectorized16; + + Vectorized() = default; + + // A ctor that accepts c10::Half is needed to fit interface with vec_base.h + // A second constructor that takes float16_t is also included + Vectorized(c10::Half val) : Vectorized((float16_t)val) {} + Vectorized(float16_t val) : Vectorized16(vdupq_n_f16(val)) {} + Vectorized( + value_type val0, + value_type val1, + value_type val2, + value_type val3, + value_type val4, + value_type val5, + value_type val6, + value_type val7) + : Vectorized16( + float16x8_t{val0, val1, val2, val3, val4, val5, val6, val7}) {} + + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + // Note: using blendv is very awkward because 0xFFFF is one of + // many NaN's in FP16 It's unfortunate that the mask has type Half + // (required from vec_base) + + // TODO + // NB: This requires that each value, i.e., each uint value, + // of the mask either all be zeros or all be 1s. + // We perhaps need some kind of an assert? + // But that will affect performance. + + // NOTE [vbslq_f16]: vbslq_f16 doesn't work on clang without + // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC. vbslq_u16 generates the + // same instruction anyway. see https://godbolt.org/z/cY4a55Y7P + Vectorized vec(mask.values); + vec.values = vreinterpretq_f16_u16(vbslq_u16( + vreinterpretq_u16_f16(vec.values), + vreinterpretq_u16_f16(b.values), + vreinterpretq_u16_f16(a.values))); + return vec; + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + uint16_t pre_mask[size()] = {0}; + for (int i = 0; i < count; i++) { + pre_mask[i] = 0xFFFF; + } + uint16x8_t mask = vld1q_u16(pre_mask); + + // Using blendv is awkward because 0xFFFF is one of many NaN's in FP16 + // so we directly use vbslq_u16 instead. (See NOTE [vbslq_f16] above.) + Vectorized vec(vreinterpretq_f16_u16(vbslq_u16( + mask, + vreinterpretq_u16_f16(b.values), + vreinterpretq_u16_f16(a.values)))); + + return vec; + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) { + return vld1q_f16(reinterpret_cast(ptr)); + } + __at_align__ float16_t tmp_values[size()]; + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(float16_t)); + return vld1q_f16(reinterpret_cast(tmp_values)); + } + void store(void* ptr, int64_t count = size()) const { + if (count == size()) { + vst1q_f16(reinterpret_cast(ptr), values); + return; + } else { + float16_t tmp_values[size()]; + vst1q_f16(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(float16_t)); + } + } + // For boolean version where we want to if any 1/all zero + // etc. can be done faster in a different way. + Vectorized isnan() const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return vreinterpretq_f16_u16(vmvnq_u16(vceqq_f16(values, values))); +#else + // NOTE: we could make this faster by doing vectorized checks of + // exponent/payload bits. + __at_align__ c10::Half tmp[size()]; + __at_align__ c10::Half res[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + if (_isnan(tmp[i])) { + std::memset(static_cast(&res[i]), 0xFF, sizeof(c10::Half)); + } else { + std::memset(static_cast(&res[i]), 0, sizeof(c10::Half)); + } + } + return loadu(res); +#endif + } + bool has_inf_nan() const { + __at_align__ c10::Half tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + if (_isnan(tmp[i]) || _isinf(tmp[i])) { + return true; + } + } + return false; + } + Vectorized abs() const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vabsq_f16(values)); +#else + return map_with_vec_float_method(&Vectorized::abs); +#endif + } + Vectorized frac() const; + Vectorized neg() const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vnegq_f16(values)); +#else + return map_with_vec_float_method(&Vectorized::neg); +#endif + } + Vectorized trunc() const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vrndq_f16(values)); +#else + return map_with_vec_float_method(&Vectorized::trunc); +#endif + } + Vectorized sqrt() const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vsqrtq_f16(values)); +#else + return map_with_vec_float_method(&Vectorized::sqrt); +#endif + } + Vectorized reciprocal() const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + auto ones = vdupq_n_f16(1.0f); + return Vectorized(vdivq_f16(ones, values)); +#else + return map_with_vec_float_method(&Vectorized::reciprocal); +#endif + } + Vectorized operator==(const Vectorized& other) const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized( + vreinterpretq_f16_u16(vceqq_f16(values, other.values))); +#else + return map2_bitmask_with_vec_float_method( + other, &Vectorized::operator==); +#endif + } + + Vectorized operator!=(const Vectorized& other) const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized( + vreinterpretq_f16_u16(vmvnq_u16(vceqq_f16(values, other.values)))); +#else + return map2_bitmask_with_vec_float_method( + other, &Vectorized::operator!=); +#endif + } + + Vectorized operator<(const Vectorized& other) const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized( + vreinterpretq_f16_u16(vcltq_f16(values, other.values))); +#else + return map2_bitmask_with_vec_float_method( + other, &Vectorized::operator<); +#endif + } + + Vectorized operator<=(const Vectorized& other) const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized( + vreinterpretq_f16_u16(vcleq_f16(values, other.values))); +#else + return map2_bitmask_with_vec_float_method( + other, &Vectorized::operator<=); +#endif + } + + Vectorized operator>(const Vectorized& other) const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized( + vreinterpretq_f16_u16(vcgtq_f16(values, other.values))); +#else + return map2_bitmask_with_vec_float_method( + other, &Vectorized::operator>); +#endif + } + + Vectorized operator>=(const Vectorized& other) const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized( + vreinterpretq_f16_u16(vcgeq_f16(values, other.values))); +#else + return map2_bitmask_with_vec_float_method( + other, &Vectorized::operator>=); +#endif + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; // Vectorized + +inline std::tuple, Vectorized> convert_half_float( + const Vectorized& a) { + static_assert(Vectorized::size() == 2 * Vectorized::size()); + float16x8_t x = a; + float32x4_t x1 = vcvt_f32_f16(vget_low_f16(x)); + float32x4_t x2 = vcvt_f32_f16(vget_high_f16(x)); + return {Vectorized(x1), Vectorized(x2)}; +} +inline Vectorized convert_float_half( + const Vectorized& a, + const Vectorized& b) { + static_assert(Vectorized::size() == 2 * Vectorized::size()); + float32x4_t x = a; + float32x4_t y = b; + float16x4_t x1 = vcvt_f16_f32(x); + float16x4_t x2 = vcvt_f16_f32(y); + return Vectorized(vcombine_f16(x1, x2)); +} + +template +Vectorized binary_operator_via_float( + Op op, + const Vectorized& a, + const Vectorized& b) { + const auto [a_float_low, a_float_high] = convert_half_float(a); + const auto [b_float_low, b_float_high] = convert_half_float(b); + return convert_float_half( + op(a_float_low, b_float_low), op(a_float_high, b_float_high)); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vaddq_f16(a, b)); +#else + return binary_operator_via_float(std::plus>(), a, b); +#endif +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vsubq_f16(a, b)); +#else + return binary_operator_via_float(std::minus>(), a, b); +#endif +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vmulq_f16(a, b)); +#else + return binary_operator_via_float(std::multiplies>(), a, b); +#endif +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vdivq_f16(a, b)); +#else + return binary_operator_via_float(std::divides>(), a, b); +#endif +} + +// frac. Implement this here so we can use subtraction +inline Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vmaxq_f16(a, b)); +#else + return binary_operator_via_float( + static_cast (*)( + const Vectorized&, const Vectorized&)>(&maximum), + a, + b); +#endif +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vminq_f16(a, b)); +#else + return binary_operator_via_float( + static_cast (*)( + const Vectorized&, const Vectorized&)>(&minimum), + a, + b); +#endif +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + return minimum(max, maximum(min, a)); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + return minimum(max, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + return maximum(min, a); +} + +template <> +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(vreinterpretq_f16_u16( + vandq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b)))); +} + +template <> +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(vreinterpretq_f16_u16( + vorrq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b)))); +} + +template <> +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(vreinterpretq_f16_u16( + veorq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b)))); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +// These are global functions, so the defaults in vec_base.h should +// work fine if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC is not available. +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +template <> +inline void convert(const float16_t* src, int16_t* dst, int64_t n) { + int64_t i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + vst1q_s16(dst + i, vcvtq_s16_f16(vld1q_f16(src + i))); + } +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +template <> +inline void convert(const int16_t* src, float16_t* dst, int64_t n) { + int64_t i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + vst1q_f16(dst + i, vcvtq_f16_s16(vld1q_s16(src + i))); + } +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vfmaq_f16(c, a, b)); +#else + return a * b + c; +#endif +} + +template <> +Vectorized inline fmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vnegq_f16(vfmsq_f16(c, a, b))); +#else + return a * b - c; +#endif +} +#endif // !defined(C10_MOBILE) && defined(__aarch64__) + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_reduced_precision_common_neon.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_reduced_precision_common_neon.h new file mode 100644 index 0000000000000000000000000000000000000000..a5ba8429f27a9b9efee41361234508e31754ba61 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_reduced_precision_common_neon.h @@ -0,0 +1,307 @@ +#pragma once +// Shared code for bfloat16 and float16. + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +namespace at::vec { +inline namespace CPU_CAPABILITY { + +// Shared implementation between Vectorized and +// Vectorized. Uses CRTP to allow derived class +// customization. +template < + typename VecT, + typename ValueT, + template typename BlendRegs, + typename Derived> +struct Vectorized16 { + protected: + VecT values; + + public: + using value_type = ValueT; + using size_type = int; + static constexpr size_type size() { + static_assert(sizeof(VecT) == 8 * sizeof(value_type)); + return 8; + } + + protected: + Derived map2( + const Derived& second, + value_type (*const f)(value_type, value_type)) const { + __at_align__ value_type tmp_first[size()]; + __at_align__ value_type tmp_second[size()]; + static_cast(this)->store( + tmp_first); // store this to tmp_first + second.store(tmp_second); + for (const auto i : c10::irange(size())) { + tmp_first[i] = f(tmp_first[i], tmp_second[i]); + } + return Derived::loadu(tmp_first); + } + + public: + Vectorized16() = default; + Vectorized16(VecT v) : values(v) {} + + operator VecT() const { + return values; + } + + template + static Derived blend(const Derived& a, const Derived& b) { + Derived vec; + vec.values = BlendRegs < 0, + (mask & 0x01) != 0 > ::impl(a.values, b.values, vec.values); + vec.values = BlendRegs < 1, + (mask & 0x02) != 0 > ::impl(a.values, b.values, vec.values); + vec.values = BlendRegs < 2, + (mask & 0x04) != 0 > ::impl(a.values, b.values, vec.values); + vec.values = BlendRegs < 3, + (mask & 0x08) != 0 > ::impl(a.values, b.values, vec.values); + + vec.values = BlendRegs < 4, + (mask & 0x10) != 0 > ::impl(a.values, b.values, vec.values); + vec.values = BlendRegs < 5, + (mask & 0x20) != 0 > ::impl(a.values, b.values, vec.values); + vec.values = BlendRegs < 6, + (mask & 0x40) != 0 > ::impl(a.values, b.values, vec.values); + vec.values = BlendRegs < 7, + (mask & 0x80) != 0 > ::impl(a.values, b.values, vec.values); + + return vec; + } + + template + static Derived arange( + value_type base = 0, + step_t step = static_cast(1)) { + const Derived base_vec(base); + const Derived step_vec(step); + const Derived step_sizes( + value_type(0), + value_type(1), + value_type(2), + value_type(3), + value_type(4), + value_type(5), + value_type(6), + value_type(7)); + return fmadd(step_sizes, step_vec, base_vec); + } + + // Very slow implementation of indexing. + // Only required because vec256_qint refers to this. + // Once we specialize that implementation for ARM + // this should be removed. TODO (kimishpatel) + value_type operator[](int idx) const { + __at_align__ value_type tmp[size()]; + static_cast(this)->store(tmp); + return tmp[idx]; + } + + int zero_mask() const { + __at_align__ value_type tmp[size()]; + static_cast(this)->store(tmp); + int mask = 0; + for (int i = 0; i < size(); ++i) { + if (tmp[i] == 0) { + mask |= (1 << i); + } + } + return mask; + } + + Derived map(value_type (*const f)(value_type)) const { + __at_align__ value_type tmp[size()]; + static_cast(this)->store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return Derived::loadu(tmp); + } + + Derived angle() const { + auto zero = Derived(0); + auto pi = Derived(c10::pi); + auto tmp = + Derived::blendv(zero, pi, *static_cast(this) < zero); + return Derived::blendv( + tmp, + *static_cast(this), + static_cast(this)->isnan()); + } + Derived real() const { + return *this; + } + Derived imag() const { + return Derived(0); + } + Derived conj() const { + return *this; + } + + // Sleef does not support FP16/BF16, so many math functions are applied by + // converting to FP32, applying the math function, and then converting back to + // FP16/BF16. + Derived acos() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::acos); + } + Derived acosh() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::acosh); + } + Derived asin() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::asin); + } + Derived asinh() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::asinh); + } + Derived atan() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::atan); + } + Derived atanh() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::atanh); + } + Derived atan2(const Derived& exp) const { + return static_cast(this)->map2_with_vec_float_method( + exp, &Vectorized::atan2); + } + Derived copysign(const Derived& sign) const { + return static_cast(this)->map2_with_vec_float_method( + sign, &Vectorized::copysign); + } + Derived erf() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::erf); + } + Derived erfc() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::erfc); + } + Derived erfinv() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::erfinv); + } + Derived exp() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::exp); + } + Derived exp2() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::exp2); + } + Derived expm1() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::expm1); + } + Derived exp_u20() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::exp_u20); + } + Derived fmod(const Derived& q) const { + // This function is questionable with a conversion, so we use map2 + return map2(q, std::fmod); + } + Derived hypot(const Derived& b) const { + return static_cast(this)->map2_with_vec_float_method( + b, &Vectorized::hypot); + } + Derived i0() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::i0); + } + Derived i0e() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::i0e); + } + Derived digamma() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::digamma); + } + Derived igamma(const Derived& x) const { + return static_cast(this)->map2_with_vec_float_method( + x, &Vectorized::igamma); + } + Derived igammac(const Derived& x) const { + return static_cast(this)->map2_with_vec_float_method( + x, &Vectorized::igammac); + } + Derived log() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::log); + } + Derived log10() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::log10); + } + Derived log1p() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::log1p); + } + Derived log2() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::log2); + } + Derived nextafter(const Derived& b) const { + // This function does not make sense with conversion, so we use map2 + return map2(b, std::nextafter); + } + Derived sin() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::sin); + } + Derived sinh() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::sinh); + } + Derived cos() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::cos); + } + Derived cosh() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::cosh); + } + Derived ceil() const { + // This function is questionable with a conversion, so we use map + return map(at::native::ceil_impl); + } + Derived floor() const { + // This function is questionable with a conversion, so we use map + return map(at::native::floor_impl); + } + Derived round() const { + // This function is questionable with a conversion, so we use map + return map(at::native::round_impl); + } + Derived tan() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::tan); + } + Derived tanh() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::tanh); + } + Derived lgamma() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::lgamma); + } + Derived rsqrt() const { + return static_cast(this)->sqrt().reciprocal(); + } + Derived pow(const Derived& exp) const { + return static_cast(this)->map2_with_vec_float_method( + exp, &Vectorized::pow); + } +}; + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/missing_vld1_neon.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/missing_vld1_neon.h new file mode 100644 index 0000000000000000000000000000000000000000..86b0af76314a0334e67f9ea646059d8a9393076a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/missing_vld1_neon.h @@ -0,0 +1,396 @@ +/* Workaround for missing vld1_*_x2 and vst1_*_x2 intrinsics in gcc-7. */ + +__extension__ extern __inline uint8x8x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_u8_x2(const uint8_t* __a) { + uint8x8x2_t ret; + asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline int8x8x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_s8_x2(const int8_t* __a) { + int8x8x2_t ret; + asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline uint16x4x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_u16_x2(const uint16_t* __a) { + uint16x4x2_t ret; + asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline int16x4x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_s16_x2(const int16_t* __a) { + int16x4x2_t ret; + asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline uint32x2x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_u32_x2(const uint32_t* __a) { + uint32x2x2_t ret; + asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline int32x2x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_s32_x2(const int32_t* __a) { + int32x2x2_t ret; + asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline uint64x1x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_u64_x2(const uint64_t* __a) { + uint64x1x2_t ret; + asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline int64x1x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_s64_x2(const int64_t* __a) { + int64x1x2_t ret; + __builtin_aarch64_simd_oi __o; + asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline float16x4x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_f16_x2(const float16_t* __a) { + float16x4x2_t ret; + asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline float32x2x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_f32_x2(const float32_t* __a) { + float32x2x2_t ret; + asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline float64x1x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_f64_x2(const float64_t* __a) { + float64x1x2_t ret; + asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline poly8x8x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_p8_x2(const poly8_t* __a) { + poly8x8x2_t ret; + asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline poly16x4x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_p16_x2(const poly16_t* __a) { + poly16x4x2_t ret; + asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline poly64x1x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_p64_x2(const poly64_t* __a) { + poly64x1x2_t ret; + asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline uint8x16x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_u8_x2(const uint8_t* __a) { + uint8x16x2_t ret; + asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline int8x16x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_s8_x2(const int8_t* __a) { + int8x16x2_t ret; + asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline uint16x8x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_u16_x2(const uint16_t* __a) { + uint16x8x2_t ret; + asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline int16x8x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_s16_x2(const int16_t* __a) { + int16x8x2_t ret; + asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline uint32x4x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_u32_x2(const uint32_t* __a) { + uint32x4x2_t ret; + asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline int32x4x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_s32_x2(const int32_t* __a) { + int32x4x2_t ret; + asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline uint64x2x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_u64_x2(const uint64_t* __a) { + uint64x2x2_t ret; + asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline int64x2x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_s64_x2(const int64_t* __a) { + int64x2x2_t ret; + asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline float16x8x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_f16_x2(const float16_t* __a) { + float16x8x2_t ret; + asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline float32x4x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_f32_x2(const float32_t* __a) { + float32x4x2_t ret; + asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline float64x2x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_f64_x2(const float64_t* __a) { + float64x2x2_t ret; + asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline poly8x16x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_p8_x2(const poly8_t* __a) { + poly8x16x2_t ret; + asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline poly16x8x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_p16_x2(const poly16_t* __a) { + poly16x8x2_t ret; + asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline poly64x2x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_p64_x2(const poly64_t* __a) { + poly64x2x2_t ret; + asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +/* vst1x2 */ + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_s64_x2(int64_t* __a, int64x1x2_t val) { + asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_u64_x2(uint64_t* __a, uint64x1x2_t val) { + asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_f64_x2(float64_t* __a, float64x1x2_t val) { + asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_s8_x2(int8_t* __a, int8x8x2_t val) { + asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_p8_x2(poly8_t* __a, poly8x8x2_t val) { + asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_s16_x2(int16_t* __a, int16x4x2_t val) { + asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_p16_x2(poly16_t* __a, poly16x4x2_t val) { + asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_s32_x2(int32_t* __a, int32x2x2_t val) { + asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_u8_x2(uint8_t* __a, uint8x8x2_t val) { + asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_u16_x2(uint16_t* __a, uint16x4x2_t val) { + asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_u32_x2(uint32_t* __a, uint32x2x2_t val) { + asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_f16_x2(float16_t* __a, float16x4x2_t val) { + asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_f32_x2(float32_t* __a, float32x2x2_t val) { + asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_p64_x2(poly64_t* __a, poly64x1x2_t val) { + asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_s8_x2(int8_t* __a, int8x16x2_t val) { + asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_p8_x2(poly8_t* __a, poly8x16x2_t val) { + asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_s16_x2(int16_t* __a, int16x8x2_t val) { + asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_p16_x2(poly16_t* __a, poly16x8x2_t val) { + asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_s32_x2(int32_t* __a, int32x4x2_t val) { + asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_s64_x2(int64_t* __a, int64x2x2_t val) { + asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_u8_x2(uint8_t* __a, uint8x16x2_t val) { + asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_u16_x2(uint16_t* __a, uint16x8x2_t val) { + asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_u32_x2(uint32_t* __a, uint32x4x2_t val) { + asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_u64_x2(uint64_t* __a, uint64x2x2_t val) { + asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_f16_x2(float16_t* __a, float16x8x2_t val) { + asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_f32_x2(float32_t* __a, float32x4x2_t val) { + asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_f64_x2(float64_t* __a, float64x2x2_t val) { + asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_p64_x2(poly64_t* __a, poly64x2x2_t val) { + asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val)); +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/missing_vst1_neon.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/missing_vst1_neon.h new file mode 100644 index 0000000000000000000000000000000000000000..a90690d8533f05547b307aaf200294dd75c3850b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/missing_vst1_neon.h @@ -0,0 +1,7 @@ +/* Workaround for missing vst1q_f32_x2 in gcc-8. */ + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_f32_x2(float32_t* __a, float32x4x2_t val) { + asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val)); +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256.h new file mode 100644 index 0000000000000000000000000000000000000000..aebb15d91d01dde8b676688f62f2dab3d85bc5cf --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256.h @@ -0,0 +1,430 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include + +#include +#if !( \ + defined(__VSX__) || defined(CPU_CAPABILITY_VSX) || \ + defined(CPU_CAPABILITY_ZVECTOR)) +#if defined(CPU_CAPABILITY_SVE256) +#include +#else +// clang-format off +#include +#include +#include +#include +#endif +#if !defined(CPU_CAPABILITY_SVE256) || !defined(__ARM_FEATURE_BF16) +#include +#endif +#include +#include +#include +// clang-format on +#elif defined(__VSX__) || defined(CPU_CAPABILITY_VSX) +#include +#else +// clang-format off +#include +#include +#include +// clang-format on +#endif + +#include +#include + +#include +#include +#include +#include +#include + +namespace at::vec { + +// Note [CPU_CAPABILITY namespace] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This header, and all of its subheaders, will be compiled with +// different architecture flags for each supported set of vector +// intrinsics. So we need to make sure they aren't inadvertently +// linked together. We do this by declaring objects in an `inline +// namespace` which changes the name mangling, but can still be +// accessed as `at::vec`. +inline namespace CPU_CAPABILITY { + +inline std::ostream& operator<<(std::ostream& stream, const c10::qint32& val) { + stream << val.val_; + return stream; +} +inline std::ostream& operator<<(std::ostream& stream, const c10::qint8& val) { + stream << static_cast(val.val_); + return stream; +} +inline std::ostream& operator<<(std::ostream& stream, const c10::quint8& val) { + stream << static_cast(val.val_); + return stream; +} + +template +std::ostream& operator<<(std::ostream& stream, const Vectorized& vec) { + T buf[Vectorized::size()]; + vec.store(buf); + stream << "vec["; + for (int i = 0; i != Vectorized::size(); i++) { + if (i != 0) { + stream << ", "; + } + stream << buf[i]; + } + stream << "]"; + return stream; +} + +#if defined(CPU_CAPABILITY_AVX2) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX2) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +inline Vectorized cast(const Vectorized& src) { + return _mm256_castpd_ps(src); +} + +template <> +inline Vectorized cast(const Vectorized& src) { + return _mm256_castps_pd(src); +} + +template <> +inline Vectorized cast(const Vectorized& src) { + return _mm256_castsi256_ps(src); +} + +template <> +inline Vectorized cast( + const Vectorized& src) { + return _mm256_castsi256_pd(src); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +#ifndef _MSC_VER +// MSVC is not working well on complex function overload. +template +std::enable_if_t< + scale == 1 || scale == 2 || scale == 4 || scale == 8, + Vectorized< + double>> inline gather(const double* base_addr, const Vectorized& vindex) { + return _mm256_i64gather_pd(base_addr, vindex, scale); +} + +template +std::enable_if_t< + scale == 1 || scale == 2 || scale == 4 || scale == 8, + Vectorized< + float>> inline gather(const float* base_addr, const Vectorized& vindex) { + return _mm256_i32gather_ps(base_addr, vindex, scale); +} +#endif +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +#ifndef _MSC_VER +// MSVC is not working well on complex function overload. +template +std:: + enable_if_t> inline mask_gather( + const Vectorized& src, + const double* base_addr, + const Vectorized& vindex, + Vectorized& mask) { + return _mm256_mask_i64gather_pd(src, base_addr, vindex, mask, scale); +} + +template +std:: + enable_if_t> inline mask_gather( + const Vectorized& src, + const float* base_addr, + const Vectorized& vindex, + Vectorized& mask) { + return _mm256_mask_i32gather_ps(src, base_addr, vindex, mask, scale); +} +#endif +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Only works for inputs in the range: [-2^51, 2^51] +// From: https://stackoverflow.com/a/41148578 +template <> +Vectorized inline convert_to_int_of_same_size( + const Vectorized& src) { + auto x = _mm256_add_pd(src, _mm256_set1_pd(0x0018000000000000)); + return _mm256_sub_epi64( + _mm256_castpd_si256(x), + _mm256_castpd_si256(_mm256_set1_pd(0x0018000000000000))); +} + +template <> +Vectorized inline convert_to_int_of_same_size( + const Vectorized& src) { + return _mm256_cvttps_epi32(src); +} + +// From: https://stackoverflow.com/a/41148578 +template <> +Vectorized inline convert_to_fp_of_same_size( + const Vectorized& src) { + __m256i magic_i_lo = _mm256_set1_epi64x(0x4330000000000000); /* 2^52 */ + __m256i magic_i_hi32 = + _mm256_set1_epi64x(0x4530000080000000); /* 2^84 + 2^63 */ + __m256i magic_i_all = + _mm256_set1_epi64x(0x4530000080100000); /* 2^84 + 2^63 + 2^52 */ + __m256d magic_d_all = _mm256_castsi256_pd(magic_i_all); + + __m256i v_lo = _mm256_blend_epi32( + magic_i_lo, src, 0b01010101); /* v_low = low32 + 2^52 */ + __m256i v_hi = _mm256_srli_epi64(src, 32); + v_hi = _mm256_xor_si256( + v_hi, magic_i_hi32); /* v_hi = high32*2^32 + 2^84 + 2^63 */ + /* int64 = low32 + high32*2^32 = v_hi + v_lo - 2^52 - 2^63 - 2^84 */ + __m256d v_hi_dbl = _mm256_sub_pd(_mm256_castsi256_pd(v_hi), magic_d_all); + __m256d result = _mm256_add_pd(v_hi_dbl, _mm256_castsi256_pd(v_lo)); + return result; +} + +template <> +Vectorized inline convert_to_fp_of_same_size( + const Vectorized& src) { + return _mm256_cvtepi32_ps(src); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +std::pair, Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, a1, a2, a3} + // b = {b0, b1, b2, b3} + + // swap lanes: + // a_swapped = {a0, a1, b0, b1} + // b_swapped = {a2, a3, b2, b3} + auto a_swapped = + _mm256_permute2f128_pd(a, b, 0b0100000); // 0, 2. 4 bits apart + auto b_swapped = + _mm256_permute2f128_pd(a, b, 0b0110001); // 1, 3. 4 bits apart + + // group cols crossing lanes: + // return {a0, b0, a1, b1} + // {a2, b2, a3, b3} + return std::make_pair( + _mm256_permute4x64_pd(a_swapped, 0b11011000), // 0, 2, 1, 3 + _mm256_permute4x64_pd(b_swapped, 0b11011000)); // 0, 2, 1, 3 +} + +template <> +std::pair, Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, a1, a2, a3, a4, a5, a6, a7} + // b = {b0, b1, b2, b3, b4, b5, b6, b7} + + // swap lanes: + // a_swapped = {a0, a1, a2, a3, b0, b1, b2, b3} + // b_swapped = {a4, a5, a6, a7, b4, b5, b6, b7} + // TODO: can we support caching this? + auto a_swapped = + _mm256_permute2f128_ps(a, b, 0b0100000); // 0, 2. 4 bits apart + auto b_swapped = + _mm256_permute2f128_ps(a, b, 0b0110001); // 1, 3. 4 bits apart + + // group cols crossing lanes: + // return {a0, b0, a1, b1, a2, b2, a3, b3} + // {a4, b4, a5, b5, a6, b6, a7, b7} + const __m256i group_ctrl = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7); + return std::make_pair( + _mm256_permutevar8x32_ps(a_swapped, group_ctrl), + _mm256_permutevar8x32_ps(b_swapped, group_ctrl)); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DEINTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +std::pair, Vectorized> inline deinterleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1} + // b = {a2, b2, a3, b3} + + // group cols crossing lanes: + // a_grouped = {a0, a1, b0, b1} + // b_grouped = {a2, a3, b2, b3} + auto a_grouped = _mm256_permute4x64_pd(a, 0b11011000); // 0, 2, 1, 3 + auto b_grouped = _mm256_permute4x64_pd(b, 0b11011000); // 0, 2, 1, 3 + + // swap lanes: + // return {a0, a1, a2, a3} + // {b0, b1, b2, b3} + return std::make_pair( + _mm256_permute2f128_pd( + a_grouped, b_grouped, 0b0100000), // 0, 2. 4 bits apart + _mm256_permute2f128_pd( + a_grouped, b_grouped, 0b0110001)); // 1, 3. 4 bits apart +} + +template <> +std::pair, Vectorized> inline deinterleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1, a2, b2, a3, b3} + // b = {a4, b4, a5, b5, a6, b6, a7, b7} + + // group cols crossing lanes: + // a_grouped = {a0, a1, a2, a3, b0, b1, b2, b3} + // b_grouped = {a4, a5, a6, a7, b4, b5, b6, b7} + // TODO: can we support caching this? + const __m256i group_ctrl = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7); + auto a_grouped = _mm256_permutevar8x32_ps(a, group_ctrl); + auto b_grouped = _mm256_permutevar8x32_ps(b, group_ctrl); + + // swap lanes: + // return {a0, a1, a2, a3, a4, a5, a6, a7} + // {b0, b1, b2, b3, b4, b5, b6, b7} + return std::make_pair( + _mm256_permute2f128_ps( + a_grouped, b_grouped, 0b0100000), // 0, 2. 4 bits apart + _mm256_permute2f128_ps( + a_grouped, b_grouped, 0b0110001)); // 1, 3. 4 bits apart +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FLIP ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +inline Vectorized flip(const Vectorized& v) { + const __m256i mask_float = _mm256_set_epi32(0, 1, 2, 3, 4, 5, 6, 7); + return _mm256_permutevar8x32_ps(v, mask_float); +} + +template <> +inline Vectorized flip(const Vectorized& v) { + return _mm256_permute4x64_pd(v, 27); // 27 == _MM_SHUFFLE(0, 1, 2, 3) +} + +template <> +inline Vectorized flip(const Vectorized& v) { + return _mm256_permute4x64_epi64(v, 27); // 27 == _MM_SHUFFLE(0, 1, 2, 3) +} + +template <> +inline Vectorized flip(const Vectorized& v) { + const __m256i mask_int32 = _mm256_set_epi32(0, 1, 2, 3, 4, 5, 6, 7); + return _mm256_permutevar8x32_epi32(v, mask_int32); +} + +template <> +inline Vectorized flip(const Vectorized& v) { + const __m256i mask = _mm256_set_epi8( + 1, + 0, + 3, + 2, + 5, + 4, + 7, + 6, + 9, + 8, + 11, + 10, + 13, + 12, + 15, + 14, + 1, + 0, + 3, + 2, + 5, + 4, + 7, + 6, + 9, + 8, + 11, + 10, + 13, + 12, + 15, + 14); + auto reversed = _mm256_shuffle_epi8(v, mask); + return _mm256_permute2x128_si256(reversed, reversed, 1); +} + +inline __m256i flip8(const __m256i& v) { + const __m256i mask_int8 = _mm256_set_epi8( + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15); + auto reversed = _mm256_shuffle_epi8(v, mask_int8); + return _mm256_permute2x128_si256(reversed, reversed, 1); +} + +template <> +inline Vectorized flip(const Vectorized& v) { + return flip8(v); +} + +template <> +inline Vectorized flip(const Vectorized& v) { + return flip8(v); +} + +inline Vectorized operator&&( + const Vectorized& self, + const Vectorized& other) { + const __m256i* self_ = reinterpret_cast(self.as_bytes()); + const __m256i* other_ = reinterpret_cast(other.as_bytes()); + __m256i out = _mm256_and_si256(*self_, *other_); + Vectorized ret; + std::memcpy(ret, &out, ret.size() * sizeof(bool)); + return ret; +} + +#endif // (defined(CPU_CAPABILITY_AVX2) + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_16bit_float.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_16bit_float.h new file mode 100644 index 0000000000000000000000000000000000000000..0e4f5c2e9886c609906afc40520d280b43340c4c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_16bit_float.h @@ -0,0 +1,829 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +// Used for shared functions and classes for vec256_bfloat16.h and +// vec256_half.h. Any functions/classes that are common between those two files +// should be defined here. Any non-shared functions/classes should be defined in +// the respective files. + +#include +#include + +#if defined(CPU_CAPABILITY_AVX2) +#define SLEEF_STATIC_LIBS +#include +#endif + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX2) + +#ifndef SLEEF_CONST +#if (defined(__GNUC__) || defined(__CLANG__)) && !defined(__INTEL_COMPILER) +#define SLEEF_CONST const +#else +#define SLEEF_CONST +#endif +#define SLEEF_CONST_OLD SLEEF_CONST +#else +#define SLEEF_CONST_OLD +#endif + +// bfloat16 conversion +static inline void cvtbf16_fp32(const __m128i& a, __m256& o) { + o = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(a), 16)); +} + +static inline void cvtbf16_fp32(const __m256i& a, __m256& o1, __m256& o2) { + __m128i lo = _mm256_extractf128_si256(a, 0); + __m128i hi = _mm256_extractf128_si256(a, 1); + cvtbf16_fp32(lo, o1); + cvtbf16_fp32(hi, o2); +} + +static inline __m128i cvtfp32_bf16(const __m256& src) { + __m256i value = _mm256_castps_si256(src); + __m256i nan = _mm256_set1_epi32(0xffff); + __m256i mask = _mm256_castps_si256(_mm256_cmp_ps(src, src, _CMP_ORD_Q)); + __m256i ones = _mm256_set1_epi32(0x1); + __m256i vec_bias = _mm256_set1_epi32(0x7fff); + // uint32_t lsb = (input >> 16) & 1; + auto t_value = _mm256_and_si256(_mm256_srli_epi32(value, 16), ones); + // uint32_t rounding_bias = 0x7fff + lsb; + t_value = _mm256_add_epi32(t_value, vec_bias); + // input += rounding_bias; + t_value = _mm256_add_epi32(t_value, value); + // input = input >> 16; + t_value = _mm256_srli_epi32(t_value, 16); + // Check NaN before converting back to bf16 + t_value = _mm256_blendv_epi8(nan, t_value, mask); + t_value = + _mm256_packus_epi32(t_value, t_value); // t[4-7] t[4-7] t[0-4] t[0-4] + t_value = _mm256_permute4x64_epi64(t_value, 0xd8); // 11 01 10 00 + return _mm256_castsi256_si128(t_value); +} + +static inline __m256i cvtfp32_bf16(const __m256& a, const __m256& b) { + __m256i lo = _mm256_castps_si256(a); + __m256i hi = _mm256_castps_si256(b); + __m256i nan = _mm256_set1_epi32(0xffff); + __m256i mask_lo = _mm256_castps_si256(_mm256_cmp_ps(a, a, _CMP_ORD_Q)); + __m256i mask_hi = _mm256_castps_si256(_mm256_cmp_ps(b, b, _CMP_ORD_Q)); + __m256i ones = _mm256_set1_epi32(0x1); + __m256i vec_bias = _mm256_set1_epi32(0x7fff); + // uint32_t lsb = (input >> 16) & 1; + auto t_lo = _mm256_and_si256(_mm256_srli_epi32(lo, 16), ones); + auto t_hi = _mm256_and_si256(_mm256_srli_epi32(hi, 16), ones); + // uint32_t rounding_bias = 0x7fff + lsb; + t_lo = _mm256_add_epi32(t_lo, vec_bias); + t_hi = _mm256_add_epi32(t_hi, vec_bias); + // input += rounding_bias; + t_lo = _mm256_add_epi32(t_lo, lo); + t_hi = _mm256_add_epi32(t_hi, hi); + // input = input >> 16; + t_lo = _mm256_srli_epi32(t_lo, 16); + t_hi = _mm256_srli_epi32(t_hi, 16); + // Check NaN before converting back to bf16 + t_lo = _mm256_blendv_epi8(nan, t_lo, mask_lo); + t_hi = _mm256_blendv_epi8(nan, t_hi, mask_hi); + + t_lo = _mm256_packus_epi32( + t_lo, t_hi); // t_hi[4-7] t_lo[4-7] t_hi[0-4] t_lo[0-4] + return _mm256_permute4x64_epi64(t_lo, 0xd8); // 11 01 10 00 +} + +static inline __m256i merge_compare_result(const __m256& a, const __m256& b) { + __m256i lo = _mm256_castps_si256(a); + __m256i hi = _mm256_castps_si256(b); + lo = _mm256_srli_epi32(lo, 16); + hi = _mm256_srli_epi32(hi, 16); + auto out = _mm256_packus_epi32(lo, hi); + return _mm256_permute4x64_epi64(out, 0xd8); +} + +// float16 conversion +static inline void cvtfp16_fp32(const __m128i& a, __m256& o) { + o = _mm256_cvtph_ps(a); +} + +static inline void cvtfp16_fp32(const __m256i& a, __m256& o1, __m256& o2) { + __m128i lo = _mm256_extractf128_si256(a, 0); + __m128i hi = _mm256_extractf128_si256(a, 1); + cvtfp16_fp32(lo, o1); + cvtfp16_fp32(hi, o2); +} + +static inline __m128i cvtfp32_fp16(const __m256& src) { + return _mm256_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); +} + +static inline __m256i cvtfp32_fp16(const __m256& a, const __m256& b) { + __m128i lo = + _mm256_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m128i hi = + _mm256_cvtps_ph(b, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), hi, 1); +} + +// dtype conversion between float16/bfloat16 and float32 +template < + typename T, + typename std::enable_if_t, int> = 0> +inline void cvt_to_fp32(const __m128i& a, __m256& o); +template <> +inline void cvt_to_fp32(const __m128i& a, __m256& o) { + cvtbf16_fp32(a, o); +} +template <> +inline void cvt_to_fp32(const __m128i& a, __m256& o) { + cvtfp16_fp32(a, o); +} + +template < + typename T, + typename std::enable_if_t, int> = 0> +inline void cvt_to_fp32(const __m256i& a, __m256& o1, __m256& o2); +template <> +inline void cvt_to_fp32(const __m256i& a, __m256& o1, __m256& o2) { + cvtbf16_fp32(a, o1, o2); +} +template <> +inline void cvt_to_fp32(const __m256i& a, __m256& o1, __m256& o2) { + cvtfp16_fp32(a, o1, o2); +} + +template < + typename T, + bool is_compare_op = false, + typename std::enable_if_t, int> = 0> +inline __m256i cvt_from_fp32(const __m256& a, const __m256& b); +template <> +inline __m256i cvt_from_fp32( + const __m256& a, + const __m256& b) { + return cvtfp32_bf16(a, b); +} +template <> +inline __m256i cvt_from_fp32(const __m256& a, const __m256& b) { + return merge_compare_result(a, b); +} +template <> +inline __m256i cvt_from_fp32(const __m256& a, const __m256& b) { + return cvtfp32_fp16(a, b); +} +template <> +inline __m256i cvt_from_fp32(const __m256& a, const __m256& b) { + return cvtfp32_fp16(a, b); +} + +template +class Vectorized16 { + static_assert( + is_reduced_floating_point_v, + "Support only float16 and bfloat16."); + + protected: + __m256i values; + + public: + using value_type = uint16_t; + using size_type = int; + static constexpr size_type size() { + return 16; + } + Vectorized16() {} + Vectorized16(__m256i v) : values(v) {} + Vectorized16(T val) { + value_type uw = val.x; + values = _mm256_set1_epi16(uw); + } + Vectorized16( + T val1, + T val2, + T val3, + T val4, + T val5, + T val6, + T val7, + T val8, + T val9, + T val10, + T val11, + T val12, + T val13, + T val14, + T val15, + T val16) { + values = _mm256_setr_epi16( + val1.x, + val2.x, + val3.x, + val4.x, + val5.x, + val6.x, + val7.x, + val8.x, + val9.x, + val10.x, + val11.x, + val12.x, + val13.x, + val14.x, + val15.x, + val16.x); + } + operator __m256i() const { + return values; + } + T& operator[](int idx) = delete; + const T& operator[](int idx) const = delete; + int zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit + // and others are translated to 0-bit + __m256i cmp = _mm256_cmpeq_epi16(values, _mm256_set1_epi16(0)); + return _mm256_movemask_epi8(cmp); + } + static Vectorized loadu(const void* ptr, int16_t count = size()) { + if (count == size()) + return _mm256_loadu_si256(reinterpret_cast(ptr)); + + __at_align__ int16_t tmp_values[size()]; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (const auto i : c10::irange(count, size())) { + tmp_values[i] = 0; + } + std::memcpy(tmp_values, ptr, count * sizeof(int16_t)); + return _mm256_loadu_si256(reinterpret_cast(tmp_values)); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); + } else if (count > 0) { + __at_align__ int16_t tmp_values[size()]; + _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(int16_t)); + } + } + template + static Vectorized blend(const Vectorized& a, const Vectorized& b) { + __at_align__ int16_t tmp_values[size()]; + a.store(tmp_values); + if (mask & 0x01) + tmp_values[0] = _mm256_extract_epi16(b.values, 0); + if (mask & 0x02) + tmp_values[1] = _mm256_extract_epi16(b.values, 1); + if (mask & 0x04) + tmp_values[2] = _mm256_extract_epi16(b.values, 2); + if (mask & 0x08) + tmp_values[3] = _mm256_extract_epi16(b.values, 3); + if (mask & 0x10) + tmp_values[4] = _mm256_extract_epi16(b.values, 4); + if (mask & 0x20) + tmp_values[5] = _mm256_extract_epi16(b.values, 5); + if (mask & 0x40) + tmp_values[6] = _mm256_extract_epi16(b.values, 6); + if (mask & 0x80) + tmp_values[7] = _mm256_extract_epi16(b.values, 7); + if (mask & 0x100) + tmp_values[8] = _mm256_extract_epi16(b.values, 8); + if (mask & 0x200) + tmp_values[9] = _mm256_extract_epi16(b.values, 9); + if (mask & 0x400) + tmp_values[10] = _mm256_extract_epi16(b.values, 10); + if (mask & 0x800) + tmp_values[11] = _mm256_extract_epi16(b.values, 11); + if (mask & 0x1000) + tmp_values[12] = _mm256_extract_epi16(b.values, 12); + if (mask & 0x2000) + tmp_values[13] = _mm256_extract_epi16(b.values, 13); + if (mask & 0x4000) + tmp_values[14] = _mm256_extract_epi16(b.values, 14); + if (mask & 0x8000) + tmp_values[15] = _mm256_extract_epi16(b.values, 15); + return loadu(tmp_values); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + return _mm256_blendv_epi8(a.values, b.values, mask.values); + } + template + static Vectorized arange( + T base = 0.f, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step, + base + 8 * step, + base + 9 * step, + base + 10 * step, + base + 11 * step, + base + 12 * step, + base + 13 * step, + base + 14 * step, + base + 15 * step); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + case 8: + return blend<255>(a, b); + case 9: + return blend<511>(a, b); + case 10: + return blend<1023>(a, b); + case 11: + return blend<2047>(a, b); + case 12: + return blend<4095>(a, b); + case 13: + return blend<8191>(a, b); + case 14: + return blend<16383>(a, b); + case 15: + return blend<32767>(a, b); + } + return b; + } + + // 'const' type qualifier on return type has no effect, but sleef defines this + // this way For example `Sleef_exp2f8_u10` signature is `const __m256 + // (__m256)` + C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wignored-qualifiers") + Vectorized map(SLEEF_CONST __m256 (*SLEEF_CONST_OLD vop)(__m256)) const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + const auto o1 = vop(lo); + const auto o2 = vop(hi); + return cvt_from_fp32(o1, o2); + } + C10_DIAGNOSTIC_POP() + Vectorized isnan() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + lo = _mm256_cmp_ps(lo, _mm256_set1_ps(0.0f), _CMP_UNORD_Q); + hi = _mm256_cmp_ps(hi, _mm256_set1_ps(0.0f), _CMP_UNORD_Q); + return merge_compare_result(lo, hi); + } + Vectorized abs() const { + return _mm256_andnot_si256(_mm256_set1_epi16(0x8000), values); + } + Vectorized angle() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + auto angle_lambda = [](__m256 values_2) { + const auto zero_vec = _mm256_set1_ps(0.f); + const auto nan_vec = _mm256_set1_ps(NAN); + const auto not_nan_mask = _mm256_cmp_ps(values_2, values_2, _CMP_EQ_OQ); + const auto nan_mask = _mm256_cmp_ps(not_nan_mask, zero_vec, _CMP_EQ_OQ); + const auto pi = _mm256_set1_ps(c10::pi); + + const auto neg_mask = _mm256_cmp_ps(values_2, zero_vec, _CMP_LT_OQ); + auto angle = _mm256_blendv_ps(zero_vec, pi, neg_mask); + angle = _mm256_blendv_ps(angle, nan_vec, nan_mask); + return angle; + }; + auto o1 = angle_lambda(lo); + auto o2 = angle_lambda(hi); + return cvt_from_fp32(o1, o2); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm256_set1_epi16(0); + } + Vectorized conj() const { + return *this; + } + Vectorized acos() const { + return map(Sleef_acosf8_u10); + } + Vectorized acosh() const { + return map(Sleef_acoshf8_u10); + } + Vectorized asin() const { + return map(Sleef_asinf8_u10); + } + Vectorized atan() const { + return map(Sleef_atanf8_u10); + } + Vectorized atanh() const { + return map(Sleef_atanhf8_u10); + } + Vectorized atan2(const Vectorized& b) const { + __m256 lo, hi; + __m256 b1, b2; + cvt_to_fp32(values, lo, hi); + cvt_to_fp32(b.values, b1, b2); + auto o1 = Sleef_atan2f8_u10(lo, b1); + auto o2 = Sleef_atan2f8_u10(hi, b2); + return cvt_from_fp32(o1, o2); + } + Vectorized copysign(const Vectorized& sign) const { + // copy sign bit (0x8000) from sign and remaining bits from values + __m256i mask_value = _mm256_set1_epi32(~0x80008000); + __m256i mask_signbit = _mm256_set1_epi32(0x80008000); + return Vectorized(_mm256_or_si256( + _mm256_and_si256(values, mask_value), + _mm256_and_si256(sign, mask_signbit))); + } + Vectorized erf() const { + return map(Sleef_erff8_u10); + } + Vectorized erfc() const { + return map(Sleef_erfcf8_u15); + } + Vectorized erfinv() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm256_storeu_ps(reinterpret_cast(tmp1), lo); + _mm256_storeu_ps(reinterpret_cast(tmp2), hi); + for (int64_t i = 0; i < size() / 2; i++) { + tmp1[i] = calc_erfinv(tmp1[i]); + tmp2[i] = calc_erfinv(tmp2[i]); + } + auto o1 = _mm256_loadu_ps(tmp1); + auto o2 = _mm256_loadu_ps(tmp2); + return cvt_from_fp32(o1, o2); + } + Vectorized exp() const { + return map(Sleef_expf8_u10); + } + Vectorized exp2() const { + return map(Sleef_exp2f8_u10); + } + Vectorized expm1() const { + return map(Sleef_expm1f8_u10); + } + Vectorized exp_u20() const { + return exp(); + } + Vectorized fmod(const Vectorized& q) const { + __m256 x_lo, x_hi; + cvt_to_fp32(values, x_lo, x_hi); + __m256 q_lo, q_hi; + cvt_to_fp32(q.values, q_lo, q_hi); + auto o1 = Sleef_fmodf8(x_lo, q_lo); + auto o2 = Sleef_fmodf8(x_hi, q_hi); + return cvt_from_fp32(o1, o2); + } + Vectorized hypot(const Vectorized& b) const { + __m256 lo, hi; + __m256 b1, b2; + cvt_to_fp32(values, lo, hi); + cvt_to_fp32(b.values, b1, b2); + auto o1 = Sleef_hypotf8_u05(lo, b1); + auto o2 = Sleef_hypotf8_u05(hi, b2); + return cvt_from_fp32(o1, o2); + } + Vectorized i0() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm256_storeu_ps(reinterpret_cast(tmp1), lo); + _mm256_storeu_ps(reinterpret_cast(tmp2), hi); + for (int64_t i = 0; i < size() / 2; i++) { + tmp1[i] = calc_i0(tmp1[i]); + tmp2[i] = calc_i0(tmp2[i]); + } + auto o1 = _mm256_loadu_ps(tmp1); + auto o2 = _mm256_loadu_ps(tmp2); + return cvt_from_fp32(o1, o2); + } + Vectorized i0e() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + constexpr auto sz = size(); + __at_align__ float tmp1[sz / 2], tmp2[sz / 2]; + _mm256_storeu_ps(reinterpret_cast(tmp1), lo); + _mm256_storeu_ps(reinterpret_cast(tmp2), hi); + + for (auto i = decltype(sz){0}; i < sz / 2; i++) { + tmp1[i] = calc_i0e(tmp1[i]); + tmp2[i] = calc_i0e(tmp2[i]); + } + const auto o1 = _mm256_loadu_ps(tmp1); + const auto o2 = _mm256_loadu_ps(tmp2); + return cvt_from_fp32(o1, o2); + } + Vectorized digamma() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + constexpr auto sz = size(); + __at_align__ float tmp1[sz / 2], tmp2[sz / 2]; + _mm256_storeu_ps(reinterpret_cast(tmp1), lo); + _mm256_storeu_ps(reinterpret_cast(tmp2), hi); + + for (auto i = decltype(sz){0}; i < sz / 2; i++) { + tmp1[i] = calc_digamma(tmp1[i]); + tmp2[i] = calc_digamma(tmp2[i]); + } + const auto o1 = _mm256_loadu_ps(tmp1); + const auto o2 = _mm256_loadu_ps(tmp2); + return cvt_from_fp32(o1, o2); + } + Vectorized igamma(const Vectorized& x) const { + __m256 lo, hi; + __m256 xlo, xhi; + cvt_to_fp32(values, lo, hi); + cvt_to_fp32(x.values, xlo, xhi); + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm256_storeu_ps(reinterpret_cast(tmp1), lo); + _mm256_storeu_ps(reinterpret_cast(tmp2), hi); + __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2]; + _mm256_storeu_ps(reinterpret_cast(tmpx1), xlo); + _mm256_storeu_ps(reinterpret_cast(tmpx2), xhi); + for (int64_t i = 0; i < size() / 2; ++i) { + tmp1[i] = calc_igamma(tmp1[i], tmpx1[i]); + tmp2[i] = calc_igamma(tmp2[i], tmpx2[i]); + } + auto o1 = _mm256_loadu_ps(tmp1); + auto o2 = _mm256_loadu_ps(tmp2); + return cvt_from_fp32(o1, o2); + } + + Vectorized igammac(const Vectorized& x) const { + __m256 lo, hi; + __m256 xlo, xhi; + cvt_to_fp32(values, lo, hi); + cvt_to_fp32(x.values, xlo, xhi); + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm256_storeu_ps(reinterpret_cast(tmp1), lo); + _mm256_storeu_ps(reinterpret_cast(tmp2), hi); + __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2]; + _mm256_storeu_ps(reinterpret_cast(tmpx1), xlo); + _mm256_storeu_ps(reinterpret_cast(tmpx2), xhi); + for (int64_t i = 0; i < size() / 2; ++i) { + tmp1[i] = calc_igammac(tmp1[i], tmpx1[i]); + tmp2[i] = calc_igammac(tmp2[i], tmpx2[i]); + } + auto o1 = _mm256_loadu_ps(tmp1); + auto o2 = _mm256_loadu_ps(tmp2); + return cvt_from_fp32(o1, o2); + } + Vectorized log() const { + return map(Sleef_logf8_u10); + } + Vectorized log2() const { + return map(Sleef_log2f8_u10); + } + Vectorized log10() const { + return map(Sleef_log10f8_u10); + } + Vectorized log1p() const { + return map(Sleef_log1pf8_u10); + } + Vectorized sin() const { + return map(Sleef_sinf8_u10); + } + Vectorized sinh() const { + return map(Sleef_sinhf8_u10); + } + Vectorized cos() const { + return map(Sleef_cosf8_u10); + } + Vectorized cosh() const { + return map(Sleef_coshf8_u10); + } + Vectorized ceil() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + auto o1 = _mm256_ceil_ps(lo); + auto o2 = _mm256_ceil_ps(hi); + return cvt_from_fp32(o1, o2); + } + Vectorized floor() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + auto o1 = _mm256_floor_ps(lo); + auto o2 = _mm256_floor_ps(hi); + return cvt_from_fp32(o1, o2); + } + Vectorized neg() const { + return _mm256_xor_si256(values, _mm256_set1_epi16(0x8000)); + } + Vectorized round() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + auto o1 = + _mm256_round_ps(lo, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + auto o2 = + _mm256_round_ps(hi, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + return cvt_from_fp32(o1, o2); + } + Vectorized tan() const { + return map(Sleef_tanf8_u10); + } + Vectorized tanh() const { + return map(Sleef_tanhf8_u10); + } + Vectorized trunc() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + auto o1 = _mm256_round_ps(lo, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + auto o2 = _mm256_round_ps(hi, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + return cvt_from_fp32(o1, o2); + } + Vectorized lgamma() const { + return map(Sleef_lgammaf8_u10); + } + Vectorized sqrt() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + auto o1 = _mm256_sqrt_ps(lo); + auto o2 = _mm256_sqrt_ps(hi); + return cvt_from_fp32(o1, o2); + } + Vectorized reciprocal() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + auto ones = _mm256_set1_ps(1); + auto o1 = _mm256_div_ps(ones, lo); + auto o2 = _mm256_div_ps(ones, hi); + return cvt_from_fp32(o1, o2); + } + Vectorized rsqrt() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + auto ones = _mm256_set1_ps(1); + auto o1 = _mm256_div_ps(ones, _mm256_sqrt_ps(lo)); + auto o2 = _mm256_div_ps(ones, _mm256_sqrt_ps(hi)); + return cvt_from_fp32(o1, o2); + } + Vectorized pow(const Vectorized& b) const { + __m256 lo, hi; + __m256 b1, b2; + cvt_to_fp32(values, lo, hi); + cvt_to_fp32(b.values, b1, b2); + auto o1 = Sleef_powf8_u10(lo, b1); + auto o2 = Sleef_powf8_u10(hi, b2); + return cvt_from_fp32(o1, o2); + } + + private: + template + Vectorized inline binary_compare(const VectorizedType& b, Op op) const { + __m256 a_lo, a_hi; + __m256 b_lo, b_hi; + cvt_to_fp32(values, a_lo, a_hi); + cvt_to_fp32(b.values, b_lo, b_hi); + auto o1 = op(a_lo, b_lo); + auto o2 = op(a_hi, b_hi); + return cvt_from_fp32(o1, o2); + } + + public: + Vectorized inline operator>(const Vectorized& other) const { + return binary_compare(other, [](__m256 x, __m256 y) { + return _mm256_cmp_ps(x, y, _CMP_GT_OQ); + }); + } + Vectorized inline operator<(const Vectorized& other) const { + return binary_compare(other, [](__m256 x, __m256 y) { + return _mm256_cmp_ps(x, y, _CMP_LT_OQ); + }); + } + Vectorized inline operator>=(const Vectorized& other) const { + return binary_compare(other, [](__m256 x, __m256 y) { + return _mm256_cmp_ps(x, y, _CMP_GE_OQ); + }); + } + Vectorized inline operator<=(const Vectorized& other) const { + return binary_compare(other, [](__m256 x, __m256 y) { + return _mm256_cmp_ps(x, y, _CMP_LE_OQ); + }); + } + Vectorized inline operator==(const Vectorized16& other) const { + return binary_compare(other, [](__m256 x, __m256 y) { + return _mm256_cmp_ps(x, y, _CMP_EQ_OQ); + }); + } + Vectorized inline operator!=(const Vectorized16& other) const { + return binary_compare(other, [](__m256 x, __m256 y) { + return _mm256_cmp_ps(x, y, _CMP_NEQ_UQ); + }); + } +}; + +template +static inline Vectorized binary_op_as_fp32( + const Vectorized& a, + const Vectorized& b, + Op op) { + __m256 a_lo, a_hi; + __m256 b_lo, b_hi; + cvt_to_fp32(__m256i(a), a_lo, a_hi); + cvt_to_fp32(__m256i(b), b_lo, b_hi); + auto o1 = op(a_lo, b_lo); + auto o2 = op(a_hi, b_hi); + return cvt_from_fp32(o1, o2); +} + +#define CONVERT_VECTORIZED_INIT(type, name) \ + inline std::tuple, Vectorized> \ + convert_##name##_float(const Vectorized& a) { \ + __m256 o1, o2; \ + cvt_to_fp32(__m256i(a), o1, o2); \ + return std::make_tuple(o1, o2); \ + } \ + inline Vectorized convert_float_##name( \ + const Vectorized& a, const Vectorized& b) { \ + return cvt_from_fp32(__m256(a), __m256(b)); \ + } + +#define LOAD_FP32_VECTORIZED_INIT(type, name) \ + inline void load_fp32_from_##name( \ + const type* data, Vectorized& out) { \ + auto values = _mm_loadu_si128(reinterpret_cast(data)); \ + __m256 out_values; \ + cvt_to_fp32(values, out_values); \ + out = out_values; \ + } \ + \ + inline void load_fp32_from_##name( \ + const type* data, Vectorized& out1, Vectorized& out2) { \ + auto vec = Vectorized::loadu(data); \ + __m256 out1_values, out2_values; \ + cvt_to_fp32(vec, out1_values, out2_values); \ + out1 = out1_values; \ + out2 = out2_values; \ + } + +#else // CPU_CAPABILITY_AVX2 + +#define CONVERT_NON_VECTORIZED_INIT(type, name) \ + inline std::tuple, Vectorized> \ + convert_##name##_float(const Vectorized& a) { \ + constexpr int64_t K = Vectorized::size(); \ + __at_align__ float arr[K]; \ + __at_align__ type arr2[K]; \ + a.store(arr2); \ + convert(arr2, arr, K); \ + return std::make_tuple( \ + Vectorized::loadu(arr), \ + Vectorized::loadu(arr + Vectorized::size())); \ + } \ + inline Vectorized convert_float_##name( \ + const Vectorized& a, const Vectorized& b) { \ + constexpr int64_t K = Vectorized::size(); \ + __at_align__ float arr[K]; \ + __at_align__ type arr2[K]; \ + a.store(arr); \ + b.store(arr + Vectorized::size()); \ + convert(arr, arr2, K); \ + return Vectorized::loadu(arr2); \ + } + +#define LOAD_FP32_NON_VECTORIZED_INIT(type, name) \ + inline void load_fp32_from_##name( \ + const type* data, Vectorized& out) { \ + __at_align__ float values[Vectorized::size()]; \ + for (const auto k : c10::irange(Vectorized::size())) { \ + values[k] = data[k]; \ + } \ + out = Vectorized::loadu(values); \ + } \ + \ + inline void load_fp32_from_##name( \ + const type* data, Vectorized& out1, Vectorized& out2) { \ + load_fp32_from_##name(data, out1); \ + data += Vectorized::size(); \ + load_fp32_from_##name(data, out2); \ + } + +#endif // CPU_CAPABILITY_AVX2 +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_bfloat16.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_bfloat16.h new file mode 100644 index 0000000000000000000000000000000000000000..1a66bd197eb025fe402221f400ff18a08e24561c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_bfloat16.h @@ -0,0 +1,280 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX2) + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorized16 { + public: + using Vectorized16::Vectorized16; + + using value_type = BFloat16; + + Vectorized frac() const; + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { + return _mm256_add_ps(x, y); + }); +} +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { + return _mm256_sub_ps(x, y); + }); +} +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { + return _mm256_mul_ps(x, y); + }); +} +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { + return _mm256_div_ps(x, y); + }); +} +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return _mm256_and_si256(a, b); +} +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return _mm256_or_si256(a, b); +} +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return _mm256_xor_si256(a, b); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0f); +} +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +// frac. Implement this here so we can use subtraction +inline Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + __m256 a_lo, a_hi; + __m256 b_lo, b_hi; + cvtbf16_fp32(__m256i(a), a_lo, a_hi); + cvtbf16_fp32(__m256i(b), b_lo, b_hi); + auto max_lo = _mm256_max_ps(a_lo, b_lo); + auto max_hi = _mm256_max_ps(a_hi, b_hi); + auto nan_lo = _mm256_cmp_ps(a_lo, b_lo, _CMP_UNORD_Q); + auto nan_hi = _mm256_cmp_ps(a_hi, b_hi, _CMP_UNORD_Q); + // Exploit the fact that all-ones is a NaN. + auto o1 = _mm256_or_ps(max_lo, nan_lo); + auto o2 = _mm256_or_ps(max_hi, nan_hi); + return cvtfp32_bf16(o1, o2); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + __m256 a_lo, a_hi; + __m256 b_lo, b_hi; + cvtbf16_fp32(__m256i(a), a_lo, a_hi); + cvtbf16_fp32(__m256i(b), b_lo, b_hi); + auto min_lo = _mm256_min_ps(a_lo, b_lo); + auto min_hi = _mm256_min_ps(a_hi, b_hi); + auto nan_lo = _mm256_cmp_ps(a_lo, b_lo, _CMP_UNORD_Q); + auto nan_hi = _mm256_cmp_ps(a_hi, b_hi, _CMP_UNORD_Q); + // Exploit the fact that all-ones is a NaN. + auto o1 = _mm256_or_ps(min_lo, nan_lo); + auto o2 = _mm256_or_ps(min_hi, nan_hi); + return cvtfp32_bf16(o1, o2); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + __m256 a_lo, a_hi; + __m256 min_lo, min_hi; + __m256 max_lo, max_hi; + cvtbf16_fp32(__m256i(a), a_lo, a_hi); + cvtbf16_fp32(__m256i(min), min_lo, min_hi); + cvtbf16_fp32(__m256i(max), max_lo, max_hi); + auto o1 = _mm256_min_ps(max_lo, _mm256_max_ps(min_lo, a_lo)); + auto o2 = _mm256_min_ps(max_hi, _mm256_max_ps(min_hi, a_hi)); + return cvtfp32_bf16(o1, o2); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + __m256 a_lo, a_hi; + __m256 max_lo, max_hi; + cvtbf16_fp32(__m256i(a), a_lo, a_hi); + cvtbf16_fp32(__m256i(max), max_lo, max_hi); + auto o1 = _mm256_min_ps(max_lo, a_lo); + auto o2 = _mm256_min_ps(max_hi, a_hi); + return cvtfp32_bf16(o1, o2); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + __m256 a_lo, a_hi; + __m256 min_lo, min_hi; + cvtbf16_fp32(__m256i(a), a_lo, a_hi); + cvtbf16_fp32(__m256i(min), min_lo, min_hi); + auto o1 = _mm256_max_ps(min_lo, a_lo); + auto o2 = _mm256_max_ps(min_hi, a_hi); + return cvtfp32_bf16(o1, o2); +} + +template <> +inline void convert(const BFloat16* src, BFloat16* dst, int64_t n) { + int64_t i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + auto vsrc = + _mm256_loadu_si256(reinterpret_cast<__m256i*>((void*)(src + i))); + _mm256_storeu_si256(reinterpret_cast<__m256i*>((void*)(dst + i)), vsrc); + } +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = src[i]; + } +} + +template <> +inline void convert(const float* src, BFloat16* dst, int64_t n) { + int64_t i; + for (i = 0; i + Vectorized::size() <= n; + i += Vectorized::size()) { + __m256 a = _mm256_loadu_ps(&src[i]); + __m256 b = _mm256_loadu_ps(&src[i + 8]); + + __m256i bf = cvtfp32_bf16(a, b); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(&dst[i]), bf); + } + for (; i < n; i++) { + dst[i] = c10::convert(src[i]); + } +} + +template <> +inline void convert(const double* src, BFloat16* dst, int64_t n) { + auto load_float = [](const double* src) -> __m256 { + // Load one float vector from an array of doubles + __m128 a = _mm256_cvtpd_ps(_mm256_loadu_pd(src)); + __m128 b = _mm256_cvtpd_ps(_mm256_loadu_pd(src + 4)); + return _mm256_insertf128_ps(_mm256_castps128_ps256(a), b, 1); + }; + + int64_t i; + for (i = 0; i + Vectorized::size() <= n; + i += Vectorized::size()) { + __m256 a = load_float(&src[i]); + __m256 b = load_float(&src[i + 8]); + + __m256i bf = cvtfp32_bf16(a, b); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(&dst[i]), bf); + } + for (; i < n; i++) { + dst[i] = c10::convert(src[i]); + } +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + __m256 a_lo, a_hi; + __m256 b_lo, b_hi; + __m256 c_lo, c_hi; + cvtbf16_fp32(__m256i(a), a_lo, a_hi); + cvtbf16_fp32(__m256i(b), b_lo, b_hi); + cvtbf16_fp32(__m256i(c), c_lo, c_hi); + auto o1 = _mm256_fmadd_ps(a_lo, b_lo, c_lo); + auto o2 = _mm256_fmadd_ps(a_hi, b_hi, c_hi); + return cvtfp32_bf16(o1, o2); +} + +CONVERT_VECTORIZED_INIT(BFloat16, bfloat16) +LOAD_FP32_VECTORIZED_INIT(BFloat16, bf16) + +#else // defined(CPU_CAPABILITY_AVX2) + +#if !( \ + defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \ + !defined(CPU_CAPABILITY_SVE256)) +CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16) +#endif + +LOAD_FP32_NON_VECTORIZED_INIT(BFloat16, bf16) +#endif // defined(CPU_CAPABILITY_AVX2) +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_complex_double.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_complex_double.h new file mode 100644 index 0000000000000000000000000000000000000000..11dcb6b53e337b9b18de58c5a6cf99d25359efe6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_complex_double.h @@ -0,0 +1,536 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#include + +#if defined(CPU_CAPABILITY_AVX2) +#define SLEEF_STATIC_LIBS +#include +#endif + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX2) + +template <> +struct is_vec_specialized_for> : std::bool_constant { +}; + +template <> +class Vectorized> { + private: + __m256d values; + + public: + using value_type = c10::complex; + using size_type = int; + static constexpr size_type size() { + return 2; + } + Vectorized() {} + Vectorized(__m256d v) : values(v) {} + Vectorized(c10::complex val) { + double real_value = val.real(); + double imag_value = val.imag(); + values = _mm256_setr_pd(real_value, imag_value, real_value, imag_value); + } + Vectorized(c10::complex val1, c10::complex val2) { + values = _mm256_setr_pd(val1.real(), val1.imag(), val2.real(), val2.imag()); + } + operator __m256d() const { + return values; + } + template + static Vectorized> blend( + const Vectorized>& a, + const Vectorized>& b) { + // convert c10::complex index mask to V index mask: xy -> xxyy + static_assert(mask > -1 && mask < 4, "Unexpected mask value"); + switch (mask) { + case 0: + return a; + case 1: + return _mm256_blend_pd(a.values, b.values, 0x03); + case 2: + return _mm256_blend_pd(a.values, b.values, 0x0c); + case 3: + break; + } + return b; + } + static Vectorized> blendv( + const Vectorized>& a, + const Vectorized>& b, + const Vectorized>& mask) { + // convert c10::complex index mask to V index mask: xy -> xxyy + auto mask_ = _mm256_unpacklo_pd(mask.values, mask.values); + return _mm256_blendv_pd(a.values, b.values, mask_); + } + template + static Vectorized> arange( + c10::complex base = 0., + step_t step = static_cast(1)) { + return Vectorized>(base, base + step); + } + static Vectorized> set( + const Vectorized>& a, + const Vectorized>& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + } + return b; + } + static Vectorized> loadu( + const void* ptr, + int64_t count = size()) { + if (count == size()) + return _mm256_loadu_pd(reinterpret_cast(ptr)); + + __at_align__ double tmp_values[2 * size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(2 * size())) { + tmp_values[i] = 0.0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(c10::complex)); + return _mm256_load_pd(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm256_storeu_pd(reinterpret_cast(ptr), values); + } else if (count > 0) { + double tmp_values[2 * size()]; + _mm256_storeu_pd(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(c10::complex)); + } + } + const c10::complex& operator[](int idx) const = delete; + c10::complex& operator[](int idx) = delete; + Vectorized> map( + c10::complex (*const f)(const c10::complex&)) const { + __at_align__ c10::complex tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + __m256d abs_2_() const { + auto val_2 = _mm256_mul_pd(values, values); // a*a b*b + return _mm256_hadd_pd(val_2, val_2); // a*a+b*b a*a+b*b + } + __m256d abs_() const { + auto real = _mm256_movedup_pd(values); // real real + // movehdup_pd does not exist... + auto imag = _mm256_permute_pd(values, 0xf); // imag imag + return Sleef_hypotd4_u05(real, imag); // abs abs + } + Vectorized> abs() const { + const __m256d real_mask = _mm256_castsi256_pd(_mm256_setr_epi64x( + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000)); + return _mm256_and_pd(abs_(), real_mask); // abs 0 + } + __m256d angle_() const { + // angle = atan2(b/a) + auto b_a = _mm256_permute_pd(values, 0x05); // b a + return Sleef_atan2d4_u10(values, b_a); // 90-angle angle + } + Vectorized> angle() const { + const __m256d real_mask = _mm256_castsi256_pd(_mm256_setr_epi64x( + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000)); + auto angle = _mm256_permute_pd(angle_(), 0x05); // angle 90-angle + return _mm256_and_pd(angle, real_mask); // angle 0 + } + Vectorized> sgn() const { + auto abs = abs_(); + auto zero = _mm256_setzero_pd(); + auto mask = _mm256_cmp_pd(abs, zero, _CMP_EQ_OQ); + auto div = _mm256_div_pd(values, abs); + return _mm256_blendv_pd(div, zero, mask); + } + __m256d real_() const { + const __m256d real_mask = _mm256_castsi256_pd(_mm256_setr_epi64x( + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000)); + return _mm256_and_pd(values, real_mask); + } + Vectorized> real() const { + return real_(); + } + __m256d imag_() const { + const __m256d imag_mask = _mm256_castsi256_pd(_mm256_setr_epi64x( + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF)); + return _mm256_and_pd(values, imag_mask); + } + Vectorized> imag() const { + return _mm256_permute_pd(imag_(), 0x05); // b a + } + __m256d conj_() const { + const __m256d sign_mask = _mm256_setr_pd(0.0, -0.0, 0.0, -0.0); + return _mm256_xor_pd(values, sign_mask); // a -b + } + Vectorized> conj() const { + return conj_(); + } + Vectorized> log() const { + // Most trigonomic ops use the log() op to improve complex number + // performance. + return map(std::log); + } + Vectorized> log2() const { + const __m256d log2_ = _mm256_set1_pd(std::log(2)); + return _mm256_div_pd(log(), log2_); + } + Vectorized> log10() const { + const __m256d log10_ = _mm256_set1_pd(std::log(10)); + return _mm256_div_pd(log(), log10_); + } + Vectorized> log1p() const { + return map(std::log1p); + } + Vectorized> asin() const { + // TODO: The vectorized implementation requires special handling for the + // case where real number/imag number is 0/Inf/NaN. + // // asin(x) + // // = -i*ln(iz + sqrt(1 -z^2)) + // // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi))) + // // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi)) + // const __m256d one = _mm256_set1_pd(1); + + // auto conj = conj_(); + // auto b_a = _mm256_permute_pd(conj, 0x05); //-b a + // auto ab = _mm256_mul_pd(conj, b_a); //-ab + // -ab auto im = _mm256_add_pd(ab, ab); //-2ab -2ab + + // auto val_2 = _mm256_mul_pd(values, values); // a*a + // b*b auto re = _mm256_hsub_pd(val_2, _mm256_permute_pd(val_2, 0x05)); // + // a*a-b*b b*b-a*a re = _mm256_sub_pd(one, re); + + // auto root = Vectorized(_mm256_blend_pd(re, im, 0x0A)).sqrt(); //sqrt(re + + // i*im) auto ln = Vectorized(_mm256_add_pd(b_a, root)).log(); //ln(iz + + // sqrt()) return Vectorized(_mm256_permute_pd(ln.values, 0x05)).conj(); + // //-i*ln() + return map(std::asin); + } + Vectorized> acos() const { + // acos(x) = pi/2 - asin(x) + constexpr auto pi_2d = c10::pi / 2; + const __m256d pi_2 = _mm256_setr_pd(pi_2d, 0.0, pi_2d, 0.0); + return _mm256_sub_pd(pi_2, asin()); + } + Vectorized> atan() const; + Vectorized> atanh() const { + return map(std::atanh); + } + Vectorized> exp() const { + // TODO: The vectorized implementation requires special handling for the + // case where real number/imag number is 0/Inf/NaN. + // //exp(a + bi) + // // = exp(a)*(cos(b) + sin(b)i) + // auto exp = Sleef_expd4_u10(values); //exp(a) exp(b) exp = + // _mm256_blend_pd(exp, _mm256_permute_pd(exp, 0x05), 0x0A); //exp(a) + // exp(a) + + // auto sin_cos = Sleef_sincosd4_u10(values); //[sin(a), cos(a)] [sin(b), + // cos(b)] auto cos_sin = _mm256_blend_pd(_mm256_permute_pd(sin_cos.y, + // 0x05), + // sin_cos.x, 0x0A); //cos(b) sin(b) + // return _mm256_mul_pd(exp, cos_sin); + return map(std::exp); + } + Vectorized> exp2() const { + // Use identity 2**x = exp(log(2) * x) + const __m256d ln_2 = _mm256_set1_pd(c10::ln_2); + Vectorized> scaled_values = + _mm256_mul_pd(values, ln_2); + return scaled_values.exp(); + } + Vectorized> expm1() const { + return map(std::expm1); + } + Vectorized> sin() const { + return map(std::sin); + } + Vectorized> sinh() const { + return map(std::sinh); + } + Vectorized> cos() const { + return map(std::cos); + } + Vectorized> cosh() const { + return map(std::cosh); + } + Vectorized> ceil() const { + return _mm256_ceil_pd(values); + } + Vectorized> floor() const { + return _mm256_floor_pd(values); + } + Vectorized> neg() const { + auto zero = _mm256_setzero_pd(); + return _mm256_sub_pd(zero, values); + } + Vectorized> round() const { + return _mm256_round_pd( + values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + Vectorized> tan() const { + return map(std::tan); + } + Vectorized> tanh() const { + return map(std::tanh); + } + Vectorized> trunc() const { + return _mm256_round_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + } + Vectorized> sqrt() const { + return map(std::sqrt); + } + Vectorized> reciprocal() const; + Vectorized> rsqrt() const { + return sqrt().reciprocal(); + } + Vectorized> pow( + const Vectorized>& exp) const { + __at_align__ c10::complex x_tmp[size()]; + __at_align__ c10::complex y_tmp[size()]; + store(x_tmp); + exp.store(y_tmp); + for (const auto i : c10::irange(size())) { + x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]); + } + return loadu(x_tmp); + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized> operator==( + const Vectorized>& other) const { + return _mm256_cmp_pd(values, other.values, _CMP_EQ_OQ); + } + Vectorized> operator!=( + const Vectorized>& other) const { + return _mm256_cmp_pd(values, other.values, _CMP_NEQ_UQ); + } + Vectorized> operator<( + const Vectorized>&) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator<=( + const Vectorized>&) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator>( + const Vectorized>&) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator>=( + const Vectorized>&) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized> eq( + const Vectorized>& other) const; + Vectorized> ne( + const Vectorized>& other) const; +}; + +template <> +Vectorized> inline operator+( + const Vectorized>& a, + const Vectorized>& b) { + return _mm256_add_pd(a, b); +} + +template <> +Vectorized> inline operator-( + const Vectorized>& a, + const Vectorized>& b) { + return _mm256_sub_pd(a, b); +} + +template <> +Vectorized> inline operator*( + const Vectorized>& a, + const Vectorized>& b) { + //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i + const __m256d sign_mask = _mm256_setr_pd(0.0, -0.0, 0.0, -0.0); + auto ac_bd = _mm256_mul_pd(a, b); // ac bd + + auto d_c = _mm256_permute_pd(b, 0x05); // d c + d_c = _mm256_xor_pd(sign_mask, d_c); // d -c + auto ad_bc = _mm256_mul_pd(a, d_c); // ad -bc + + auto ret = _mm256_hsub_pd(ac_bd, ad_bc); // ac - bd ad + bc + return ret; +} + +template <> +Vectorized> inline operator/( + const Vectorized>& a, + const Vectorized>& b) { + // TODO: The vectorized implementation requires special handling for the case + // where real number/imag number is 0/Inf/NaN. + // //re + im*i = (a + bi) / (c + di) + // auto mask = _mm256_set1_pd(-0.f); + // auto fabs_cd = _mm256_andnot_pd(mask, b); // |c| |d| + // auto fabs_dc = _mm256_permute_pd(fabs_cd, 0x05); // |d| |c| + // auto scale = _mm256_div_pd(_mm256_set1_pd(1.0f), _mm256_max_pd(fabs_cd, + // fabs_dc)); // 1/sc 1/sc auto a2 = _mm256_mul_pd(a, scale); // + // a/sc b/sc auto b2 = _mm256_mul_pd(b, scale); // c/sc d/sc + // auto acbd2 = _mm256_mul_pd(a2, b2); + + // const __m256d sign_mask = _mm256_setr_pd(-0.0, 0.0, -0.0, 0.0); + // auto dc2 = _mm256_permute_pd(b2, 0x05); // d/sc c/sc + // dc2 = _mm256_xor_pd(sign_mask, dc2); // -d/|c,d| c/sc + // auto adbc2 = _mm256_mul_pd(a2, dc2); //-ad/sc^2 bc/sc^2 + // auto res2 = _mm256_hadd_pd(acbd2, adbc2); //(ac+bd)/sc^2 (bc-ad)/sc^2 + + // // get the denominator + // auto denom2 = Vectorized>(b2).abs_2_(); // + // (c^2+d^2)/sc^2 (c^2+d^2)/sc^2 res2 = _mm256_div_pd(res2, denom2); return + // res2; + __at_align__ c10::complex + tmp1[Vectorized>::size()]; + __at_align__ c10::complex + tmp2[Vectorized>::size()]; + __at_align__ c10::complex + out[Vectorized>::size()]; + a.store(tmp1); + b.store(tmp2); + for (const auto i : c10::irange(Vectorized>::size())) { + out[i] = tmp1[i] / tmp2[i]; + } + return _mm256_loadu_pd(reinterpret_cast(out)); +} + +// reciprocal. Implement this here so we can use multiplication. +inline Vectorized> Vectorized< + c10::complex>::reciprocal() const { + // TODO: The vectorized implementation requires special handling for the case + // where real number/imag number is 0/Inf/NaN. + // //re + im*i = (a + bi) / (c + di) + // //re = (ac + bd)/abs_2() = c/abs_2() + // //im = (bc - ad)/abs_2() = d/abs_2() + // const __m256d sign_mask = _mm256_setr_pd(0.0, -0.0, 0.0, -0.0); + // auto c_d = _mm256_xor_pd(sign_mask, values); //c -d + // return _mm256_div_pd(c_d, abs_2_()); + __at_align__ c10::complex tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = c10::complex(1) / tmp[i]; + } + return loadu(tmp); +} + +inline Vectorized> Vectorized>::atan() + const { + // TODO: The vectorized implementation requires special handling for the case + // where real number/imag number is 0/Inf/NaN. + // // atan(x) = i/2 * ln((i + z)/(i - z)) + // const __m256d i = _mm256_setr_pd(0.0, 1.0, 0.0, 1.0); + // const Vectorized i_half = _mm256_setr_pd(0.0, 0.5, 0.0, 0.5); + + // auto sum = Vectorized(_mm256_add_pd(i, values)); // a + // 1+b auto sub = Vectorized(_mm256_sub_pd(i, values)); // -a 1-b auto + // ln = (sum/sub).log(); // ln((i + + // z)/(i - z)) return i_half*ln; // i/2*ln() + return map(std::atan); +} + +template <> +Vectorized> inline maximum( + const Vectorized>& a, + const Vectorized>& b) { + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + auto mask = _mm256_cmp_pd(abs_a, abs_b, _CMP_LT_OQ); + auto max = _mm256_blendv_pd(a, b, mask); + // Exploit the fact that all-ones is a NaN. + auto isnan = _mm256_cmp_pd(abs_a, abs_b, _CMP_UNORD_Q); + return _mm256_or_pd(max, isnan); +} + +template <> +Vectorized> inline minimum( + const Vectorized>& a, + const Vectorized>& b) { + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + auto mask = _mm256_cmp_pd(abs_a, abs_b, _CMP_GT_OQ); + auto min = _mm256_blendv_pd(a, b, mask); + // Exploit the fact that all-ones is a NaN. + auto isnan = _mm256_cmp_pd(abs_a, abs_b, _CMP_UNORD_Q); + return _mm256_or_pd(min, isnan); +} + +template <> +Vectorized> inline operator&( + const Vectorized>& a, + const Vectorized>& b) { + return _mm256_and_pd(a, b); +} + +template <> +Vectorized> inline operator|( + const Vectorized>& a, + const Vectorized>& b) { + return _mm256_or_pd(a, b); +} + +template <> +Vectorized> inline operator^( + const Vectorized>& a, + const Vectorized>& b) { + return _mm256_xor_pd(a, b); +} + +inline Vectorized> Vectorized>::eq( + const Vectorized>& other) const { + auto eq = (*this == other); // compares real and imag individually + // If both real numbers and imag numbers are equal, then the complex numbers + // are equal + return (eq.real() & eq.imag()) & + Vectorized>(_mm256_set1_pd(1.0)); +} + +inline Vectorized> Vectorized>::ne( + const Vectorized>& other) const { + auto ne = (*this != other); // compares real and imag individually + // If either real numbers or imag numbers are not equal, then the complex + // numbers are not equal + return (ne.real() | ne.imag()) & + Vectorized>(_mm256_set1_pd(1.0)); +} + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_complex_float.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_complex_float.h new file mode 100644 index 0000000000000000000000000000000000000000..e87a012e33d5f3e9d7193706927514d1c57b62b8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_complex_float.h @@ -0,0 +1,618 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#include +#if defined(CPU_CAPABILITY_AVX2) +#define SLEEF_STATIC_LIBS +#include +#endif + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX2) + +template <> +struct is_vec_specialized_for> : std::bool_constant { +}; + +template <> +class Vectorized> { + private: + __m256 values; + + public: + using value_type = c10::complex; + using size_type = int; + static constexpr size_type size() { + return 4; + } + Vectorized() {} + Vectorized(__m256 v) : values(v) {} + Vectorized(c10::complex val) { + float real_value = val.real(); + float imag_value = val.imag(); + values = _mm256_setr_ps( + real_value, + imag_value, + real_value, + imag_value, + real_value, + imag_value, + real_value, + imag_value); + } + Vectorized( + c10::complex val1, + c10::complex val2, + c10::complex val3, + c10::complex val4) { + values = _mm256_setr_ps( + val1.real(), + val1.imag(), + val2.real(), + val2.imag(), + val3.real(), + val3.imag(), + val4.real(), + val4.imag()); + } + operator __m256() const { + return values; + } + template + static Vectorized> blend( + const Vectorized>& a, + const Vectorized>& b) { + // convert c10::complex index mask to V index mask: xy -> xxyy + static_assert(mask > -1 && mask < 16, "Unexpected mask range"); + switch (mask) { + case 0: + return a; + case 1: + return _mm256_blend_ps( + a.values, b.values, 0x03); // b0000 0001 = b0000 0011 + case 2: + return _mm256_blend_ps( + a.values, b.values, 0x0C); // b0000 0010 = b0000 1100 + case 3: + return _mm256_blend_ps( + a.values, b.values, 0x0F); // b0000 0011 = b0000 1111 + case 4: + return _mm256_blend_ps( + a.values, b.values, 0x30); // b0000 0100 = b0011 0000 + case 5: + return _mm256_blend_ps( + a.values, b.values, 0x33); // b0000 0101 = b0011 0011 + case 6: + return _mm256_blend_ps( + a.values, b.values, 0x3C); // b0000 0110 = b0011 1100 + case 7: + return _mm256_blend_ps( + a.values, b.values, 0x3F); // b0000 0111 = b0011 1111 + case 8: + return _mm256_blend_ps( + a.values, b.values, 0xC0); // b0000 1000 = b1100 0000 + case 9: + return _mm256_blend_ps( + a.values, b.values, 0xC3); // b0000 1001 = b1100 0011 + case 10: + return _mm256_blend_ps( + a.values, b.values, 0xCC); // b0000 1010 = b1100 1100 + case 11: + return _mm256_blend_ps( + a.values, b.values, 0xCF); // b0000 1011 = b1100 1111 + case 12: + return _mm256_blend_ps( + a.values, b.values, 0xF0); // b0000 1100 = b1111 0000 + case 13: + return _mm256_blend_ps( + a.values, b.values, 0xF3); // b0000 1101 = b1111 0011 + case 14: + return _mm256_blend_ps( + a.values, b.values, 0xFC); // b0000 1110 = b1111 1100 + default: + break; + } + return b; + } + static Vectorized> blendv( + const Vectorized>& a, + const Vectorized>& b, + const Vectorized>& mask) { + // convert c10::complex index mask to V index mask: xy -> xxyy + auto mask_ = _mm256_unpacklo_ps(mask.values, mask.values); + return _mm256_blendv_ps(a.values, b.values, mask_); + } + template + static Vectorized> arange( + c10::complex base = 0., + step_t step = static_cast(1)) { + return Vectorized>( + base, + base + step, + base + c10::complex(2) * step, + base + c10::complex(3) * step); + } + static Vectorized> set( + const Vectorized>& a, + const Vectorized>& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + } + return b; + } + static Vectorized> loadu( + const void* ptr, + int64_t count = size()) { + if (count == size()) + return _mm256_loadu_ps(reinterpret_cast(ptr)); + + __at_align__ float tmp_values[2 * size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(2 * size())) { + tmp_values[i] = 0.0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(c10::complex)); + return _mm256_load_ps(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm256_storeu_ps(reinterpret_cast(ptr), values); + } else if (count > 0) { + float tmp_values[2 * size()]; + _mm256_storeu_ps(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(c10::complex)); + } + } + const c10::complex& operator[](int idx) const = delete; + c10::complex& operator[](int idx) = delete; + Vectorized> map( + c10::complex (*const f)(const c10::complex&)) const { + __at_align__ c10::complex tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + __m256 abs_2_() const { + auto val_2 = _mm256_mul_ps(values, values); // a*a b*b + auto ret = _mm256_hadd_ps(val_2, val_2); // a*a+b*b a*a+b*b + return _mm256_permute_ps(ret, 0xD8); + } + __m256 abs_() const { + auto real = _mm256_moveldup_ps(values); // real real + auto imag = _mm256_movehdup_ps(values); // imag imag + return Sleef_hypotf8_u05(real, imag); // abs abs + } + Vectorized> abs() const { + const __m256 real_mask = _mm256_castsi256_ps(_mm256_setr_epi32( + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000)); + return _mm256_and_ps(abs_(), real_mask); // abs 0 + } + __m256 angle_() const { + // angle = atan2(b/a) + auto b_a = _mm256_permute_ps(values, 0xB1); // b a + return Sleef_atan2f8_u10(values, b_a); // 90-angle angle + } + Vectorized> angle() const { + const __m256 real_mask = _mm256_castsi256_ps(_mm256_setr_epi32( + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000)); + auto angle = _mm256_permute_ps(angle_(), 0xB1); // angle 90-angle + return _mm256_and_ps(angle, real_mask); // angle 0 + } + Vectorized> sgn() const { + auto abs = abs_(); + auto zero = _mm256_setzero_ps(); + auto mask = _mm256_cmp_ps(abs, zero, _CMP_EQ_OQ); + auto div = _mm256_div_ps(values, abs); + return _mm256_blendv_ps(div, zero, mask); + } + __m256 real_() const { + const __m256 real_mask = _mm256_castsi256_ps(_mm256_setr_epi32( + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000)); + return _mm256_and_ps(values, real_mask); + } + Vectorized> real() const { + return real_(); + } + __m256 imag_() const { + const __m256 imag_mask = _mm256_castsi256_ps(_mm256_setr_epi32( + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF)); + return _mm256_and_ps(values, imag_mask); + } + Vectorized> imag() const { + return _mm256_permute_ps(imag_(), 0xB1); // b a + } + __m256 conj_() const { + const __m256 sign_mask = + _mm256_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); + return _mm256_xor_ps(values, sign_mask); // a -b + } + Vectorized> conj() const { + return conj_(); + } + Vectorized> log() const { + // Most trigonomic ops use the log() op to improve complex number + // performance. + return map(std::log); + } + Vectorized> log2() const { + const __m256 log2_ = _mm256_set1_ps(std::log(2)); + return _mm256_div_ps(log(), log2_); + } + Vectorized> log10() const { + const __m256 log10_ = _mm256_set1_ps(std::log(10)); + return _mm256_div_ps(log(), log10_); + } + Vectorized> log1p() const { + return map(std::log1p); + } + Vectorized> asin() const { + // TODO: The vectorized implementation requires special handling for the + // case where real number/imag number is 0/Inf/NaN. + // // asin(x) + // // = -i*ln(iz + sqrt(1 -z^2)) + // // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi))) + // // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi)) + // const __m256 one = _mm256_set1_ps(1); + + // auto conj = conj_(); + // auto b_a = _mm256_permute_ps(conj, 0xB1); //-b a + // auto ab = _mm256_mul_ps(conj, b_a); //-ab + // -ab auto im = _mm256_add_ps(ab, ab); //-2ab -2ab + + // auto val_2 = _mm256_mul_ps(values, values); // a*a + // b*b auto re = _mm256_hsub_ps(val_2, _mm256_permute_ps(val_2, 0xB1)); // + // a*a-b*b b*b-a*a re = _mm256_permute_ps(re, 0xD8); re = + // _mm256_sub_ps(one, re); + + // auto root = Vectorized(_mm256_blend_ps(re, im, 0xAA)).sqrt(); //sqrt(re + + // i*im) auto ln = Vectorized(_mm256_add_ps(b_a, root)).log(); //ln(iz + + // sqrt()) return Vectorized(_mm256_permute_ps(ln.values, 0xB1)).conj(); + // //-i*ln() + return map(std::asin); + } + Vectorized> acos() const { + return map(std::acos); + } + Vectorized> atan() const; + Vectorized> atanh() const { + return map(std::atanh); + } + Vectorized> exp() const { + // TODO: The vectorized implementation requires special handling for the + // case where real number/imag number is 0/Inf/NaN. + // //exp(a + bi) + // // = exp(a)*(cos(b) + sin(b)i) + // auto exp = Sleef_expf8_u10(values); //exp(a) exp(b) exp = + // _mm256_blend_ps(exp, _mm256_permute_ps(exp, 0xB1), 0xAA); //exp(a) + // exp(a) + + // auto sin_cos = Sleef_sincosf8_u10(values); //[sin(a), cos(a)] [sin(b), + // cos(b)] auto cos_sin = _mm256_blend_ps(_mm256_permute_ps(sin_cos.y, + // 0xB1), + // sin_cos.x, 0xAA); //cos(b) sin(b) + // return _mm256_mul_ps(exp, cos_sin); + return map(std::exp); + } + Vectorized> exp2() const { + // Use identity 2**x = exp(log(2) * x) + const __m256 ln_2 = _mm256_set1_ps(c10::ln_2); + Vectorized> scaled_values = _mm256_mul_ps(values, ln_2); + return scaled_values.exp(); + } + Vectorized> expm1() const { + return map(std::expm1); + } + Vectorized> sin() const { + return map(std::sin); + } + Vectorized> sinh() const { + return map(std::sinh); + } + Vectorized> cos() const { + return map(std::cos); + } + Vectorized> cosh() const { + return map(std::cosh); + } + Vectorized> ceil() const { + return _mm256_ceil_ps(values); + } + Vectorized> floor() const { + return _mm256_floor_ps(values); + } + Vectorized> neg() const { + auto zero = _mm256_setzero_ps(); + return _mm256_sub_ps(zero, values); + } + Vectorized> round() const { + return _mm256_round_ps( + values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + Vectorized> tan() const { + return map(std::tan); + } + Vectorized> tanh() const { + return map(std::tanh); + } + Vectorized> trunc() const { + return _mm256_round_ps(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + } + Vectorized> sqrt() const { + return map(std::sqrt); + } + Vectorized> reciprocal() const; + Vectorized> rsqrt() const { + return sqrt().reciprocal(); + } + Vectorized> pow( + const Vectorized>& exp) const { + __at_align__ c10::complex x_tmp[size()]; + __at_align__ c10::complex y_tmp[size()]; + store(x_tmp); + exp.store(y_tmp); + for (const auto i : c10::irange(size())) { + x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]); + } + return loadu(x_tmp); + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized> operator==( + const Vectorized>& other) const { + return _mm256_cmp_ps(values, other.values, _CMP_EQ_OQ); + } + Vectorized> operator!=( + const Vectorized>& other) const { + return _mm256_cmp_ps(values, other.values, _CMP_NEQ_UQ); + } + Vectorized> operator<( + const Vectorized>& /*other*/) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator<=( + const Vectorized>& /*other*/) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator>( + const Vectorized>& /*other*/) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator>=( + const Vectorized>& /*other*/) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized> eq( + const Vectorized>& other) const; + Vectorized> ne( + const Vectorized>& other) const; +}; + +template <> +Vectorized> inline operator+( + const Vectorized>& a, + const Vectorized>& b) { + return _mm256_add_ps(a, b); +} + +template <> +Vectorized> inline operator-( + const Vectorized>& a, + const Vectorized>& b) { + return _mm256_sub_ps(a, b); +} + +template <> +Vectorized> inline operator*( + const Vectorized>& a, + const Vectorized>& b) { + //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i + const __m256 sign_mask = + _mm256_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); + auto ac_bd = _mm256_mul_ps(a, b); // ac bd + + auto d_c = _mm256_permute_ps(b, 0xB1); // d c + d_c = _mm256_xor_ps(sign_mask, d_c); // d -c + auto ad_bc = _mm256_mul_ps(a, d_c); // ad -bc + + auto ret = _mm256_hsub_ps(ac_bd, ad_bc); // ac - bd ad + bc + ret = _mm256_permute_ps(ret, 0xD8); + return ret; +} + +template <> +Vectorized> inline operator/( + const Vectorized>& a, + const Vectorized>& b) { + // TODO: The vectorized implementation requires special handling for the case + // where real number/imag number is 0/Inf/NaN. + // //re + im*i = (a + bi) / (c + di) + // auto mask = _mm256_set1_ps(-0.f); + // auto fabs_cd = _mm256_andnot_ps(mask, b); // |c| |d| + // auto fabs_dc = _mm256_permute_ps(fabs_cd, 0xB1); // |d| |c| + // auto scale = _mm256_rcp_ps(_mm256_max_ps(fabs_cd, fabs_dc)); // 1/sc 1/sc + // auto a2 = _mm256_mul_ps(a, scale); // a/sc b/sc + // auto b2 = _mm256_mul_ps(b, scale); // c/sc d/sc + // auto acbd2 = _mm256_mul_ps(a2, b2); + + // const __m256 sign_mask = _mm256_setr_ps(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, + // -0.0, 0.0); auto dc2 = _mm256_permute_ps(b2, 0xB1); // d/sc c/sc + // dc2 = _mm256_xor_ps(sign_mask, dc2); // -d/|c,d| c/sc + // auto adbc2 = _mm256_mul_ps(a2, dc2); //-ad/sc^2 bc/sc^2 + // auto res2 = _mm256_hadd_ps(acbd2, adbc2); //(ac+bd)/sc^2 (bc-ad)/sc^2 + // res2 = _mm256_permute_ps(res2, 0xD8); + + // // get the denominator + // auto denom2 = Vectorized>(b2).abs_2_(); // + // (c^2+d^2)/sc^2 (c^2+d^2)/sc^2 res2 = _mm256_div_ps(res2, denom2); return + // res2; + __at_align__ c10::complex + tmp1[Vectorized>::size()]; + __at_align__ c10::complex + tmp2[Vectorized>::size()]; + __at_align__ c10::complex out[Vectorized>::size()]; + a.store(tmp1); + b.store(tmp2); + for (const auto i : c10::irange(Vectorized>::size())) { + out[i] = tmp1[i] / tmp2[i]; + } + return _mm256_loadu_ps(reinterpret_cast(out)); +} + +// reciprocal. Implement this here so we can use multiplication. +inline Vectorized> Vectorized< + c10::complex>::reciprocal() const { + // TODO: The vectorized implementation requires special handling for the case + // where real number/imag number is 0/Inf/NaN. + // //re + im*i = (a + bi) / (c + di) + // //re = (ac + bd)/abs_2() = c/abs_2() + // //im = (bc - ad)/abs_2() = d/abs_2() + // const __m256 sign_mask = _mm256_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, + // 0.0, -0.0); auto c_d = _mm256_xor_ps(sign_mask, values); //c -d + // return _mm256_div_ps(c_d, abs_2_()); + __at_align__ c10::complex tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = c10::complex(1) / tmp[i]; + } + return loadu(tmp); +} + +inline Vectorized> Vectorized>::atan() + const { + // TODO: The vectorized implementation requires special handling for the case + // where real number/imag number is 0/Inf/NaN. + // // atan(x) = i/2 * ln((i + z)/(i - z)) + // const __m256 i = _mm256_setr_ps(0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0); + // const Vectorized i_half = _mm256_setr_ps(0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.0, + // 0.5); + + // auto sum = Vectorized(_mm256_add_ps(i, values)); // a + // 1+b auto sub = Vectorized(_mm256_sub_ps(i, values)); // -a 1-b auto + // ln = (sum/sub).log(); // ln((i + + // z)/(i - z)) return i_half*ln; // i/2*ln() + return map(std::atan); +} + +template <> +Vectorized> inline maximum( + const Vectorized>& a, + const Vectorized>& b) { + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_LT_OQ); + auto max = _mm256_blendv_ps(a, b, mask); + // Exploit the fact that all-ones is a NaN. + auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q); + return _mm256_or_ps(max, isnan); +} + +template <> +Vectorized> inline minimum( + const Vectorized>& a, + const Vectorized>& b) { + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_GT_OQ); + auto min = _mm256_blendv_ps(a, b, mask); + // Exploit the fact that all-ones is a NaN. + auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q); + return _mm256_or_ps(min, isnan); +} + +template <> +Vectorized> inline operator&( + const Vectorized>& a, + const Vectorized>& b) { + return _mm256_and_ps(a, b); +} + +template <> +Vectorized> inline operator|( + const Vectorized>& a, + const Vectorized>& b) { + return _mm256_or_ps(a, b); +} + +template <> +Vectorized> inline operator^( + const Vectorized>& a, + const Vectorized>& b) { + return _mm256_xor_ps(a, b); +} + +inline Vectorized> Vectorized>::eq( + const Vectorized>& other) const { + auto eq = (*this == other); // compares real and imag individually + // If both real numbers and imag numbers are equal, then the complex numbers + // are equal + return (eq.real() & eq.imag()) & + Vectorized>(_mm256_set1_ps(1.0f)); +} + +inline Vectorized> Vectorized>::ne( + const Vectorized>& other) const { + auto ne = (*this != other); // compares real and imag individually + // If either real numbers or imag numbers are not equal, then the complex + // numbers are not equal + return (ne.real() | ne.imag()) & + Vectorized>(_mm256_set1_ps(1.0f)); +} + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_convert.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_convert.h new file mode 100644 index 0000000000000000000000000000000000000000..c5909452fca15afe47740db4b701d43ac3aa8a96 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_convert.h @@ -0,0 +1,365 @@ +#pragma once + +#include +#include +#include +#include + +namespace at::vec { +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + __m256 value; + cvtbf16_fp32(_mm256_castsi256_si128(src[0]), value); + result[0] = value; + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN result; + __m256 value; + cvtfp16_fp32(_mm256_castsi256_si128(src[0]), value); + result[0] = value; + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + result[0] = _mm256_castsi128_si256(cvtfp32_bf16(src[0])); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + result[0] = convert_float_bfloat16(src[0], src[1]); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + std::tie(result[0], result[1]) = convert_bfloat16_float(src[0]); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN result; + result[0] = _mm256_castsi128_si256(cvtfp32_fp16(src[0])); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN result; + result[0] = convert_float_half(src[0], src[1]); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN result; + std::tie(result[0], result[1]) = convert_half_float(src[0]); + return result; + } +}; + +template <> +inline Vectorized convert_to_fp_of_same_size( + const Vectorized& src); + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + auto low_double = at::vec::convert_to_fp_of_same_size(src[0]); + auto low = _mm256_cvtpd_ps(low_double); + auto high_double = at::vec::convert_to_fp_of_same_size(src[1]); + auto high = _mm256_cvtpd_ps(high_double); + return Vectorized( + _mm256_insertf128_ps(_mm256_castps128_ps256(low), high, 1)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + // Scalarization is the most reliable way of converting fp to int64 on AVX2. + // Check: https://stackoverflow.com/questions/41144668 + float buffer[8]; + src.store(buffer); + at::vec::VectorizedN result; + result[0] = Vectorized( + static_cast(buffer[0]), + static_cast(buffer[1]), + static_cast(buffer[2]), + static_cast(buffer[3])); + result[1] = Vectorized( + static_cast(buffer[4]), + static_cast(buffer[5]), + static_cast(buffer[6]), + static_cast(buffer[7])); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + auto low = _mm256_shuffle_epi32(src[0], _MM_SHUFFLE(2, 0, 2, 0)); + auto high = _mm256_shuffle_epi32(src[1], _MM_SHUFFLE(2, 0, 2, 0)); + auto low_perm = _mm256_permute4x64_epi64(low, _MM_SHUFFLE(3, 1, 2, 0)); + auto high_perm = _mm256_permute4x64_epi64(high, _MM_SHUFFLE(3, 1, 2, 0)); + return Vectorized(_mm256_blend_epi32(low_perm, high_perm, 0xF0)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + at::vec::VectorizedN result; + result[0] = _mm256_cvtepi32_epi64(_mm256_castsi256_si128(src[0])); + result[1] = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(src[0], 1)); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + auto src128 = _mm256_castsi256_si128(src[0]); + return Vectorized(_mm256_cvtepi8_epi32(src128)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + auto src128 = _mm256_castsi256_si128(src[0]); + return Vectorized(_mm256_cvtepu8_epi32(src128)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + return Vectorized(_mm256_cvttps_epi32(src[0])); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + return Vectorized(_mm256_cvtepi32_ps(src[0])); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + auto src128 = _mm256_castsi256_si128(src[0]); + return Vectorized(_mm256_cvtepu8_epi16(src128)); + } +}; + +template +struct VecConvert< + dst_t, + 1, + src_t, + 1, + typename std::enable_if_t< + (is_reduced_floating_point_v && is_8bit_integer_v) || + (is_reduced_floating_point_v && is_8bit_integer_v), + void>> { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN tmp_fp32 = VecConvert::apply(src); + return VecConvert::apply(tmp_fp32); + } +}; + +template +struct VecConvert< + dst_t, + 1, + float, + 2, + typename std::enable_if_t, void>> { + static inline VectorizedN apply(const VectorizedN& src) { + at::vec::Vectorized vec1 = convert_float_to_int8(src[0]); + at::vec::Vectorized vec2 = convert_float_to_int8(src[1]); + __m128 lane2 = _mm256_castps256_ps128(_mm256_castsi256_ps(vec2)); + __m256 combined = _mm256_insertf128_ps(_mm256_castsi256_ps(vec1), lane2, 1); + // Shuffle [191:128] bit from combined in to [127:64] bit of result + __m256i result = + _mm256_permute4x64_epi64(_mm256_castps_si256(combined), 0b11011000); + return at::vec::Vectorized(result); + } +}; + +template +struct VecConvert< + dst_t, + 1, + float, + 1, + typename std::enable_if_t, void>> { + static inline VectorizedN apply(const VectorizedN& src) { + return convert_float_to_int8(src[0]); + } +}; + +template +struct VecConvert< + float, + 2, + src_t, + 1, + typename std::enable_if_t, void>> { + static inline VectorizedN apply(const VectorizedN& src) { + // Shuffle [127:64] bit from src[0] in to [191:128] bit of shuffled + __m256i shuffled = _mm256_permute4x64_epi64(src[0], 0b11011000); + __m256i src2 = + _mm256_castsi128_si256(_mm_castps_si128(_mm256_extractf128_ps( + _mm256_castsi256_ps(shuffled), 1) // Extract the second 128-bit lane + )); + return VectorizedN( + convert_int8_to_float(src[0]), + convert_int8_to_float(src2)); + } +}; + +template +struct VecConvert< + dst_t, + 1, + int64_t, + 2, + std::enable_if_t< + std::is_same_v || std::is_same_v>> { + static inline VectorizedN apply( + const VectorizedN& src) { + return VecConvert::apply( + VecConvert::apply(src)); + } +}; + +#endif /* defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) */ + +#if (defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)) +template +struct VecConvert< + float, + 1, + src_t, + 1, + typename std::enable_if_t, void>> { + static inline VectorizedN apply(const VectorizedN& src) { + return convert_int8_to_float(src[0]); + } +}; +#endif + +#if defined(CPU_CAPABILITY_SVE256) && defined(__ARM_FEATURE_BF16) + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN res; + // Load 16-bit unsigned integers from src into an SVE vector + svuint16_t u16x4 = + svld1_u16(svptrue_b16(), reinterpret_cast(&src[0])); + // Zero-extend to 32-bit SVE does not have direct vmovl_u16 equivalent. + vls_uint32_t u32x4 = + svreinterpret_u32_u16(svzip1_u16(svdup_n_u16(0), u16x4)); + // Reinterpret as float32 + vls_float32_t f32x4 = svreinterpret_f32_u32(u32x4); + res[0] = Vectorized(f32x4); + return res; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN res; + std::tie(res[0], res[1]) = convert_bfloat16_float(src[0]); + return res; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN res; + res[0] = convert_float_bfloat16(src[0], src[1]); + return res; + } +}; + +#endif // defined(CPU_CAPABILITY_SVE256) && defined(__ARM_FEATURE_BF16) + +template +struct VecConvert< + float, + 1, + src_t, + 1, + typename std::enable_if_t, void>> { + static inline VectorizedN apply(const VectorizedN& src) { + auto [res_vec1, res_vec2] = convert_to_float(src[0]); + return res_vec1; + } +}; + +template +struct VecConvert< + dst_t, + 1, + float, + 1, + typename std::enable_if_t, void>> { + static inline VectorizedN apply(const VectorizedN& src) { + return convert_from_float(src[0], src[0]); + } +}; + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_double.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_double.h new file mode 100644 index 0000000000000000000000000000000000000000..aaba475c17786966aef9b8c186cc1f8ad396fb7b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_double.h @@ -0,0 +1,505 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#if defined(CPU_CAPABILITY_AVX2) +#define SLEEF_STATIC_LIBS +#include +#endif + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX2) + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + __m256d values; + + public: + using value_type = double; + using size_type = int; + static constexpr size_type size() { + return 4; + } + Vectorized() {} + Vectorized(__m256d v) : values(v) {} + Vectorized(double val) { + values = _mm256_set1_pd(val); + } + Vectorized(double val1, double val2, double val3, double val4) { + values = _mm256_setr_pd(val1, val2, val3, val4); + } + operator __m256d() const { + return values; + } + template + static Vectorized blend( + const Vectorized& a, + const Vectorized& b) { + return _mm256_blend_pd(a.values, b.values, mask); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + return _mm256_blendv_pd(a.values, b.values, mask.values); + } + template + static Vectorized arange( + double base = 0., + step_t step = static_cast(1)) { + return Vectorized( + base, base + step, base + 2 * step, base + 3 * step); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return _mm256_loadu_pd(reinterpret_cast(ptr)); + + __at_align__ double tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0.0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(double)); + return _mm256_load_pd(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm256_storeu_pd(reinterpret_cast(ptr), values); + } else if (count > 0) { + double tmp_values[size()]; + _mm256_storeu_pd(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(double)); + } + } + const double& operator[](int idx) const = delete; + double& operator[](int idx) = delete; + int zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit + // and others are translated to 0-bit + __m256d cmp = _mm256_cmp_pd(values, _mm256_set1_pd(0.0), _CMP_EQ_OQ); + return _mm256_movemask_pd(cmp); + } + Vectorized isnan() const { + return _mm256_cmp_pd(values, _mm256_set1_pd(0.0), _CMP_UNORD_Q); + } + bool has_inf_nan() const { + __m256d self_sub = _mm256_sub_pd(values, values); + return (_mm256_movemask_epi8(_mm256_castpd_si256(self_sub)) & 0x77777777) != + 0; + } + Vectorized map(double (*const f)(double)) const { + __at_align__ double tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + Vectorized abs() const { + auto mask = _mm256_set1_pd(-0.f); + return _mm256_andnot_pd(mask, values); + } + Vectorized angle() const { + const auto zero_vec = _mm256_set1_pd(0.f); + const auto nan_vec = _mm256_set1_pd(NAN); + const auto not_nan_mask = _mm256_cmp_pd(values, values, _CMP_EQ_OQ); + const auto nan_mask = _mm256_cmp_pd(not_nan_mask, zero_vec, _CMP_EQ_OQ); + const auto pi = _mm256_set1_pd(c10::pi); + + const auto neg_mask = _mm256_cmp_pd(values, zero_vec, _CMP_LT_OQ); + auto angle = _mm256_blendv_pd(zero_vec, pi, neg_mask); + angle = _mm256_blendv_pd(angle, nan_vec, nan_mask); + return angle; + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm256_set1_pd(0); + } + Vectorized conj() const { + return *this; + } + Vectorized acos() const { + return Vectorized(Sleef_acosd4_u10(values)); + } + Vectorized acosh() const { + return Vectorized(Sleef_acoshd4_u10(values)); + } + Vectorized asin() const { + return Vectorized(Sleef_asind4_u10(values)); + } + Vectorized asinh() const { + return Vectorized(Sleef_asinhd4_u10(values)); + } + Vectorized atan() const { + return Vectorized(Sleef_atand4_u10(values)); + } + Vectorized atanh() const { + return Vectorized(Sleef_atanhd4_u10(values)); + } + Vectorized atan2(const Vectorized& b) const { + return Vectorized(Sleef_atan2d4_u10(values, b)); + } + Vectorized copysign(const Vectorized& sign) const { + return Vectorized(Sleef_copysignd4(values, sign)); + } + Vectorized erf() const { + return Vectorized(Sleef_erfd4_u10(values)); + } + Vectorized erfc() const { + return Vectorized(Sleef_erfcd4_u15(values)); + } + Vectorized erfinv() const { + return map(calc_erfinv); + } + Vectorized exp() const { + return Vectorized(Sleef_expd4_u10(values)); + } + Vectorized exp2() const { + return Vectorized(Sleef_exp2d4_u10(values)); + } + Vectorized expm1() const { + return Vectorized(Sleef_expm1d4_u10(values)); + } + Vectorized exp_u20() const { + return exp(); + } + Vectorized fmod(const Vectorized& q) const { + return Vectorized(Sleef_fmodd4(values, q)); + } + Vectorized hypot(const Vectorized& b) const { + return Vectorized(Sleef_hypotd4_u05(values, b)); + } + Vectorized i0() const { + return map(calc_i0); + } + Vectorized i0e() const { + return map(calc_i0e); + } + Vectorized digamma() const { + return map(calc_digamma); + } + Vectorized igamma(const Vectorized& x) const { + __at_align__ double tmp[size()]; + __at_align__ double tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (const auto i : c10::irange(size())) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized igammac(const Vectorized& x) const { + __at_align__ double tmp[size()]; + __at_align__ double tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (const auto i : c10::irange(size())) { + tmp[i] = calc_igammac(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized log() const { + return Vectorized(Sleef_logd4_u10(values)); + } + Vectorized log2() const { + return Vectorized(Sleef_log2d4_u10(values)); + } + Vectorized log10() const { + return Vectorized(Sleef_log10d4_u10(values)); + } + Vectorized log1p() const { + return Vectorized(Sleef_log1pd4_u10(values)); + } + Vectorized sin() const { + return Vectorized(Sleef_sind4_u10(values)); + } + Vectorized sinh() const { + return Vectorized(Sleef_sinhd4_u10(values)); + } + Vectorized cos() const { + return Vectorized(Sleef_cosd4_u10(values)); + } + Vectorized cosh() const { + return Vectorized(Sleef_coshd4_u10(values)); + } + Vectorized ceil() const { + return _mm256_ceil_pd(values); + } + Vectorized floor() const { + return _mm256_floor_pd(values); + } + Vectorized frac() const; + Vectorized neg() const { + return _mm256_xor_pd(_mm256_set1_pd(-0.), values); + } + Vectorized nextafter(const Vectorized& b) const { + return Vectorized(Sleef_nextafterd4(values, b)); + } + Vectorized round() const { + return _mm256_round_pd( + values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + Vectorized tan() const { + return Vectorized(Sleef_tand4_u10(values)); + } + Vectorized tanh() const { + return Vectorized(Sleef_tanhd4_u10(values)); + } + Vectorized trunc() const { + return _mm256_round_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + } + Vectorized lgamma() const { + return Vectorized(Sleef_lgammad4_u10(values)); + } + Vectorized sqrt() const { + return _mm256_sqrt_pd(values); + } + Vectorized reciprocal() const { + return _mm256_div_pd(_mm256_set1_pd(1), values); + } + Vectorized rsqrt() const { + return _mm256_div_pd(_mm256_set1_pd(1), _mm256_sqrt_pd(values)); + } + Vectorized pow(const Vectorized& b) const { + return Vectorized(Sleef_powd4_u10(values, b)); + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized operator==(const Vectorized& other) const { + return _mm256_cmp_pd(values, other.values, _CMP_EQ_OQ); + } + + Vectorized operator!=(const Vectorized& other) const { + return _mm256_cmp_pd(values, other.values, _CMP_NEQ_UQ); + } + + Vectorized operator<(const Vectorized& other) const { + return _mm256_cmp_pd(values, other.values, _CMP_LT_OQ); + } + + Vectorized operator<=(const Vectorized& other) const { + return _mm256_cmp_pd(values, other.values, _CMP_LE_OQ); + } + + Vectorized operator>(const Vectorized& other) const { + return _mm256_cmp_pd(values, other.values, _CMP_GT_OQ); + } + + Vectorized operator>=(const Vectorized& other) const { + return _mm256_cmp_pd(values, other.values, _CMP_GE_OQ); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm256_add_pd(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm256_sub_pd(a, b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return _mm256_mul_pd(a, b); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return _mm256_div_pd(a, b); +} + +// frac. Implement this here so we can use subtraction. +inline Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + Vectorized max = _mm256_max_pd(a, b); + Vectorized isnan = _mm256_cmp_pd(a, b, _CMP_UNORD_Q); + // Exploit the fact that all-ones is a NaN. + return _mm256_or_pd(max, isnan); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + Vectorized min = _mm256_min_pd(a, b); + Vectorized isnan = _mm256_cmp_pd(a, b, _CMP_UNORD_Q); + // Exploit the fact that all-ones is a NaN. + return _mm256_or_pd(min, isnan); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + return _mm256_min_pd(max, _mm256_max_pd(min, a)); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + return _mm256_max_pd(min, a); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + return _mm256_min_pd(max, a); +} + +template <> +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return _mm256_and_pd(a, b); +} + +template <> +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return _mm256_or_pd(a, b); +} + +template <> +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return _mm256_xor_pd(a, b); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0); +} + +template <> +inline void convert(const double* src, double* dst, int64_t n) { + int64_t i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + _mm256_storeu_pd(dst + i, _mm256_loadu_pd(src + i)); + } +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = src[i]; + } +} + +#ifdef CPU_CAPABILITY_AVX2 +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return _mm256_fmadd_pd(a, b, c); +} + +template <> +Vectorized inline fmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return _mm256_fmsub_pd(a, b, c); +} +#endif + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_float.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_float.h new file mode 100644 index 0000000000000000000000000000000000000000..d80a558b69476565496659239138eb13460ac8a5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_float.h @@ -0,0 +1,768 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#if defined(CPU_CAPABILITY_AVX2) +#define SLEEF_STATIC_LIBS +#include +#endif + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX2) + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + __m256 values; + + public: + using value_type = float; + using size_type = int; + static constexpr size_type size() { + return 8; + } + Vectorized() {} + Vectorized(__m256 v) : values(v) {} + Vectorized(float val) { + values = _mm256_set1_ps(val); + } + Vectorized( + float val1, + float val2, + float val3, + float val4, + float val5, + float val6, + float val7, + float val8) { + values = _mm256_setr_ps(val1, val2, val3, val4, val5, val6, val7, val8); + } + Vectorized(const float (&arr)[8]) + : Vectorized( + arr[0], + arr[1], + arr[2], + arr[3], + arr[4], + arr[5], + arr[6], + arr[7]) {} + operator __m256() const { + return values; + } + template + static Vectorized blend( + const Vectorized& a, + const Vectorized& b) { + return _mm256_blend_ps(a.values, b.values, mask); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + return _mm256_blendv_ps(a.values, b.values, mask.values); + } + template + static Vectorized arange( + float base = 0.f, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return _mm256_loadu_ps(reinterpret_cast(ptr)); + __at_align__ float tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0.0; + } + std::memcpy( + tmp_values, reinterpret_cast(ptr), count * sizeof(float)); + return _mm256_loadu_ps(tmp_values); + } + void store(void* ptr, int64_t count = size()) const { + if (count == size()) { + _mm256_storeu_ps(reinterpret_cast(ptr), values); + } else if (count > 0) { + float tmp_values[size()]; + _mm256_storeu_ps(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(float)); + } + } + const float& operator[](int idx) const = delete; + float& operator[](int idx) = delete; + int zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit + // and others are translated to 0-bit + __m256 cmp = _mm256_cmp_ps(values, _mm256_set1_ps(0.0f), _CMP_EQ_OQ); + return _mm256_movemask_ps(cmp); + } + Vectorized isnan() const { + return _mm256_cmp_ps(values, _mm256_set1_ps(0.0f), _CMP_UNORD_Q); + } + + bool has_inf_nan() const { + __m256 self_sub = _mm256_sub_ps(values, values); + return (_mm256_movemask_epi8(_mm256_castps_si256(self_sub)) & 0x77777777) != + 0; + } + + Vectorized map(float (*const f)(float)) const { + __at_align__ float tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + Vectorized abs() const { + auto mask = _mm256_set1_ps(-0.f); + return _mm256_andnot_ps(mask, values); + } + Vectorized angle() const { + const auto zero_vec = _mm256_set1_ps(0.f); + const auto nan_vec = _mm256_set1_ps(NAN); + const auto not_nan_mask = _mm256_cmp_ps(values, values, _CMP_EQ_OQ); + const auto nan_mask = _mm256_cmp_ps(not_nan_mask, zero_vec, _CMP_EQ_OQ); + const auto pi = _mm256_set1_ps(c10::pi); + + const auto neg_mask = _mm256_cmp_ps(values, zero_vec, _CMP_LT_OQ); + auto angle = _mm256_blendv_ps(zero_vec, pi, neg_mask); + angle = _mm256_blendv_ps(angle, nan_vec, nan_mask); + return angle; + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm256_set1_ps(0); + } + Vectorized conj() const { + return *this; + } + Vectorized acos() const { + return Vectorized(Sleef_acosf8_u10(values)); + } + Vectorized acosh() const { + return Vectorized(Sleef_acoshf8_u10(values)); + } + Vectorized asin() const { + return Vectorized(Sleef_asinf8_u10(values)); + } + Vectorized asinh() const { + return Vectorized(Sleef_asinhf8_u10(values)); + } + Vectorized atan() const { + return Vectorized(Sleef_atanf8_u10(values)); + } + Vectorized atanh() const { + return Vectorized(Sleef_atanhf8_u10(values)); + } + Vectorized atan2(const Vectorized& b) const { + return Vectorized(Sleef_atan2f8_u10(values, b)); + } + Vectorized copysign(const Vectorized& sign) const { + return Vectorized(Sleef_copysignf8(values, sign)); + } + Vectorized erf() const { + // constants + const auto neg_zero_vec = _mm256_set1_ps(-0.f); + const auto one_vec = _mm256_set1_ps(1.0f); + const auto p = _mm256_set1_ps(0.3275911f); + const auto p1 = _mm256_set1_ps(0.254829592f); + const auto p2 = _mm256_set1_ps(-0.284496736f); + const auto p3 = _mm256_set1_ps(1.421413741f); + const auto p4 = _mm256_set1_ps(-1.453152027f); + const auto p5 = _mm256_set1_ps(1.061405429f); + // sign(x) + auto sign_mask = _mm256_and_ps(neg_zero_vec, values); + auto abs_vec = _mm256_xor_ps(sign_mask, values); + // t = 1 / (p * abs(x) + 1) + auto tmp0 = _mm256_fmadd_ps(p, abs_vec, one_vec); + auto t = _mm256_div_ps(one_vec, tmp0); + // r = p5 * t ^ 4 + p4 * t ^ 3 + p3 * t ^ 2 + p2 * t + p1 + auto tmp1 = _mm256_fmadd_ps(p5, t, p4); + auto tmp2 = _mm256_fmadd_ps(tmp1, t, p3); + auto tmp3 = _mm256_fmadd_ps(tmp2, t, p2); + auto r = _mm256_fmadd_ps(tmp3, t, p1); + // - exp(- x * x) + auto pow_2 = _mm256_mul_ps(values, values); + auto neg_pow_2 = _mm256_xor_ps(neg_zero_vec, pow_2); + // auto tmp4 = exp(neg_pow_2); + auto tmp4 = Vectorized(Sleef_expf8_u10(neg_pow_2)); + auto tmp5 = _mm256_xor_ps(neg_zero_vec, tmp4); + // erf(x) = sign(x) * (1 - r * t * exp(- x * x)) + auto tmp6 = _mm256_mul_ps(tmp5, t); + auto tmp7 = _mm256_fmadd_ps(tmp6, r, one_vec); + return _mm256_xor_ps(sign_mask, tmp7); + } + Vectorized erfc() const { + return Vectorized(Sleef_erfcf8_u15(values)); + } + Vectorized erfinv() const { + return map(calc_erfinv); + } + Vectorized exp() const { + return Vectorized(Sleef_expf8_u10(values)); + } + Vectorized exp2() const { + return Vectorized(Sleef_exp2f8_u10(values)); + } + Vectorized expm1() const { + return Vectorized(Sleef_expm1f8_u10(values)); + } + Vectorized exp_u20() const { + // A faster version of exp with ULP=20 + const __m256 vec_factorial_1 = + _mm256_set1_ps(0.999999701f); // 1/factorial(1) + const __m256 vec_factorial_2 = + _mm256_set1_ps(0.499991506f); // 1/factorial(2) + const __m256 vec_factorial_3 = + _mm256_set1_ps(0.166676521f); // 1/factorial(3) + const __m256 vec_factorial_4 = + _mm256_set1_ps(0.0418978221f); // 1/factorial(4) + const __m256 vec_factorial_5 = + _mm256_set1_ps(0.00828929059f); // 1/factorial(5) + const __m256 vec_exp_log2ef = + _mm256_castsi256_ps(_mm256_set1_epi32(0x3fb8aa3b)); // log2(e) + const __m256 vec_half = _mm256_set1_ps(0.5f); + const __m256 vec_one = _mm256_set1_ps(1.f); + const __m256 vec_zero = _mm256_set1_ps(0.f); + const __m256 vec_two = _mm256_set1_ps(2.f); + const __m256 vec_ln2f = + _mm256_castsi256_ps(_mm256_set1_epi32(0x3f317218)); // ln(2) + const __m256 vec_ln_flt_min = + _mm256_castsi256_ps(_mm256_set1_epi32(0xc2aeac50)); + const __m256 vec_ln_flt_max = + _mm256_castsi256_ps(_mm256_set1_epi32(0x42b17218)); + const __m256i vec_127 = _mm256_set1_epi32(0x0000007f); + const int n_mantissa_bits = 23; + + // exp(x) = + // = exp(n * ln(2) + r) // divide x by ln(2) and get quot and rem + // = 2^n * exp(r) // simplify the exp(n*ln(2)) expression + + auto less_ln_flt_min_mask = + _mm256_cmp_ps(values, vec_ln_flt_min, 1 /*_CMP_LT_OS*/); + auto vec_src = _mm256_min_ps(values, vec_ln_flt_max); + vec_src = _mm256_max_ps(vec_src, vec_ln_flt_min); + + // fx = floorf(x * log2ef + 0.5) + auto vec_fx = _mm256_fmadd_ps(vec_src, vec_exp_log2ef, vec_half); + vec_fx = _mm256_floor_ps(vec_fx); + + // x = x - fx * ln2 + auto vec_exp_poly = _mm256_fnmadd_ps(vec_fx, vec_ln2f, vec_src); + + // compute polynomial + auto vec_res = + _mm256_fmadd_ps(vec_exp_poly, vec_factorial_5, vec_factorial_4); + vec_res = _mm256_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_3); + vec_res = _mm256_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_2); + vec_res = _mm256_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_1); + vec_res = _mm256_fmadd_ps(vec_exp_poly, vec_res, vec_one); + + // compute 2^(n-1) + auto vec_exp_number = _mm256_sub_ps(vec_fx, vec_one); + auto vec_exp_number_i = _mm256_cvtps_epi32(vec_exp_number); + auto vec_two_pow_n_i = _mm256_add_epi32(vec_exp_number_i, vec_127); + vec_two_pow_n_i = _mm256_slli_epi32(vec_two_pow_n_i, n_mantissa_bits); + auto vec_two_pow_n = _mm256_castsi256_ps(vec_two_pow_n_i); + vec_two_pow_n = + _mm256_blendv_ps(vec_two_pow_n, vec_zero, less_ln_flt_min_mask); + + // y = y * 2^n + vec_res = _mm256_mul_ps(vec_res, vec_two_pow_n); + vec_res = _mm256_mul_ps(vec_res, vec_two); + return vec_res; + } + Vectorized fmod(const Vectorized& q) const { + return Vectorized(Sleef_fmodf8(values, q)); + } + Vectorized log() const { + return Vectorized(Sleef_logf8_u10(values)); + } + Vectorized log2() const { + return Vectorized(Sleef_log2f8_u10(values)); + } + Vectorized log10() const { + return Vectorized(Sleef_log10f8_u10(values)); + } + Vectorized log1p() const { + return Vectorized(Sleef_log1pf8_u10(values)); + } + Vectorized frac() const; + Vectorized sin() const { + return Vectorized(Sleef_sinf8_u35(values)); + } + Vectorized sinh() const { + return Vectorized(Sleef_sinhf8_u10(values)); + } + Vectorized cos() const { + return Vectorized(Sleef_cosf8_u35(values)); + } + Vectorized cosh() const { + return Vectorized(Sleef_coshf8_u10(values)); + } + Vectorized ceil() const { + return _mm256_ceil_ps(values); + } + Vectorized floor() const { + return _mm256_floor_ps(values); + } + Vectorized hypot(const Vectorized& b) const { + return Vectorized(Sleef_hypotf8_u05(values, b)); + } + Vectorized i0() const { + return map(calc_i0); + } + Vectorized i0e() const { + return map(calc_i0e); + } + Vectorized digamma() const { + return map(calc_digamma); + } + Vectorized igamma(const Vectorized& x) const { + __at_align__ float tmp[size()]; + __at_align__ float tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (const auto i : c10::irange(size())) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized igammac(const Vectorized& x) const { + __at_align__ float tmp[size()]; + __at_align__ float tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (const auto i : c10::irange(size())) { + tmp[i] = calc_igammac(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized neg() const { + return _mm256_xor_ps(_mm256_set1_ps(-0.f), values); + } + Vectorized nextafter(const Vectorized& b) const { + return Vectorized(Sleef_nextafterf8(values, b)); + } + Vectorized round() const { + return _mm256_round_ps( + values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + Vectorized tan() const { + return Vectorized(Sleef_tanf8_u10(values)); + } + Vectorized tanh() const { + return Vectorized(Sleef_tanhf8_u10(values)); + } + Vectorized trunc() const { + return _mm256_round_ps(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + } + Vectorized lgamma() const { + return Vectorized(Sleef_lgammaf8_u10(values)); + } + Vectorized sqrt() const { + return _mm256_sqrt_ps(values); + } + Vectorized reciprocal() const { + return _mm256_div_ps(_mm256_set1_ps(1), values); + } + Vectorized rsqrt() const { + return _mm256_div_ps(_mm256_set1_ps(1), _mm256_sqrt_ps(values)); + } + Vectorized pow(const Vectorized& b) const { + return Vectorized(Sleef_powf8_u10(values, b)); + } + float reduce_add() const { + auto v = values; + // 128-bit shuffle + auto v1 = _mm256_permute2f128_ps(v, v, 0x1); + v = _mm256_add_ps(v, v1); + // 64-bit shuffle + v1 = _mm256_shuffle_ps(v, v, 0x4E); + v = _mm256_add_ps(v, v1); + // 32-bit shuffle + v1 = _mm256_shuffle_ps(v, v, 0xB1); + v = _mm256_add_ps(v, v1); + return _mm256_cvtss_f32(v); + } + float reduce_max() const { + auto v = values; + // 128-bit shuffle + auto v1 = _mm256_permute2f128_ps(v, v, 0x1); + v = _mm256_max_ps(v, v1); + // 64-bit shuffle + v1 = _mm256_shuffle_ps(v, v, 0x4E); + v = _mm256_max_ps(v, v1); + // 32-bit shuffle + v1 = _mm256_shuffle_ps(v, v, 0xB1); + v = _mm256_max_ps(v, v1); + return _mm256_cvtss_f32(v); + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized operator==(const Vectorized& other) const { + return _mm256_cmp_ps(values, other.values, _CMP_EQ_OQ); + } + + Vectorized operator!=(const Vectorized& other) const { + return _mm256_cmp_ps(values, other.values, _CMP_NEQ_UQ); + } + + Vectorized operator<(const Vectorized& other) const { + return _mm256_cmp_ps(values, other.values, _CMP_LT_OQ); + } + + Vectorized operator<=(const Vectorized& other) const { + return _mm256_cmp_ps(values, other.values, _CMP_LE_OQ); + } + + Vectorized operator>(const Vectorized& other) const { + return _mm256_cmp_ps(values, other.values, _CMP_GT_OQ); + } + + Vectorized operator>=(const Vectorized& other) const { + return _mm256_cmp_ps(values, other.values, _CMP_GE_OQ); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm256_add_ps(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm256_sub_ps(a, b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return _mm256_mul_ps(a, b); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return _mm256_div_ps(a, b); +} + +// frac. Implement this here so we can use subtraction +inline Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + Vectorized max = _mm256_max_ps(a, b); + Vectorized isnan = _mm256_cmp_ps(a, b, _CMP_UNORD_Q); + // Exploit the fact that all-ones is a NaN. + return _mm256_or_ps(max, isnan); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + Vectorized min = _mm256_min_ps(a, b); + Vectorized isnan = _mm256_cmp_ps(a, b, _CMP_UNORD_Q); + // Exploit the fact that all-ones is a NaN. + return _mm256_or_ps(min, isnan); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + return _mm256_min_ps(max, _mm256_max_ps(min, a)); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + return _mm256_min_ps(max, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + return _mm256_max_ps(min, a); +} + +template <> +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return _mm256_and_ps(a, b); +} + +template <> +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return _mm256_or_ps(a, b); +} + +template <> +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return _mm256_xor_ps(a, b); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +template <> +inline void convert(const float* src, float* dst, int64_t n) { + int64_t i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + _mm256_storeu_ps(dst + i, _mm256_loadu_ps(src + i)); + } +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = src[i]; + } +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return _mm256_fmadd_ps(a, b, c); +} + +template <> +Vectorized inline fmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return _mm256_fmsub_ps(a, b, c); +} + +// TODO: rewrite with ATEN vectorized (need to add unpack and shuffle) +// Used by Inductor CPP codegen for micro gemm +inline void transpose_block(at::vec::VectorizedN& input) { + __m256 temp0[8]; + // unpacking and interleaving 32-bit elements + // a0 b0 a1 b1 a4 b4 a5 b5 + // a2 b2 a3 b3 a6 b6 a7 b7 + // c0 d0 c1 d1 ... + // c2 d2 c3 d3 ... + // e0 f0 e1 f1 ... + // e2 f2 e3 f3 ... + // g0 h0 g1 h1 ... + // g2 h2 g3 h3 ... + temp0[0] = _mm256_unpacklo_ps(input[0], input[1]); + temp0[1] = _mm256_unpackhi_ps(input[0], input[1]); + temp0[2] = _mm256_unpacklo_ps(input[2], input[3]); + temp0[3] = _mm256_unpackhi_ps(input[2], input[3]); + temp0[4] = _mm256_unpacklo_ps(input[4], input[5]); + temp0[5] = _mm256_unpackhi_ps(input[4], input[5]); + temp0[6] = _mm256_unpacklo_ps(input[6], input[7]); + temp0[7] = _mm256_unpackhi_ps(input[6], input[7]); + + __m256 temp1[8]; + // unpacking and interleaving 64-bit elements + // a0 b0 c0 d0 a4 b4 c4 d4 + // a1 b1 c1 d1 ... + // a2 b2 c2 d2 ... + // a3 b3 c3 d3 ... + // e0 f0 g0 h0 e4 f4 g4 h4 + // e1 f1 g1 h1 ... + // e2 f2 g2 h2 ... + // e3 f3 g3 h3 ... + temp1[0] = _mm256_castpd_ps(_mm256_unpacklo_pd( + _mm256_castps_pd(temp0[0]), _mm256_castps_pd(temp0[2]))); + temp1[1] = _mm256_castpd_ps(_mm256_unpackhi_pd( + _mm256_castps_pd(temp0[0]), _mm256_castps_pd(temp0[2]))); + temp1[2] = _mm256_castpd_ps(_mm256_unpacklo_pd( + _mm256_castps_pd(temp0[1]), _mm256_castps_pd(temp0[3]))); + temp1[3] = _mm256_castpd_ps(_mm256_unpackhi_pd( + _mm256_castps_pd(temp0[1]), _mm256_castps_pd(temp0[3]))); + temp1[4] = _mm256_castpd_ps(_mm256_unpacklo_pd( + _mm256_castps_pd(temp0[4]), _mm256_castps_pd(temp0[6]))); + temp1[5] = _mm256_castpd_ps(_mm256_unpackhi_pd( + _mm256_castps_pd(temp0[4]), _mm256_castps_pd(temp0[6]))); + temp1[6] = _mm256_castpd_ps(_mm256_unpacklo_pd( + _mm256_castps_pd(temp0[5]), _mm256_castps_pd(temp0[7]))); + temp1[7] = _mm256_castpd_ps(_mm256_unpackhi_pd( + _mm256_castps_pd(temp0[5]), _mm256_castps_pd(temp0[7]))); + + // shuffle 128-bits (composed of 4 32-bit elements) + // a0 b0 c0 d0 e0 f0 g0 h0 + // a1 b1 c1 d1 ... + // a2 b2 c2 d2 ... + // a3 b3 c3 d3 ... + // a4 b4 c4 d4 ... + // a5 b5 c5 d5 ... + // a6 b6 c6 d6 ... + // a7 b7 c7 d7 ... + input[0] = _mm256_permute2f128_ps(temp1[0], temp1[4], 0x20); + input[1] = _mm256_permute2f128_ps(temp1[1], temp1[5], 0x20); + input[2] = _mm256_permute2f128_ps(temp1[2], temp1[6], 0x20); + input[3] = _mm256_permute2f128_ps(temp1[3], temp1[7], 0x20); + input[4] = _mm256_permute2f128_ps(temp1[0], temp1[4], 0x31); + input[5] = _mm256_permute2f128_ps(temp1[1], temp1[5], 0x31); + input[6] = _mm256_permute2f128_ps(temp1[2], temp1[6], 0x31); + input[7] = _mm256_permute2f128_ps(temp1[3], temp1[7], 0x31); +} + +// Used by Inductor CPP codegen +template <> +inline void transpose_mxn( + const float* src, + int64_t ld_src, + float* dst, + int64_t ld_dst) { + // load from src to registers + at::vec::VectorizedN input; + // a: a0 a1 a2 a3 a4 a5 a6 a7 + // b: b0 b1 b2 b3 b4 b5 b6 b7 + // c: c0 c1 c2 c3 c4 c5 c6 c7 + // d: d0 d1 d2 d3 d4 d5 d6 d7 + // e: e0 e1 e2 e3 e4 e5 e6 e7 + // f: f0 f1 f2 f3 f4 f5 f6 f7 + // g: g0 g1 g2 g3 g4 g5 g6 g7 + // h: h0 h1 h2 h3 h4 h5 h6 h7 + int i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i < 8; i++) { + input[i] = _mm256_loadu_ps(&src[i * ld_src]); + } + + transpose_block(input); + + // store from registers to dst +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i < 8; i++) { + _mm256_storeu_ps(&dst[i * ld_dst], input[i]); + } +} + +template <> +inline void transpose_mxn( + const float* src, + int64_t ld_src, + float* dst, + int64_t ld_dst) { + transpose_mxn(src, ld_src, dst, ld_dst); + transpose_mxn(src + 8, ld_src, dst + 8 * ld_dst, ld_dst); + transpose_mxn(src + 8 * ld_src, ld_src, dst + 8, ld_dst); + transpose_mxn( + src + 8 * ld_src + 8, ld_src, dst + 8 * ld_dst + 8, ld_dst); +} +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_half.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_half.h new file mode 100644 index 0000000000000000000000000000000000000000..cda6fd94d5702fbdd2c5c30793e71bc4bbb58bdf --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_half.h @@ -0,0 +1,280 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#ifdef CPU_CAPABILITY_AVX2 + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorized16 { + public: + using Vectorized16::Vectorized16; + + using value_type = Half; + + Vectorized frac() const; + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { + return _mm256_add_ps(x, y); + }); +} +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { + return _mm256_sub_ps(x, y); + }); +} +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { + return _mm256_mul_ps(x, y); + }); +} +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { + return _mm256_div_ps(x, y); + }); +} +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return _mm256_and_si256(a, b); +} +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return _mm256_or_si256(a, b); +} +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return _mm256_xor_si256(a, b); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0f); +} +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +// frac. Implement this here so we can use subtraction +inline Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + __m256 a_lo, a_hi; + __m256 b_lo, b_hi; + cvtfp16_fp32(__m256i(a), a_lo, a_hi); + cvtfp16_fp32(__m256i(b), b_lo, b_hi); + auto max_lo = _mm256_max_ps(a_lo, b_lo); + auto max_hi = _mm256_max_ps(a_hi, b_hi); + auto nan_lo = _mm256_cmp_ps(a_lo, b_lo, _CMP_UNORD_Q); + auto nan_hi = _mm256_cmp_ps(a_hi, b_hi, _CMP_UNORD_Q); + // Exploit the fact that all-ones is a NaN. + auto o1 = _mm256_or_ps(max_lo, nan_lo); + auto o2 = _mm256_or_ps(max_hi, nan_hi); + return cvtfp32_fp16(o1, o2); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + __m256 a_lo, a_hi; + __m256 b_lo, b_hi; + cvtfp16_fp32(__m256i(a), a_lo, a_hi); + cvtfp16_fp32(__m256i(b), b_lo, b_hi); + auto min_lo = _mm256_min_ps(a_lo, b_lo); + auto min_hi = _mm256_min_ps(a_hi, b_hi); + auto nan_lo = _mm256_cmp_ps(a_lo, b_lo, _CMP_UNORD_Q); + auto nan_hi = _mm256_cmp_ps(a_hi, b_hi, _CMP_UNORD_Q); + // Exploit the fact that all-ones is a NaN. + auto o1 = _mm256_or_ps(min_lo, nan_lo); + auto o2 = _mm256_or_ps(min_hi, nan_hi); + return cvtfp32_fp16(o1, o2); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + __m256 a_lo, a_hi; + __m256 min_lo, min_hi; + __m256 max_lo, max_hi; + cvtfp16_fp32(__m256i(a), a_lo, a_hi); + cvtfp16_fp32(__m256i(min), min_lo, min_hi); + cvtfp16_fp32(__m256i(max), max_lo, max_hi); + auto o1 = _mm256_min_ps(max_lo, _mm256_max_ps(min_lo, a_lo)); + auto o2 = _mm256_min_ps(max_hi, _mm256_max_ps(min_hi, a_hi)); + return cvtfp32_fp16(o1, o2); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + __m256 a_lo, a_hi; + __m256 max_lo, max_hi; + cvtfp16_fp32(__m256i(a), a_lo, a_hi); + cvtfp16_fp32(__m256i(max), max_lo, max_hi); + auto o1 = _mm256_min_ps(max_lo, a_lo); + auto o2 = _mm256_min_ps(max_hi, a_hi); + return cvtfp32_fp16(o1, o2); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + __m256 a_lo, a_hi; + __m256 min_lo, min_hi; + cvtfp16_fp32(__m256i(a), a_lo, a_hi); + cvtfp16_fp32(__m256i(min), min_lo, min_hi); + auto o1 = _mm256_max_ps(min_lo, a_lo); + auto o2 = _mm256_max_ps(min_hi, a_hi); + return cvtfp32_fp16(o1, o2); +} + +template <> +inline void convert(const Half* src, Half* dst, int64_t n) { + int64_t i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + auto vsrc = + _mm256_loadu_si256(reinterpret_cast<__m256i*>((void*)(src + i))); + _mm256_storeu_si256(reinterpret_cast<__m256i*>((void*)(dst + i)), vsrc); + } +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = src[i]; + } +} + +template <> +inline void convert(const float* src, Half* dst, int64_t n) { + int64_t i; + for (i = 0; i + Vectorized::size() <= n; + i += Vectorized::size()) { + __m256 a = _mm256_loadu_ps(&src[i]); + __m256 b = _mm256_loadu_ps(&src[i + 8]); + + __m256i c = cvtfp32_fp16(a, b); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(&dst[i]), c); + } + for (; i < n; i++) { + dst[i] = c10::convert(src[i]); + } +} + +template <> +inline void convert(const double* src, Half* dst, int64_t n) { + auto load_float = [](const double* src) -> __m256 { + // Load one float vector from an array of doubles + __m128 a = _mm256_cvtpd_ps(_mm256_loadu_pd(src)); + __m128 b = _mm256_cvtpd_ps(_mm256_loadu_pd(src + 4)); + return _mm256_insertf128_ps(_mm256_castps128_ps256(a), b, 1); + }; + + int64_t i; + for (i = 0; i + Vectorized::size() <= n; + i += Vectorized::size()) { + __m256 a = load_float(&src[i]); + __m256 b = load_float(&src[i + 8]); + + __m256i c = cvtfp32_fp16(a, b); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(&dst[i]), c); + } + for (; i < n; i++) { + dst[i] = c10::convert(src[i]); + } +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + __m256 a_lo, a_hi; + __m256 b_lo, b_hi; + __m256 c_lo, c_hi; + cvtfp16_fp32(__m256i(a), a_lo, a_hi); + cvtfp16_fp32(__m256i(b), b_lo, b_hi); + cvtfp16_fp32(__m256i(c), c_lo, c_hi); + auto o1 = _mm256_fmadd_ps(a_lo, b_lo, c_lo); + auto o2 = _mm256_fmadd_ps(a_hi, b_hi, c_hi); + return cvtfp32_fp16(o1, o2); +} + +CONVERT_VECTORIZED_INIT(Half, half) +LOAD_FP32_VECTORIZED_INIT(Half, fp16) + +#else // defined(CPU_CAPABILITY_AVX2) + +#if !( \ + defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \ + !defined(CPU_CAPABILITY_SVE256)) +CONVERT_NON_VECTORIZED_INIT(Half, half) +#endif + +LOAD_FP32_NON_VECTORIZED_INIT(Half, fp16) +#endif // defined(CPU_CAPABILITY_AVX2) +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_int.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_int.h new file mode 100644 index 0000000000000000000000000000000000000000..6df2d46062d8beaeb9158d13e4ca517dc03a2b74 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_int.h @@ -0,0 +1,2318 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#include + +namespace at::vec { +inline namespace CPU_CAPABILITY { + +#ifdef CPU_CAPABILITY_AVX2 + +struct Vectorizedi { + protected: + __m256i values; + + static inline __m256i invert(const __m256i& v) { + const auto ones = _mm256_set1_epi64x(-1); + return _mm256_xor_si256(ones, v); + } + + public: + Vectorizedi() {} + Vectorizedi(__m256i v) : values(v) {} + operator __m256i() const { + return values; + } +}; + +#else + +struct Vectorizedi {}; // dummy definition to make Vectorizedi always defined + +#endif // CPU_CAPABILITY_AVX2 + +#ifdef CPU_CAPABILITY_AVX2 + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorizedi { + private: + static const Vectorized ones; + + public: + using value_type = int64_t; + using size_type = int; + static constexpr size_type size() { + return 4; + } + using Vectorizedi::Vectorizedi; + Vectorized() {} + Vectorized(int64_t v) { + values = _mm256_set1_epi64x(v); + } + Vectorized(int64_t val1, int64_t val2, int64_t val3, int64_t val4) { + values = _mm256_setr_epi64x(val1, val2, val3, val4); + } + template + static Vectorized blend( + Vectorized a, + Vectorized b) { + __at_align__ int64_t tmp_values[size()]; + a.store(tmp_values); + if (mask & 0x01) + tmp_values[0] = _mm256_extract_epi64(b.values, 0); + if (mask & 0x02) + tmp_values[1] = _mm256_extract_epi64(b.values, 1); + if (mask & 0x04) + tmp_values[2] = _mm256_extract_epi64(b.values, 2); + if (mask & 0x08) + tmp_values[3] = _mm256_extract_epi64(b.values, 3); + return loadu(tmp_values); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + return _mm256_blendv_epi8(a.values, b.values, mask.values); + } + template + static Vectorized arange( + int64_t base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, base + step, base + 2 * step, base + 3 * step); + } + static Vectorized set( + Vectorized a, + Vectorized b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr) { + return _mm256_loadu_si256(reinterpret_cast(ptr)); + } + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ int64_t tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy(tmp_values, ptr, count * sizeof(int64_t)); + return loadu(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + // ptr need not to be aligned here. See + // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html + _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); + } else if (count > 0) { + __at_align__ int64_t tmp_values[size()]; + _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(int64_t)); + } + } + const int64_t& operator[](int idx) const = delete; + int64_t& operator[](int idx) = delete; + Vectorized abs() const { + auto zero = _mm256_set1_epi64x(0); + auto is_larger = _mm256_cmpgt_epi64(zero, values); + auto inverse = _mm256_xor_si256(values, is_larger); + return _mm256_sub_epi64(inverse, is_larger); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm256_set1_epi64x(0); + } + Vectorized conj() const { + return *this; + } + Vectorized neg() const; + Vectorized operator==(const Vectorized& other) const { + return _mm256_cmpeq_epi64(values, other.values); + } + Vectorized operator!=(const Vectorized& other) const { + return invert(_mm256_cmpeq_epi64(values, other.values)); + } + Vectorized operator<(const Vectorized& other) const { + return _mm256_cmpgt_epi64(other.values, values); + } + Vectorized operator<=(const Vectorized& other) const { + return invert(_mm256_cmpgt_epi64(values, other.values)); + } + Vectorized operator>(const Vectorized& other) const { + return _mm256_cmpgt_epi64(values, other.values); + } + Vectorized operator>=(const Vectorized& other) const { + return invert(_mm256_cmpgt_epi64(other.values, values)); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorizedi { + private: + static const Vectorized ones; + + public: + using value_type = int32_t; + static constexpr int size() { + return 8; + } + using Vectorizedi::Vectorizedi; + Vectorized() {} + Vectorized(int32_t v) { + values = _mm256_set1_epi32(v); + } + Vectorized( + int32_t val1, + int32_t val2, + int32_t val3, + int32_t val4, + int32_t val5, + int32_t val6, + int32_t val7, + int32_t val8) { + values = _mm256_setr_epi32(val1, val2, val3, val4, val5, val6, val7, val8); + } + template + static Vectorized blend( + Vectorized a, + Vectorized b) { + return _mm256_blend_epi32(a, b, mask); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + return _mm256_blendv_epi8(a.values, b.values, mask.values); + } + template + static Vectorized arange( + int32_t base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step); + } + static Vectorized set( + Vectorized a, + Vectorized b, + int32_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr) { + return _mm256_loadu_si256(reinterpret_cast(ptr)); + } + static Vectorized loadu(const void* ptr, int32_t count) { + __at_align__ int32_t tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy(tmp_values, ptr, count * sizeof(int32_t)); + return loadu(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + // ptr need not to be aligned here. See + // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html + _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); + } else if (count > 0) { + __at_align__ int32_t tmp_values[size()]; + _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(int32_t)); + } + } + const int32_t& operator[](int idx) const = delete; + int32_t& operator[](int idx) = delete; + Vectorized abs() const { + return _mm256_abs_epi32(values); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm256_set1_epi32(0); + } + Vectorized conj() const { + return *this; + } + Vectorized neg() const; + int32_t reduce_add() const { + auto v = values; + // 128-bit shuffle + auto v1 = _mm256_permute2f128_si256(v, v, 0x1); + v = _mm256_add_epi32(v, v1); + // 64-bit shuffle + v1 = _mm256_shuffle_epi32(v, 0x4E); + v = _mm256_add_epi32(v, v1); + // 32-bit shuffle + v1 = _mm256_shuffle_epi32(v, 0xB1); + v = _mm256_add_epi32(v, v1); + __m128i lo = _mm256_castsi256_si128(v); + return _mm_cvtsi128_si32(lo); + } + int32_t reduce_max() const { + auto v = values; + // 128-bit shuffle + auto v1 = _mm256_permute2f128_si256(v, v, 0x1); + v = _mm256_max_epi32(v, v1); + // 64-bit shuffle + v1 = _mm256_shuffle_epi32(v, 0x4E); + v = _mm256_max_epi32(v, v1); + // 32-bit shuffle + v1 = _mm256_shuffle_epi32(v, 0xB1); + v = _mm256_max_epi32(v, v1); + __m128i lo = _mm256_castsi256_si128(v); + return _mm_cvtsi128_si32(lo); + } + Vectorized operator==(const Vectorized& other) const { + return _mm256_cmpeq_epi32(values, other.values); + } + Vectorized operator!=(const Vectorized& other) const { + return invert(_mm256_cmpeq_epi32(values, other.values)); + } + Vectorized operator<(const Vectorized& other) const { + return _mm256_cmpgt_epi32(other.values, values); + } + Vectorized operator<=(const Vectorized& other) const { + return invert(_mm256_cmpgt_epi32(values, other.values)); + } + Vectorized operator>(const Vectorized& other) const { + return _mm256_cmpgt_epi32(values, other.values); + } + Vectorized operator>=(const Vectorized& other) const { + return invert(_mm256_cmpgt_epi32(other.values, values)); + } + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +inline void convert(const int32_t* src, float* dst, int64_t n) { + int64_t i; + // int32_t and float have same size +#ifndef _MSC_VER +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + auto input_vec = + _mm256_loadu_si256(reinterpret_cast(src + i)); + auto output_vec = _mm256_cvtepi32_ps(input_vec); + _mm256_storeu_ps(reinterpret_cast(dst + i), output_vec); + } +#ifndef _MSC_VER +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +template <> +inline void convert(const int32_t* src, double* dst, int64_t n) { + int64_t i; + // int32_t has half the size of double +#ifndef _MSC_VER +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + auto input_128_vec = + _mm_loadu_si128(reinterpret_cast(src + i)); + auto output_vec = _mm256_cvtepi32_pd(input_128_vec); + _mm256_storeu_pd(reinterpret_cast(dst + i), output_vec); + } +#ifndef _MSC_VER +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorizedi { + private: + static const Vectorized ones; + + public: + using value_type = int16_t; + static constexpr int size() { + return 16; + } + using Vectorizedi::Vectorizedi; + Vectorized() {} + Vectorized(int16_t v) { + values = _mm256_set1_epi16(v); + } + Vectorized( + int16_t val1, + int16_t val2, + int16_t val3, + int16_t val4, + int16_t val5, + int16_t val6, + int16_t val7, + int16_t val8, + int16_t val9, + int16_t val10, + int16_t val11, + int16_t val12, + int16_t val13, + int16_t val14, + int16_t val15, + int16_t val16) { + values = _mm256_setr_epi16( + val1, + val2, + val3, + val4, + val5, + val6, + val7, + val8, + val9, + val10, + val11, + val12, + val13, + val14, + val15, + val16); + } + template + static Vectorized blend( + Vectorized a, + Vectorized b) { + __at_align__ int16_t tmp_values[size()]; + a.store(tmp_values); + if (mask & 0x01) + tmp_values[0] = _mm256_extract_epi16(b.values, 0); + if (mask & 0x02) + tmp_values[1] = _mm256_extract_epi16(b.values, 1); + if (mask & 0x04) + tmp_values[2] = _mm256_extract_epi16(b.values, 2); + if (mask & 0x08) + tmp_values[3] = _mm256_extract_epi16(b.values, 3); + if (mask & 0x10) + tmp_values[4] = _mm256_extract_epi16(b.values, 4); + if (mask & 0x20) + tmp_values[5] = _mm256_extract_epi16(b.values, 5); + if (mask & 0x40) + tmp_values[6] = _mm256_extract_epi16(b.values, 6); + if (mask & 0x80) + tmp_values[7] = _mm256_extract_epi16(b.values, 7); + if (mask & 0x100) + tmp_values[8] = _mm256_extract_epi16(b.values, 8); + if (mask & 0x200) + tmp_values[9] = _mm256_extract_epi16(b.values, 9); + if (mask & 0x400) + tmp_values[10] = _mm256_extract_epi16(b.values, 10); + if (mask & 0x800) + tmp_values[11] = _mm256_extract_epi16(b.values, 11); + if (mask & 0x1000) + tmp_values[12] = _mm256_extract_epi16(b.values, 12); + if (mask & 0x2000) + tmp_values[13] = _mm256_extract_epi16(b.values, 13); + if (mask & 0x4000) + tmp_values[14] = _mm256_extract_epi16(b.values, 14); + if (mask & 0x8000) + tmp_values[15] = _mm256_extract_epi16(b.values, 15); + return loadu(tmp_values); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + return _mm256_blendv_epi8(a.values, b.values, mask.values); + } + template + static Vectorized arange( + int16_t base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step, + base + 8 * step, + base + 9 * step, + base + 10 * step, + base + 11 * step, + base + 12 * step, + base + 13 * step, + base + 14 * step, + base + 15 * step); + } + static Vectorized set( + Vectorized a, + Vectorized b, + int16_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + case 8: + return blend<255>(a, b); + case 9: + return blend<511>(a, b); + case 10: + return blend<1023>(a, b); + case 11: + return blend<2047>(a, b); + case 12: + return blend<4095>(a, b); + case 13: + return blend<8191>(a, b); + case 14: + return blend<16383>(a, b); + case 15: + return blend<32767>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr) { + return _mm256_loadu_si256(reinterpret_cast(ptr)); + } + static Vectorized loadu(const void* ptr, int16_t count) { + __at_align__ int16_t tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy(tmp_values, ptr, count * sizeof(int16_t)); + return loadu(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + // ptr need not to be aligned here. See + // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html + _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); + } else if (count > 0) { + __at_align__ int16_t tmp_values[size()]; + _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(int16_t)); + } + } + const int16_t& operator[](int idx) const = delete; + int16_t& operator[](int idx) = delete; + Vectorized abs() const { + return _mm256_abs_epi16(values); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm256_set1_epi16(0); + } + Vectorized conj() const { + return *this; + } + Vectorized neg() const; + Vectorized operator==(const Vectorized& other) const { + return _mm256_cmpeq_epi16(values, other.values); + } + Vectorized operator!=(const Vectorized& other) const { + return invert(_mm256_cmpeq_epi16(values, other.values)); + } + Vectorized operator<(const Vectorized& other) const { + return _mm256_cmpgt_epi16(other.values, values); + } + Vectorized operator<=(const Vectorized& other) const { + return invert(_mm256_cmpgt_epi16(values, other.values)); + } + Vectorized operator>(const Vectorized& other) const { + return _mm256_cmpgt_epi16(values, other.values); + } + Vectorized operator>=(const Vectorized& other) const { + return invert(_mm256_cmpgt_epi16(other.values, values)); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template +class Vectorized8 : public Vectorizedi { + static_assert( + std::is_same_v || std::is_same_v, + "Only int8_t/uint8_t are supported"); + + protected: + static const Vectorized ones; + + public: + using value_type = T; + static constexpr int size() { + return 32; + } + using Vectorizedi::Vectorizedi; + Vectorized8() {} + Vectorized8(T v) { + values = _mm256_set1_epi8(v); + } + Vectorized8( + T val1, + T val2, + T val3, + T val4, + T val5, + T val6, + T val7, + T val8, + T val9, + T val10, + T val11, + T val12, + T val13, + T val14, + T val15, + T val16, + T val17, + T val18, + T val19, + T val20, + T val21, + T val22, + T val23, + T val24, + T val25, + T val26, + T val27, + T val28, + T val29, + T val30, + T val31, + T val32) { + values = _mm256_setr_epi8( + val1, + val2, + val3, + val4, + val5, + val6, + val7, + val8, + val9, + val10, + val11, + val12, + val13, + val14, + val15, + val16, + val17, + val18, + val19, + val20, + val21, + val22, + val23, + val24, + val25, + val26, + val27, + val28, + val29, + val30, + val31, + val32); + } + template + static Vectorized blend(Vectorized a, Vectorized b) { + __at_align__ T tmp_values[size()]; + a.store(tmp_values); + if (mask & 0x01) + tmp_values[0] = _mm256_extract_epi8(b.values, 0); + if (mask & 0x02) + tmp_values[1] = _mm256_extract_epi8(b.values, 1); + if (mask & 0x04) + tmp_values[2] = _mm256_extract_epi8(b.values, 2); + if (mask & 0x08) + tmp_values[3] = _mm256_extract_epi8(b.values, 3); + if (mask & 0x10) + tmp_values[4] = _mm256_extract_epi8(b.values, 4); + if (mask & 0x20) + tmp_values[5] = _mm256_extract_epi8(b.values, 5); + if (mask & 0x40) + tmp_values[6] = _mm256_extract_epi8(b.values, 6); + if (mask & 0x80) + tmp_values[7] = _mm256_extract_epi8(b.values, 7); + if (mask & 0x100) + tmp_values[8] = _mm256_extract_epi8(b.values, 8); + if (mask & 0x200) + tmp_values[9] = _mm256_extract_epi8(b.values, 9); + if (mask & 0x400) + tmp_values[10] = _mm256_extract_epi8(b.values, 10); + if (mask & 0x800) + tmp_values[11] = _mm256_extract_epi8(b.values, 11); + if (mask & 0x1000) + tmp_values[12] = _mm256_extract_epi8(b.values, 12); + if (mask & 0x2000) + tmp_values[13] = _mm256_extract_epi8(b.values, 13); + if (mask & 0x4000) + tmp_values[14] = _mm256_extract_epi8(b.values, 14); + if (mask & 0x8000) + tmp_values[15] = _mm256_extract_epi8(b.values, 15); + if (mask & 0x010000) + tmp_values[16] = _mm256_extract_epi8(b.values, 16); + if (mask & 0x020000) + tmp_values[17] = _mm256_extract_epi8(b.values, 17); + if (mask & 0x040000) + tmp_values[18] = _mm256_extract_epi8(b.values, 18); + if (mask & 0x080000) + tmp_values[19] = _mm256_extract_epi8(b.values, 19); + if (mask & 0x100000) + tmp_values[20] = _mm256_extract_epi8(b.values, 20); + if (mask & 0x200000) + tmp_values[21] = _mm256_extract_epi8(b.values, 21); + if (mask & 0x400000) + tmp_values[22] = _mm256_extract_epi8(b.values, 22); + if (mask & 0x800000) + tmp_values[23] = _mm256_extract_epi8(b.values, 23); + if (mask & 0x1000000) + tmp_values[24] = _mm256_extract_epi8(b.values, 24); + if (mask & 0x2000000) + tmp_values[25] = _mm256_extract_epi8(b.values, 25); + if (mask & 0x4000000) + tmp_values[26] = _mm256_extract_epi8(b.values, 26); + if (mask & 0x8000000) + tmp_values[27] = _mm256_extract_epi8(b.values, 27); + if (mask & 0x10000000) + tmp_values[28] = _mm256_extract_epi8(b.values, 28); + if (mask & 0x20000000) + tmp_values[29] = _mm256_extract_epi8(b.values, 29); + if (mask & 0x40000000) + tmp_values[30] = _mm256_extract_epi8(b.values, 30); + if (mask & 0x80000000) + tmp_values[31] = _mm256_extract_epi8(b.values, 31); + return loadu(tmp_values); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + return _mm256_blendv_epi8(a.values, b.values, mask.values); + } + template + static Vectorized arange( + T base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step, + base + 8 * step, + base + 9 * step, + base + 10 * step, + base + 11 * step, + base + 12 * step, + base + 13 * step, + base + 14 * step, + base + 15 * step, + base + 16 * step, + base + 17 * step, + base + 18 * step, + base + 19 * step, + base + 20 * step, + base + 21 * step, + base + 22 * step, + base + 23 * step, + base + 24 * step, + base + 25 * step, + base + 26 * step, + base + 27 * step, + base + 28 * step, + base + 29 * step, + base + 30 * step, + base + 31 * step); + } + static Vectorized set(Vectorized a, Vectorized b, T count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<0x1>(a, b); + case 2: + return blend<0x3>(a, b); + case 3: + return blend<0x7>(a, b); + case 4: + return blend<0xF>(a, b); + case 5: + return blend<0x1F>(a, b); + case 6: + return blend<0x3F>(a, b); + case 7: + return blend<0x7F>(a, b); + case 8: + return blend<0xFF>(a, b); + case 9: + return blend<0x1FF>(a, b); + case 10: + return blend<0x3FF>(a, b); + case 11: + return blend<0x7FF>(a, b); + case 12: + return blend<0xFFF>(a, b); + case 13: + return blend<0x1FFF>(a, b); + case 14: + return blend<0x3FFF>(a, b); + case 15: + return blend<0x7FFF>(a, b); + case 16: + return blend<0xFFFF>(a, b); + case 17: + return blend<0x1FFFF>(a, b); + case 18: + return blend<0x3FFFF>(a, b); + case 19: + return blend<0x7FFFF>(a, b); + case 20: + return blend<0xFFFFF>(a, b); + case 21: + return blend<0x1FFFFF>(a, b); + case 22: + return blend<0x3FFFFF>(a, b); + case 23: + return blend<0x7FFFFF>(a, b); + case 24: + return blend<0xFFFFFF>(a, b); + case 25: + return blend<0x1FFFFFF>(a, b); + case 26: + return blend<0x3FFFFFF>(a, b); + case 27: + return blend<0x7FFFFFF>(a, b); + case 28: + return blend<0xFFFFFFF>(a, b); + case 29: + return blend<0x1FFFFFFF>(a, b); + case 30: + return blend<0x3FFFFFFF>(a, b); + case 31: + return blend<0x7FFFFFFF>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr) { + return _mm256_loadu_si256(reinterpret_cast(ptr)); + } + static Vectorized loadu_one_fourth(const void* ptr) { + // Fast path if only load element number of 8. + // Note: We didn't merge it as fast path of loadu(const void* ptr, T count), + // Because loadu(const void* ptr, T count) requires zero initialization for + // upper 128 bits. However, by using _mm256_castsi128_si256, the upper 128 + // bits of the result are undefined. + // TODO We can use _mm256_zextsi128_si256 in the furture, + // since gcc 9.3 doesn't support it now. + __m128i input_128 = _mm_loadl_epi64(reinterpret_cast(ptr)); + return _mm256_castsi128_si256(input_128); + } + static Vectorized loadu(const void* ptr, T count) { + __at_align__ T tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy(tmp_values, ptr, count * sizeof(T)); + return loadu(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + // ptr need not to be aligned here. See + // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html + _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); + } else if (count > 0) { + if (count == 8) { + // Fast path if only store element number of 8 + _mm_storel_epi64( + reinterpret_cast<__m128i*>(ptr), _mm256_castsi256_si128(values)); + } else { + __at_align__ T tmp_values[size()]; + _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(T)); + } + } + } + const T& operator[](int idx) const = delete; + T& operator[](int idx) = delete; + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm256_set1_epi8(0); + } + Vectorized conj() const { + return *this; + } +}; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorized8 { + public: + using Vectorized8::Vectorized8; + + Vectorized neg() const; + + Vectorized abs() const { + return _mm256_abs_epi8(values); + } + + Vectorized operator==(const Vectorized& other) const { + return _mm256_cmpeq_epi8(values, other.values); + } + Vectorized operator!=(const Vectorized& other) const { + return invert(_mm256_cmpeq_epi8(values, other.values)); + } + Vectorized operator<(const Vectorized& other) const { + return _mm256_cmpgt_epi8(other.values, values); + } + Vectorized operator<=(const Vectorized& other) const { + return invert(_mm256_cmpgt_epi8(values, other.values)); + } + Vectorized operator>(const Vectorized& other) const { + return other < *this; + } + Vectorized operator>=(const Vectorized& other) const { + return other <= *this; + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorized8 { + public: + using Vectorized8::Vectorized8; + + Vectorized neg() const; + + Vectorized abs() const { + return *this; + } + + Vectorized operator==(const Vectorized& other) const { + return _mm256_cmpeq_epi8(values, other.values); + } + Vectorized operator!=(const Vectorized& other) const { + return invert(_mm256_cmpeq_epi8(values, other.values)); + } + Vectorized operator<(const Vectorized& other) const { + __m256i max = _mm256_max_epu8(values, other.values); + return invert(_mm256_cmpeq_epi8(max, values)); + } + Vectorized operator<=(const Vectorized& other) const { + __m256i max = _mm256_max_epu8(values, other.values); + return _mm256_cmpeq_epi8(max, other.values); + } + Vectorized operator>(const Vectorized& other) const { + return other < *this; + } + Vectorized operator>=(const Vectorized& other) const { + return other <= *this; + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm256_add_epi64(a, b); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm256_add_epi32(a, b); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm256_add_epi16(a, b); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm256_add_epi8(a, b); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm256_add_epi8(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm256_sub_epi64(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm256_sub_epi32(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm256_sub_epi16(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm256_sub_epi8(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm256_sub_epi8(a, b); +} + +// Negation. Defined here so we can utilize operator- +inline Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +inline Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +inline Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +inline Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +inline Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +// Emulate operations with no native 64-bit support in avx, +// by extracting each element, performing the operation pointwise, +// then combining the results into a vector. +template +Vectorized inline emulate( + const Vectorized& a, + const Vectorized& b, + const op_t& op) { + int64_t a0 = _mm256_extract_epi64(a, 0); + int64_t a1 = _mm256_extract_epi64(a, 1); + int64_t a2 = _mm256_extract_epi64(a, 2); + int64_t a3 = _mm256_extract_epi64(a, 3); + + int64_t b0 = _mm256_extract_epi64(b, 0); + int64_t b1 = _mm256_extract_epi64(b, 1); + int64_t b2 = _mm256_extract_epi64(b, 2); + int64_t b3 = _mm256_extract_epi64(b, 3); + + int64_t c0 = op(a0, b0); + int64_t c1 = op(a1, b1); + int64_t c2 = op(a2, b2); + int64_t c3 = op(a3, b3); + + return _mm256_set_epi64x(c3, c2, c1, c0); +} + +template +Vectorized inline emulate( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c, + const op_t& op) { + int64_t a0 = _mm256_extract_epi64(a, 0); + int64_t a1 = _mm256_extract_epi64(a, 1); + int64_t a2 = _mm256_extract_epi64(a, 2); + int64_t a3 = _mm256_extract_epi64(a, 3); + + int64_t b0 = _mm256_extract_epi64(b, 0); + int64_t b1 = _mm256_extract_epi64(b, 1); + int64_t b2 = _mm256_extract_epi64(b, 2); + int64_t b3 = _mm256_extract_epi64(b, 3); + + int64_t c0 = _mm256_extract_epi64(c, 0); + int64_t c1 = _mm256_extract_epi64(c, 1); + int64_t c2 = _mm256_extract_epi64(c, 2); + int64_t c3 = _mm256_extract_epi64(c, 3); + + int64_t d0 = op(a0, b0, c0); + int64_t d1 = op(a1, b1, c1); + int64_t d2 = op(a2, b2, c2); + int64_t d3 = op(a3, b3, c3); + + return _mm256_set_epi64x(d3, d2, d1, d0); +} + +// AVX2 has no intrinsic for int64_t multiply so it needs to be emulated +// This could be implemented more efficiently using epi32 instructions +// This is also technically avx compatible, but then we'll need AVX +// code for add as well. +// Note: intentionally ignores undefined behavior like (-lowest * -1). +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return emulate( + a, b, [](int64_t a_point, int64_t b_point) __ubsan_ignore_undefined__ { + return a_point * b_point; + }); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return _mm256_mullo_epi32(a, b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return _mm256_mullo_epi16(a, b); +} + +template +Vectorized inline int_elementwise_binary_256( + const Vectorized& a, + const Vectorized& b, + Op op) { + T values_a[Vectorized::size()]; + T values_b[Vectorized::size()]; + a.store(values_a); + b.store(values_b); + for (int i = 0; i != Vectorized::size(); i++) { + values_a[i] = op(values_a[i], values_b[i]); + } + return Vectorized::loadu(values_a); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + // We don't have an instruction for multiplying int8_t +#ifndef CPU_CAPABILITY_AVX2 + return int_elementwise_binary_256(a, b, std::multiplies()); +#else + __m256i mask00FF = _mm256_set1_epi16(0x00FF); + __m256i a_lo = _mm256_srai_epi16(_mm256_slli_epi16(a, 8), 8); + __m256i b_lo = _mm256_srai_epi16(_mm256_slli_epi16(b, 8), 8); + __m256i a_hi = _mm256_srai_epi16(a, 8); + __m256i b_hi = _mm256_srai_epi16(b, 8); + __m256i res_lo = _mm256_and_si256(_mm256_mullo_epi16(a_lo, b_lo), mask00FF); + __m256i res_hi = _mm256_slli_epi16(_mm256_mullo_epi16(a_hi, b_hi), 8); + __m256i res = _mm256_or_si256(res_hi, res_lo); + return res; +#endif +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + // We don't have an instruction for multiplying uint8_t +#ifndef CPU_CAPABILITY_AVX2 + return int_elementwise_binary_256(a, b, std::multiplies()); +#else + __m256i mask00FF = _mm256_set1_epi16(0x00FF); + __m256i a_lo = _mm256_and_si256(a, mask00FF); + __m256i b_lo = _mm256_and_si256(b, mask00FF); + __m256i a_hi = _mm256_srli_epi16(a, 8); + __m256i b_hi = _mm256_srli_epi16(b, 8); + __m256i res_lo = _mm256_and_si256(_mm256_mullo_epi16(a_lo, b_lo), mask00FF); + __m256i res_hi = _mm256_slli_epi16(_mm256_mullo_epi16(a_hi, b_hi), 8); + __m256i res = _mm256_or_si256(res_hi, res_lo); + return res; +#endif +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { +#ifndef CPU_CAPABILITY_AVX2 + return emulate(a, b, [](int64_t a_point, int64_t b_point) { + return std::min(a_point, b_point); + }); +#else + __m256i cmp = _mm256_cmpgt_epi64(a, b); + return _mm256_blendv_epi8(a, b, cmp); +#endif +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return _mm256_min_epi32(a, b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return _mm256_min_epi16(a, b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return _mm256_min_epi8(a, b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return _mm256_min_epu8(a, b); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { +#ifndef CPU_CAPABILITY_AVX2 + return emulate(a, b, [](int64_t a_point, int64_t b_point) { + return std::max(a_point, b_point); + }); +#else + __m256i cmp = _mm256_cmpgt_epi64(a, b); + return _mm256_blendv_epi8(b, a, cmp); +#endif +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return _mm256_max_epi32(a, b); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return _mm256_max_epi16(a, b); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return _mm256_max_epi8(a, b); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return _mm256_max_epu8(a, b); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min_val, + const Vectorized& max_val) { +#ifndef CPU_CAPABILITY_AVX2 + return emulate( + a, + min_val, + max_val, + [](int64_t a_point, int64_t min_point, int64_t max_point) { + return std::min(max_point, std::max(a_point, min_point)); + }); +#else + return minimum(maximum(a, min_val), max_val); +#endif +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min_val, + const Vectorized& max_val) { + return _mm256_min_epi32(max_val, _mm256_max_epi32(a, min_val)); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min_val, + const Vectorized& max_val) { + return _mm256_min_epi16(max_val, _mm256_max_epi16(a, min_val)); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min_val, + const Vectorized& max_val) { + return _mm256_min_epi8(max_val, _mm256_max_epi8(a, min_val)); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min_val, + const Vectorized& max_val) { + return _mm256_min_epu8(max_val, _mm256_max_epu8(a, min_val)); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max_val) { +#ifndef CPU_CAPABILITY_AVX2 + return emulate(a, max_val, [](int64_t a_point, int64_t max_point) { + return std::min(max_point, a_point); + }); +#else + return minimum(max_val, a); +#endif +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max_val) { + return _mm256_min_epi32(max_val, a); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max_val) { + return _mm256_min_epi16(max_val, a); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max_val) { + return _mm256_min_epi8(max_val, a); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max_val) { + return _mm256_min_epu8(max_val, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min_val) { +#ifndef CPU_CAPABILITY_AVX2 + return emulate(a, min_val, [](int64_t a_point, int64_t min_point) { + return std::max(min_point, a_point); + }); +#else + return maximum(min_val, a); +#endif +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min_val) { + return _mm256_max_epi32(min_val, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min_val) { + return _mm256_max_epi16(min_val, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min_val) { + return _mm256_max_epi8(min_val, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min_val) { + return _mm256_max_epu8(min_val, a); +} + +template +std::enable_if_t< + !(std::is_same_v || std::is_same_v), + Vectorized< + int32_t>> inline convert_to_int32(const T* ptr, int count = Vectorized::size()) { + return Vectorized::loadu(ptr, count); +} + +template +std:: + enable_if_t, Vectorized> inline convert_to_int32( + const int8_t* ptr, + int count = Vectorized::size()) { + if (count == Vectorized::size()) { + return _mm256_cvtepi8_epi32( + _mm_loadl_epi64(reinterpret_cast(ptr))); + } else { + auto a = Vectorized::loadu(ptr, count); + return _mm256_cvtepi8_epi32(_mm256_castsi256_si128(a)); + } +} + +template +std:: + enable_if_t, Vectorized> inline convert_to_int32( + const uint8_t* ptr, + int count = Vectorized::size()) { + if (count == Vectorized::size()) { + return _mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ptr))); + } else { + auto a = Vectorized::loadu(ptr, count); + return _mm256_cvtepu8_epi32(_mm256_castsi256_si128(a)); + } +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return int_elementwise_binary_256(a, b, std::divides()); +} +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return int_elementwise_binary_256(a, b, std::divides()); +} +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return int_elementwise_binary_256(a, b, std::divides()); +} +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return int_elementwise_binary_256(a, b, std::divides()); +} +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return int_elementwise_binary_256(a, b, std::divides()); +} + +template < + class T, + typename std::enable_if_t< + std::is_base_of>::value, + int> = 0> +inline Vectorized operator&(const Vectorized& a, const Vectorized& b) { + return _mm256_and_si256(a, b); +} +template < + class T, + typename std::enable_if_t< + std::is_base_of>::value, + int> = 0> +inline Vectorized operator|(const Vectorized& a, const Vectorized& b) { + return _mm256_or_si256(a, b); +} +template < + class T, + typename std::enable_if_t< + std::is_base_of>::value, + int> = 0> +inline Vectorized operator^(const Vectorized& a, const Vectorized& b) { + return _mm256_xor_si256(a, b); +} +template < + class T, + typename std::enable_if_t< + std::is_base_of>::value, + int> = 0> +inline Vectorized operator~(const Vectorized& a) { + return _mm256_xor_si256(a, _mm256_set1_epi32(-1)); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +template +Vectorized inline shift_256_16( + const Vectorized& a, + const Vectorized& b) { + // No vector instruction for shifting int16_t, so emulating it instead. + + // Control masks for shuffle operation, treating 256 bits as an + // array of 16-bit elements, and considering pairs of neighboring + // elements. Specifially, a mask named "ctl_M_N" (M,N in [0,1], and + // M!=N) is set so that shuffle will move element with index M from + // input pair into element with index N in output pair, and element + // with index M in output pair will be set to all 0s. + __m256i ctl_0_1 = _mm256_set_epi8( + 29, + 28, + 0x80, + 0x80, + 25, + 24, + 0x80, + 0x80, + 21, + 20, + 0x80, + 0x80, + 17, + 16, + 0x80, + 0x80, + 13, + 12, + 0x80, + 0x80, + 9, + 8, + 0x80, + 0x80, + 5, + 4, + 0x80, + 0x80, + 1, + 0, + 0x80, + 0x80); + __m256i ctl_1_0 = _mm256_set_epi8( + 0x80, + 0x80, + 31, + 30, + 0x80, + 0x80, + 27, + 26, + 0x80, + 0x80, + 23, + 22, + 0x80, + 0x80, + 19, + 18, + 0x80, + 0x80, + 15, + 14, + 0x80, + 0x80, + 11, + 10, + 0x80, + 0x80, + 7, + 6, + 0x80, + 0x80, + 3, + 2); + + // Masks for bitwise and operation, treating 256 bits as an array of + // 16-bit elements, and considering them in pairs of neighboring + // elements. A mask named "keep_M" (M in [0,1]) is set so that + // bitwise and will copy element with index M from input pair into + // element with the same index in output pair, while the other + // element in output pair will be set to all 0s. + __m256i keep_0 = _mm256_set1_epi32(0xFFFF); + __m256i keep_1 = _mm256_set1_epi32(0xFFFF0000); + + // Take each 16-bit element with idx%2==0 from input array to be + // shifted and extend it to 32 bits so that 0s are added to the + // right. Then, perform shifting on this 32-bit number. Upper 16 + // bits will be proper result of shifting original 16-bit number, so + // write them to result array, into the same position from which + // corresponding input element is taken. Also, make sure that + // result array elements with idx%2!=0 are set to all 0s. + // + // Note that number of bits to shift for is extended to 32 bits by + // adding 0s to the left. That means this number is not properly + // sign-extended for negative values. However, number of bits to + // shift is treated as an unsigned integer by respective shift + // intrinsics anyway so if negative then either with or without + // proper sign extension, it will be interpreted as a number greater + // than 32, and the shifting result will be the same. + __m256i a0 = _mm256_shuffle_epi8(a, ctl_0_1); + __m256i b0 = _mm256_and_si256(b, keep_0); + __m256i c0; + if (left_shift) + c0 = _mm256_sllv_epi32(a0, b0); + else + c0 = _mm256_srav_epi32(a0, b0); + c0 = _mm256_shuffle_epi8(c0, ctl_1_0); + + // Peform shifting the same way for input array elements with + // idx%2==1. + __m256i a1 = _mm256_and_si256(a, keep_1); + __m256i b1 = _mm256_shuffle_epi8(b, ctl_1_0); + __m256i c1; + if (left_shift) + c1 = _mm256_sllv_epi32(a1, b1); + else + c1 = _mm256_srav_epi32(a1, b1); + c1 = _mm256_and_si256(c1, keep_1); + + // Merge partial results into the final result. + __m256i c = _mm256_or_si256(c0, c1); + + return c; +} + +template < + bool left_shift, + typename T, + typename std::enable_if_t< + std::is_same_v || std::is_same_v, + int> = 0> +Vectorized inline shift_256_8( + const Vectorized& a, + const Vectorized& b) { + // No vector instruction for shifting int8_t/uint8_t, so emulating + // it instead. + + // Control masks for shuffle operation, treating 256 bits as an + // array of 8-bit elements, and considering quadruples of + // neighboring elements. Specifially, a mask named "ctl_M_N" (M,N + // in [0,1,2,3], and M!=N) is set so that shuffle will move element + // with index M from input quadruple into element with index N in + // output quadruple, and other elements in output quadruple will be + // set to all 0s. + __m256i ctl_0_3 = _mm256_set_epi8( + 28, + 0x80, + 0x80, + 0x80, + 24, + 0x80, + 0x80, + 0x80, + 20, + 0x80, + 0x80, + 0x80, + 16, + 0x80, + 0x80, + 0x80, + 12, + 0x80, + 0x80, + 0x80, + 8, + 0x80, + 0x80, + 0x80, + 4, + 0x80, + 0x80, + 0x80, + 0, + 0x80, + 0x80, + 0x80); + __m256i ctl_1_0 = _mm256_set_epi8( + 0x80, + 0x80, + 0x80, + 29, + 0x80, + 0x80, + 0x80, + 25, + 0x80, + 0x80, + 0x80, + 21, + 0x80, + 0x80, + 0x80, + 17, + 0x80, + 0x80, + 0x80, + 13, + 0x80, + 0x80, + 0x80, + 9, + 0x80, + 0x80, + 0x80, + 5, + 0x80, + 0x80, + 0x80, + 1); + __m256i ctl_1_3 = _mm256_set_epi8( + 29, + 0x80, + 0x80, + 0x80, + 25, + 0x80, + 0x80, + 0x80, + 21, + 0x80, + 0x80, + 0x80, + 17, + 0x80, + 0x80, + 0x80, + 13, + 0x80, + 0x80, + 0x80, + 9, + 0x80, + 0x80, + 0x80, + 5, + 0x80, + 0x80, + 0x80, + 1, + 0x80, + 0x80, + 0x80); + __m256i ctl_2_0 = _mm256_set_epi8( + 0x80, + 0x80, + 0x80, + 30, + 0x80, + 0x80, + 0x80, + 26, + 0x80, + 0x80, + 0x80, + 22, + 0x80, + 0x80, + 0x80, + 18, + 0x80, + 0x80, + 0x80, + 14, + 0x80, + 0x80, + 0x80, + 10, + 0x80, + 0x80, + 0x80, + 6, + 0x80, + 0x80, + 0x80, + 2); + __m256i ctl_2_3 = _mm256_set_epi8( + 30, + 0x80, + 0x80, + 0x80, + 26, + 0x80, + 0x80, + 0x80, + 22, + 0x80, + 0x80, + 0x80, + 18, + 0x80, + 0x80, + 0x80, + 14, + 0x80, + 0x80, + 0x80, + 10, + 0x80, + 0x80, + 0x80, + 6, + 0x80, + 0x80, + 0x80, + 2, + 0x80, + 0x80, + 0x80); + __m256i ctl_3_0 = _mm256_set_epi8( + 0x80, + 0x80, + 0x80, + 31, + 0x80, + 0x80, + 0x80, + 27, + 0x80, + 0x80, + 0x80, + 23, + 0x80, + 0x80, + 0x80, + 19, + 0x80, + 0x80, + 0x80, + 15, + 0x80, + 0x80, + 0x80, + 11, + 0x80, + 0x80, + 0x80, + 7, + 0x80, + 0x80, + 0x80, + 3); + __m256i ctl_3_1 = _mm256_set_epi8( + 0x80, + 0x80, + 31, + 0x80, + 0x80, + 0x80, + 27, + 0x80, + 0x80, + 0x80, + 23, + 0x80, + 0x80, + 0x80, + 19, + 0x80, + 0x80, + 0x80, + 15, + 0x80, + 0x80, + 0x80, + 11, + 0x80, + 0x80, + 0x80, + 7, + 0x80, + 0x80, + 0x80, + 3, + 0x80); + __m256i ctl_3_2 = _mm256_set_epi8( + 0x80, + 31, + 0x80, + 0x80, + 0x80, + 27, + 0x80, + 0x80, + 0x80, + 23, + 0x80, + 0x80, + 0x80, + 19, + 0x80, + 0x80, + 0x80, + 15, + 0x80, + 0x80, + 0x80, + 11, + 0x80, + 0x80, + 0x80, + 7, + 0x80, + 0x80, + 0x80, + 3, + 0x80, + 0x80); + + // Masks for bitwise and operation, treating 256 bits as an array of + // 8-bit elements, and considering them in quadruples of neighboring + // elements. A mask named "keep_M" (M in [0,1,2,3]) is set so that + // bitwise and will copy element with index M from input quadruple + // into element with the same index in output quadruple, while the + // other elements in output quadruple will be set to all 0s. + __m256i keep_0 = _mm256_set1_epi32(0xFF); + __m256i keep_3 = _mm256_set1_epi32(0xFF000000); + + // Take each 8-bit element with idx%4==0 from input array to be + // shifted and extend it to 32 bits so that 0s are added to the + // right. Then, perform shifting on this 32-bit number. Upper 8 + // bits will be proper result of shifting original 8-bit number, so + // write them to result array, into the same position from which + // corresponding input element is taken. Also, make sure that + // result array elements with idx%4!=0 are set to all 0s. + // + // Note that number of bits to shift for is extended to 32 bits by + // adding 0s to the left. That means this number is not properly + // sign-extended for negative values. However, number of bits to + // shift is treated as an unsigned integer by respective shift + // intrinsics anyway so if negative then either with or without + // proper sign extension, it will be interpreted as a number greater + // than 32, and the shifting result will be the same. + __m256i a0 = _mm256_shuffle_epi8(a, ctl_0_3); + __m256i b0 = _mm256_and_si256(b, keep_0); + __m256i c0; + if (left_shift) + c0 = _mm256_sllv_epi32(a0, b0); + else if constexpr (std::is_same_v) + c0 = _mm256_srav_epi32(a0, b0); + else + c0 = _mm256_srlv_epi32(a0, b0); + c0 = _mm256_shuffle_epi8(c0, ctl_3_0); + + // Peform shifting the same way for input array elements with + // idx%4==1. + __m256i a1 = _mm256_shuffle_epi8(a, ctl_1_3); + __m256i b1 = _mm256_shuffle_epi8(b, ctl_1_0); + __m256i c1; + if (left_shift) + c1 = _mm256_sllv_epi32(a1, b1); + else if constexpr (std::is_same_v) + c1 = _mm256_srav_epi32(a1, b1); + else + c1 = _mm256_srlv_epi32(a1, b1); + c1 = _mm256_shuffle_epi8(c1, ctl_3_1); + + // Peform shifting the same way for input array elements with + // idx%4==2. + __m256i a2 = _mm256_shuffle_epi8(a, ctl_2_3); + __m256i b2 = _mm256_shuffle_epi8(b, ctl_2_0); + __m256i c2; + if (left_shift) + c2 = _mm256_sllv_epi32(a2, b2); + else if constexpr (std::is_same_v) + c2 = _mm256_srav_epi32(a2, b2); + else + c2 = _mm256_srlv_epi32(a2, b2); + c2 = _mm256_shuffle_epi8(c2, ctl_3_2); + + // Peform shifting the same way for input array elements with + // idx%4==3. + __m256i a3 = _mm256_and_si256(a, keep_3); + __m256i b3 = _mm256_shuffle_epi8(b, ctl_3_0); + __m256i c3; + if (left_shift) + c3 = _mm256_sllv_epi32(a3, b3); + else if constexpr (std::is_same_v) + c3 = _mm256_srav_epi32(a3, b3); + else + c3 = _mm256_srlv_epi32(a3, b3); + c3 = _mm256_and_si256(c3, keep_3); + + // Merge partial results into the final result. + __m256i c01 = _mm256_or_si256(c0, c1); + __m256i c23 = _mm256_or_si256(c2, c3); + __m256i c = _mm256_or_si256(c01, c23); + + return c; +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return _mm256_sllv_epi64(a, b); +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return _mm256_sllv_epi32(a, b); +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return shift_256_16(a, b); +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return shift_256_8(a, b); +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return shift_256_8(a, b); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + // No vector instruction for right arithmetic shifting int64_t, so emulating + // it instead. + + // Clamp the shift values such that shift values < 0 and > 64 are changed to + // 64 which results in -1 for negative input and 0 for non-negative input. + __m256i zero = _mm256_set1_epi64x(0); + __m256i max_shift = _mm256_set1_epi64x(64); + __m256i mask = _mm256_or_si256( + _mm256_cmpgt_epi64(zero, b), _mm256_cmpgt_epi64(b, max_shift)); + __m256i shift = _mm256_blendv_epi8(b, max_shift, mask); + // Shift the number logically to the right, thus filling the most + // significant bits with 0s. Then, replace these bits with the sign + // bit. + __m256i sign_bits = _mm256_cmpgt_epi64(zero, a); + __m256i sign_shift = _mm256_sub_epi64(max_shift, shift); + __m256i sign_ext = _mm256_sllv_epi64(sign_bits, sign_shift); + __m256i c = _mm256_srlv_epi64(a, shift); + c = _mm256_or_si256(c, sign_ext); + + return c; +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return _mm256_srav_epi32(a, b); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return shift_256_16(a, b); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return shift_256_8(a, b); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return shift_256_8(a, b); +} + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_mask.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_mask.h new file mode 100644 index 0000000000000000000000000000000000000000..fb1a17f99695c120a31e4243a6ed5dcb698f45f9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_mask.h @@ -0,0 +1,298 @@ +#pragma once + +#include +#include +#include + +namespace at::vec { +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) + +template +struct VecMaskLoad< + T, + dst_n, + mask_t, + mask_n, + typename std::enable_if_t< + (mask_n == dst_n * 2 && dst_n >= 1) && + (std::is_same_v || std::is_same_v), + void>> { + static inline VectorizedN apply( + const T* ptr, + const VecMask& vec_mask) { + VectorizedN tmp_vec; + VectorizedN result; + for (int i = 0; i < dst_n; i++) { + tmp_vec[0] = vec_mask[2 * i]; + tmp_vec[1] = vec_mask[2 * i + 1]; + auto int64_mask = VecMask(tmp_vec).template cast(); + auto int_mask = int64_mask.template cast()[0]; + if constexpr (std::is_same_v) { + result[i] = Vectorized( + _mm256_maskload_ps(ptr + i * Vectorized::size(), int_mask)); + } else { + result[i] = Vectorized( + _mm256_maskload_epi32(ptr + i * Vectorized::size(), int_mask)); + } + } + return result; + } +}; + +template +struct VecMaskLoad< + T, + dst_n, + mask_t, + dst_n, + typename std::enable_if_t< + std::is_same_v || std::is_same_v, + void>> { + static inline VectorizedN apply( + const T* ptr, + const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < dst_n; i++) { + auto tmp_mask = VecMask(vec_mask[i]); + auto int_mask = tmp_mask.template cast()[0]; + if constexpr (std::is_same_v) { + result[i] = Vectorized( + _mm256_maskload_ps(ptr + i * Vectorized::size(), int_mask)); + } else { + result[i] = Vectorized( + _mm256_maskload_epi32(ptr + i * Vectorized::size(), int_mask)); + } + } + return result; + } +}; + +template +struct VecMaskLoad< + T, + 2, + mask_t, + 1, + typename std::enable_if_t< + std::is_same_v || std::is_same_v>> { + static inline VectorizedN apply( + const T* ptr, + const VecMask& vec_mask) { + auto int64_mask = vec_mask.template cast(); + auto result = at::vec::VectorizedN(); + if constexpr (std::is_same_v) { + result[0] = _mm256_maskload_pd(ptr, int64_mask[0]); + result[1] = _mm256_maskload_pd( + ptr + at::vec::Vectorized::size(), int64_mask[1]); + } else { + result[0] = _mm256_maskload_epi64( + reinterpret_cast(ptr), int64_mask[0]); + result[1] = _mm256_maskload_epi64( + reinterpret_cast( + ptr + at::vec::Vectorized::size()), + int64_mask[1]); + } + return result; + } +}; + +// TODO: add specialization of VecMaskLoad for bfloat16/half and int8/uint8 + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm256_castsi256_ps(vec_mask[i]); + } + return result; + } +}; + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm256_castps_si256(vec_mask[i]); + } + return result; + } +}; + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm256_castpd_si256(vec_mask[i]); + } + return result; + } +}; + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm256_castsi256_pd(vec_mask[i]); + } + return result; + } +}; + +template +struct VecMaskCast< + int64_t, + dst_n, + mask_t, + mask_n, + typename std::enable_if_t< + (dst_n == 2 * mask_n) && + (std::is_same_v || std::is_same_v), + void>> { + static inline VecMask apply( + const VecMask& vec_mask) { + VectorizedN result; + auto int_mask = vec_mask.template cast(); +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < mask_n; ++i) { + auto int64_vec = + convert(VectorizedN(int_mask[i])); + result[2 * i] = int64_vec[0]; + result[2 * i + 1] = int64_vec[1]; + } + return VecMask(result); + } +}; + +template +struct VecMaskCast< + dst_t, + dst_n, + int64_t, + mask_n, + typename std::enable_if_t< + (mask_n == 2 * dst_n) && + (std::is_same_v || std::is_same_v), + void>> { + static inline VecMask apply( + const VecMask& vec_mask) { + VectorizedN result; + VectorizedN int64_vec; + for (int i = 0; i < dst_n; ++i) { + int64_vec[0] = vec_mask[2 * i]; + int64_vec[1] = vec_mask[2 * i + 1]; + result[i] = convert(int64_vec); + } + return VecMask(result).template cast(); + } +}; + +template <> +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + auto int64_mask = VecMaskCast::apply(vec_mask); + return VecMaskCast::apply(int64_mask); + } +}; +template <> +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + auto int64_mask = VecMaskCast::apply(vec_mask); + return VecMaskCast::apply(int64_mask); + } +}; + +template <> +inline bool VecMask::all_zero() const { + return _mm256_testz_si256(mask_[0], mask_[0]); +} + +template <> +inline bool VecMask::is_masked(int i) const { + return _mm256_movemask_ps(_mm256_castsi256_ps(mask_[0])) & (1 << i); +} + +template <> +inline bool VecMask::all_masked() const { + int mask = _mm256_movemask_ps(_mm256_castsi256_ps(mask_[0])); + return mask == 0xff; +} + +template +struct VecMaskCheck { + static inline bool all_zero(const VectorizedN& vec_mask) { + bool all_zero = true; + for (int i = 0; i < N; ++i) { + all_zero = all_zero && (_mm256_testz_si256(vec_mask[i], vec_mask[i]) > 0); + if (!all_zero) { + return all_zero; + } + } + return all_zero; + } + + static inline bool is_masked(const VectorizedN& vec_mask, int i) { + for (int j = 0; j < N; ++j) { + if (i < (j + 1) * 4) { + return _mm256_movemask_pd(_mm256_castsi256_pd(vec_mask[j])) & + (1 << (i - j * 4)); + } + } + return false; + } + + static inline bool all_masked(const VectorizedN& vec_mask) { + bool all_masked = true; + for (int i = 0; i < N; ++i) { + all_masked = all_masked && + (_mm256_movemask_pd(_mm256_castsi256_pd(vec_mask[i])) == 0x0f); + if (!all_masked) { + return all_masked; + } + } + return all_masked; + } +}; + +#define VEC_MASK_METHOD_WITH_CAST_TO_INT( \ + T, N, return_type, method, args_def, args) \ + template <> \ + inline return_type VecMask::method args_def const { \ + return cast().method args; \ + } + +VEC_MASK_METHOD_WITH_CAST_TO_INT(float, 1, bool, all_zero, (), ()) +VEC_MASK_METHOD_WITH_CAST_TO_INT(int64_t, 2, bool, all_zero, (), ()) +VEC_MASK_METHOD_WITH_CAST_TO_INT(float, 1, bool, is_masked, (int i), (i)) +VEC_MASK_METHOD_WITH_CAST_TO_INT(int64_t, 2, bool, is_masked, (int i), (i)) +VEC_MASK_METHOD_WITH_CAST_TO_INT(float, 1, bool, all_masked, (), ()) +VEC_MASK_METHOD_WITH_CAST_TO_INT(int64_t, 2, bool, all_masked, (), ()) + +#undef VEC_MASK_DEFINE_METHOD_WITH_CAST_TO_INT + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_qint.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_qint.h new file mode 100644 index 0000000000000000000000000000000000000000..e43307c13aa535e9011b7a579989d9b52ae77918 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_qint.h @@ -0,0 +1,1397 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include + +#include +#include +#include +#include + +#include +#include + +// This file defines Vectorized<> for the quantized types. +// +// +// Currently, we simply use these classes as efficient converters between +// the quantized types and Vectorized, usually in bandwidth-bound cases +// where doing the arithmetic in full-precision is acceptable (e.g. +// elementwise operators). +// +// +// Conversions are as follows: +// Vectorized -> 4x Vectorized +// Vectorized -> 4x Vectorized +// Vectorized -> 1x Vectorized +// +// The size of the returned float vector is specified by the special +// constexpr function float_num_vecs. The type of the value returned +// from dequantize (and expected as an argument to quantize) is +// specified by float_vec_return_type. +// +// When writing kernels with these vectors, it is expected that floating- +// point operations will be carried out in a loop over +// Vectorized::float_num_vecs iterations. + +namespace at::vec { +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX2) + +#ifdef _MSC_VER +__declspec(align(64)) struct Vectorizedqi { + protected: + __m256i vals; +#else +struct Vectorizedqi { + protected: + __m256i vals __attribute__((aligned(64))); +#endif + + public: + Vectorizedqi() {} + Vectorizedqi(__m256i v) : vals(v) {} + operator __m256i() const { + return vals; + } +}; + +template +__m256i pack_saturate_and_clamp( + __m256i first, + __m256i second, + T min_val, + T max_val); + +template <> +inline __m256i pack_saturate_and_clamp( + __m256i /*first*/, + __m256i /*second*/, + int32_t /*min_val*/, + int32_t /*max_val*/) { + // This function is for linkage only, will not be used + TORCH_CHECK(false, "pack_saturate_and_clamp is not supported"); +} + +template <> +inline __m256i pack_saturate_and_clamp( + __m256i first, + __m256i second, + int8_t min_val, + int8_t max_val) { + __m256i packed_and_sat = _mm256_packs_epi16(first, second); + return _mm256_max_epi8( + _mm256_set1_epi8(min_val), + _mm256_min_epi8(packed_and_sat, _mm256_set1_epi8(max_val))); +} + +template <> +inline __m256i pack_saturate_and_clamp( + __m256i first, + __m256i second, + uint8_t min_val, + uint8_t max_val) { + __m256i packed_and_sat = _mm256_packus_epi16(first, second); + return _mm256_max_epu8( + _mm256_set1_epi8(min_val), + _mm256_min_epu8(packed_and_sat, _mm256_set1_epi8(max_val))); +} + +template +typename std::enable_if_t< + std::is_same_v || std::is_same_v, + at::vec::Vectorized< + float>> inline convert_int8_to_float(at::vec::Vectorized src) { + // Note: this function only convert inputs number of elements equal to + // at::vec::Vectorized.size() Only handle first 8*8 bits + __m128i input_128 = _mm256_castsi256_si128(src); + // Convert from 8*uint8/int8 to 8*int32 + __m256i input_256_int32; + if constexpr (std::is_same_v) + input_256_int32 = _mm256_cvtepu8_epi32(input_128); + else + input_256_int32 = _mm256_cvtepi8_epi32(input_128); + // Convert from 8*int32 to 8*float + return _mm256_cvtepi32_ps(input_256_int32); +} + +template +typename std::enable_if_t< + std::is_same_v || std::is_same_v, + at::vec::Vectorized< + T>> inline convert_float_to_int8(at::vec::Vectorized src) { + // Convert from float32 to int32 with truncation + __m256i x_values_int32 = _mm256_cvttps_epi32(src); + + // Convert from int32 to int16 using signed saturation + __m256i xy_packed_v = _mm256_packs_epi32(x_values_int32, x_values_int32); + + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + + // Convert from int16 to uint8/int8 using unsigned saturation + __m256i xyzw_clamped_v = + pack_saturate_and_clamp(xy_packed_v, xy_packed_v, min_val, max_val); + __m256i permute_mask_v = + _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00); + return _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v); +} + +template +__FORCE_INLINE void QuantizeAvx2( + const float* src, + T* dst, + int len, + float inverse_scale, + int64_t zero_point) { + constexpr int VLEN = 8; + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + const __m256i min_v = _mm256_set1_epi32(min_val); + const __m256i max_v = _mm256_set1_epi32(max_val); + // This is the largest int32 value < int32_max exactly representable in float + constexpr int32_t int32_float_max_val = + std::numeric_limits::max() - 127; + int i = 0; + __m256 inverse_scale_v = _mm256_set1_ps(inverse_scale); + // clang-format off + static const __m256i shuffle_mask_v = _mm256_set_epi8( + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0x0c, 0x08, 0x04, 0x00, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0x0c, 0x08, 0x04, 0x00); + // clang-format on + __m256i permute_mask_v = + _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00); + __m256i permute_mask_l8_v = + _mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00); + int len_aligned = len / (VLEN * 4) * (VLEN * 4); + for (; i < len_aligned; i += 4 * VLEN) { + // x + __m256 x_vals = _mm256_load_ps(src + i); + __m256 x_transformed_v = _mm256_mul_ps(x_vals, inverse_scale_v); + // If the floating point value is greater than int32_max, + // _mm256_cvtps_epi32 converts them to -ve. Clip at int32_float_max_val to + // Clip at int32_float_max_val to avoid this. + x_transformed_v = + _mm256_min_ps(x_transformed_v, _mm256_set1_ps(int32_float_max_val)); + // y + __m256 y_vals = _mm256_load_ps(src + i + VLEN); + __m256 y_transformed_v = _mm256_mul_ps(y_vals, inverse_scale_v); + y_transformed_v = + _mm256_min_ps(y_transformed_v, _mm256_set1_ps(int32_float_max_val)); + // z + __m256 z_vals = _mm256_load_ps(src + i + 2 * VLEN); + __m256 z_transformed_v = _mm256_mul_ps(z_vals, inverse_scale_v); + z_transformed_v = + _mm256_min_ps(z_transformed_v, _mm256_set1_ps(int32_float_max_val)); + // w + __m256 w_vals = _mm256_load_ps(src + i + 3 * VLEN); + __m256 w_transformed_v = _mm256_mul_ps(w_vals, inverse_scale_v); + w_transformed_v = + _mm256_min_ps(w_transformed_v, _mm256_set1_ps(int32_float_max_val)); + + __m256i x_rounded_v = _mm256_cvtps_epi32(x_transformed_v); + __m256i y_rounded_v = _mm256_cvtps_epi32(y_transformed_v); + __m256i z_rounded_v = _mm256_cvtps_epi32(z_transformed_v); + __m256i w_rounded_v = _mm256_cvtps_epi32(w_transformed_v); + + // add zero point + x_rounded_v = _mm256_add_epi32(x_rounded_v, _mm256_set1_epi32(zero_point)); + y_rounded_v = _mm256_add_epi32(y_rounded_v, _mm256_set1_epi32(zero_point)); + z_rounded_v = _mm256_add_epi32(z_rounded_v, _mm256_set1_epi32(zero_point)); + w_rounded_v = _mm256_add_epi32(w_rounded_v, _mm256_set1_epi32(zero_point)); + + __m256i xy_packed_v = _mm256_packs_epi32(x_rounded_v, y_rounded_v); + __m256i zw_packed_v = _mm256_packs_epi32(z_rounded_v, w_rounded_v); + __m256i xyzw_clamped_v = + pack_saturate_and_clamp(xy_packed_v, zw_packed_v, min_val, max_val); + + xyzw_clamped_v = + _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst + i), xyzw_clamped_v); + } + + // Additional 8-lane AVX2 version to take advantage when len is smaller + // based on fbgemm::QuantizeAvx2 (https://github.com/pytorch/FBGEMM) + for (; i < len / VLEN * VLEN; i += VLEN) { + __m256 x_vals = _mm256_load_ps(src + i); + __m256 x_transformed_v = _mm256_mul_ps(x_vals, inverse_scale_v); + x_transformed_v = + _mm256_min_ps(x_transformed_v, _mm256_set1_ps(int32_float_max_val)); + __m256i x_rounded_v = _mm256_cvtps_epi32(x_transformed_v); + x_rounded_v = _mm256_add_epi32(x_rounded_v, _mm256_set1_epi32(zero_point)); + __m256i x_clipped_v = + _mm256_max_epi32(min_v, _mm256_min_epi32(max_v, x_rounded_v)); + + x_clipped_v = _mm256_shuffle_epi8(x_clipped_v, shuffle_mask_v); + x_clipped_v = _mm256_permutevar8x32_epi32(x_clipped_v, permute_mask_l8_v); + _mm_storel_epi64( + reinterpret_cast<__m128i*>(dst + i), + _mm256_castsi256_si128(x_clipped_v)); + } + + for (; i < len; ++i) { + float transformed = src[i] * inverse_scale; + + // Not exactly the same behavior as the vectorized code. + // The vectorized code above always rounds to even in halfway cases + // (https://software.intel.com/en-us/node/523819), but std::nearbyint + // does the same only when the current rounding mode is FE_TONEAREST. + // However, in practice, this should not be a problem because most cases + // use the default rounding mode FE_TONEAREST. + // Note that we cannot implement the same behavior as the vectorized code + // using std::round because it does rounding away from zero in halfway + // cases. + transformed = zero_point + std::nearbyint(transformed); + float clipped = + std::min(std::max(transformed, float(min_val)), float(max_val)); + dst[i] = clipped; + } +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public Vectorizedqi { + using size_type = int; + static constexpr size_type kSize = Vectorized::size(); + static constexpr size_type size() { + return kSize; + } + + static constexpr int kFloatNumVecs = kSize / Vectorized::size(); + static constexpr int float_num_vecs() { + return kFloatNumVecs; + } + + static constexpr int int_num_vecs() { + return 1; + } + + using float_vec_return_type = std::array, kFloatNumVecs>; + using int_vec_return_type = std::array, 1>; + using value_type = c10::qint32::underlying; + + public: + using Vectorizedqi::Vectorizedqi; + Vectorized() {} + + Vectorized(__m256i vals_) { + vals = vals_; + } + + // Broadcast constructor + Vectorized(const c10::qint32& val) { + value_type uw = val.val_; + vals = _mm256_set1_epi32(uw); + } + + void store(void* ptr, int count = size()) const { + if (count != size()) { + memcpy(ptr, &vals, count * sizeof(value_type)); + } else { + _mm256_storeu_si256((__m256i*)ptr, vals); + } + } + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return _mm256_loadu_si256((const __m256i*)tmp_values); + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized /*zero_point*/, + Vectorized scale_zp_premul) const { + __m256 float_vals = _mm256_cvtepi32_ps(vals); + return {vec::fmadd(scale, Vectorized(float_vals), scale_zp_premul)}; + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + __m256 float_vals = _mm256_cvtepi32_ps(vals); + return {(Vectorized(float_vals) - zero_point) * scale}; + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float /*inverse_scale*/) { + Vectorized retval; + auto rhs_data = (__m256)rhs[0]; + at::native::quantize_vec( + scale, + zero_point, + (float*)&rhs_data, + (c10::qint32*)&retval.vals, + size()); + return retval; + } + + Vectorized maximum(Vectorized b) const { + return _mm256_max_epi32(vals, b.vals); + } + + Vectorized minimum(Vectorized b) const { + return _mm256_min_epi32(vals, b.vals); + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + return _mm256_min_epi32( + _mm256_max_epi32(vals, zero_point.vals), q_six.vals); + } + + int_vec_return_type widening_subtract(Vectorized b) const { + return {_mm256_sub_epi32(vals, b)}; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + __m256 multiplier_v = _mm256_set1_ps(multiplier); + __m256i zero_point_v = _mm256_set1_epi32(zero_point); + + __m256 scaled = _mm256_mul_ps(_mm256_cvtepi32_ps(inp[0]), multiplier_v); + __m256i rounded = _mm256_cvtps_epi32(scaled); + return _mm256_add_epi32(rounded, zero_point_v); + } + + private: + // Load from memory constructor + Vectorized(const void* ptr) { + vals = _mm256_loadu_si256((const __m256i*)ptr); + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return _mm256_mullo_epi32(a, b); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm256_add_epi32(a, b); +} + +/* + * Convert values from int32 back to int8/uint8 + */ +template +__m256i RequantizeAvx2( + const std::array, 4>& inp, + __m256 multiplier, + __m256i zp) { + static_assert( + std::is_same_v || std::is_same_v, + "Only int8_t/uint8_t are supported"); + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + __m256i permute_mask_v = + _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00); + __m256 x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(inp[0]), multiplier); + __m256 y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(inp[1]), multiplier); + __m256 z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(inp[2]), multiplier); + __m256 w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(inp[3]), multiplier); + + __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); + __m256i y_rounded_v = _mm256_cvtps_epi32(y_scaled_v); + __m256i z_rounded_v = _mm256_cvtps_epi32(z_scaled_v); + __m256i w_rounded_v = _mm256_cvtps_epi32(w_scaled_v); + + /* Add zero point */ + __m256i x_v = _mm256_add_epi32(x_rounded_v, zp); + __m256i y_v = _mm256_add_epi32(y_rounded_v, zp); + __m256i z_v = _mm256_add_epi32(z_rounded_v, zp); + __m256i w_v = _mm256_add_epi32(w_rounded_v, zp); + + /* Pack to int16_t and saturate */ + __m256i xy_packed_v = _mm256_packs_epi32(x_v, y_v); + __m256i zw_packed_v = _mm256_packs_epi32(z_v, w_v); + + __m256i xyzw_clamped_v = + pack_saturate_and_clamp(xy_packed_v, zw_packed_v, min_val, max_val); + + /* + * xyzw_clamped_v has results in the following layout so we need to + * permute: x0-3 y0-3 z0-3 w0-3 x4-7 y4-7 z4-7 w4-7 + */ + xyzw_clamped_v = _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v); + return xyzw_clamped_v; +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public Vectorizedqi { + static constexpr int kSize = VECTOR_WIDTH; + static constexpr int size() { + return kSize; + } + + static constexpr int kFloatNumVecs = kSize / Vectorized::size(); + static constexpr int float_num_vecs() { + return kFloatNumVecs; + } + + static constexpr int kIntNumVecs = kSize / Vectorized::size(); + static constexpr int int_num_vecs() { + return kIntNumVecs; + } + + using float_vec_return_type = std::array, kFloatNumVecs>; + using int_vec_return_type = std::array, kIntNumVecs>; + using value_type = typename c10::qint8::underlying; + + public: + using Vectorizedqi::Vectorizedqi; + + Vectorized() {} + Vectorized(__m256i vals_) { + vals = vals_; + } + + // Broadcast constructor + Vectorized(const c10::qint8& val) { + value_type uw = val.val_; + vals = _mm256_set1_epi8(uw); + } + + // This is needed because the compiler emits awful code for the default + // constructor for moving the enum + // NOLINTNEXTLINE(clang-diagnostic-deprecated-copy) + C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wdeprecated-copy") + C10_CLANG_DIAGNOSTIC_IGNORE("-Wdeprecated-copy") +#endif + Vectorized(const Vectorized& other) : Vectorizedqi(other.vals) {} + C10_CLANG_DIAGNOSTIC_POP() + + void store(void* ptr, int count = size()) const { + if (count != size()) { + memcpy(ptr, &vals, count * sizeof(value_type)); + } else { + _mm256_storeu_si256((__m256i*)ptr, vals); + } + } + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return _mm256_loadu_si256((const __m256i*)tmp_values); + } + + private: + __m256i cvtepi8_epi32(__m128i epi8_vals) const { + return _mm256_cvtepi8_epi32(epi8_vals); + } + + public: + float_vec_return_type dequantize( + Vectorized scale, + Vectorized /*zero_point*/, + Vectorized scale_neg_zp_premul) const { + __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0)); + __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1)); + __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2)); + __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3)); + + __m256 float_val0 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val0)); + __m256 float_val1 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val1)); + __m256 float_val2 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val2)); + __m256 float_val3 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val3)); + + auto val0 = + vec::fmadd(scale, Vectorized(float_val0), scale_neg_zp_premul); + auto val1 = + vec::fmadd(scale, Vectorized(float_val1), scale_neg_zp_premul); + auto val2 = + vec::fmadd(scale, Vectorized(float_val2), scale_neg_zp_premul); + auto val3 = + vec::fmadd(scale, Vectorized(float_val3), scale_neg_zp_premul); + return {val0, val1, val2, val3}; + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0)); + __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1)); + __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2)); + __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3)); + + __m256 float_val0 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val0)); + __m256 float_val1 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val1)); + __m256 float_val2 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val2)); + __m256 float_val3 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val3)); + + auto val0 = (Vectorized(float_val0) - zero_point) * scale; + auto val1 = (Vectorized(float_val1) - zero_point) * scale; + auto val2 = (Vectorized(float_val2) - zero_point) * scale; + auto val3 = (Vectorized(float_val3) - zero_point) * scale; + return {val0, val1, val2, val3}; + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float /*scale*/, + int32_t zero_point, + float inverse_scale) { + auto* rhs_data = (float*)rhs.data(); + int8_t quantized_values[32]; + QuantizeAvx2( + rhs_data, quantized_values, 32, inverse_scale, zero_point); + return Vectorized::loadu(quantized_values); + } + + Vectorized maximum(Vectorized b) const { + return _mm256_max_epi8(vals, b.vals); + } + + Vectorized minimum(Vectorized b) const { + return _mm256_min_epi8(vals, b.vals); + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + return _mm256_min_epi8(_mm256_max_epi8(vals, zero_point.vals), q_six.vals); + } + + int_vec_return_type widening_subtract(Vectorized b) const { + __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0)); + __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1)); + __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2)); + __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3)); + + __m256i int32_val0 = cvtepi8_epi32(int_val0); + __m256i int32_val1 = cvtepi8_epi32(int_val1); + __m256i int32_val2 = cvtepi8_epi32(int_val2); + __m256i int32_val3 = cvtepi8_epi32(int_val3); + + __m128i int_b0 = _mm_set1_epi64x(_mm256_extract_epi64(b, 0)); + __m128i int_b1 = _mm_set1_epi64x(_mm256_extract_epi64(b, 1)); + __m128i int_b2 = _mm_set1_epi64x(_mm256_extract_epi64(b, 2)); + __m128i int_b3 = _mm_set1_epi64x(_mm256_extract_epi64(b, 3)); + + __m256i int32_b0 = cvtepi8_epi32(int_b0); + __m256i int32_b1 = cvtepi8_epi32(int_b1); + __m256i int32_b2 = cvtepi8_epi32(int_b2); + __m256i int32_b3 = cvtepi8_epi32(int_b3); + + __m256i res_0 = _mm256_sub_epi32(int32_val0, int32_b0); + __m256i res_1 = _mm256_sub_epi32(int32_val1, int32_b1); + __m256i res_2 = _mm256_sub_epi32(int32_val2, int32_b2); + __m256i res_3 = _mm256_sub_epi32(int32_val3, int32_b3); + + return { + Vectorized(res_0), + Vectorized(res_1), + Vectorized(res_2), + Vectorized(res_3)}; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + __m256 multiplier_v = _mm256_set1_ps(multiplier); + __m256i zero_point_v = _mm256_set1_epi32(zero_point); + return RequantizeAvx2(inp, multiplier_v, zero_point_v); + } + + private: + // Load from memory constructor + Vectorized(const void* ptr) { + vals = _mm256_loadu_si256((const __m256i*)ptr); + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public Vectorizedqi { + static constexpr int kSize = VECTOR_WIDTH; + static constexpr int size() { + return kSize; + } + + static constexpr int kFloatNumVecs = kSize / Vectorized::size(); + static constexpr int float_num_vecs() { + return kFloatNumVecs; + } + + static constexpr int kIntNumVecs = kSize / Vectorized::size(); + static constexpr int int_num_vecs() { + return kIntNumVecs; + } + + using float_vec_return_type = std::array, kFloatNumVecs>; + using int_vec_return_type = std::array, kIntNumVecs>; + using value_type = typename c10::quint8::underlying; + + public: + using Vectorizedqi::Vectorizedqi; + Vectorized() {} + + Vectorized(__m256i vals_) { + vals = vals_; + } + + // Broadcast constructor + Vectorized(const c10::quint8& val) { + value_type uw = val.val_; + vals = _mm256_set1_epi8(uw); + } + + // NOLINTNEXTLINE(clang-diagnostic-deprecated-copy) + C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wdeprecated-copy") + C10_CLANG_DIAGNOSTIC_IGNORE("-Wdeprecated-copy") +#endif + Vectorized(const Vectorized& other) : Vectorizedqi(other.vals) {} + C10_CLANG_DIAGNOSTIC_POP() + + void store(void* ptr, int count = size()) const { + if (count != size()) { + memcpy(ptr, &vals, count * sizeof(value_type)); + } else { + _mm256_storeu_si256((__m256i*)ptr, vals); + } + } + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return _mm256_loadu_si256((const __m256i*)tmp_values); + } + + private: + __m256i cvtepu8_epi32(__m128i epu8_vals) const { + return _mm256_cvtepu8_epi32(epu8_vals); + } + + public: + float_vec_return_type dequantize( + Vectorized scale, + Vectorized /*zero_point*/, + Vectorized scale_zp_premul) const { + __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0)); + __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1)); + __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2)); + __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3)); + + __m256 float_val0 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val0)); + __m256 float_val1 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val1)); + __m256 float_val2 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val2)); + __m256 float_val3 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val3)); + + auto val0 = + vec::fmadd(scale, Vectorized(float_val0), scale_zp_premul); + auto val1 = + vec::fmadd(scale, Vectorized(float_val1), scale_zp_premul); + auto val2 = + vec::fmadd(scale, Vectorized(float_val2), scale_zp_premul); + auto val3 = + vec::fmadd(scale, Vectorized(float_val3), scale_zp_premul); + return {val0, val1, val2, val3}; + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0)); + __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1)); + __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2)); + __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3)); + + __m256 float_val0 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val0)); + __m256 float_val1 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val1)); + __m256 float_val2 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val2)); + __m256 float_val3 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val3)); + + auto val0 = (Vectorized(float_val0) - zero_point) * scale; + auto val1 = (Vectorized(float_val1) - zero_point) * scale; + auto val2 = (Vectorized(float_val2) - zero_point) * scale; + auto val3 = (Vectorized(float_val3) - zero_point) * scale; + return {val0, val1, val2, val3}; + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float /*scale*/, + int32_t zero_point, + float inverse_scale) { + auto* rhs_data = (float*)rhs.data(); + uint8_t quantized_values[32]; + QuantizeAvx2( + rhs_data, quantized_values, 32, inverse_scale, zero_point); + return Vectorized::loadu(quantized_values); + } + + Vectorized maximum(Vectorized b) const { + return _mm256_max_epu8(vals, b.vals); + } + + Vectorized minimum(Vectorized b) const { + return _mm256_min_epu8(vals, b.vals); + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + return _mm256_min_epu8(_mm256_max_epu8(vals, zero_point.vals), q_six.vals); + } + + int_vec_return_type widening_subtract(Vectorized b) const { + __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0)); + __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1)); + __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2)); + __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3)); + + __m256i int32_val0 = cvtepu8_epi32(int_val0); + __m256i int32_val1 = cvtepu8_epi32(int_val1); + __m256i int32_val2 = cvtepu8_epi32(int_val2); + __m256i int32_val3 = cvtepu8_epi32(int_val3); + + __m128i int_b0 = _mm_set1_epi64x(_mm256_extract_epi64(b, 0)); + __m128i int_b1 = _mm_set1_epi64x(_mm256_extract_epi64(b, 1)); + __m128i int_b2 = _mm_set1_epi64x(_mm256_extract_epi64(b, 2)); + __m128i int_b3 = _mm_set1_epi64x(_mm256_extract_epi64(b, 3)); + + __m256i int32_b0 = cvtepu8_epi32(int_b0); + __m256i int32_b1 = cvtepu8_epi32(int_b1); + __m256i int32_b2 = cvtepu8_epi32(int_b2); + __m256i int32_b3 = cvtepu8_epi32(int_b3); + + __m256i res_0 = _mm256_sub_epi32(int32_val0, int32_b0); + __m256i res_1 = _mm256_sub_epi32(int32_val1, int32_b1); + __m256i res_2 = _mm256_sub_epi32(int32_val2, int32_b2); + __m256i res_3 = _mm256_sub_epi32(int32_val3, int32_b3); + return { + Vectorized(res_0), + Vectorized(res_1), + Vectorized(res_2), + Vectorized(res_3)}; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + __m256 multiplier_v = _mm256_set1_ps(multiplier); + __m256i zero_point_v = _mm256_set1_epi32(zero_point); + return RequantizeAvx2(inp, multiplier_v, zero_point_v); + } + + private: + // Load from memory constructor + Vectorized(const void* ptr) { + vals = _mm256_loadu_si256((const __m256i*)ptr); + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +#elif !defined(CPU_CAPABILITY_SVE256) + +// NOTE: These are low-performance implementations that we fall back on +// if we are not building with AVX2. This may not be an issue, because +// currently for quantization we assume the user has at least AVX512 +// installed, so these can simply act as a reference implementation. +// +// If in the future we relax this requirement (AVX2+), we should probably +// revisit these implementations + +template < + typename T, + typename float_vec_return_type_, + typename int_vec_return_type_, + int size_> +struct VectorizedQuantizedConverter { + static constexpr int size() { + return size_; + } + + static constexpr int float_num_vecs() { + return size_ / Vectorized::size(); + } + + static constexpr int int_num_vecs() { + return size_ / Vectorized::size(); + } + + using float_vec_return_type = float_vec_return_type_; + using int_vec_return_type = int_vec_return_type_; + + using value_type = typename T::underlying; + std::array vals; + + VectorizedQuantizedConverter(T val) { + for (const auto i : c10::irange(size())) { + vals[i] = val.val_; + } + } + + VectorizedQuantizedConverter(const void* ptr) { + memcpy(vals.data(), ptr, sizeof(value_type) * size()); + } + + void store(void* ptr, int count = size()) const { + memcpy(ptr, vals.data(), count * sizeof(value_type)); + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized /*scale_zp_premul*/) const { + float_vec_return_type rv; + for (const auto i : c10::irange(float_num_vecs())) { + float tmp_vals[Vectorized::size()]; + for (const auto j : c10::irange(Vectorized::size())) { + tmp_vals[j] = at::native::dequantize_val( + scale[j], + zero_point[j], + T(vals[Vectorized::size() * i + j])); + } + rv[i] = Vectorized(tmp_vals); + } + return rv; + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + Vectorized scale_zp_premul; + return dequantize(scale, zero_point, scale_zp_premul); + } + + protected: + VectorizedQuantizedConverter() {} +}; + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + Vectorized::size()> { + using VectorizedQuantizedConverter::VectorizedQuantizedConverter; + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return Vectorized(tmp_values); + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float /*inverse_scale*/) { + std::array qvals; + std::array::size()> float_vals; + + for (const auto i : c10::irange(float_num_vecs())) { + rhs[i].store(&float_vals[i * Vectorized::size()]); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::qint32*)qvals.data(), + float_vals.size()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + for (const auto i : c10::irange(size())) { + retval[0].vals[i] = vals[i] - b.vals[i]; + } + return retval; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = + std::nearbyint(static_cast(inp[0].vals[i]) * multiplier) + + zero_point; + } + return retval; + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + Vectorized retval; + for (const auto i : c10::irange(std::decay_t::size())) { + retval.vals[i] = a.vals[i] * b.vals[i]; + } + return retval; +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + Vectorized retval; + for (const auto i : c10::irange(std::decay_t::size())) { + retval.vals[i] = a.vals[i] + b.vals[i]; + } + return retval; +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + 4 * Vectorized::size()> { + using VectorizedQuantizedConverter::VectorizedQuantizedConverter; + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return Vectorized(tmp_values); + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float /*inverse_scale*/) { + std::array qvals; + std::array::size()> float_vals; + + for (const auto i : c10::irange(float_num_vecs())) { + rhs[i].store(&float_vals[i * Vectorized::size()]); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::qint8*)qvals.data(), + float_vals.size()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + constexpr int elem_per_int_vec = size() / int_num_vecs(); + for (const auto i : c10::irange(int_num_vecs())) { + for (const auto j : c10::irange(elem_per_int_vec)) { + retval[i].vals[j] = + static_cast(vals[i * elem_per_int_vec + j]) - + static_cast(b.vals[i * elem_per_int_vec + j]); + } + } + return retval; + } + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + constexpr int elem_per_int_vec = size() / int_num_vecs(); + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + Vectorized retval; + for (const auto i : c10::irange(int_num_vecs())) { + for (const auto j : c10::irange(elem_per_int_vec)) { + int32_t rounded = + std::nearbyint(static_cast(inp[i].vals[j]) * multiplier) + + zero_point; + retval.vals[i * elem_per_int_vec + j] = + std::min(std::max(rounded, min_val), max_val); + } + } + return retval; + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + 4 * Vectorized::size()> { + using VectorizedQuantizedConverter::VectorizedQuantizedConverter; + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return Vectorized(tmp_values); + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float /*inverse_scale*/) { + std::array qvals; + std::array::size()> float_vals; + + for (const auto i : c10::irange(float_num_vecs())) { + rhs[i].store(&float_vals[i * Vectorized::size()]); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::quint8*)qvals.data(), + float_vals.size()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + constexpr int elem_per_int_vec = size() / int_num_vecs(); + for (const auto i : c10::irange(int_num_vecs())) { + for (const auto j : c10::irange(elem_per_int_vec)) { + retval[i].vals[j] = + static_cast(vals[i * elem_per_int_vec + j]) - + static_cast(b.vals[i * elem_per_int_vec + j]); + } + } + return retval; + } + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + constexpr int elem_per_int_vec = size() / int_num_vecs(); + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + Vectorized retval; + for (const auto i : c10::irange(int_num_vecs())) { + for (const auto j : c10::irange(elem_per_int_vec)) { + int32_t rounded = + std::nearbyint(static_cast(inp[i].vals[j]) * multiplier) + + zero_point; + retval.vals[i * elem_per_int_vec + j] = + std::min(std::max(rounded, min_val), max_val); + } + } + return retval; + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +#endif // if defined(CPU_CAPABILITY_AVX2) + +#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256)) +std::pair, Vectorized> inline convert_int8_to_float( + at::vec::Vectorized src) { + auto s8x8 = vld1_s8(src.operator const int8_t*()); + auto s16x8 = vmovl_s8(s8x8); + + auto s32x4_hi = vmovl_s16(vget_high_s16(s16x8)); + auto s32x4_lo = vmovl_s16(vget_low_s16(s16x8)); + + return std::make_pair( + Vectorized(vcvtq_f32_s32(s32x4_lo)), + Vectorized(vcvtq_f32_s32(s32x4_hi))); +} + +std::pair, Vectorized> inline convert_int8_to_float( + at::vec::Vectorized src) { + auto u8x8 = vld1_u8(src.operator const uint8_t*()); + auto u16x8 = vmovl_u8(u8x8); + auto u32x4_hi = vmovl_u16(vget_high_u16(u16x8)); + auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8)); + + return std::make_pair( + Vectorized(vcvtq_f32_u32(u32x4_lo)), + Vectorized(vcvtq_f32_u32(u32x4_hi))); +} + +Vectorized inline convert_int8_half_register_to_float( + at::vec::Vectorized src) { + auto s8x8 = vld1_s8(src.operator const int8_t*()); + auto s16x8 = vmovl_s8(s8x8); + + auto s32x4_lo = vmovl_s16(vget_low_s16(s16x8)); + + return Vectorized(vcvtq_f32_s32(s32x4_lo)); +} + +Vectorized inline convert_int8_half_register_to_float( + at::vec::Vectorized src) { + auto u8x8 = vld1_u8(src.operator const uint8_t*()); + auto u16x8 = vmovl_u8(u8x8); + auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8)); + + return Vectorized(vcvtq_f32_u32(u32x4_lo)); +} + +#endif +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_bfloat16_vsx.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_bfloat16_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..5a00e46ed179c62f17ba7e086f466f44d2353c7d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_bfloat16_vsx.h @@ -0,0 +1,75 @@ +#pragma once + +#include +#include +#include +#include + +namespace at { +namespace vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +inline std::tuple, Vectorized> convert_bfloat16_float( + const Vectorized& a) { + constexpr int64_t K = Vectorized::size(); + __at_align__ float arr[K]; + __at_align__ BFloat16 arr2[K]; + a.store(arr2); + convert(arr2, arr, K); + return std::make_tuple( + Vectorized::loadu(arr), + Vectorized::loadu(arr + Vectorized::size())); +} + +inline Vectorized convert_float_bfloat16( + const Vectorized& a, + const Vectorized& b) { + constexpr int64_t K = Vectorized::size(); + __at_align__ float arr[K]; + __at_align__ BFloat16 arr2[K]; + a.store(arr); + b.store(arr + Vectorized::size()); + convert(arr, arr2, K); + return Vectorized::loadu(arr2); +} + +inline void load_fp32_from_bf16( + const c10::BFloat16* data, + Vectorized& out) { + __at_align__ float values[Vectorized::size()]; + for (const auto k : c10::irange(Vectorized::size())) { + values[k] = data[k]; + } + out = Vectorized::loadu(values); +} + +inline void load_fp32_from_bf16( + const c10::BFloat16* data, + Vectorized& out1, + Vectorized& out2) { + load_fp32_from_bf16(data, out1); + data += Vectorized::size(); + load_fp32_from_bf16(data, out2); +} + +inline void load_fp32_from_fp16(const c10::Half* data, Vectorized& out) { + __at_align__ float values[Vectorized::size()]; + for (const auto k : c10::irange(Vectorized::size())) { + values[k] = data[k]; + } + out = Vectorized::loadu(values); +} + +inline void load_fp32_from_fp16( + const c10::Half* data, + Vectorized& out1, + Vectorized& out2) { + load_fp32_from_fp16(data, out1); + data += Vectorized::size(); + load_fp32_from_fp16(data, out2); +} + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_common_vsx.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_common_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..de737919ad41173a43011c2a01d85ae5bf0642f9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_common_vsx.h @@ -0,0 +1,249 @@ +#pragma once + +#include +#include +#include + +// Note: header order is important here +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +namespace at { +namespace vec { + +inline namespace CPU_CAPABILITY { + +DEFINE_CLAMP_FUNCS(c10::quint8) +DEFINE_CLAMP_FUNCS(c10::qint8) +DEFINE_CLAMP_FUNCS(c10::qint32) +DEFINE_CLAMP_FUNCS(int16_t) +DEFINE_CLAMP_FUNCS(int32_t) +DEFINE_CLAMP_FUNCS(int64_t) +DEFINE_CLAMP_FUNCS(float) +DEFINE_CLAMP_FUNCS(double) + +template <> +Vectorized C10_ALWAYS_INLINE fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return Vectorized{ + vec_madd(a.vec0(), b.vec0(), c.vec0()), + vec_madd(a.vec1(), b.vec1(), c.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return Vectorized{ + a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()}; +} +template <> +Vectorized C10_ALWAYS_INLINE fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return Vectorized{ + a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()}; +} +template <> +Vectorized C10_ALWAYS_INLINE fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return Vectorized{ + a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()}; +} + +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(float) +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(double) +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int64_t) +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int32_t) +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int16_t) + +template <> +Vectorized C10_ALWAYS_INLINE +convert_to_int_of_same_size(const Vectorized& src) { + return Vectorized{vec_signed(src.vec0()), vec_signed(src.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +convert_to_int_of_same_size(const Vectorized& src) { + return Vectorized{vec_signed(src.vec0()), vec_signed(src.vec1())}; +} + +template <> +inline void convert(const int32_t* src, float* dst, int64_t n) { + // int32_t and float have same size + int64_t i; + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + const int32_t* src_a = src + i; + float* dst_a = dst + i; + vint32 input_vec0 = + vec_vsx_ld(offset0, reinterpret_cast(src_a)); + vint32 input_vec1 = + vec_vsx_ld(offset16, reinterpret_cast(src_a)); + vfloat32 c0 = vec_float(input_vec0); + vfloat32 c1 = vec_float(input_vec1); + vec_vsx_st(c0, offset0, dst_a); + vec_vsx_st(c1, offset16, dst_a); + } + + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +template <> +inline void convert(const int64_t* src, double* dst, int64_t n) { + int64_t i; + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + const int64_t* src_a = src + i; + double* dst_a = dst + i; + vint64 input_vec0 = + vec_vsx_ld(offset0, reinterpret_cast(src_a)); + vint64 input_vec1 = + vec_vsx_ld(offset16, reinterpret_cast(src_a)); + vfloat64 c0 = vec_double(input_vec0); + vfloat64 c1 = vec_double(input_vec1); + vec_vsx_st(c0, offset0, reinterpret_cast(dst_a)); + vec_vsx_st(c1, offset16, reinterpret_cast(dst_a)); + } + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} +// Generic implementation to fix compiler error +// TO-DO : Add optimized version for ppc64 +inline std::tuple, Vectorized> convert_half_float( + const Vectorized& a) { + constexpr int64_t K = Vectorized::size(); + __at_align__ float arr[K]; + __at_align__ Half arr2[K]; + a.store(arr2); + convert(arr2, arr, K); + return std::make_tuple( + Vectorized::loadu(arr), + Vectorized::loadu(arr + Vectorized::size())); +} + +inline Vectorized convert_float_half( + const Vectorized& a, + const Vectorized& b) { + constexpr int64_t K = Vectorized::size(); + __at_align__ float arr[K]; + __at_align__ Half arr2[K]; + a.store(arr); + b.store(arr + Vectorized::size()); + convert(arr, arr2, K); + return Vectorized::loadu(arr2); +}; + +template <> +std::pair, Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, a1, a2, a3} + // b = {b0, b1, b2, b3} + + vfloat64 ab00 = vec_xxpermdi(a.vec0(), b.vec0(), 0); + vfloat64 ab11 = vec_xxpermdi(a.vec0(), b.vec0(), 3); + vfloat64 ab2_00 = vec_xxpermdi(a.vec1(), b.vec1(), 0); + vfloat64 ab2_11 = vec_xxpermdi(a.vec1(), b.vec1(), 3); + // return {a0, b0, a1, b1} + // {a2, b2, a3, b3} + return std::make_pair( + Vectorized{ab00, ab11}, Vectorized{ab2_00, ab2_11}); +} + +template <> +std::pair, Vectorized> inline deinterleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1} + // b = {a2, b2, a3, b3} + vfloat64 aa01 = vec_xxpermdi(a.vec0(), a.vec1(), 0); + vfloat64 aa23 = vec_xxpermdi(b.vec0(), b.vec1(), 0); + + vfloat64 bb_01 = vec_xxpermdi(a.vec0(), a.vec1(), 3); + vfloat64 bb_23 = vec_xxpermdi(b.vec0(), b.vec1(), 3); + + // swap lanes: + // return {a0, a1, a2, a3} + // {b0, b1, b2, b3} + return std::make_pair( + Vectorized{aa01, aa23}, Vectorized{bb_01, bb_23}); +} + +template <> +std::pair, Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, a1, a2, a3,, a4, a5, a6, a7} + // b = {b0, b1, b2, b3,, b4, b5, b6, b7} + + vfloat32 ab0011 = vec_mergeh(a.vec0(), b.vec0()); + vfloat32 ab2233 = vec_mergel(a.vec0(), b.vec0()); + + vfloat32 ab2_0011 = vec_mergeh(a.vec1(), b.vec1()); + vfloat32 ab2_2233 = vec_mergel(a.vec1(), b.vec1()); + // group cols crossing lanes: + // return {a0, b0, a1, b1,, a2, b2, a3, b3} + // {a4, b4, a5, b5,, a6, b6, a7, b7} + + return std::make_pair( + Vectorized{ab0011, ab2233}, Vectorized{ab2_0011, ab2_2233}); +} + +template <> +std::pair, Vectorized> inline deinterleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1,, a2, b2, a3, b3} + // b = {a4, b4, a5, b5,, a6, b6, a7, b7} + + // {a0,a2,b0,b2} {a1,a3,b1,b3} + vfloat32 a0a2b0b2 = vec_mergeh(a.vec0(), a.vec1()); + vfloat32 a1a3b1b3 = vec_mergel(a.vec0(), a.vec1()); + + vfloat32 aa0123 = vec_mergeh(a0a2b0b2, a1a3b1b3); + vfloat32 bb0123 = vec_mergel(a0a2b0b2, a1a3b1b3); + + vfloat32 a0a2b0b2_2 = vec_mergeh(b.vec0(), b.vec1()); + vfloat32 a1a3b1b3_2 = vec_mergel(b.vec0(), b.vec1()); + + vfloat32 aa0123_2 = vec_mergeh(a0a2b0b2_2, a1a3b1b3_2); + vfloat32 bb0123_2 = vec_mergel(a0a2b0b2_2, a1a3b1b3_2); + + // it could be done with vec_perm ,too + // swap lanes: + // return {a0, a1, a2, a3,, a4, a5, a6, a7} + // {b0, b1, b2, b3,, b4, b5, b6, b7} + + return std::make_pair( + Vectorized{aa0123, aa0123_2}, Vectorized{bb0123, bb0123_2}); +} + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..80bd71dae7c975643b6e1e99016aa91d319f56bc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h @@ -0,0 +1,679 @@ +#pragma once +#include +#include +#include +#include +#include + +namespace at { +namespace vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { +using ComplexDbl = c10::complex; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + union { + struct { + vfloat64 _vec0; + vfloat64 _vec1; + }; + struct { + vbool64 _vecb0; + vbool64 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + using value_type = ComplexDbl; + using vec_internal_type = vfloat64; + using vec_internal_mask_type = vbool64; + using size_type = int; + static constexpr size_type size() { + return 2; + } + Vectorized() {} + C10_ALWAYS_INLINE Vectorized(vfloat64 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vectorized(vbool64 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vectorized(vfloat64 v1, vfloat64 v2) + : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vectorized(vbool64 v1, vbool64 v2) + : _vecb0{v1}, _vecb1{v2} {} + + Vectorized(ComplexDbl val) { + double real_value = val.real(); + double imag_value = val.imag(); + _vec0 = vfloat64{real_value, imag_value}; + _vec1 = vfloat64{real_value, imag_value}; + } + Vectorized(ComplexDbl val1, ComplexDbl val2) { + _vec0 = vfloat64{val1.real(), val1.imag()}; + _vec1 = vfloat64{val2.real(), val2.imag()}; + } + + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + template + static std:: + enable_if_t> + C10_ALWAYS_INLINE blend( + const Vectorized& a, + const Vectorized& b) { + return a; + } + + template + static std:: + enable_if_t> + C10_ALWAYS_INLINE blend( + const Vectorized& a, + const Vectorized& b) { + return b; + } + + template + static std:: + enable_if_t> + C10_ALWAYS_INLINE blend( + const Vectorized& a, + const Vectorized& b) { + return {b._vec0, a._vec1}; + } + + template + static std:: + enable_if_t> + C10_ALWAYS_INLINE blend( + const Vectorized& a, + const Vectorized& b) { + return {a._vec0, b._vec1}; + } + + template + static Vectorized C10_ALWAYS_INLINE + el_blend(const Vectorized& a, const Vectorized& b) { + const vbool64 mask_1st = VsxDblMask1(mask); + const vbool64 mask_2nd = VsxDblMask2(mask); + return { + (vfloat64)vec_sel(a._vec0, b._vec0, mask_1st), + (vfloat64)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + // convert std::complex index mask to V index mask: xy -> xxyy + auto mask_complex = Vectorized( + vec_splat(mask._vec0, 0), vec_splat(mask._vec1, 0)); + return { + vec_sel(a._vec0, b._vec0, mask_complex._vecb0), + vec_sel(a._vec1, b._vec1, mask_complex._vecb1)}; + } + + static Vectorized C10_ALWAYS_INLINE elwise_blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + return { + vec_sel(a._vec0, b._vec0, mask._vecb0), + vec_sel(a._vec1, b._vec1, mask._vecb1)}; + } + template + static Vectorized arange( + ComplexDbl base = 0., + step_t step = static_cast(1)) { + return Vectorized(base, base + step); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + } + return b; + } + + static Vectorized C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + + __at_align__ value_type tmp_values[size()] = {}; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return { + vec_vsx_ld(offset0, reinterpret_cast(tmp_values)), + vec_vsx_ld(offset16, reinterpret_cast(tmp_values))}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, reinterpret_cast(tmp_values)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(tmp_values)); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + + const ComplexDbl& operator[](int idx) const = delete; + ComplexDbl& operator[](int idx) = delete; + + Vectorized map(ComplexDbl (*const f)(ComplexDbl)) const { + __at_align__ ComplexDbl tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + + Vectorized map(ComplexDbl (*const f)(const ComplexDbl&)) const { + __at_align__ ComplexDbl tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + + Vectorized el_swapped() const { + vfloat64 v0 = vec_xxpermdi(_vec0, _vec0, 2); + vfloat64 v1 = vec_xxpermdi(_vec1, _vec1, 2); + return {v0, v1}; + } + + Vectorized el_madd( + const Vectorized& multiplier, + const Vectorized& val) const { + return { + vec_madd(_vec0, multiplier._vec0, val._vec0), + vec_madd(_vec1, multiplier._vec1, val._vec1)}; + } + + Vectorized el_mergeo() const { + vfloat64 v0 = vec_splat(_vec0, 1); + vfloat64 v1 = vec_splat(_vec1, 1); + return {v0, v1}; + } + + Vectorized el_mergee() const { + vfloat64 v0 = vec_splat(_vec0, 0); + vfloat64 v1 = vec_splat(_vec1, 0); + return {v0, v1}; + } + + static Vectorized el_mergee( + const Vectorized& first, + const Vectorized& second) { + return { + vec_mergeh(first._vec0, second._vec0), + vec_mergeh(first._vec1, second._vec1)}; + } + + static Vectorized el_mergeo( + const Vectorized& first, + const Vectorized& second) { + return { + vec_mergel(first._vec0, second._vec0), + vec_mergel(first._vec1, second._vec1)}; + } + + Vectorized abs_2_() const { + auto a = (*this).elwise_mult(*this); + auto permuted = a.el_swapped(); + a = a + permuted; + return a; + } + + Vectorized abs_() const { + auto vi = el_mergeo(); + auto vr = el_mergee(); + return { + Sleef_hypotd2_u05vsx(vr._vec0, vi._vec0), + Sleef_hypotd2_u05vsx(vr._vec1, vi._vec1)}; + } + + Vectorized abs() const { + return abs_() & vd_real_mask; + } + + Vectorized angle_() const { + // angle = atan2(b/a) + // auto b_a = _mm256_permute_pd(values, 0x05); // b a + // return Sleef_atan2d4_u10(values, b_a); // 90-angle angle + Vectorized ret; + ret._vec0[0] = std::atan2(_vec0[1], _vec0[0]); + ret._vec1[0] = std::atan2(_vec1[1], _vec1[0]); + return ret; + } + + Vectorized angle() const { + return angle_() & vd_real_mask; + } + + Vectorized real_() const { + return *this & vd_real_mask; + } + Vectorized real() const { + return *this & vd_real_mask; + } + Vectorized imag_() const { + return *this & vd_imag_mask; + } + Vectorized imag() const { + return imag_().el_swapped(); + } + + Vectorized conj_() const { + return *this ^ vd_isign_mask; + } + Vectorized conj() const { + return *this ^ vd_isign_mask; + } + + Vectorized log() const { + // Most trigonomic ops use the log() op to improve complex number + // performance. + return map(std::log); + } + + Vectorized log2() const { + // log2eB_inv + auto ret = log(); + return ret.elwise_mult(vd_log2e_inv); + } + Vectorized log10() const { + auto ret = log(); + return ret.elwise_mult(vd_log10e_inv); + } + + Vectorized log1p() const { + return map(std::log1p); + } + + Vectorized asin() const { + // asin(x) + // = -i*ln(iz + sqrt(1 -z^2)) + // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi))) + // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi)) + auto conj = conj_(); + auto b_a = conj.el_swapped(); + auto ab = conj.elwise_mult(b_a); + auto im = ab + ab; + auto val_2 = (*this).elwise_mult(*this); + auto val_2_swapped = val_2.el_swapped(); + auto re = horizontal_sub(val_2, val_2_swapped); + re = Vectorized(vd_one) - re; + auto root = el_blend<0x0A>(re, im).sqrt(); + auto ln = (b_a + root).log(); + return ln.el_swapped().conj(); + } + + Vectorized acos() const { + // acos(x) = pi/2 - asin(x) + return Vectorized(vd_pi_2) - asin(); + } + + Vectorized atan() const { + // atan(x) = i/2 * ln((i + z)/(i - z)) + auto ione = Vectorized(vd_imag_one); + auto sum = ione + *this; + auto sub = ione - *this; + auto ln = (sum / sub).log(); // ln((i + z)/(i - z)) + return ln * vd_imag_half; // i/2*ln() + } + Vectorized atanh() const { + return map(std::atanh); + } + + Vectorized sin() const { + return map(std::sin); + } + Vectorized sinh() const { + return map(std::sinh); + } + Vectorized cos() const { + return map(std::cos); + } + Vectorized cosh() const { + return map(std::cosh); + } + + Vectorized tan() const { + return map(std::tan); + } + Vectorized tanh() const { + return map(std::tanh); + } + Vectorized ceil() const { + return {vec_ceil(_vec0), vec_ceil(_vec1)}; + } + Vectorized floor() const { + return {vec_floor(_vec0), vec_floor(_vec1)}; + } + Vectorized neg() const { + auto z = Vectorized(vd_zero); + return z - *this; + } + Vectorized round() const { + return {vec_rint(_vec0), vec_rint(_vec1)}; + } + + Vectorized trunc() const { + return {vec_trunc(_vec0), vec_trunc(_vec1)}; + } + + Vectorized elwise_sqrt() const { + return {vec_sqrt(_vec0), vec_sqrt(_vec1)}; + } + + Vectorized sqrt() const { + return map(std::sqrt); + } + + Vectorized reciprocal() const { + // re + im*i = (a + bi) / (c + di) + // re = (ac + bd)/abs_2() = c/abs_2() + // im = (bc - ad)/abs_2() = d/abs_2() + auto c_d = *this ^ vd_isign_mask; // c -d + auto abs = abs_2_(); + return c_d.elwise_div(abs); + } + + Vectorized rsqrt() const { + return sqrt().reciprocal(); + } + + static Vectorized horizontal_add( + Vectorized& first, + Vectorized& second) { + // Operates on individual floats, see _mm_hadd_ps + // {f0+f1, s0+s1, f2+f3, s2+s3, ...} + // i.e. it sums the re and im of each value and interleaves first and + // second: {f_re0 + f_im0, s_re0 + s_im0, f_re1 + f_im1, s_re1 + s_im1, ...} + return el_mergee(first, second) + el_mergeo(first, second); + } + + static Vectorized horizontal_sub( + Vectorized& first, + Vectorized& second) { + // we will simulate it differently with 6 instructions total + // lets permute second so that we can add it getting horizontal sums + auto first_perm = first.el_swapped(); // 2perm + auto second_perm = second.el_swapped(); // 2perm + // summ + auto first_ret = first - first_perm; // 2sub + auto second_ret = second - second_perm; // 2 sub + // now lets choose evens + return el_mergee(first_ret, second_ret); // 2 mergee's + } + + Vectorized inline operator*( + const Vectorized& b) const { + //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i +#if 1 + // this is more vsx friendly than simulating horizontal from x86 + auto vi = b.el_mergeo(); + auto vr = b.el_mergee(); + vi = vi ^ vd_rsign_mask; + auto ret = elwise_mult(vr); + auto vx_swapped = el_swapped(); + ret = vx_swapped.elwise_mult(vi) + ret; +#else + auto ac_bd = elwise_mult(b); + auto d_c = b.el_swapped(); + d_c = d_c ^ vd_isign_mask; + auto ad_bc = elwise_mult(d_c); + auto ret = horizontal_sub(ac_bd, ad_bc); +#endif + return ret; + } + + Vectorized inline operator/( + const Vectorized& b) const { + // re + im*i = (a + bi) / (c + di) + // re = (ac + bd)/abs_2() + // im = (bc - ad)/abs_2() + // auto fabs_cd = Vectorized{ + // vec_andc(b._vec0, vd_sign_mask), + // vec_andc(b._vec1, vd_sign_mask)}; // |c| |d| + // auto fabs_dc = fabs_cd.el_swapped(); // |d| |c| + // auto scale = fabs_cd.elwise_max(fabs_dc); // sc = max(|c|, |d|) + // auto a2 = elwise_div(scale); // a/sc b/sc + // auto b2 = b.elwise_div(scale); // c/sc d/sc + // auto acbd2 = a2.elwise_mult(b2); // ac/sc^2 bd/sc^2 + // auto dc2 = b2.el_swapped(); // d/sc c/sc + // dc2 = dc2 ^ vd_rsign_mask; // -d/sc c/sc + // auto adbc2 = a2.elwise_mult(dc2); // -ad/sc^2 bc/sc^2 + // auto ret = horizontal_add(acbd2, adbc2); // (ac+bd)/sc^2 (bc-ad)/sc^2 + // auto denom2 = b2.abs_2_(); // (c^2+d^2)/sc^2 + // (c^2+d^2)/sc^2 ret = ret.elwise_div(denom2); return ret; + + __at_align__ c10::complex + tmp1[Vectorized>::size()]; + __at_align__ c10::complex + tmp2[Vectorized>::size()]; + __at_align__ c10::complex + out[Vectorized>::size()]; + this->store(tmp1); + b.store(tmp2); + + for (const auto i : c10::irange(Vectorized>::size())) { + out[i] = tmp1[i] / tmp2[i]; + } + return loadu(out); + } + + Vectorized exp() const { + return map(std::exp); + } + Vectorized exp2() const { + return map(exp2_impl); + } + Vectorized expm1() const { + return map(std::expm1); + } + + Vectorized pow(const Vectorized& exp) const { + __at_align__ ComplexDbl x_tmp[size()]; + __at_align__ ComplexDbl y_tmp[size()]; + store(x_tmp); + exp.store(y_tmp); + for (const auto i : c10::irange(size())) { + x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]); + } + return loadu(x_tmp); + } + + Vectorized sgn() const { + return map(at::native::sgn_impl); + } + + Vectorized operator<(const Vectorized& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized operator<=(const Vectorized& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized operator>(const Vectorized& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized operator>=(const Vectorized& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized eq(const Vectorized& other) const { + auto eq = (*this == other); // compares real and imag individually + // If both real numbers and imag numbers are equal, then the complex numbers + // are equal + return (eq.real() & eq.imag()) & vd_one; + } + Vectorized ne(const Vectorized& other) const { + auto ne = (*this != other); // compares real and imag individually + // If either real numbers or imag numbers are not equal, then the complex + // numbers are not equal + return (ne.real() | ne.imag()) & vd_one; + } + + DEFINE_MEMBER_OP(operator==, ComplexDbl, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, ComplexDbl, vec_cmpne) + + DEFINE_MEMBER_OP(operator+, ComplexDbl, vec_add) + DEFINE_MEMBER_OP(operator-, ComplexDbl, vec_sub) + DEFINE_MEMBER_OP(operator&, ComplexDbl, vec_and) + DEFINE_MEMBER_OP(operator|, ComplexDbl, vec_or) + DEFINE_MEMBER_OP(operator^, ComplexDbl, vec_xor) + // elementwise helpers + DEFINE_MEMBER_OP(elwise_mult, ComplexDbl, vec_mul) + DEFINE_MEMBER_OP(elwise_div, ComplexDbl, vec_div) + DEFINE_MEMBER_OP(elwise_gt, ComplexDbl, vec_cmpgt) + DEFINE_MEMBER_OP(elwise_ge, ComplexDbl, vec_cmpge) + DEFINE_MEMBER_OP(elwise_lt, ComplexDbl, vec_cmplt) + DEFINE_MEMBER_OP(elwise_le, ComplexDbl, vec_cmple) + DEFINE_MEMBER_OP(elwise_max, ComplexDbl, vec_max) +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + // auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_LT_OQ); + // auto max = _mm256_blendv_ps(a, b, mask); + auto mask = abs_a.elwise_lt(abs_b); + auto max = Vectorized::elwise_blendv(a, b, mask); + + return max; + // Exploit the fact that all-ones is a NaN. + // auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q); + // return _mm256_or_ps(max, isnan); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + // auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_GT_OQ); + // auto min = _mm256_blendv_ps(a, b, mask); + auto mask = abs_a.elwise_gt(abs_b); + auto min = Vectorized::elwise_blendv(a, b, mask); + return min; + // Exploit the fact that all-ones is a NaN. + // auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q); + // return _mm256_or_ps(min, isnan); +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator+(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator-(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator&(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator|(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator^(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator*(const Vectorized& a, const Vectorized& b) { + // (a + ib) * (c + id) = (ac - bd) + i(ad + bc) + // Split into real and imaginary parts + auto a_real = a.el_mergee(); // real part of a + auto a_imag = a.el_mergeo(); // imag part of a + auto b_real = b.el_mergee(); // real part of b + auto b_imag = b.el_mergeo(); // imag part of b + + // Compute components + auto ac = a_real.elwise_mult(b_real); // real*real + auto bd = a_imag.elwise_mult(b_imag); // imag*imag + + // Real part: ac - bd + auto real = ac - bd; + + auto ad = a_real.elwise_mult(b_imag); // real*imag + auto bc = a_imag.elwise_mult(b_real); // imag*real + + // Imag = ad + bc + auto imag = ad + bc; + + // Merge real and imaginary parts into vectors + __vector double v0 = vec_mergeh(real.vec0(), imag.vec0()); // [r0, i0] + __vector double v1 = vec_mergeh(real.vec1(), imag.vec1()); // [r1, i1] + + // Create the final result + auto result = Vectorized{v0, v1}; + return result; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator/(const Vectorized& a, const Vectorized& b) { + // re + im*i = (a + bi) / (c + di) + // re = (ac + bd)/abs_2() + // im = (bc - ad)/abs_2() + // Take absolute values of real and imaginary parts of b + __at_align__ c10::complex + tmp1[Vectorized>::size()]; + __at_align__ c10::complex + tmp2[Vectorized>::size()]; + __at_align__ c10::complex + out[Vectorized>::size()]; + a.store(tmp1); + b.store(tmp2); + for (const auto i : c10::irange(Vectorized>::size())) { + out[i] = tmp1[i] / tmp2[i]; + } + return Vectorized::loadu(out); +} + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..a2ebafd7eb2b05d131d759c35b9042a678815edc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h @@ -0,0 +1,771 @@ + +#pragma once +#include +#include +#include +#include +#include + +namespace at { +namespace vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { +using ComplexFlt = c10::complex; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + union { + struct { + vfloat32 _vec0; + vfloat32 _vec1; + }; + struct { + vbool32 _vecb0; + vbool32 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + using value_type = ComplexFlt; + using vec_internal_type = vfloat32; + using vec_internal_mask_type = vbool32; + using size_type = int; + + static constexpr size_type size() { + return 4; + } + Vectorized() {} + + C10_ALWAYS_INLINE Vectorized(vfloat32 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vectorized(vbool32 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vectorized(vfloat32 v1, vfloat32 v2) + : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vectorized(vbool32 v1, vbool32 v2) + : _vecb0{v1}, _vecb1{v2} {} + + Vectorized(ComplexFlt val) { + float real_value = val.real(); + float imag_value = val.imag(); + _vec0 = vfloat32{real_value, imag_value, real_value, imag_value}; + _vec1 = vfloat32{real_value, imag_value, real_value, imag_value}; + } + + Vectorized( + ComplexFlt val1, + ComplexFlt val2, + ComplexFlt val3, + ComplexFlt val4) { + _vec0 = vfloat32{val1.real(), val1.imag(), val2.real(), val2.imag()}; + _vec1 = vfloat32{val3.real(), val3.imag(), val4.real(), val4.imag()}; + } + + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return a; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return b; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return {b._vec0, a._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return {a._vec0, b._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool32 mask_1st = VsxComplexMask1(mask); + return {(vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), a._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool32 mask_1st = VsxComplexMask1(mask); + return {(vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), b._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool32 mask_2nd = VsxComplexMask2(mask); + // generated masks + return {a._vec0, (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool32 mask_2nd = VsxComplexMask2(mask); + // generated masks + return {b._vec0, (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool32 mask_1st = VsxComplexMask1(mask); + const vbool32 mask_2nd = VsxComplexMask2(mask); + return { + (vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), + (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static Vectorized C10_ALWAYS_INLINE + el_blend(const Vectorized& a, const Vectorized& b) { + const vbool32 mask_1st = VsxMask1(mask); + const vbool32 mask_2nd = VsxMask2(mask); + return { + (vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), + (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + // convert std::complex index mask to V index mask: xy -> xxyy + auto mask_complex = Vectorized( + vec_mergeh(mask._vec0, mask._vec0), vec_mergeh(mask._vec1, mask._vec1)); + return { + vec_sel( + a._vec0, b._vec0, reinterpret_cast(mask_complex._vec0)), + vec_sel( + a._vec1, b._vec1, reinterpret_cast(mask_complex._vec1)), + }; + } + + static Vectorized elwise_blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + return { + vec_sel(a._vec0, b._vec0, reinterpret_cast(mask._vec0)), + vec_sel(a._vec1, b._vec1, reinterpret_cast(mask._vec1)), + }; + } + + template + static Vectorized arange( + ComplexFlt base = 0., + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + ComplexFlt(2) * step, + base + ComplexFlt(3) * step); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + } + return b; + } + + static Vectorized C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + + __at_align__ value_type tmp_values[size()] = {}; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return { + vec_vsx_ld(offset0, reinterpret_cast(tmp_values)), + vec_vsx_ld(offset16, reinterpret_cast(tmp_values))}; + } + + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, reinterpret_cast(tmp_values)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(tmp_values)); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + + const ComplexFlt& operator[](int idx) const = delete; + ComplexFlt& operator[](int idx) = delete; + + Vectorized map(ComplexFlt (*const f)(ComplexFlt)) const { + __at_align__ ComplexFlt tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + + Vectorized map(ComplexFlt (*const f)(const ComplexFlt&)) const { + __at_align__ ComplexFlt tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + + static Vectorized horizontal_add( + Vectorized& first, + Vectorized& second) { + // Operates on individual floats, see _mm_hadd_ps + // {f0+f1, s0+s1, f2+f3, s2+s3, ...} + // i.e. it sums the re and im of each value and interleaves first and + // second: {f_re0 + f_im0, s_re0 + s_im0, f_re1 + f_im1, s_re1 + s_im1, ...} + return el_mergee(first, second) + el_mergeo(first, second); + } + + static Vectorized horizontal_sub_permD8( + Vectorized& first, + Vectorized& second) { + // we will simulate it differently with 6 instructions total + // lets permute second so that we can add it getting horizontal sums + auto first_perm = first.el_swapped(); // 2perm + auto second_perm = second.el_swapped(); // 2perm + // sum + auto first_ret = first - first_perm; // 2sub + auto second_ret = second - second_perm; // 2 sub + // now lets choose evens + return el_mergee(first_ret, second_ret); // 2 mergee's + } + + Vectorized abs_2_() const { + auto a = (*this).elwise_mult(*this); + auto permuted = a.el_swapped(); + a = a + permuted; + return a.el_mergee(); + } + + Vectorized abs_() const { + auto vi = el_mergeo(); + auto vr = el_mergee(); + return { + Sleef_hypotf4_u05vsx(vr._vec0, vi._vec0), + Sleef_hypotf4_u05vsx(vr._vec1, vi._vec1)}; + } + + Vectorized abs() const { + return abs_() & real_mask; + } + + Vectorized real_() const { + return *this & real_mask; + } + Vectorized real() const { + return *this & real_mask; + } + Vectorized imag_() const { + return *this & imag_mask; + } + Vectorized imag() const { + // we can use swap_mask or sldwi + auto ret = imag_(); + return { + vec_sldw(ret._vec0, ret._vec0, 3), vec_sldw(ret._vec1, ret._vec1, 3)}; + } + + Vectorized conj_() const { + return *this ^ isign_mask; + } + Vectorized conj() const { + return *this ^ isign_mask; + } + + Vectorized log() const { + // Most trigonomic ops use the log() op to improve complex number + // performance. + return map(std::log); + } + + Vectorized log2() const { + // log2eB_inv + auto ret = log(); + return ret.elwise_mult(log2e_inv); + } + Vectorized log10() const { + auto ret = log(); + return ret.elwise_mult(log10e_inv); + } + + Vectorized log1p() const { + return map(std::log1p); + } + + Vectorized el_swapped() const { + vfloat32 v0 = vec_perm(_vec0, _vec0, swap_mask); + vfloat32 v1 = vec_perm(_vec1, _vec1, swap_mask); + return {v0, v1}; + } + + Vectorized el_mergee() const { + // as mergee phased in , we can use vec_perm with mask + return {vec_mergee(_vecb0, _vecb0), vec_mergee(_vecb1, _vecb1)}; + } + + Vectorized el_mergeo() const { + // as mergeo phased in , we can use vec_perm with mask + return {vec_mergeo(_vecb0, _vecb0), vec_mergeo(_vecb1, _vecb1)}; + } + + Vectorized el_madd( + const Vectorized& multiplier, + const Vectorized& val) const { + return { + vec_madd(_vec0, multiplier._vec0, val._vec0), + vec_madd(_vec1, multiplier._vec1, val._vec1)}; + } + + static Vectorized el_mergee( + const Vectorized& first, + const Vectorized& second) { + return { + vec_mergee(first._vecb0, second._vecb0), + vec_mergee(first._vecb1, second._vecb1)}; + } + + static Vectorized el_mergeo( + const Vectorized& first, + const Vectorized& second) { + return { + vec_mergeo(first._vecb0, second._vecb0), + vec_mergeo(first._vecb1, second._vecb1)}; + } + + Vectorized angle_() const { + // angle = atan2(b/a) + // auto b_a = _mm256_permute_ps(values, 0xB1); // b a + // return Sleef_atan2f8_u10(values, b_a); // 90-angle angle + Vectorized ret; + for (int i = 0; i < 4; i += 2) { + ret._vec0[i] = std::atan2(_vec0[i + 1], _vec0[i]); + ret._vec1[i] = std::atan2(_vec1[i + 1], _vec1[i]); + } + return ret; + } + + Vectorized angle() const { + return angle_() & real_mask; + } + + Vectorized sin() const { + return map(std::sin); + } + Vectorized sinh() const { + return map(std::sinh); + } + Vectorized cos() const { + return map(std::cos); + } + Vectorized cosh() const { + return map(std::cosh); + } + Vectorized ceil() const { + return {vec_ceil(_vec0), vec_ceil(_vec1)}; + } + Vectorized floor() const { + return {vec_floor(_vec0), vec_floor(_vec1)}; + } + Vectorized neg() const { + auto z = Vectorized(zero); + return z - *this; + } + Vectorized round() const { + return {vec_round(_vec0), vec_round(_vec1)}; + } + Vectorized tan() const { + return map(std::tan); + } + Vectorized tanh() const { + return map(std::tanh); + } + Vectorized trunc() const { + return {vec_trunc(_vec0), vec_trunc(_vec1)}; + } + + Vectorized elwise_sqrt() const { + return {vec_sqrt(_vec0), vec_sqrt(_vec1)}; + } + + Vectorized sqrt() const { + return map(std::sqrt); + } + + Vectorized reciprocal() const { + // re + im*i = (a + bi) / (c + di) + // re = (ac + bd)/abs_2() = c/abs_2() + // im = (bc - ad)/abs_2() = d/abs_2() + auto c_d = *this ^ isign_mask; // c -d + auto abs = abs_2_(); + return c_d.elwise_div(abs); + } + + Vectorized rsqrt() const { + return sqrt().reciprocal(); + } + + Vectorized pow(const Vectorized& exp) const { + __at_align__ ComplexFlt x_tmp[size()]; + __at_align__ ComplexFlt y_tmp[size()]; + store(x_tmp); + exp.store(y_tmp); + for (const auto i : c10::irange(size())) { + x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]); + } + return loadu(x_tmp); + } + + Vectorized atan() const { + // atan(x) = i/2 * ln((i + z)/(i - z)) + auto ione = Vectorized(imag_one); + auto sum = ione + *this; + auto sub = ione - *this; + auto ln = (sum / sub).log(); // ln((i + z)/(i - z)) + return ln * imag_half; // i/2*ln() + } + Vectorized atanh() const { + return map(std::atanh); + } + + Vectorized acos() const { + // acos(x) = pi/2 - asin(x) + return Vectorized(pi_2) - asin(); + } + + Vectorized inline operator*( + const Vectorized& b) const { + //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i + +#if 1 + // this is more vsx friendly than simulating horizontal from x86 + + auto vi = b.el_mergeo(); + auto vr = b.el_mergee(); + vi = vi ^ rsign_mask; + auto ret = elwise_mult(vr); + auto vx_swapped = el_swapped(); + ret = vx_swapped.elwise_mult(vi) + ret; + return ret; + +#else + + auto ac_bd = elwise_mult(b); + auto d_c = b.el_swapped(); + d_c = d_c ^ isign_mask; + auto ad_bc = elwise_mult(d_c); + auto ret = horizontal_sub_permD8(ac_bd, ad_bc); + return ret; +#endif + } + + Vectorized inline operator/( + const Vectorized& b) const { +#if 1 + __at_align__ c10::complex + tmp1[Vectorized>::size()]; + __at_align__ c10::complex + tmp2[Vectorized>::size()]; + __at_align__ c10::complex + out[Vectorized>::size()]; + this->store(tmp1); + b.store(tmp2); + + for (const auto i : c10::irange(Vectorized>::size())) { + out[i] = tmp1[i] / tmp2[i]; + } + return loadu(out); +#else + auto fabs_cd = Vectorized{ + vec_andc(b._vec0, sign_mask), vec_andc(b._vec1, sign_mask)}; // |c| |d| + auto fabs_dc = fabs_cd.el_swapped(); // |d| |c| + auto scale = fabs_cd.elwise_max(fabs_dc); // sc = max(|c|, |d|) + auto a2 = elwise_div(scale); // a/sc b/sc + auto b2 = b.elwise_div(scale); // c/sc d/sc + auto acbd2 = a2.elwise_mult(b2); // ac/sc^2 bd/s + auto dc2 = b2.el_swapped(); // d/sc c/sc + dc2 = dc2 ^ rsign_mask; // -d/sc c/sc + auto adbc2 = a2.elwise_mult(dc2); // -ad/sc^2 bc/sc^2 + auto ret = horizontal_add(acbd2, adbc2); // (ac+bd)/sc^2 (bc-ad)/sc^2 + auto denom2 = b2.abs_2_(); // (c^2+d^2)/sc^2 (c^2+d^2)/sc^2 + ret = ret.elwise_div(denom2); + return ret; +#endif + } + + Vectorized asin() const { + // asin(x) + // = -i*ln(iz + sqrt(1 -z^2)) + // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi))) + // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi)) + +#if 1 + auto conj = conj_(); + auto b_a = conj.el_swapped(); + auto ab = conj.elwise_mult(b_a); + auto im = ab + ab; + auto val_2 = (*this).elwise_mult(*this); + auto val_2_swapped = val_2.el_swapped(); + auto re = horizontal_sub_permD8(val_2, val_2_swapped); + re = Vectorized(one) - re; + auto root = el_blend<0xAA>(re, im).sqrt(); + auto ln = (b_a + root).log(); + return ln.el_swapped().conj(); +#else + return map(std::asin); +#endif + } + + Vectorized exp() const { + return map(std::exp); + } + Vectorized exp2() const { + return map(exp2_impl); + } + Vectorized expm1() const { + return map(std::expm1); + } + + Vectorized eq(const Vectorized& other) const { + auto eq = (*this == other); // compares real and imag individually + // If both real numbers and imag numbers are equal, then the complex numbers + // are equal + return (eq.real() & eq.imag()) & one; + } + Vectorized ne(const Vectorized& other) const { + auto ne = (*this != other); // compares real and imag individually + // If either real numbers or imag numbers are not equal, then the complex + // numbers are not equal + return (ne.real() | ne.imag()) & one; + } + + Vectorized sgn() const { + return map(at::native::sgn_impl); + } + + Vectorized operator<(const Vectorized& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized operator<=(const Vectorized& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized operator>(const Vectorized& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized operator>=(const Vectorized& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + DEFINE_MEMBER_OP(operator==, ComplexFlt, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, ComplexFlt, vec_cmpne) + + DEFINE_MEMBER_OP(operator+, ComplexFlt, vec_add) + DEFINE_MEMBER_OP(operator-, ComplexFlt, vec_sub) + DEFINE_MEMBER_OP(operator&, ComplexFlt, vec_and) + DEFINE_MEMBER_OP(operator|, ComplexFlt, vec_or) + DEFINE_MEMBER_OP(operator^, ComplexFlt, vec_xor) + // elementwise helpers + DEFINE_MEMBER_OP(elwise_mult, ComplexFlt, vec_mul) + DEFINE_MEMBER_OP(elwise_div, ComplexFlt, vec_div) + DEFINE_MEMBER_OP(elwise_gt, ComplexFlt, vec_cmpgt) + DEFINE_MEMBER_OP(elwise_ge, ComplexFlt, vec_cmpge) + DEFINE_MEMBER_OP(elwise_lt, ComplexFlt, vec_cmplt) + DEFINE_MEMBER_OP(elwise_le, ComplexFlt, vec_cmple) + DEFINE_MEMBER_OP(elwise_max, ComplexFlt, vec_max) +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + // auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_LT_OQ); + // auto max = _mm256_blendv_ps(a, b, mask); + auto mask = abs_a.elwise_lt(abs_b); + auto max = Vectorized::elwise_blendv(a, b, mask); + + return max; + // Exploit the fact that all-ones is a NaN. + // auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q); + // return _mm256_or_ps(max, isnan); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + // auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_GT_OQ); + // auto min = _mm256_blendv_ps(a, b, mask); + auto mask = abs_a.elwise_gt(abs_b); + auto min = Vectorized::elwise_blendv(a, b, mask); + return min; + // Exploit the fact that all-ones is a NaN. + // auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q); + // return _mm256_or_ps(min, isnan); +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator+(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator-(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator&(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator|(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator^(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator*(const Vectorized& a, const Vectorized& b) { + // (a + ib) * (c + id) = (ac - bd) + i(ad + bc) + // Split into real and imaginary parts + auto a_real = a.el_mergee(); // real part of a + auto a_imag = a.el_mergeo(); // imag part of a + auto b_real = b.el_mergee(); // real part of b + auto b_imag = b.el_mergeo(); // imag part of b + + auto b_imag_neg = b_imag ^ rsign_mask; + // Compute components + auto ac = a_real.elwise_mult(b_real); // real * real + auto bd = a_imag.elwise_mult(b_imag_neg); // imag * imag + auto ad = a_real.elwise_mult(b_imag); // real * imag + auto bc = a_imag.elwise_mult(b_real); // imag * real + + // Real = ac - bd (fix the negative bd part) + auto real = ac + bd; // Real part calculation + auto imag = ad + bc; // Imaginary part calculation + + // Step 1: Extract from real and imag + __vector float r0 = real.vec0(); // {r0, r1, r2, r3} + __vector float i0 = imag.vec0(); // {i0, i1, i2, i3} + + __vector float r1 = real.vec1(); // imag[0..3] + __vector float i1 = imag.vec1(); // imag[4..7] + + __vector unsigned char perm_lo = { + 0, + 1, + 2, + 3, // r0 + 16, + 17, + 18, + 19, // + 8, + 9, + 10, + 11, // r1 + 24, + 25, + 26, + 27}; + __vector float v0 = + vec_perm(r0, i0, perm_lo); // Interleave r0 and i0, r1 and i1 + __vector float v1 = vec_perm(r1, i1, perm_lo); + Vectorized result(v0, v1); + return result; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator/(const Vectorized& a, const Vectorized& b) { + // Take absolute values of real and imaginary parts of b + __at_align__ c10::complex + tmp1[Vectorized>::size()]; + __at_align__ c10::complex + tmp2[Vectorized>::size()]; + __at_align__ c10::complex out[Vectorized>::size()]; + a.store(tmp1); + b.store(tmp2); + for (const auto i : + c10::irange(Vectorized>:: + size())) { //{Vectorized>::size())) + //{ + out[i] = tmp1[i] / tmp2[i]; + } + return Vectorized::loadu(out); +} + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..a8474402aa7d96fd45a8e98d1e0397dad0720c8e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h @@ -0,0 +1,512 @@ +#pragma once + +#include +#include +#include +#include + +#include + +namespace at { +namespace vec { + +inline namespace CPU_CAPABILITY { + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + union { + struct { + vfloat64 _vec0; + vfloat64 _vec1; + }; + struct { + vbool64 _vecb0; + vbool64 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + using value_type = double; + using vec_internal_type = vfloat64; + using vec_internal_mask_type = vbool64; + using size_type = int; + static constexpr size_type size() { + return 4; + } + Vectorized() {} + C10_ALWAYS_INLINE Vectorized(vfloat64 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vectorized(vbool64 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vectorized(vfloat64 v1, vfloat64 v2) + : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vectorized(vbool64 v1, vbool64 v2) + : _vecb0{v1}, _vecb1{v2} {} + C10_ALWAYS_INLINE Vectorized(double scalar) + : _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {} + C10_ALWAYS_INLINE Vectorized( + double scalar1, + double scalar2, + double scalar3, + double scalar4) + : _vec0{vfloat64{scalar1, scalar2}}, _vec1{vfloat64{scalar3, scalar4}} {} + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + int zero_mask() const { + auto cmp = (*this == vd_zero); + return (cmp._vecb0[0] & 1) | (cmp._vecb0[1] & 2) | (cmp._vecb1[0] & 4) | + (cmp._vecb1[1] & 8); + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return a; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return b; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return {b._vec0, a._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return {a._vec0, b._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool64 mask_1st = VsxDblMask1(mask); + return {(vfloat64)vec_sel(a._vec0, b._vec0, mask_1st), a._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool64 mask_1st = VsxDblMask1(mask); + return {(vfloat64)vec_sel(a._vec0, b._vec0, mask_1st), b._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool64 mask_2nd = VsxDblMask2(mask); + // generated masks + return {a._vec0, (vfloat64)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool64 mask_2nd = VsxDblMask2(mask); + // generated masks + return {b._vec0, (vfloat64)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool64 mask_1st = VsxDblMask1(mask); + const vbool64 mask_2nd = VsxDblMask2(mask); + return { + (vfloat64)vec_sel(a._vec0, b._vec0, mask_1st), + (vfloat64)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + static Vectorized C10_ALWAYS_INLINE blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + // the mask used here returned by comparision of vec256 + + return { + vec_sel(a._vec0, b._vec0, mask._vecb0), + vec_sel(a._vec1, b._vec1, mask._vecb1)}; + } + template + static Vectorized arange( + double base = 0., + step_t step = static_cast(1)) { + return Vectorized( + base, base + step, base + 2 * step, base + 3 * step); + } + + static Vectorized C10_ALWAYS_INLINE + set(const Vectorized& a, + const Vectorized& b, + size_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + } + + return b; + } + static Vectorized C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + + __at_align__ value_type tmp_values[size()] = {}; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, tmp_values); + vec_vsx_st(_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + const double& operator[](int idx) const = delete; + double& operator[](int idx) = delete; + Vectorized map(double (*const f)(double)) const { + Vectorized ret; + for (const auto i : c10::irange(size() / 2)) { + ret._vec0[i] = f(_vec0[i]); + } + for (const auto i : c10::irange(size() / 2)) { + ret._vec1[i] = f(_vec1[i]); + } + return ret; + } + + Vectorized mapbi( + double (*const f)(double, double), + const Vectorized& other) const { + Vectorized ret; + for (const auto i : c10::irange(size() / 2)) { + ret._vec0[i] = f(_vec0[i], other._vec0[i]); + } + for (const auto i : c10::irange(size() / 2)) { + ret._vec1[i] = f(_vec1[i], other._vec1[i]); + } + return ret; + } + Vectorized C10_ALWAYS_INLINE abs() const { + return {vec_abs(_vec0), vec_abs(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE acos() const { + return {Sleef_acosd2_u10(_vec0), Sleef_acosd2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE acosh() const { + return {Sleef_acoshd2_u10(_vec0), Sleef_acoshd2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE asin() const { + return {Sleef_asind2_u10(_vec0), Sleef_asind2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE asinh() const { + return {Sleef_asinhd2_u10(_vec0), Sleef_asinhd2_u10(_vec1)}; + } + Vectorized atan() const { + return {Sleef_atand2_u10(_vec0), Sleef_atand2_u10(_vec1)}; + } + Vectorized atanh() const { + return {Sleef_atanhd2_u10(_vec0), Sleef_atanhd2_u10(_vec1)}; + } + Vectorized atan2(const Vectorized& b) const { + return { + Sleef_atan2d2_u10(_vec0, b._vec0), Sleef_atan2d2_u10(_vec1, b._vec1)}; + } + Vectorized copysign(const Vectorized& sign) const { + return { + Sleef_copysignd2(_vec0, sign._vec0), + Sleef_copysignd2(_vec1, sign._vec1)}; + } + Vectorized erf() const { + return {Sleef_erfd2_u10(_vec0), Sleef_erfd2_u10(_vec1)}; + } + Vectorized erfc() const { + return {Sleef_erfcd2_u15(_vec0), Sleef_erfcd2_u15(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE exp() const { + return {Sleef_expd2_u10(_vec0), Sleef_expd2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE exp2() const { + return {Sleef_exp2d2_u10(_vec0), Sleef_exp2d2_u10(_vec1)}; + } + Vectorized expm1() const { + return {Sleef_expm1d2_u10(_vec0), Sleef_expm1d2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE exp_u20() const { + return exp(); + } + + Vectorized lgamma() const __ubsan_ignore_undefined__ { + return {Sleef_lgammad2_u10(_vec0), Sleef_lgammad2_u10(_vec1)}; + } + + Vectorized erfinv() const { + return map(calc_erfinv); + } + + Vectorized angle() const { + auto tmp = blendv( + Vectorized(0), + Vectorized(c10::pi), + *this < Vectorized(0)); + return blendv(tmp, *this, isnan()); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return Vectorized{0}; + } + Vectorized conj() const { + return *this; + } + + Vectorized C10_ALWAYS_INLINE log() const { + return {Sleef_logd2_u10(_vec0), Sleef_logd2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE log10() const { + return {Sleef_log10d2_u10(_vec0), Sleef_log10d2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE log1p() const { + return {Sleef_log1pd2_u10(_vec0), Sleef_log1pd2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE log2() const { + return {Sleef_log2d2_u10(_vec0), Sleef_log2d2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE ceil() const { + return {vec_ceil(_vec0), vec_ceil(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE cos() const { + return {Sleef_cosd2_u10(_vec0), Sleef_cosd2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE cosh() const { + return {Sleef_coshd2_u10(_vec0), Sleef_coshd2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE floor() const { + return {vec_floor(_vec0), vec_floor(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE neg() const { + return {vec_neg(_vec0), vec_neg(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE round() const { + return {vec_rint(_vec0), vec_rint(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE sin() const { + return {Sleef_sind2_u10(_vec0), Sleef_sind2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE sinh() const { + return {Sleef_sinhd2_u10(_vec0), Sleef_sinhd2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE tan() const { + return {Sleef_tand2_u10(_vec0), Sleef_tand2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE tanh() const { + return {Sleef_tanhd2_u10(_vec0), Sleef_tanhd2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE trunc() const { + return {vec_trunc(_vec0), vec_trunc(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE frac() const { + return *this - trunc(); + } + + Vectorized C10_ALWAYS_INLINE sqrt() const { + return {vec_sqrt(_vec0), vec_sqrt(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE reciprocal() const { + return { + vec_div(vd_one, _vec0), // vec_re(_vec0) is estimated one. + vec_div(vd_one, _vec1)}; + } + Vectorized C10_ALWAYS_INLINE rsqrt() const { + return sqrt().reciprocal(); + } + + Vectorized C10_ALWAYS_INLINE pow(const Vectorized& b) const { + return {Sleef_powd2_u10(_vec0, b._vec0), Sleef_powd2_u10(_vec1, b._vec1)}; + } + Vectorized C10_ALWAYS_INLINE fmod(const Vectorized& b) const { + return {Sleef_fmodd2(_vec0, b._vec0), Sleef_fmodd2(_vec1, b._vec1)}; + } + + Vectorized hypot(const Vectorized& b) const { + return { + Sleef_hypotd2_u05(_vec0, b._vec0), Sleef_hypotd2_u05(_vec1, b._vec1)}; + } + + Vectorized nextafter(const Vectorized& b) const { + return { + Sleef_nextafterd2(_vec0, b._vec0), Sleef_nextafterd2(_vec1, b._vec1)}; + } + + Vectorized igamma(const Vectorized& x) const { + return mapbi(calc_igamma, x); + } + + Vectorized igammac(const Vectorized& x) const { + return mapbi(calc_igammac, x); + } + + Vectorized i0() const { + return map(calc_i0); + } + + Vectorized i0e() const { + return map(calc_i0e); + } + + Vectorized digamma() const { + return map(calc_digamma); + } + + Vectorized _nor() const { + return {vec_nor(_vec0, _vec0), vec_nor(_vec1, _vec1)}; + } + + Vectorized isnan() const { + auto x = *this; + auto ret = (x == x); + return ret._nor(); + } + bool has_inf_nan() const { + for (const auto i : c10::irange(size() / 2)) { + if (_isnan(_vec0[i]) || _isinf(_vec0[i])) { + return true; + } + } + for (const auto i : c10::irange(size() / 2)) { + if (_isnan(_vec1[i]) || _isinf(_vec1[i])) { + return true; + } + } + return false; + } + + DEFINE_MEMBER_OP(operator==, double, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, double, vec_cmpne) + DEFINE_MEMBER_OP(operator<, double, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, double, vec_cmple) + DEFINE_MEMBER_OP(operator>, double, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, double, vec_cmpge) + DEFINE_MEMBER_OP_AND_ONE(eq, double, vec_cmpeq) + DEFINE_MEMBER_OP_AND_ONE(ne, double, vec_cmpne) + DEFINE_MEMBER_OP_AND_ONE(lt, double, vec_cmplt) + DEFINE_MEMBER_OP_AND_ONE(le, double, vec_cmple) + DEFINE_MEMBER_OP_AND_ONE(gt, double, vec_cmpgt) + DEFINE_MEMBER_OP_AND_ONE(ge, double, vec_cmpge) + DEFINE_MEMBER_OP(operator+, double, vec_add) + DEFINE_MEMBER_OP(operator-, double, vec_sub) + DEFINE_MEMBER_OP(operator*, double, vec_mul) + DEFINE_MEMBER_OP(operator/, double, vec_div) + DEFINE_MEMBER_OP(maximum, double, vec_max_nan2) + DEFINE_MEMBER_OP(minimum, double, vec_min_nan2) + DEFINE_MEMBER_OP(operator&, double, vec_and) + DEFINE_MEMBER_OP(operator|, double, vec_or) + DEFINE_MEMBER_OP(operator^, double, vec_xor) + DEFINE_MEMBER_TERNARY_OP(madd, double, vec_madd) +}; +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return a.minimum(b); +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator+(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator-(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator*(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator/(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_div(a.vec0(), b.vec0()), vec_div(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator&(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator|(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator^(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())}; +} + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..df5f10a6e5b0128106628db6b33ddcd84e619d71 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h @@ -0,0 +1,545 @@ +#pragma once + +#include +#include +#include +#include +namespace at { +namespace vec { +// See Note [CPU_CAPABILITY namespace] + +inline namespace CPU_CAPABILITY { + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + union { + struct { + vfloat32 _vec0; + vfloat32 _vec1; + }; + struct { + vbool32 _vecb0; + vbool32 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + using value_type = float; + using vec_internal_type = vfloat32; + using vec_internal_mask_type = vbool32; + using size_type = int; + + static constexpr size_type size() { + return 8; + } + Vectorized() {} + + C10_ALWAYS_INLINE Vectorized(vfloat32 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vectorized(vbool32 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vectorized(vfloat32 v1, vfloat32 v2) + : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vectorized(vbool32 v1, vbool32 v2) + : _vecb0{v1}, _vecb1{v2} {} + C10_ALWAYS_INLINE Vectorized(float scalar) + : _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {} + C10_ALWAYS_INLINE Vectorized( + float scalar1, + float scalar2, + float scalar3, + float scalar4, + float scalar5, + float scalar6, + float scalar7, + float scalar8) + : _vec0{vfloat32{scalar1, scalar2, scalar3, scalar4}}, + _vec1{vfloat32{scalar5, scalar6, scalar7, scalar8}} {} + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return a; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return b; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return {b._vec0, a._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return {a._vec0, b._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool32 mask_1st = VsxMask1(mask); + return {(vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), a._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool32 mask_1st = VsxMask1(mask); + return {(vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), b._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool32 mask_2nd = VsxMask2(mask); + // generated masks + return {a._vec0, (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool32 mask_2nd = VsxMask2(mask); + // generated masks + return {b._vec0, (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool32 mask_1st = VsxMask1(mask); + const vbool32 mask_2nd = VsxMask2(mask); + return { + (vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), + (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + static Vectorized C10_ALWAYS_INLINE blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + // the mask used here returned by comparision of vec256 + // assuming this we can use the same mask directly with vec_sel + return { + vec_sel(a._vec0, b._vec0, mask._vecb0), + vec_sel(a._vec1, b._vec1, mask._vecb1)}; + } + + template + static Vectorized arange( + float base = 0.f, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + size_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + } + + return b; + } + static Vectorized C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + + __at_align__ value_type tmp_values[size()] = {}; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, tmp_values); + vec_vsx_st(_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + + const float& operator[](int idx) const = delete; + float& operator[](int idx) = delete; + + Vectorized map(float (*const f)(float)) const { + Vectorized ret; + for (int i = 0; i < size() / 2; i++) { + ret._vec0[i] = f(_vec0[i]); + } + for (int i = 0; i < size() / 2; i++) { + ret._vec1[i] = f(_vec1[i]); + } + return ret; + } + + Vectorized mapbi( + float (*const f)(float, float), + const Vectorized& other) const { + Vectorized ret; + for (int i = 0; i < size() / 2; i++) { + ret._vec0[i] = f(_vec0[i], other._vec0[i]); + } + for (int i = 0; i < size() / 2; i++) { + ret._vec1[i] = f(_vec1[i], other._vec1[i]); + } + return ret; + } + + Vectorized _nor() const { + return {vec_nor(_vec0, _vec0), vec_nor(_vec1, _vec1)}; + } + + Vectorized isnan() const { + auto x = *this; + auto ret = (x == x); + return ret._nor(); + } + + bool has_inf_nan() const { + for (const auto i : c10::irange(size() / 2)) { + if (_isnan(_vec0[i]) || _isinf(_vec0[i])) { + return true; + } + } + for (const auto i : c10::irange(size() / 2)) { + if (_isnan(_vec1[i]) || _isinf(_vec1[i])) { + return true; + } + } + return false; + } + + int zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit + // and others are translated to 0-bit + //__m256 cmp = _mm256_cmp_ps(values, _mm256_set1_ps(0.0f), _CMP_EQ_OQ); + auto cmp = (*this == zero); + // return _mm256_movemask_ps(cmp); + // possible simulation //mask= lvsl ( 0 ) vbpermq( vec, mask <<5) + vuint64 result0 = vec_vbpermq((vuint8)cmp._vecb0, mask_zero_bits); + vuint64 result1 = vec_vbpermq((vuint8)cmp._vecb1, mask_zero_bits); + return (result0[1] >> 12 | (result1[1] >> 8)); + } + + Vectorized C10_ALWAYS_INLINE abs() const { + return {vec_abs(_vec0), vec_abs(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE acos() const { + return {Sleef_acosf4_u10(_vec0), Sleef_acosf4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE acosh() const { + return {Sleef_acoshf4_u10(_vec0), Sleef_acoshf4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE asin() const { + return {Sleef_asinf4_u10(_vec0), Sleef_asinf4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE asinh() const { + return {Sleef_asinhf4_u10(_vec0), Sleef_asinhf4_u10(_vec1)}; + } + Vectorized atan() const { + return {Sleef_atanf4_u10(_vec0), Sleef_atanf4_u10(_vec1)}; + } + Vectorized atanh() const { + return {Sleef_atanhf4_u10(_vec0), Sleef_atanhf4_u10(_vec1)}; + } + Vectorized atan2(const Vectorized& b) const { + return { + Sleef_atan2f4_u10(_vec0, b._vec0), Sleef_atan2f4_u10(_vec1, b._vec1)}; + } + Vectorized copysign(const Vectorized& sign) const { + return { + Sleef_copysignf4(_vec0, sign._vec0), + Sleef_copysignf4(_vec1, sign._vec1)}; + } + Vectorized lgamma() const { + return {Sleef_lgammaf4_u10(_vec0), Sleef_lgammaf4_u10(_vec1)}; + } + Vectorized erf() const { + return {Sleef_erff4_u10(_vec0), Sleef_erff4_u10(_vec1)}; + } + + Vectorized erfc() const { + return {Sleef_erfcf4_u15(_vec0), Sleef_erfcf4_u15(_vec1)}; + } + + Vectorized erfinv() const { + return map(calc_erfinv); + } + + Vectorized angle() const { + auto tmp = blendv( + Vectorized(0), + Vectorized(c10::pi), + *this < Vectorized(0)); + return blendv(tmp, *this, isnan()); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return Vectorized{0}; + } + Vectorized conj() const { + return *this; + } + + Vectorized C10_ALWAYS_INLINE exp() const { + return {Sleef_expf4_u10(_vec0), Sleef_expf4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE exp2() const { + return {Sleef_exp2f4_u10(_vec0), Sleef_exp2f4_u10(_vec1)}; + } + Vectorized expm1() const { + return {Sleef_expm1f4_u10(_vec0), Sleef_expm1f4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE exp_u20() const { + return exp(); + } + + Vectorized C10_ALWAYS_INLINE log() const { + return {Sleef_logf4_u10(_vec0), Sleef_logf4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE log10() const { + return {Sleef_log10f4_u10(_vec0), Sleef_log10f4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE log1p() const { + return {Sleef_log1pf4_u10(_vec0), Sleef_log1pf4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE log2() const { + return {Sleef_log2f4_u10(_vec0), Sleef_log2f4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE ceil() const { + return {vec_ceil(_vec0), vec_ceil(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE cos() const { + return {Sleef_cosf4_u10(_vec0), Sleef_cosf4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE cosh() const { + return {Sleef_coshf4_u10(_vec0), Sleef_coshf4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE floor() const { + return {vec_floor(_vec0), vec_floor(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE neg() const { + return {vec_neg(_vec0), vec_neg(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE round() const { + return {vec_round(_vec0), vec_round(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE sin() const { + return {Sleef_sinf4_u10(_vec0), Sleef_sinf4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE sinh() const { + return {Sleef_sinhf4_u10(_vec0), Sleef_sinhf4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE tan() const { + return {Sleef_tanf4_u10(_vec0), Sleef_tanf4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE tanh() const { + return {Sleef_tanhf4_u10(_vec0), Sleef_tanhf4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE trunc() const { + return {vec_trunc(_vec0), vec_trunc(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE frac() const { + return *this - trunc(); + } + + Vectorized C10_ALWAYS_INLINE sqrt() const { + return {vec_sqrt(_vec0), vec_sqrt(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE reciprocal() const { + return Vectorized(one) / (*this); + } + Vectorized C10_ALWAYS_INLINE rsqrt() const { + return sqrt().reciprocal(); + } + + Vectorized C10_ALWAYS_INLINE pow(const Vectorized& exp) const { + return { + Sleef_powf4_u10(_vec0, exp._vec0), Sleef_powf4_u10(_vec1, exp._vec1)}; + } + + Vectorized fmod(const Vectorized& b) const { + return {Sleef_fmodf4(_vec0, b._vec0), Sleef_fmodf4(_vec1, b._vec1)}; + } + + Vectorized hypot(const Vectorized& b) const { + return { + Sleef_hypotf4_u05(_vec0, b._vec0), Sleef_hypotf4_u05(_vec1, b._vec1)}; + } + + Vectorized nextafter(const Vectorized& b) const { + return { + Sleef_nextafterf4(_vec0, b._vec0), Sleef_nextafterf4(_vec1, b._vec1)}; + } + + Vectorized igamma(const Vectorized& x) const { + return mapbi(calc_igamma, x); + } + + Vectorized igammac(const Vectorized& x) const { + return mapbi(calc_igammac, x); + } + + Vectorized i0() const { + return map(calc_i0); + } + + Vectorized i0e() const { + return map(calc_i0e); + } + + Vectorized digamma() const { + return map(calc_digamma); + } + + DEFINE_MEMBER_OP(operator==, float, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, float, vec_cmpne) + DEFINE_MEMBER_OP(operator<, float, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, float, vec_cmple) + DEFINE_MEMBER_OP(operator>, float, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, float, vec_cmpge) + DEFINE_MEMBER_OP_AND_ONE(eq, float, vec_cmpeq) + DEFINE_MEMBER_OP_AND_ONE(ne, float, vec_cmpne) + DEFINE_MEMBER_OP_AND_ONE(lt, float, vec_cmplt) + DEFINE_MEMBER_OP_AND_ONE(le, float, vec_cmple) + DEFINE_MEMBER_OP_AND_ONE(gt, float, vec_cmpgt) + DEFINE_MEMBER_OP_AND_ONE(ge, float, vec_cmpge) + DEFINE_MEMBER_OP(operator+, float, vec_add) + DEFINE_MEMBER_OP(operator-, float, vec_sub) + DEFINE_MEMBER_OP(operator*, float, vec_mul) + DEFINE_MEMBER_OP(operator/, float, vec_div) + DEFINE_MEMBER_OP(maximum, float, vec_max_nan2) + DEFINE_MEMBER_OP(minimum, float, vec_min_nan2) + DEFINE_MEMBER_OP(operator&, float, vec_and) + DEFINE_MEMBER_OP(operator|, float, vec_or) + DEFINE_MEMBER_OP(operator^, float, vec_xor) + DEFINE_MEMBER_TERNARY_OP(madd, float, vec_madd) +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return a.minimum(b); +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator+(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator-(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator*(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator/(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_div(a.vec0(), b.vec0()), vec_div(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator&(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator|(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator^(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())}; +} + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..b5354e66d00eb9c342fea10c7024bf33df3fe3cd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h @@ -0,0 +1,435 @@ +#pragma once + +#include +#include +#include +namespace at { +namespace vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + union { + struct { + vint16 _vec0; + vint16 _vec1; + }; + struct { + vbool16 _vecb0; + vbool16 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + using value_type = int16_t; + using vec_internal_type = vint16; + using vec_internal_mask_type = vbool16; + using size_type = int; + static constexpr size_type size() { + return 16; + } + Vectorized() {} + C10_ALWAYS_INLINE Vectorized(vint16 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vectorized(vbool16 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vectorized(vint16 v1, vint16 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vectorized(vbool16 v1, vbool16 v2) + : _vecb0{v1}, _vecb1{v2} {} + C10_ALWAYS_INLINE Vectorized(int16_t scalar) + : _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {} + + C10_ALWAYS_INLINE Vectorized( + int16_t scalar1, + int16_t scalar2, + int16_t scalar3, + int16_t scalar4, + int16_t scalar5, + int16_t scalar6, + int16_t scalar7, + int16_t scalar8, + int16_t scalar9, + int16_t scalar10, + int16_t scalar11, + int16_t scalar12, + int16_t scalar13, + int16_t scalar14, + int16_t scalar15, + int16_t scalar16) + : _vec0{vint16{ + scalar1, + scalar2, + scalar3, + scalar4, + scalar5, + scalar6, + scalar7, + scalar8}}, + _vec1{vint16{ + scalar9, + scalar10, + scalar11, + scalar12, + scalar13, + scalar14, + scalar15, + scalar16}} {} + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return a; + } + + template + static std::enable_if_t<(mask & 65535) == 65535, Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return b; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return {b._vec0, a._vec1}; + } + + template + static std::enable_if_t<(mask > 0 && mask < 255), Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + constexpr int16_t g0 = (mask & 1) * 0xffff; + constexpr int16_t g1 = ((mask & 2) >> 1) * 0xffff; + constexpr int16_t g2 = ((mask & 4) >> 2) * 0xffff; + constexpr int16_t g3 = ((mask & 8) >> 3) * 0xffff; + constexpr int16_t g4 = ((mask & 16) >> 4) * 0xffff; + constexpr int16_t g5 = ((mask & 32) >> 5) * 0xffff; + constexpr int16_t g6 = ((mask & 64) >> 6) * 0xffff; + constexpr int16_t g7 = ((mask & 128) >> 7) * 0xffff; + const vint16 mask_1st = vint16{g0, g1, g2, g3, g4, g5, g6, g7}; + + return {(vint16)vec_sel(a._vec0, b._vec0, (vbool16)mask_1st), a._vec1}; + } + + template + static std::enable_if_t< + (mask > 255 && (mask & 65535) != 65535 && ((mask & 255) == 255)), + Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + constexpr int16_t g0_2 = (mask & 1) * 0xffff; + constexpr int16_t g1_2 = ((mask & 2) >> 1) * 0xffff; + constexpr int16_t g2_2 = ((mask & 4) >> 2) * 0xffff; + constexpr int16_t g3_2 = ((mask & 8) >> 3) * 0xffff; + constexpr int16_t g4_2 = ((mask & 16) >> 4) * 0xffff; + constexpr int16_t g5_2 = ((mask & 32) >> 5) * 0xffff; + constexpr int16_t g6_2 = ((mask & 64) >> 6) * 0xffff; + constexpr int16_t g7_2 = ((mask & 128) >> 7) * 0xffff; + + const vint16 mask_2nd = + vint16{g0_2, g1_2, g2_2, g3_2, g4_2, g5_2, g6_2, g7_2}; + // generated masks + return {b._vec0, (vint16)vec_sel(a._vec1, b._vec1, (vbool16)mask_2nd)}; + } + + template + static std::enable_if_t< + (mask > 255 && ((mask & 65535) != 65535) && ((mask & 255) == 0)), + Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + constexpr int16_t mask2 = (mask & 65535) >> 16; + constexpr int16_t g0_2 = (mask & 1) * 0xffff; + constexpr int16_t g1_2 = ((mask & 2) >> 1) * 0xffff; + constexpr int16_t g2_2 = ((mask & 4) >> 2) * 0xffff; + constexpr int16_t g3_2 = ((mask & 8) >> 3) * 0xffff; + constexpr int16_t g4_2 = ((mask & 16) >> 4) * 0xffff; + constexpr int16_t g5_2 = ((mask & 32) >> 5) * 0xffff; + constexpr int16_t g6_2 = ((mask & 64) >> 6) * 0xffff; + constexpr int16_t g7_2 = ((mask & 128) >> 7) * 0xffff; + + const vint16 mask_2nd = + vint16{g0_2, g1_2, g2_2, g3_2, g4_2, g5_2, g6_2, g7_2}; + // generated masks + return {a, (vint16)vec_sel(a._vec1, b._vec1, (vbool16)mask_2nd)}; + } + + template + static std::enable_if_t< + (mask > 255 && ((mask & 65535) != 65535) && ((mask & 255) != 0) && + ((mask & 255) != 255)), + Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + constexpr int16_t g0 = (mask & 1) * 0xffff; + constexpr int16_t g1 = ((mask & 2) >> 1) * 0xffff; + constexpr int16_t g2 = ((mask & 4) >> 2) * 0xffff; + constexpr int16_t g3 = ((mask & 8) >> 3) * 0xffff; + constexpr int16_t g4 = ((mask & 16) >> 4) * 0xffff; + constexpr int16_t g5 = ((mask & 32) >> 5) * 0xffff; + constexpr int16_t g6 = ((mask & 64) >> 6) * 0xffff; + constexpr int16_t g7 = ((mask & 128) >> 7) * 0xffff; + constexpr int16_t mask2 = (mask & 65535) >> 16; + constexpr int16_t g0_2 = (mask & 1) * 0xffff; + constexpr int16_t g1_2 = ((mask & 2) >> 1) * 0xffff; + constexpr int16_t g2_2 = ((mask & 4) >> 2) * 0xffff; + constexpr int16_t g3_2 = ((mask & 8) >> 3) * 0xffff; + constexpr int16_t g4_2 = ((mask & 16) >> 4) * 0xffff; + constexpr int16_t g5_2 = ((mask & 32) >> 5) * 0xffff; + constexpr int16_t g6_2 = ((mask & 64) >> 6) * 0xffff; + constexpr int16_t g7_2 = ((mask & 128) >> 7) * 0xffff; + + const vint16 mask_1st = vint16{g0, g1, g2, g3, g4, g5, g6, g7}; + const vint16 mask_2nd = + vint16{g0_2, g1_2, g2_2, g3_2, g4_2, g5_2, g6_2, g7_2}; + // generated masks + return { + (vint16)vec_sel(a._vec0, b._vec0, (vbool16)mask_1st), + (vint16)vec_sel(a._vec1, b._vec1, (vbool16)mask_2nd)}; + } + + static Vectorized C10_ALWAYS_INLINE blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + // the mask used here returned by comparision of vec256 + // assuming this we can use the same mask directly with vec_sel + // warning intel style mask will not work properly + return { + vec_sel(a._vec0, b._vec0, mask._vecb0), + vec_sel(a._vec1, b._vec1, mask._vecb1)}; + } + + template + static Vectorized arange( + int16_t base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step, + base + 8 * step, + base + 9 * step, + base + 10 * step, + base + 11 * step, + base + 12 * step, + base + 13 * step, + base + 14 * step, + base + 15 * step); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + size_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + case 8: + return blend<255>(a, b); + case 9: + return blend<511>(a, b); + case 10: + return blend<1023>(a, b); + case 11: + return blend<2047>(a, b); + case 12: + return blend<4095>(a, b); + case 13: + return blend<8191>(a, b); + case 14: + return blend<16383>(a, b); + case 15: + return blend<32767>(a, b); + } + return b; + } + static Vectorized C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + + __at_align__ value_type tmp_values[size()] = {}; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, tmp_values); + vec_vsx_st(_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + const int16_t& operator[](int idx) const = delete; + int16_t& operator[](int idx) = delete; + + Vectorized angle() const { + return blendv( + Vectorized(0), + Vectorized(c10::pi), + *this < Vectorized(0)); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return Vectorized{0}; + } + Vectorized conj() const { + return *this; + } + + Vectorized C10_ALWAYS_INLINE abs() const { + return {vec_abs(_vec0), vec_abs(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE neg() const { + return {vec_neg(_vec0), vec_neg(_vec1)}; + } + + DEFINE_MEMBER_UNARY_OP(operator~, int16_t, vec_not) + DEFINE_MEMBER_OP(operator==, int16_t, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, int16_t, vec_cmpne) + DEFINE_MEMBER_OP(operator<, int16_t, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, int16_t, vec_cmple) + DEFINE_MEMBER_OP(operator>, int16_t, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, int16_t, vec_cmpge) + DEFINE_MEMBER_OP_AND_ONE(eq, int16_t, vec_cmpeq) + DEFINE_MEMBER_OP_AND_ONE(ne, int16_t, vec_cmpne) + DEFINE_MEMBER_OP_AND_ONE(lt, int16_t, vec_cmplt) + DEFINE_MEMBER_OP_AND_ONE(le, int16_t, vec_cmple) + DEFINE_MEMBER_OP_AND_ONE(gt, int16_t, vec_cmpgt) + DEFINE_MEMBER_OP_AND_ONE(ge, int16_t, vec_cmpge) + DEFINE_MEMBER_OP(operator+, int16_t, vec_add) + DEFINE_MEMBER_OP(operator-, int16_t, vec_sub) + DEFINE_MEMBER_OP(operator*, int16_t, vec_mul) + DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, int16_t, /) + DEFINE_MEMBER_OP(maximum, int16_t, vec_max) + DEFINE_MEMBER_OP(minimum, int16_t, vec_min) + DEFINE_MEMBER_OP(operator&, int16_t, vec_and) + DEFINE_MEMBER_OP(operator|, int16_t, vec_or) + DEFINE_MEMBER_OP(operator^, int16_t, vec_xor) +}; + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + vuint16 shift_vec0 = reinterpret_cast(b.vec0()); + vuint16 shift_vec1 = reinterpret_cast(b.vec1()); + return Vectorized{ + vec_sl(a.vec0(), shift_vec0), vec_sl(a.vec1(), shift_vec1)}; +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + vuint16 shift_vec0 = reinterpret_cast(b.vec0()); + vuint16 shift_vec1 = reinterpret_cast(b.vec1()); + return Vectorized{ + vec_sr(a.vec0(), shift_vec0), vec_sr(a.vec1(), shift_vec1)}; +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return a.minimum(b); +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator+(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator-(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator*(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator/(const Vectorized& a, const Vectorized& b) { + return Vectorized{a.vec0() / b.vec0(), a.vec1() / b.vec1()}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator&(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator|(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator^(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())}; +} + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..8bb885821c87c90bf57b6f415b9dd933c866991a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h @@ -0,0 +1,365 @@ +#pragma once + +#include +#include +#include +namespace at { +namespace vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + union { + struct { + vint32 _vec0; + vint32 _vec1; + }; + struct { + vbool32 _vecb0; + vbool32 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + using value_type = int32_t; + using vec_internal_type = vint32; + using vec_internal_mask_type = vbool32; + using size_type = int; + static constexpr size_type size() { + return 8; + } + Vectorized() {} + C10_ALWAYS_INLINE Vectorized(vint32 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vectorized(vbool32 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vectorized(vint32 v1, vint32 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vectorized(vbool32 v1, vbool32 v2) + : _vecb0{v1}, _vecb1{v2} {} + C10_ALWAYS_INLINE Vectorized(int32_t scalar) + : _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {} + C10_ALWAYS_INLINE Vectorized( + int32_t scalar1, + int32_t scalar2, + int32_t scalar3, + int32_t scalar4, + int32_t scalar5, + int32_t scalar6, + int32_t scalar7, + int32_t scalar8) + : _vec0{vint32{scalar1, scalar2, scalar3, scalar4}}, + _vec1{vint32{scalar5, scalar6, scalar7, scalar8}} {} + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return a; + } + + template + static std::enable_if_t<(mask & 255) == 255, Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return b; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return {b._vec0, a._vec1}; + } + + template + static std::enable_if_t<(mask > 0 && mask < 15), Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + constexpr uint32_t g0 = (mask & 1) * 0xffffffff; + constexpr uint32_t g1 = ((mask & 2) >> 1) * 0xffffffff; + constexpr uint32_t g2 = ((mask & 4) >> 2) * 0xffffffff; + constexpr uint32_t g3 = ((mask & 8) >> 3) * 0xffffffff; + const vbool32 mask_1st = (vbool32){g0, g1, g2, g3}; + + return {(vint32)vec_sel(a._vec0, b._vec0, (vbool32)mask_1st), a._vec1}; + } + + template + static std::enable_if_t< + (mask > 15 && (mask & 255) != 255 && ((mask & 15) == 15)), + Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + constexpr uint32_t mask2 = (mask & 255) >> 4; + constexpr uint32_t g0_2 = (mask2 & 1) * 0xffffffff; + constexpr uint32_t g1_2 = ((mask2 & 2) >> 1) * 0xffffffff; + constexpr uint32_t g2_2 = ((mask2 & 4) >> 2) * 0xffffffff; + constexpr uint32_t g3_2 = ((mask2 & 8) >> 3) * 0xffffffff; + + const vbool32 mask_2nd = (vbool32){g0_2, g1_2, g2_2, g3_2}; + // generated masks + return {b._vec0, (vint32)vec_sel(a._vec1, b._vec1, (vbool32)mask_2nd)}; + } + + template + static std::enable_if_t< + (mask > 15 && ((mask & 255) != 255) && ((mask & 15) == 0)), + Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + constexpr uint32_t mask2 = (mask & 255) >> 4; + constexpr uint32_t g0_2 = (mask2 & 1) * 0xffffffff; + constexpr uint32_t g1_2 = ((mask2 & 2) >> 1) * 0xffffffff; + constexpr uint32_t g2_2 = ((mask2 & 4) >> 2) * 0xffffffff; + constexpr uint32_t g3_2 = ((mask2 & 8) >> 3) * 0xffffffff; + + const vbool32 mask_2nd = (vbool32){g0_2, g1_2, g2_2, g3_2}; + // generated masks + return {a, (vint32)vec_sel(a._vec1, b._vec1, (vbool32)mask_2nd)}; + } + + template + static std::enable_if_t< + (mask > 15 && ((mask & 255) != 255) && ((mask & 15) != 0) && + ((mask & 15) != 15)), + Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + constexpr uint32_t g0 = (mask & 1) * 0xffffffff; + constexpr uint32_t g1 = ((mask & 2) >> 1) * 0xffffffff; + constexpr uint32_t g2 = ((mask & 4) >> 2) * 0xffffffff; + constexpr uint32_t g3 = ((mask & 8) >> 3) * 0xffffffff; + constexpr uint32_t mask2 = (mask & 255) >> 4; + constexpr uint32_t g0_2 = (mask2 & 1) * 0xffffffff; + constexpr uint32_t g1_2 = ((mask2 & 2) >> 1) * 0xffffffff; + constexpr uint32_t g2_2 = ((mask2 & 4) >> 2) * 0xffffffff; + constexpr uint32_t g3_2 = ((mask2 & 8) >> 3) * 0xffffffff; + + const vbool32 mask_1st = (vbool32){g0, g1, g2, g3}; + const vbool32 mask_2nd = (vbool32){g0_2, g1_2, g2_2, g3_2}; + // generated masks + return { + (vint32)vec_sel(a._vec0, b._vec0, (vbool32)mask_1st), + (vint32)vec_sel(a._vec1, b._vec1, (vbool32)mask_2nd)}; + } + + static Vectorized C10_ALWAYS_INLINE blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + // the mask used here returned by comparision of vec256 + // assuming this we can use the same mask directly with vec_sel + // warning intel style mask will not work properly + return { + vec_sel(a._vec0, b._vec0, mask._vecb0), + vec_sel(a._vec1, b._vec1, mask._vecb1)}; + } + + template + static Vectorized arange( + int32_t base = 0.f, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + size_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + } + + return b; + } + static Vectorized C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + + __at_align__ value_type tmp_values[size()] = {}; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, tmp_values); + vec_vsx_st(_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + const int32_t& operator[](int idx) const = delete; + int32_t& operator[](int idx) = delete; + + Vectorized angle() const { + return blendv( + Vectorized(0), + Vectorized(c10::pi), + *this < Vectorized(0)); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return Vectorized{0}; + } + Vectorized conj() const { + return *this; + } + + Vectorized C10_ALWAYS_INLINE abs() const { + return {vec_abs(_vec0), vec_abs(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE neg() const { + return {vec_neg(_vec0), vec_neg(_vec1)}; + } + + DEFINE_MEMBER_UNARY_OP(operator~, int32_t, vec_not) + DEFINE_MEMBER_OP(operator==, int32_t, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, int32_t, vec_cmpne) + DEFINE_MEMBER_OP(operator<, int32_t, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, int32_t, vec_cmple) + DEFINE_MEMBER_OP(operator>, int32_t, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, int32_t, vec_cmpge) + DEFINE_MEMBER_OP_AND_ONE(eq, int32_t, vec_cmpeq) + DEFINE_MEMBER_OP_AND_ONE(ne, int32_t, vec_cmpne) + DEFINE_MEMBER_OP_AND_ONE(lt, int32_t, vec_cmplt) + DEFINE_MEMBER_OP_AND_ONE(le, int32_t, vec_cmple) + DEFINE_MEMBER_OP_AND_ONE(gt, int32_t, vec_cmpgt) + DEFINE_MEMBER_OP_AND_ONE(ge, int32_t, vec_cmpge) + DEFINE_MEMBER_OP(operator+, int32_t, vec_add) + DEFINE_MEMBER_OP(operator-, int32_t, vec_sub) + DEFINE_MEMBER_OP(operator*, int32_t, vec_mul) + DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, int32_t, /) + DEFINE_MEMBER_OP(maximum, int32_t, vec_max) + DEFINE_MEMBER_OP(minimum, int32_t, vec_min) + DEFINE_MEMBER_OP(operator&, int32_t, vec_and) + DEFINE_MEMBER_OP(operator|, int32_t, vec_or) + DEFINE_MEMBER_OP(operator^, int32_t, vec_xor) +}; + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + vuint32 shift_vec0 = reinterpret_cast(b.vec0()); + vuint32 shift_vec1 = reinterpret_cast(b.vec1()); + return Vectorized{ + vec_sl(a.vec0(), shift_vec0), vec_sl(a.vec1(), shift_vec1)}; +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + vuint32 shift_vec0 = reinterpret_cast(b.vec0()); + vuint32 shift_vec1 = reinterpret_cast(b.vec1()); + return Vectorized{ + vec_sr(a.vec0(), shift_vec0), vec_sr(a.vec1(), shift_vec1)}; +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return a.minimum(b); +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator+(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator-(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator*(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator/(const Vectorized& a, const Vectorized& b) { + return Vectorized{a.vec0() / b.vec0(), a.vec1() / b.vec1()}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator&(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator|(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator^(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())}; +} + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..f04aeaf1038bb013b73bb321641f3b0d23258b61 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h @@ -0,0 +1,319 @@ +#pragma once + +#include +#include +#include +namespace at { +namespace vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + union { + struct { + vint64 _vec0; + vint64 _vec1; + }; + struct { + vbool64 _vecb0; + vbool64 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + using value_type = int64_t; + using vec_internal_type = vint64; + using vec_internal_mask_type = vbool64; + using size_type = int; + using ElementType = signed long long; + static constexpr size_type size() { + return 4; + } + Vectorized() {} + C10_ALWAYS_INLINE Vectorized(vint64 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vectorized(vbool64 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vectorized(vint64 v1, vint64 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vectorized(vbool64 v1, vbool64 v2) + : _vecb0{v1}, _vecb1{v2} {} + C10_ALWAYS_INLINE Vectorized(int64_t scalar) + : _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {} + C10_ALWAYS_INLINE Vectorized( + int64_t scalar1, + int64_t scalar2, + int64_t scalar3, + int64_t scalar4) + : _vec0{vint64{scalar1, scalar2}}, _vec1{vint64{scalar3, scalar4}} {} + + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return a; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return {b._vec0, a._vec1}; + } + + template + static std::enable_if_t<(mask & 15) == 15, Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return b; + } + + template + static std::enable_if_t<(mask > 0 && mask < 3), Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + constexpr uint64_t g0 = (mask & 1) * 0xffffffffffffffff; + constexpr uint64_t g1 = ((mask & 2) >> 1) * 0xffffffffffffffff; + const vbool64 mask_1st = (vbool64){g0, g1}; + return {(vint64)vec_sel(a._vec0, b._vec0, (vbool64)mask_1st), a._vec1}; + } + + template + static std::enable_if_t<(mask > 3) && (mask & 3) == 0, Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + constexpr uint64_t g0_2 = ((mask & 4) >> 2) * 0xffffffffffffffff; + constexpr uint64_t g1_2 = ((mask & 8) >> 3) * 0xffffffffffffffff; + + const vbool64 mask_2nd = (vbool64){g0_2, g1_2}; + return {a._vec0, (vint64)vec_sel(a._vec1, b._vec1, (vbool64)mask_2nd)}; + } + + template + static std::enable_if_t< + (mask > 3) && (mask & 3) != 0 && (mask & 15) != 15, + Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + constexpr uint64_t g0 = (mask & 1) * 0xffffffffffffffff; + constexpr uint64_t g1 = ((mask & 2) >> 1) * 0xffffffffffffffff; + constexpr uint64_t g0_2 = ((mask & 4) >> 2) * 0xffffffffffffffff; + constexpr uint64_t g1_2 = ((mask & 8) >> 3) * 0xffffffffffffffff; + + const vbool64 mask_1st = (vbool64){g0, g1}; + const vbool64 mask_2nd = (vbool64){g0_2, g1_2}; + return { + (vint64)vec_sel(a._vec0, b._vec0, (vbool64)mask_1st), + (vint64)vec_sel(a._vec1, b._vec1, (vbool64)mask_2nd)}; + } + + static Vectorized C10_ALWAYS_INLINE blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + // the mask used here returned by comparision of vec256 + + return { + vec_sel(a._vec0, b._vec0, mask._vecb0), + vec_sel(a._vec1, b._vec1, mask._vecb1)}; + } + template + static Vectorized arange( + int64_t base = 0., + step_t step = static_cast(1)) { + return Vectorized( + base, base + step, base + 2 * step, base + 3 * step); + } + + static Vectorized C10_ALWAYS_INLINE + set(const Vectorized& a, + const Vectorized& b, + size_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + } + + return b; + } + static Vectorized C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + static_assert(sizeof(double) == sizeof(value_type)); + const double* dptr = reinterpret_cast(ptr); + return {// treat it as double load + (vint64)vec_vsx_ld(offset0, dptr), + (vint64)vec_vsx_ld(offset16, dptr)}; + } + + __at_align__ double tmp_values[size()] = {}; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return { + (vint64)vec_vsx_ld(offset0, tmp_values), + (vint64)vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + double* dptr = reinterpret_cast(ptr); + vec_vsx_st((vfloat64)_vec0, offset0, dptr); + vec_vsx_st((vfloat64)_vec1, offset16, dptr); + } else if (count > 0) { + __at_align__ double tmp_values[size()]; + vec_vsx_st((vfloat64)_vec0, offset0, tmp_values); + vec_vsx_st((vfloat64)_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + const int64_t& operator[](int idx) const = delete; + int64_t& operator[](int idx) = delete; + + Vectorized angle() const { + return blendv( + Vectorized(0), + Vectorized(c10::pi), + *this < Vectorized(0)); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return Vectorized{0}; + } + Vectorized conj() const { + return *this; + } + + Vectorized C10_ALWAYS_INLINE abs() const { + return {vec_abs(_vec0), vec_abs(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE neg() const { + return {vec_neg(_vec0), vec_neg(_vec1)}; + } + + DEFINE_MEMBER_UNARY_OP(operator~, int64_t, vec_not) + DEFINE_MEMBER_OP(operator==, int64_t, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, int64_t, vec_cmpne) + DEFINE_MEMBER_OP(operator<, int64_t, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, int64_t, vec_cmple) + DEFINE_MEMBER_OP(operator>, int64_t, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, int64_t, vec_cmpge) + DEFINE_MEMBER_OP_AND_ONE(eq, int64_t, vec_cmpeq) + DEFINE_MEMBER_OP_AND_ONE(ne, int64_t, vec_cmpne) + DEFINE_MEMBER_OP_AND_ONE(lt, int64_t, vec_cmplt) + DEFINE_MEMBER_OP_AND_ONE(le, int64_t, vec_cmple) + DEFINE_MEMBER_OP_AND_ONE(gt, int64_t, vec_cmpgt) + DEFINE_MEMBER_OP_AND_ONE(ge, int64_t, vec_cmpge) + DEFINE_MEMBER_OP(operator+, int64_t, vec_add) + DEFINE_MEMBER_OP(operator-, int64_t, vec_sub) + DEFINE_MEMBER_OP(operator*, int64_t, vec_mul) + DEFINE_MEMBER_OP(operator/, int64_t, vec_div) + DEFINE_MEMBER_OP(maximum, int64_t, vec_max) + DEFINE_MEMBER_OP(minimum, int64_t, vec_min) + DEFINE_MEMBER_OP(operator&, int64_t, vec_and) + DEFINE_MEMBER_OP(operator|, int64_t, vec_or) + DEFINE_MEMBER_OP(operator^, int64_t, vec_xor) +}; + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + vuint64 shift_vec0 = reinterpret_cast(b.vec0()); + vuint64 shift_vec1 = reinterpret_cast(b.vec1()); + return Vectorized{ + vec_sl(a.vec0(), shift_vec0), vec_sl(a.vec1(), shift_vec1)}; +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + vuint64 shift_vec0 = reinterpret_cast(b.vec0()); + vuint64 shift_vec1 = reinterpret_cast(b.vec1()); + return Vectorized{ + vec_sr(a.vec0(), shift_vec0), vec_sr(a.vec1(), shift_vec1)}; +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return a.minimum(b); +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator+(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator-(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator*(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator/(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_div(a.vec0(), b.vec0()), vec_div(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator&(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator|(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator^(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())}; +} + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..4466dcd2add0883698cae100f0c620a9ce661b2d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h @@ -0,0 +1,301 @@ +#pragma once + +#include +#include +#include +#include +#include + +// This file defines Vectorized<> for the quantized types. +// +// +// Currently, we simply use these classes as efficient converters between +// the quantized types and Vectorized, usually in bandwidth-bound cases +// where doing the arithmetic in full-precision is acceptable (e.g. +// elementwise operators). +// +// +// Conversions are as follows: +// Vectorized -> 1x Vectorized +// +// The size of the returned float vector is specified by the special +// constexpr function float_num_vecs. The type of the value returned +// from dequantize (and expected as an argument to quantize) is +// specified by float_vec_return_type. +// +// When writing kernels with these vectors, it is expected that floating- +// point operations will be carried out in a loop over +// Vectorized::float_num_vecs iterations. + +namespace at { +namespace vec { +inline namespace CPU_CAPABILITY { + +template <> +struct is_vec_specialized_for : std::bool_constant {}; +template <> +struct Vectorized { + private: + union { + struct { + vint32 _vec0; + vint32 _vec1; + }; + struct { + vbool32 _vecb0; + vbool32 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + Vectorized() {} + + using size_type = int; + static constexpr size_type size() { + return 8; + } + + static constexpr size_t float_num_vecs() { + return 1; + } + static constexpr int int_num_vecs() { + return 1; + } + using float_vec_return_type = std::array, 1>; + using int_vec_return_type = std::array, 1>; + using value_type = c10::qint32::underlying; + using vec_internal_type = vint32; + using vec_internal_mask_type = vbool32; + C10_ALWAYS_INLINE Vectorized(vint32 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vectorized(vbool32 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vectorized(vint32 v1, vint32 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vectorized(vbool32 v1, vbool32 v2) + : _vecb0{v1}, _vecb1{v2} {} + + Vectorized(const c10::qint32& val) + : _vec0(vec_splats(val.val_)), _vec1(vec_splats(val.val_)) {} + + static Vectorized C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + + __at_align__ value_type tmp_values[size()] = {}; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, tmp_values); + vec_vsx_st(_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_zp_premul) const { + vfloat32 float_vals0 = vec_float(_vec0); + vfloat32 float_vals1 = vec_float(_vec1); + vfloat32 scale_vec0 = scale.vec0(); + vfloat32 scale_vec1 = scale.vec1(); + vfloat32 zero_point_vec0 = zero_point.vec0(); + vfloat32 zero_point_vec1 = zero_point.vec1(); + + vfloat32 vec_sub_zero_point_0 = vec_sub(float_vals0, zero_point_vec0); + vfloat32 vec_sub_zero_point_1 = vec_sub(float_vals1, zero_point_vec1); + Vectorized vf0 = { + vec_mul(scale_vec0, vec_sub_zero_point_0), + vec_mul(scale_vec1, vec_sub_zero_point_1)}; + return {vf0}; + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + vfloat32 float_vals0 = vec_float(_vec0); + vfloat32 float_vals1 = vec_float(_vec1); + vfloat32 scale_vec0 = scale.vec0(); + vfloat32 scale_vec1 = scale.vec1(); + vfloat32 zero_point0 = zero_point.vec0(); + vfloat32 zero_point1 = zero_point.vec1(); + return {Vectorized{ + (float_vals0 - zero_point0) * scale_vec0, + (float_vals1 - zero_point1) * scale_vec1}}; + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + Vectorized retval; + + const vint32 vmin = vec_splats(std::numeric_limits::min()); + const vint32 vmax = vec_splats(std::numeric_limits::max()); + vfloat32 inverse_scale_v = vec_splats(inverse_scale); + vfloat32 vec_zero_point = vec_splats((float)(zero_point)); + Vectorized vf0 = rhs[0]; + + vfloat32 vecf0 = vf0.vec0(); + vfloat32 vecf1 = vf0.vec1(); + vecf0 = vec_mul(vecf0, inverse_scale_v); + vecf1 = vec_mul(vecf1, inverse_scale_v); + vecf0 = vec_add(vec_rint(vecf0), vec_zero_point); + vecf1 = vec_add(vec_rint(vecf1), vec_zero_point); + vint32 veci0 = vec_signed(vecf0); + vint32 veci1 = vec_signed(vecf1); + + veci0 = vec_max(veci0, vmin); + veci1 = vec_max(veci1, vmin); + veci0 = vec_min(veci0, vmax); + veci1 = vec_min(veci1, vmax); + + return {veci0, veci1}; + } + + Vectorized relu(Vectorized zero_point) const { + return {vec_max(_vec0, zero_point._vec0), vec_max(_vec1, zero_point._vec1)}; + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) const { + vint32 max0 = vec_max(_vec0, zero_point._vec0); + vint32 max1 = vec_max(_vec1, zero_point._vec1); + return {vec_min(max0, q_six._vec0), vec_min(max1, q_six._vec1)}; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + return {*this - b}; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + const vint32 vmin = vec_splats(std::numeric_limits::min()); + const vint32 vmax = vec_splats(std::numeric_limits::max()); + vfloat32 vec_mult = vec_splats(multiplier); + vint32 vec_zero_point = vec_splats(zero_point); + Vectorized vi = inp[0]; + vfloat32 vecf0 = vec_float(vi.vec0()); + vfloat32 vecf1 = vec_float(vi.vec1()); + + vecf0 = vec_mul(vecf0, vec_mult); + vecf1 = vec_mul(vecf1, vec_mult); + + vecf0 = vec_rint(vecf0); + vecf1 = vec_rint(vecf1); + + vint32 veci0 = vec_add(vec_signed(vecf0), vec_zero_point); + vint32 veci1 = vec_add(vec_signed(vecf1), vec_zero_point); + + veci0 = vec_max(veci0, vmin); + veci1 = vec_max(veci1, vmin); + veci0 = vec_min(veci0, vmax); + veci1 = vec_min(veci1, vmax); + + return {veci0, veci1}; + } + + DEFINE_MEMBER_OP(operator==, c10::qint32, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, c10::qint32, vec_cmpne) + DEFINE_MEMBER_OP(operator<, c10::qint32, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, c10::qint32, vec_cmple) + DEFINE_MEMBER_OP(operator>, c10::qint32, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, c10::qint32, vec_cmpge) + DEFINE_MEMBER_OP(operator+, c10::qint32, vec_add) + DEFINE_MEMBER_OP(operator-, c10::qint32, vec_sub) + DEFINE_MEMBER_OP(operator*, c10::qint32, vec_mul) + DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, c10::qint32, /) + DEFINE_MEMBER_OP(maximum, c10::qint32, vec_max) + DEFINE_MEMBER_OP(minimum, c10::qint32, vec_min) + DEFINE_MEMBER_OP(operator&, c10::qint32, vec_and) + DEFINE_MEMBER_OP(operator|, c10::qint32, vec_or) + DEFINE_MEMBER_OP(operator^, c10::qint32, vec_xor) +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return a.minimum(b); +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator+(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator-(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator*(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator/(const Vectorized& a, const Vectorized& b) { + return Vectorized{a.vec0() / b.vec0(), a.vec1() / b.vec1()}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator&(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator|(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator^(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())}; +} + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_qint8_vsx.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_qint8_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..f32c79a87fbede1c2081dc5ae3f8964e681e19be --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_qint8_vsx.h @@ -0,0 +1,512 @@ +#pragma once + +#include +#include +#include +#include +#include + +// This file defines Vectorized<> for the quantized types. +// +// +// Currently, we simply use these classes as efficient converters between +// the quantized types and Vectorized, usually in bandwidth-bound cases +// where doing the arithmetic in full-precision is acceptable (e.g. +// elementwise operators). +// +// +// Conversions are as follows: +// Vectorized -> 4x Vectorized +// +// The size of the returned float vector is specified by the special +// constexpr function float_num_vecs. The type of the value returned +// from dequantize (and expected as an argument to quantize) is +// specified by float_vec_return_type. +// +// When writing kernels with these vectors, it is expected that floating- +// point operations will be carried out in a loop over +// Vectorized::float_num_vecs iterations. + +namespace at { +namespace vec { +inline namespace CPU_CAPABILITY { + +template <> +struct is_vec_specialized_for : std::bool_constant {}; +template <> +struct Vectorized { + private: + union { + struct { + vint8 _vec0; + vint8 _vec1; + }; + struct { + vbool8 _vecb0; + vbool8 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + Vectorized() {} + using size_type = int; + static constexpr size_type size() { + return 32; + } + + static constexpr size_t float_num_vecs() { + return 4; + } + static constexpr int int_num_vecs() { + return 4; + } + using float_vec_return_type = std::array, 4>; + using int_vec_return_type = std::array, 4>; + using value_type = typename c10::qint8::underlying; + using vec_internal_type = vint8; + using vec_internal_mask_type = vbool8; + // Broadcast constructor + C10_ALWAYS_INLINE Vectorized(const c10::qint8& val) + : _vec0{vec_splats(val.val_)}, _vec1{vec_splats(val.val_)} {} + + C10_ALWAYS_INLINE Vectorized(const Vectorized& other) + : _vec0{other._vec0}, _vec1(other._vec1) {} + + C10_ALWAYS_INLINE Vectorized(vint8 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vectorized(vbool8 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vectorized(vint8 v1, vint8 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vectorized(vbool8 v1, vbool8 v2) : _vecb0{v1}, _vecb1{v2} {} + + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + static C10_ALWAYS_INLINE Vectorized loadu( + const void* ptr, + int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + __at_align__ value_type tmp_values[size()] = {}; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, tmp_values); + vec_vsx_st(_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + + public: + float_vec_return_type C10_ALWAYS_INLINE dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_zp_premul) const { + vint16 vecshi0 = vec_unpackh(_vec0); + vint16 vecshi1 = vec_unpackl(_vec0); + + vint16 vecshi2 = vec_unpackh(_vec1); + vint16 vecshi3 = vec_unpackl(_vec1); + + vint32 veci0 = vec_unpackh(vecshi0); + vint32 veci1 = vec_unpackl(vecshi0); + + vint32 veci2 = vec_unpackh(vecshi1); + vint32 veci3 = vec_unpackl(vecshi1); + + vint32 veci4 = vec_unpackh(vecshi2); + vint32 veci5 = vec_unpackl(vecshi2); + + vint32 veci6 = vec_unpackh(vecshi3); + vint32 veci7 = vec_unpackl(vecshi3); + + vfloat32 vecf0_0 = vec_float(veci0); + vfloat32 vecf1_0 = vec_float(veci1); + + vfloat32 vecf0_1 = vec_float(veci2); + vfloat32 vecf1_1 = vec_float(veci3); + + vfloat32 vecf0_2 = vec_float(veci4); + vfloat32 vecf1_2 = vec_float(veci5); + + vfloat32 vecf0_3 = vec_float(veci6); + vfloat32 vecf1_3 = vec_float(veci7); + vfloat32 scale_vec0 = scale.vec0(); + vfloat32 scale_vec1 = scale.vec1(); + + vfloat32 zero_point_vec0 = zero_point.vec0(); + vfloat32 zero_point_vec1 = zero_point.vec1(); + + vfloat32 vec_substract_src_zp0_0 = vec_sub(vecf0_0, zero_point_vec0); + vfloat32 vec_substract_src_zp1_0 = vec_sub(vecf1_0, zero_point_vec1); + Vectorized vf0_zp = { + vec_mul(scale_vec0, vec_substract_src_zp0_0), + vec_mul(scale_vec1, vec_substract_src_zp1_0)}; + + vfloat32 vec_substract_src_zp0_1 = vec_sub(vecf0_1, zero_point_vec0); + vfloat32 vec_substract_src_zp1_1 = vec_sub(vecf1_1, zero_point_vec1); + Vectorized vf1_zp = { + vec_mul(scale_vec0, vec_substract_src_zp0_1), + vec_mul(scale_vec1, vec_substract_src_zp1_1)}; + + vfloat32 vec_substract_src_zp0_2 = vec_sub(vecf0_2, zero_point_vec0); + vfloat32 vec_substract_src_zp1_2 = vec_sub(vecf1_2, zero_point_vec1); + Vectorized vf2_zp = { + vec_mul(scale_vec0, vec_substract_src_zp0_2), + vec_mul(scale_vec1, vec_substract_src_zp1_2)}; + + vfloat32 vec_substract_src_zp0_3 = vec_sub(vecf0_3, zero_point_vec0); + vfloat32 vec_substract_src_zp1_3 = vec_sub(vecf1_3, zero_point_vec1); + Vectorized vf3_zp = { + vec_mul(scale_vec0, vec_substract_src_zp0_3), + vec_mul(scale_vec1, vec_substract_src_zp1_3)}; + + return {vf0_zp, vf1_zp, vf2_zp, vf3_zp}; + } + + float_vec_return_type C10_ALWAYS_INLINE + dequantize(Vectorized scale, Vectorized zero_point) const { + vint16 vecshi0 = vec_unpackh(_vec0); + vint16 vecshi1 = vec_unpackl(_vec0); + + vint16 vecshi2 = vec_unpackh(_vec1); + vint16 vecshi3 = vec_unpackl(_vec1); + + vint32 veci0 = vec_unpackh(vecshi0); + vint32 veci1 = vec_unpackl(vecshi0); + + vint32 veci2 = vec_unpackh(vecshi1); + vint32 veci3 = vec_unpackl(vecshi1); + + vint32 veci4 = vec_unpackh(vecshi2); + vint32 veci5 = vec_unpackl(vecshi2); + + vint32 veci6 = vec_unpackh(vecshi3); + vint32 veci7 = vec_unpackl(vecshi3); + + vfloat32 vecf0_0 = vec_float(veci0); + vfloat32 vecf1_0 = vec_float(veci1); + + vfloat32 vecf0_1 = vec_float(veci2); + vfloat32 vecf1_1 = vec_float(veci3); + + vfloat32 vecf0_2 = vec_float(veci4); + vfloat32 vecf1_2 = vec_float(veci5); + + vfloat32 vecf0_3 = vec_float(veci6); + vfloat32 vecf1_3 = vec_float(veci7); + vfloat32 scale_vec0 = scale.vec0(); + vfloat32 scale_vec1 = scale.vec1(); + vfloat32 zero_point0 = zero_point.vec0(); + vfloat32 zero_point1 = zero_point.vec1(); + return { + Vectorized{ + (vecf0_0 - zero_point0) * scale_vec0, + (vecf1_0 - zero_point1) * scale_vec1}, + Vectorized{ + (vecf0_1 - zero_point0) * scale_vec0, + (vecf1_1 - zero_point1) * scale_vec1}, + Vectorized{ + (vecf0_2 - zero_point0) * scale_vec0, + (vecf1_2 - zero_point1) * scale_vec1}, + Vectorized{ + (vecf0_3 - zero_point0) * scale_vec0, + (vecf1_3 - zero_point1) * scale_vec1}}; + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + // constexpr int32_t min_val = std::numeric_limits::min(); + // constexpr int32_t max_val = std::numeric_limits::max(); + + vfloat32 inverse_scale_v = vec_splats(inverse_scale); + vfloat32 vec_zero_point = vec_splats((float)zero_point); + // vint32 vmin = vec_splats(min_val); + // vint32 vmax = vec_splats(max_val); + + Vectorized vf0 = rhs[0]; + Vectorized vf1 = rhs[1]; + Vectorized vf2 = rhs[2]; + Vectorized vf3 = rhs[3]; + vfloat32 vecf0 = vf0.vec0(); + vfloat32 vecf1 = vf0.vec1(); + vfloat32 vecf2 = vf1.vec0(); + vfloat32 vecf3 = vf1.vec1(); + + vfloat32 vecf4 = vf2.vec0(); + vfloat32 vecf5 = vf2.vec1(); + vfloat32 vecf6 = vf3.vec0(); + vfloat32 vecf7 = vf3.vec1(); + + vecf0 = vec_mul(vecf0, inverse_scale_v); + vecf1 = vec_mul(vecf1, inverse_scale_v); + vecf2 = vec_mul(vecf2, inverse_scale_v); + vecf3 = vec_mul(vecf3, inverse_scale_v); + + vecf4 = vec_mul(vecf4, inverse_scale_v); + vecf5 = vec_mul(vecf5, inverse_scale_v); + vecf6 = vec_mul(vecf6, inverse_scale_v); + vecf7 = vec_mul(vecf7, inverse_scale_v); + + vecf0 = vec_add(vec_rint(vecf0), vec_zero_point); + vecf1 = vec_add(vec_rint(vecf1), vec_zero_point); + vecf2 = vec_add(vec_rint(vecf2), vec_zero_point); + vecf3 = vec_add(vec_rint(vecf3), vec_zero_point); + + vecf4 = vec_add(vec_rint(vecf4), vec_zero_point); + vecf5 = vec_add(vec_rint(vecf5), vec_zero_point); + vecf6 = vec_add(vec_rint(vecf6), vec_zero_point); + vecf7 = vec_add(vec_rint(vecf7), vec_zero_point); + + vint32 veci0 = vec_signed(vecf0); + vint32 veci1 = vec_signed(vecf1); + vint32 veci2 = vec_signed(vecf2); + vint32 veci3 = vec_signed(vecf3); + + vint32 veci4 = vec_signed(vecf4); + vint32 veci5 = vec_signed(vecf5); + vint32 veci6 = vec_signed(vecf6); + vint32 veci7 = vec_signed(vecf7); + + // veci0 = vec_min(vmax, vec_max( vmin, vecf0)) ; + // veci1 = vec_min(vmax, vec_max( vmin, vecf1)) ; + // veci2 = vec_min(vmax, vec_max( vmin, vecf2)) ; + // veci3 = vec_min(vmax, vec_max( vmin, vecf3)) ; + + // veci4 = vec_min(vmax, vec_max( vmin, vecf4)) ; + // veci5 = vec_min(vmax, vec_max( vmin, vecf5)) ; + // veci6 = vec_min(vmax, vec_max( vmin, vecf6)) ; + // veci7 = vec_min(vmax, vec_max( vmin, vecf7)) ; + // vec_packs CLAMP already + vint16 vecshi0 = vec_packs(veci0, veci1); + vint16 vecshi1 = vec_packs(veci2, veci3); + vint16 vecshi2 = vec_packs(veci4, veci5); + vint16 vecshi3 = vec_packs(veci6, veci7); + + vint8 vec0 = vec_packs(vecshi0, vecshi1); + vint8 vec1 = vec_packs(vecshi2, vecshi3); + + return {vec0, vec1}; + } + + Vectorized C10_ALWAYS_INLINE + relu(Vectorized zero_point) const { + return {vec_max(_vec0, zero_point._vec0), vec_max(_vec1, zero_point._vec1)}; + } + + Vectorized C10_ALWAYS_INLINE + relu6(Vectorized zero_point, Vectorized q_six) const { + vint8 max0 = vec_max(_vec0, zero_point._vec0); + vint8 max1 = vec_max(_vec1, zero_point._vec1); + return {vec_min(max0, q_six._vec0), vec_min(max1, q_six._vec1)}; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + vint16 vecshi0 = vec_unpackh(_vec0); + vint16 vecBshi0 = vec_unpackh(b._vec0); + vint16 vecshi1 = vec_unpackl(_vec0); + vint16 vecBshi1 = vec_unpackl(b._vec0); + + vint16 vecshi2 = vec_unpackh(_vec1); + vint16 vecBshi2 = vec_unpackh(b._vec1); + vint16 vecshi3 = vec_unpackl(_vec1); + vint16 vecBshi3 = vec_unpackl(b._vec1); + + vint32 veci0 = vec_unpackh(vecshi0); + vint32 vecBi0 = vec_unpackh(vecBshi0); + vint32 veci1 = vec_unpackl(vecshi0); + vint32 vecBi1 = vec_unpackl(vecBshi0); + + vint32 veci2 = vec_unpackh(vecshi1); + vint32 vecBi2 = vec_unpackh(vecBshi1); + vint32 veci3 = vec_unpackl(vecshi1); + vint32 vecBi3 = vec_unpackl(vecBshi1); + + vint32 veci4 = vec_unpackh(vecshi2); + vint32 vecBi4 = vec_unpackh(vecBshi2); + vint32 veci5 = vec_unpackl(vecshi2); + vint32 vecBi5 = vec_unpackl(vecBshi2); + + vint32 veci6 = vec_unpackh(vecshi3); + vint32 vecBi6 = vec_unpackh(vecBshi3); + vint32 veci7 = vec_unpackl(vecshi3); + vint32 vecBi7 = vec_unpackl(vecBshi3); + + return { + Vectorized(veci0 - vecBi0, veci1 - vecBi1), + Vectorized(veci2 - vecBi2, veci3 - vecBi3), + Vectorized(veci4 - vecBi4, veci5 - vecBi5), + Vectorized(veci6 - vecBi6, veci7 - vecBi7)}; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + vfloat32 vec_multiplier = vec_splats(multiplier); + vint32 vec_zero_point = vec_splats(zero_point); + + Vectorized vi0 = inp[0]; + Vectorized vi1 = inp[1]; + Vectorized vi2 = inp[2]; + Vectorized vi3 = inp[3]; + + vfloat32 vecf0 = vec_float(vi0.vec0()); + vfloat32 vecf1 = vec_float(vi0.vec1()); + vfloat32 vecf2 = vec_float(vi1.vec0()); + vfloat32 vecf3 = vec_float(vi1.vec1()); + + vfloat32 vecf4 = vec_float(vi2.vec0()); + vfloat32 vecf5 = vec_float(vi2.vec1()); + vfloat32 vecf6 = vec_float(vi3.vec0()); + vfloat32 vecf7 = vec_float(vi3.vec1()); + + vecf0 = vec_mul(vecf0, vec_multiplier); + vecf1 = vec_mul(vecf1, vec_multiplier); + vecf2 = vec_mul(vecf2, vec_multiplier); + vecf3 = vec_mul(vecf3, vec_multiplier); + + vecf4 = vec_mul(vecf4, vec_multiplier); + vecf5 = vec_mul(vecf5, vec_multiplier); + vecf6 = vec_mul(vecf6, vec_multiplier); + vecf7 = vec_mul(vecf7, vec_multiplier); + + vecf0 = vec_rint(vecf0); + vecf1 = vec_rint(vecf1); + vecf2 = vec_rint(vecf2); + vecf3 = vec_rint(vecf3); + + vecf4 = vec_rint(vecf4); + vecf5 = vec_rint(vecf5); + vecf6 = vec_rint(vecf6); + vecf7 = vec_rint(vecf7); + + vint32 veci0 = vec_signed(vecf0); + vint32 veci1 = vec_signed(vecf1); + vint32 veci2 = vec_signed(vecf2); + vint32 veci3 = vec_signed(vecf3); + + vint32 veci4 = vec_signed(vecf4); + vint32 veci5 = vec_signed(vecf5); + vint32 veci6 = vec_signed(vecf6); + vint32 veci7 = vec_signed(vecf7); + + veci0 = vec_add(veci0, vec_zero_point); + veci1 = vec_add(veci1, vec_zero_point); + veci2 = vec_add(veci2, vec_zero_point); + veci3 = vec_add(veci3, vec_zero_point); + + veci4 = vec_add(veci4, vec_zero_point); + veci5 = vec_add(veci5, vec_zero_point); + veci6 = vec_add(veci6, vec_zero_point); + veci7 = vec_add(veci7, vec_zero_point); + + vint16 vecshi0 = vec_packs(veci0, veci1); + vint16 vecshi1 = vec_packs(veci2, veci3); + vint16 vecshi2 = vec_packs(veci4, veci5); + vint16 vecshi3 = vec_packs(veci6, veci7); + + vint8 vec0 = vec_packs(vecshi0, vecshi1); + vint8 vec1 = vec_packs(vecshi2, vecshi3); + + return {vec0, vec1}; + } + + DEFINE_MEMBER_OP(operator==, c10::qint8, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, c10::qint8, vec_cmpne) + DEFINE_MEMBER_OP(operator<, c10::qint8, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, c10::qint8, vec_cmple) + DEFINE_MEMBER_OP(operator>, c10::qint8, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, c10::qint8, vec_cmpge) + DEFINE_MEMBER_OP(operator+, c10::qint8, vec_add) + DEFINE_MEMBER_OP(operator-, c10::qint8, vec_sub) + DEFINE_MEMBER_OP(operator*, c10::qint8, vec_mul) + DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, c10::qint8, /) + DEFINE_MEMBER_OP(maximum, c10::qint8, vec_max) + DEFINE_MEMBER_OP(minimum, c10::qint8, vec_min) + DEFINE_MEMBER_OP(operator&, c10::qint8, vec_and) + DEFINE_MEMBER_OP(operator|, c10::qint8, vec_or) + DEFINE_MEMBER_OP(operator^, c10::qint8, vec_xor) +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return a.minimum(b); +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator+(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator-(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator*(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator/(const Vectorized& a, const Vectorized& b) { + return Vectorized{a.vec0() / b.vec0(), a.vec1() / b.vec1()}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator&(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator|(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator^(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())}; +} + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_quint8_vsx.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_quint8_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..05b731ebe705e7c7ca453fe169d608afee796d69 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_quint8_vsx.h @@ -0,0 +1,533 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include + +// This file defines Vectorized<> for the quantized types. +// +// +// Currently, we simply use these classes as efficient converters between +// the quantized types and Vectorized, usually in bandwidth-bound cases +// where doing the arithmetic in full-precision is acceptable (e.g. +// elementwise operators). +// +// +// Conversions are as follows: +// Vectorized -> 4x Vectorized +// +// The size of the returned float vector is specified by the special +// constexpr function float_num_vecs. The type of the value returned +// from dequantize (and expected as an argument to quantize) is +// specified by float_vec_return_type. +// +// When writing kernels with these vectors, it is expected that floating- +// point operations will be carried out in a loop over +// Vectorized::float_num_vecs iterations. + +namespace at { +namespace vec { +inline namespace CPU_CAPABILITY { + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +const vint16 mask_unsigned = vec_splats((short int)0xFF); +template <> +struct Vectorized { + private: + union { + struct { + vuint8 _vec0; + vuint8 _vec1; + }; + struct { + vbool8 _vecb0; + vbool8 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + Vectorized() {} + using size_type = int; + static constexpr size_type size() { + return 32; + } + + static constexpr size_t float_num_vecs() { + return 4; + } + static constexpr int int_num_vecs() { + return 4; + } + using float_vec_return_type = std::array, 4>; + using int_vec_return_type = std::array, 4>; + using value_type = typename c10::quint8::underlying; + using vec_internal_type = vuint8; + using vec_internal_mask_type = vbool8; + // Broadcast constructor + C10_ALWAYS_INLINE Vectorized(const c10::quint8& val) + : _vec0(vec_splats(val.val_)), _vec1(vec_splats(val.val_)) {} + + C10_ALWAYS_INLINE Vectorized(const Vectorized& other) + : _vec0{other._vec0}, _vec1(other._vec1) {} + + C10_ALWAYS_INLINE Vectorized(vuint8 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vectorized(vbool8 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vectorized(vuint8 v1, vuint8 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vectorized(vbool8 v1, vbool8 v2) : _vecb0{v1}, _vecb1{v2} {} + + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + static C10_ALWAYS_INLINE Vectorized loadu( + const void* ptr, + int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + __at_align__ value_type tmp_values[size()] = {}; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, tmp_values); + vec_vsx_st(_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + + public: + float_vec_return_type C10_ALWAYS_INLINE dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_zp_premul) const { + // unpacking unsigned as signed + vint16 vecshi0 = vec_unpackh((vint8)_vec0); + vint16 vecshi1 = vec_unpackl((vint8)_vec0); + + vint16 vecshi2 = vec_unpackh((vint8)_vec1); + vint16 vecshi3 = vec_unpackl((vint8)_vec1); + + // signed -> unsigned + vecshi0 = vec_and(vecshi0, mask_unsigned); + vecshi1 = vec_and(vecshi1, mask_unsigned); + + vecshi2 = vec_and(vecshi2, mask_unsigned); + vecshi3 = vec_and(vecshi3, mask_unsigned); + + vint32 veci0 = vec_unpackh(vecshi0); + vint32 veci1 = vec_unpackl(vecshi0); + + vint32 veci2 = vec_unpackh(vecshi1); + vint32 veci3 = vec_unpackl(vecshi1); + + vint32 veci4 = vec_unpackh(vecshi2); + vint32 veci5 = vec_unpackl(vecshi2); + + vint32 veci6 = vec_unpackh(vecshi3); + vint32 veci7 = vec_unpackl(vecshi3); + + vfloat32 vecf0_0 = vec_float(veci0); + vfloat32 vecf1_0 = vec_float(veci1); + + vfloat32 vecf0_1 = vec_float(veci2); + vfloat32 vecf1_1 = vec_float(veci3); + + vfloat32 vecf0_2 = vec_float(veci4); + vfloat32 vecf1_2 = vec_float(veci5); + + vfloat32 vecf0_3 = vec_float(veci6); + vfloat32 vecf1_3 = vec_float(veci7); + vfloat32 scale_vec0 = scale.vec0(); + vfloat32 scale_vec1 = scale.vec1(); + + vfloat32 zero_point_vec0 = zero_point.vec0(); + vfloat32 zero_point_vec1 = zero_point.vec1(); + + vfloat32 vec_substract_src_zp0_0 = vec_sub(vecf0_0, zero_point_vec0); + vfloat32 vec_substract_src_zp1_0 = vec_sub(vecf1_0, zero_point_vec1); + Vectorized vf0_zp = { + vec_mul(scale_vec0, vec_substract_src_zp0_0), + vec_mul(scale_vec1, vec_substract_src_zp1_0)}; + + vfloat32 vec_substract_src_zp0_1 = vec_sub(vecf0_1, zero_point_vec0); + vfloat32 vec_substract_src_zp1_1 = vec_sub(vecf1_1, zero_point_vec1); + Vectorized vf1_zp = { + vec_mul(scale_vec0, vec_substract_src_zp0_1), + vec_mul(scale_vec1, vec_substract_src_zp1_1)}; + + vfloat32 vec_substract_src_zp0_2 = vec_sub(vecf0_2, zero_point_vec0); + vfloat32 vec_substract_src_zp1_2 = vec_sub(vecf1_2, zero_point_vec1); + Vectorized vf2_zp = { + vec_mul(scale_vec0, vec_substract_src_zp0_2), + vec_mul(scale_vec1, vec_substract_src_zp1_2)}; + + vfloat32 vec_substract_src_zp0_3 = vec_sub(vecf0_3, zero_point_vec0); + vfloat32 vec_substract_src_zp1_3 = vec_sub(vecf1_3, zero_point_vec1); + Vectorized vf3_zp = { + vec_mul(scale_vec0, vec_substract_src_zp0_3), + vec_mul(scale_vec1, vec_substract_src_zp1_3)}; + + return {vf0_zp, vf1_zp, vf2_zp, vf3_zp}; + } + + float_vec_return_type C10_ALWAYS_INLINE + dequantize(Vectorized scale, Vectorized zero_point) const { + // unpacking unsigned as signed + vint16 vecshi0 = vec_unpackh((vint8)_vec0); + vint16 vecshi1 = vec_unpackl((vint8)_vec0); + + vint16 vecshi2 = vec_unpackh((vint8)_vec1); + vint16 vecshi3 = vec_unpackl((vint8)_vec1); + + // signed -> unsigned + vecshi0 = vec_and(vecshi0, mask_unsigned); + vecshi1 = vec_and(vecshi1, mask_unsigned); + + vecshi2 = vec_and(vecshi2, mask_unsigned); + vecshi3 = vec_and(vecshi3, mask_unsigned); + + vint32 veci0 = vec_unpackh(vecshi0); + vint32 veci1 = vec_unpackl(vecshi0); + + vint32 veci2 = vec_unpackh(vecshi1); + vint32 veci3 = vec_unpackl(vecshi1); + + vint32 veci4 = vec_unpackh(vecshi2); + vint32 veci5 = vec_unpackl(vecshi2); + + vint32 veci6 = vec_unpackh(vecshi3); + vint32 veci7 = vec_unpackl(vecshi3); + + vfloat32 vecf0_0 = vec_float(veci0); + vfloat32 vecf1_0 = vec_float(veci1); + + vfloat32 vecf0_1 = vec_float(veci2); + vfloat32 vecf1_1 = vec_float(veci3); + + vfloat32 vecf0_2 = vec_float(veci4); + vfloat32 vecf1_2 = vec_float(veci5); + + vfloat32 vecf0_3 = vec_float(veci6); + vfloat32 vecf1_3 = vec_float(veci7); + vfloat32 scale_vec0 = scale.vec0(); + vfloat32 scale_vec1 = scale.vec1(); + + vfloat32 zero_point0 = zero_point.vec0(); + vfloat32 zero_point1 = zero_point.vec1(); + return { + Vectorized{ + (vecf0_0 - zero_point0) * scale_vec0, + (vecf1_0 - zero_point1) * scale_vec1}, + Vectorized{ + (vecf0_1 - zero_point0) * scale_vec0, + (vecf1_1 - zero_point1) * scale_vec1}, + Vectorized{ + (vecf0_2 - zero_point0) * scale_vec0, + (vecf1_2 - zero_point1) * scale_vec1}, + Vectorized{ + (vecf0_3 - zero_point0) * scale_vec0, + (vecf1_3 - zero_point1) * scale_vec1}}; + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + // constexpr int32_t min_val = std::numeric_limits::min(); + // constexpr int32_t max_val = std::numeric_limits::max(); + + vfloat32 vec_inverse = vec_splats(inverse_scale); + vfloat32 vec_zero_point = vec_splats((float)zero_point); + // vuint32 vmin = vec_splats(min_val); + // vuint32 vmax = vec_splats(max_val); + Vectorized vf0 = rhs[0]; + Vectorized vf1 = rhs[1]; + Vectorized vf2 = rhs[2]; + Vectorized vf3 = rhs[3]; + vfloat32 vecf0 = vf0.vec0(); + vfloat32 vecf1 = vf0.vec1(); + vfloat32 vecf2 = vf1.vec0(); + vfloat32 vecf3 = vf1.vec1(); + + vfloat32 vecf4 = vf2.vec0(); + vfloat32 vecf5 = vf2.vec1(); + vfloat32 vecf6 = vf3.vec0(); + vfloat32 vecf7 = vf3.vec1(); + + vecf0 = vec_mul(vecf0, vec_inverse); + vecf1 = vec_mul(vecf1, vec_inverse); + vecf2 = vec_mul(vecf2, vec_inverse); + vecf3 = vec_mul(vecf3, vec_inverse); + + vecf4 = vec_mul(vecf4, vec_inverse); + vecf5 = vec_mul(vecf5, vec_inverse); + vecf6 = vec_mul(vecf6, vec_inverse); + vecf7 = vec_mul(vecf7, vec_inverse); + + vecf0 = vec_add(vec_rint(vecf0), vec_zero_point); + vecf1 = vec_add(vec_rint(vecf1), vec_zero_point); + vecf2 = vec_add(vec_rint(vecf2), vec_zero_point); + vecf3 = vec_add(vec_rint(vecf3), vec_zero_point); + + vecf4 = vec_add(vec_rint(vecf4), vec_zero_point); + vecf5 = vec_add(vec_rint(vecf5), vec_zero_point); + vecf6 = vec_add(vec_rint(vecf6), vec_zero_point); + vecf7 = vec_add(vec_rint(vecf7), vec_zero_point); + + vint32 veci0 = vec_signed(vecf0); + vint32 veci1 = vec_signed(vecf1); + vint32 veci2 = vec_signed(vecf2); + vint32 veci3 = vec_signed(vecf3); + + vint32 veci4 = vec_signed(vecf4); + vint32 veci5 = vec_signed(vecf5); + vint32 veci6 = vec_signed(vecf6); + vint32 veci7 = vec_signed(vecf7); + + vint16 vecshi0 = vec_packs(veci0, veci1); + vint16 vecshi1 = vec_packs(veci2, veci3); + vint16 vecshi2 = vec_packs(veci4, veci5); + vint16 vecshi3 = vec_packs(veci6, veci7); + + vuint8 vec0 = vec_packsu(vecshi0, vecshi1); + vuint8 vec1 = vec_packsu(vecshi2, vecshi3); + + return {vec0, vec1}; + } + + Vectorized C10_ALWAYS_INLINE + relu(Vectorized zero_point) const { + return {vec_max(_vec0, zero_point._vec0), vec_max(_vec1, zero_point._vec1)}; + } + + Vectorized C10_ALWAYS_INLINE relu6( + Vectorized zero_point, + Vectorized q_six) const { + vuint8 max0 = vec_max(_vec0, zero_point._vec0); + vuint8 max1 = vec_max(_vec1, zero_point._vec1); + return {vec_min(max0, q_six._vec0), vec_min(max1, q_six._vec1)}; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + vint16 vecshi0 = vec_unpackh((vint8)_vec0); + vint16 vecBshi0 = vec_unpackh((vint8)b._vec0); + vint16 vecshi1 = vec_unpackl((vint8)_vec0); + vint16 vecBshi1 = vec_unpackl((vint8)b._vec0); + + vint16 vecshi2 = vec_unpackh((vint8)_vec1); + vint16 vecBshi2 = vec_unpackh((vint8)b._vec1); + vint16 vecshi3 = vec_unpackl((vint8)_vec1); + vint16 vecBshi3 = vec_unpackl((vint8)b._vec1); + + vecshi0 = vec_and(vecshi0, mask_unsigned); + vecBshi0 = vec_and(vecBshi0, mask_unsigned); + vecshi1 = vec_and(vecshi1, mask_unsigned); + vecBshi1 = vec_and(vecBshi1, mask_unsigned); + + vecshi2 = vec_and(vecshi2, mask_unsigned); + vecBshi2 = vec_and(vecBshi2, mask_unsigned); + vecshi3 = vec_and(vecshi3, mask_unsigned); + vecBshi3 = vec_and(vecBshi3, mask_unsigned); + + vint32 veci0 = vec_unpackh(vecshi0); + vint32 vecBi0 = vec_unpackh(vecBshi0); + vint32 veci1 = vec_unpackl(vecshi0); + vint32 vecBi1 = vec_unpackl(vecBshi0); + + vint32 veci2 = vec_unpackh(vecshi1); + vint32 vecBi2 = vec_unpackh(vecBshi1); + vint32 veci3 = vec_unpackl(vecshi1); + vint32 vecBi3 = vec_unpackl(vecBshi1); + + vint32 veci4 = vec_unpackh(vecshi2); + vint32 vecBi4 = vec_unpackh(vecBshi2); + vint32 veci5 = vec_unpackl(vecshi2); + vint32 vecBi5 = vec_unpackl(vecBshi2); + + vint32 veci6 = vec_unpackh(vecshi3); + vint32 vecBi6 = vec_unpackh(vecBshi3); + vint32 veci7 = vec_unpackl(vecshi3); + vint32 vecBi7 = vec_unpackl(vecBshi3); + + return { + Vectorized(veci0 - vecBi0, veci1 - vecBi1), + Vectorized(veci2 - vecBi2, veci3 - vecBi3), + Vectorized(veci4 - vecBi4, veci5 - vecBi5), + Vectorized(veci6 - vecBi6, veci7 - vecBi7)}; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + vfloat32 vec_multiplier = vec_splats(multiplier); + vint32 vec_zero_point = vec_splats(zero_point); + + Vectorized vi0 = inp[0]; + Vectorized vi1 = inp[1]; + Vectorized vi2 = inp[2]; + Vectorized vi3 = inp[3]; + + vfloat32 vecf0 = vec_float(vi0.vec0()); + vfloat32 vecf1 = vec_float(vi0.vec1()); + vfloat32 vecf2 = vec_float(vi1.vec0()); + vfloat32 vecf3 = vec_float(vi1.vec1()); + + vfloat32 vecf4 = vec_float(vi2.vec0()); + vfloat32 vecf5 = vec_float(vi2.vec1()); + vfloat32 vecf6 = vec_float(vi3.vec0()); + vfloat32 vecf7 = vec_float(vi3.vec1()); + + vecf0 = vec_mul(vecf0, vec_multiplier); + vecf1 = vec_mul(vecf1, vec_multiplier); + vecf2 = vec_mul(vecf2, vec_multiplier); + vecf3 = vec_mul(vecf3, vec_multiplier); + + vecf4 = vec_mul(vecf4, vec_multiplier); + vecf5 = vec_mul(vecf5, vec_multiplier); + vecf6 = vec_mul(vecf6, vec_multiplier); + vecf7 = vec_mul(vecf7, vec_multiplier); + + vecf0 = vec_rint(vecf0); + vecf1 = vec_rint(vecf1); + vecf2 = vec_rint(vecf2); + vecf3 = vec_rint(vecf3); + + vecf4 = vec_rint(vecf4); + vecf5 = vec_rint(vecf5); + vecf6 = vec_rint(vecf6); + vecf7 = vec_rint(vecf7); + + vint32 veci0 = vec_signed(vecf0); + vint32 veci1 = vec_signed(vecf1); + vint32 veci2 = vec_signed(vecf2); + vint32 veci3 = vec_signed(vecf3); + + vint32 veci4 = vec_signed(vecf4); + vint32 veci5 = vec_signed(vecf5); + vint32 veci6 = vec_signed(vecf6); + vint32 veci7 = vec_signed(vecf7); + + veci0 = vec_add(veci0, vec_zero_point); + veci1 = vec_add(veci1, vec_zero_point); + veci2 = vec_add(veci2, vec_zero_point); + veci3 = vec_add(veci3, vec_zero_point); + + veci4 = vec_add(veci4, vec_zero_point); + veci5 = vec_add(veci5, vec_zero_point); + veci6 = vec_add(veci6, vec_zero_point); + veci7 = vec_add(veci7, vec_zero_point); + + vint16 vecshi0 = vec_packs(veci0, veci1); + vint16 vecshi1 = vec_packs(veci2, veci3); + vint16 vecshi2 = vec_packs(veci4, veci5); + vint16 vecshi3 = vec_packs(veci6, veci7); + + vuint8 vec0 = vec_packsu(vecshi0, vecshi1); + vuint8 vec1 = vec_packsu(vecshi2, vecshi3); + + return {vec0, vec1}; + } + + DEFINE_MEMBER_OP(operator==, c10::quint8, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, c10::quint8, vec_cmpne) + DEFINE_MEMBER_OP(operator<, c10::quint8, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, c10::quint8, vec_cmple) + DEFINE_MEMBER_OP(operator>, c10::quint8, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, c10::quint8, vec_cmpge) + DEFINE_MEMBER_OP(operator+, c10::quint8, vec_add) + DEFINE_MEMBER_OP(operator-, c10::quint8, vec_sub) + DEFINE_MEMBER_OP(operator*, c10::quint8, vec_mul) + DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, c10::quint8, /) + DEFINE_MEMBER_OP(maximum, c10::quint8, vec_max) + DEFINE_MEMBER_OP(minimum, c10::quint8, vec_min) + DEFINE_MEMBER_OP(operator&, c10::quint8, vec_and) + DEFINE_MEMBER_OP(operator|, c10::quint8, vec_or) + DEFINE_MEMBER_OP(operator^, c10::quint8, vec_xor) +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return a.minimum(b); +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator+(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator-(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator*(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator/(const Vectorized& a, const Vectorized& b) { + return Vectorized{a.vec0() / b.vec0(), a.vec1() / b.vec1()}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator&(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator|(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator^(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())}; +} + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vsx_helpers.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vsx_helpers.h new file mode 100644 index 0000000000000000000000000000000000000000..0e164d21a38c6102143caff24192ceaa77d034f2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vsx_helpers.h @@ -0,0 +1,526 @@ +#pragma once +#include +#include +#include + +#if defined(__clang__) +typedef __vector __bool char vbool8; +typedef __vector __bool short vbool16; +typedef __vector __bool int vbool32; +typedef __vector __bool long long vbool64; +using vint8 = __attribute__((vector_size(16))) signed char; +using vint16 = __attribute__((vector_size(16))) signed short; +using vint32 = __attribute__((vector_size(16))) signed int; +using vint64 = __attribute__((vector_size(16))) signed long long; +using vuint8 = __attribute__((vector_size(16))) unsigned char; +using vuint16 = __attribute__((vector_size(16))) unsigned short; +using vuint32 = __attribute__((vector_size(16))) unsigned int; +using vuint64 = __attribute__((vector_size(16))) unsigned long long; +using vfloat32 = __attribute__((vector_size(16))) float; +using vfloat64 = __attribute__((vector_size(16))) double; +#else +using vbool8 = + __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) char; +using vbool16 = + __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) short; +using vbool32 = + __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) int; +using vbool64 = __attribute__((altivec(vector__))) +__attribute__((altivec(bool__))) long long; +using vint8 = __attribute__((altivec(vector__))) signed char; +using vint16 = __attribute__((altivec(vector__))) signed short; +using vint32 = __attribute__((altivec(vector__))) signed int; +using vint64 = __attribute__((altivec(vector__))) signed long long; +using vuint8 = __attribute__((altivec(vector__))) unsigned char; +using vuint16 = __attribute__((altivec(vector__))) unsigned short; +using vuint32 = __attribute__((altivec(vector__))) unsigned int; +using vuint64 = __attribute__((altivec(vector__))) unsigned long long; +using vfloat32 = __attribute__((altivec(vector__))) float; +using vfloat64 = __attribute__((altivec(vector__))) double; +#endif + +#if !defined(vec_float) +C10_ALWAYS_INLINE vfloat32 vec_float(const vint32& vec_in) { + vfloat32 vec_out; + __asm__("xvcvsxwsp %x0,%x1" : "=wf"(vec_out) : "wa"(vec_in)); + return vec_out; +} +#endif + +#if !defined(vec_signed) +C10_ALWAYS_INLINE vint32 vec_signed(const vfloat32& vec_in) { + vint32 vec_out; + __asm__("xvcvspsxws %x0,%x1" : "=wa"(vec_out) : "wf"(vec_in)); + return vec_out; +} + +C10_ALWAYS_INLINE vint64 vec_signed(const vfloat64& vec_in) { + vint64 vec_out; + __asm__("xvcvdpsxds %x0,%x1" : "=wa"(vec_out) : "wd"(vec_in)); + return vec_out; +} +#endif + +#if !defined(vec_neg) +C10_ALWAYS_INLINE vfloat32 vec_neg(const vfloat32& vec_in) { + vfloat32 vec_out; + __asm__("xvnegsp %x0,%x1" : "=wf"(vec_out) : "wf"(vec_in)); + return vec_out; +} + +C10_ALWAYS_INLINE vfloat64 vec_neg(const vfloat64& vec_in) { + vfloat64 vec_out; + __asm__("xvnegdp %x0,%x1" : "=wd"(vec_out) : "wd"(vec_in)); + return vec_out; +} + +C10_ALWAYS_INLINE vint16 vec_neg(const vint16& vec_in) { + vint16 vint0 = {0, 0, 0, 0, 0, 0, 0, 0}; + return vec_vsubuhm(vint0, vec_in); +} + +C10_ALWAYS_INLINE vint32 vec_neg(const vint32& vec_in) { + vint32 vint0 = {0, 0, 0, 0}; + return vec_vsubuwm(vint0, vec_in); +} + +C10_ALWAYS_INLINE vint64 vec_neg(const vint64& vec_in) { + return -vec_in; +} +#endif + +#if !defined(vec_sldw) +template +C10_ALWAYS_INLINE vfloat32 +vec_sldw_aux(const vfloat32& vec_in0, const vfloat32& vec_in1) { + vfloat32 vec_out; + __asm("xxsldwi %x0, %x1, %x2, %3 " + : "=wa"(vec_out) + : "wa"(vec_in0), "wa"(vec_in1), "I"(C)); + return vec_out; +} + +#define vec_sldw(a, b, c) vec_sldw_aux(a, b) +#endif + +#define vec_not(a) vec_nor(a, a) +#if defined(__clang__) && !defined(vec_splats) +C10_ALWAYS_INLINE vint64 vec_splats(const int64_t& a) { + return vec_splats(a); +} +#endif +// Vectorized min/max which return a if any operand is nan +template +C10_ALWAYS_INLINE T vec_min_nan(const T& a, const T& b) { + return vec_min(a, b); +} +template +C10_ALWAYS_INLINE T vec_max_nan(const T& a, const T& b) { + return vec_max(a, b); +} + +// Specializations for float/double taken from Eigen +template <> +C10_ALWAYS_INLINE vfloat32 +vec_min_nan(const vfloat32& a, const vfloat32& b) { + // NOTE: about 10% slower than vec_min, but consistent with std::min and SSE + // regarding NaN + vfloat32 ret; + __asm__("xvcmpgesp %x0,%x1,%x2\n\txxsel %x0,%x1,%x2,%x0" + : "=&wa"(ret) + : "wa"(a), "wa"(b)); + return ret; +} +// Specializations for float/double taken from Eigen +template <> +C10_ALWAYS_INLINE vfloat32 +vec_max_nan(const vfloat32& a, const vfloat32& b) { + // NOTE: about 10% slower than vec_max, but consistent with std::min and SSE + // regarding NaN + vfloat32 ret; + __asm__("xvcmpgtsp %x0,%x2,%x1\n\txxsel %x0,%x1,%x2,%x0" + : "=&wa"(ret) + : "wa"(a), "wa"(b)); + return ret; +} + +template <> +C10_ALWAYS_INLINE vfloat64 +vec_min_nan(const vfloat64& a, const vfloat64& b) { + // NOTE: about 10% slower than vec_min, but consistent with std::min and SSE + // regarding NaN + vfloat64 ret; + __asm__("xvcmpgedp %x0,%x1,%x2\n\txxsel %x0,%x1,%x2,%x0" + : "=&wa"(ret) + : "wa"(a), "wa"(b)); + return ret; +} +template <> +C10_ALWAYS_INLINE vfloat64 +vec_max_nan(const vfloat64& a, const vfloat64& b) { + // NOTE: about 10% slower than vec_max, but consistent with std::max and SSE + // regarding NaN + vfloat64 ret; + __asm__("xvcmpgtdp %x0,%x2,%x1\n\txxsel %x0,%x1,%x2,%x0" + : "=&wa"(ret) + : "wa"(a), "wa"(b)); + return ret; +} + +// Vectorizes min/max function which returns nan if any side is nan +#define C10_VSX_VEC_NAN_PROPAG(name, type, btype, func) \ + C10_ALWAYS_INLINE type name(const type& a, const type& b) { \ + type tmp = func(a, b); \ + btype nan_a = vec_cmpne(a, a); \ + btype nan_b = vec_cmpne(b, b); \ + tmp = vec_sel(tmp, a, nan_a); \ + return vec_sel(tmp, b, nan_b); \ + } + +C10_VSX_VEC_NAN_PROPAG(vec_min_nan2, vfloat32, vbool32, vec_min) +C10_VSX_VEC_NAN_PROPAG(vec_max_nan2, vfloat32, vbool32, vec_max) +C10_VSX_VEC_NAN_PROPAG(vec_min_nan2, vfloat64, vbool64, vec_min) +C10_VSX_VEC_NAN_PROPAG(vec_max_nan2, vfloat64, vbool64, vec_max) + +#undef C10_VSX_VEC_NAN_PROPAG + +#define DEFINE_MEMBER_UNARY_OP(op, op_type, func) \ + Vectorized C10_ALWAYS_INLINE op() const { \ + return Vectorized{func(_vec0), func(_vec1)}; \ + } + +#define DEFINE_MEMBER_OP(op, op_type, func) \ + Vectorized C10_ALWAYS_INLINE op(const Vectorized& other) \ + const { \ + return Vectorized{ \ + func(_vec0, other._vec0), func(_vec1, other._vec1)}; \ + } + +#define DEFINE_MEMBER_BITWISE_OP(op, op_type, func) \ + Vectorized C10_ALWAYS_INLINE op(const Vectorized& other) \ + const { \ + return Vectorized{ \ + func(_vecb0, other._vecb0), func(_vecb1, other._vecb1)}; \ + } + +#define DEFINE_MEMBER_TERNARY_OP(op, op_type, func) \ + Vectorized C10_ALWAYS_INLINE op( \ + const Vectorized& b, const Vectorized& c) const { \ + return Vectorized{ \ + func(_vec0, b._vec0, c._vec0), func(_vec1, b._vec1, c._vec1)}; \ + } + +#define DEFINE_MEMBER_EMULATE_BINARY_OP(op, op_type, binary_op) \ + Vectorized C10_ALWAYS_INLINE op(const Vectorized& b) \ + const { \ + Vectorized::vec_internal_type ret_0; \ + Vectorized::vec_internal_type ret_1; \ + for (int i = 0; i < Vectorized::size() / 2; i++) { \ + ret_0[i] = _vec0[i] binary_op b._vec0[i]; \ + ret_1[i] = _vec1[i] binary_op b._vec1[i]; \ + } \ + return Vectorized{ret_0, ret_1}; \ + } + +#define DEFINE_MEMBER_OP_AND_ONE(op, op_type, func) \ + Vectorized C10_ALWAYS_INLINE op(const Vectorized& other) \ + const { \ + using vvtype = Vectorized::vec_internal_type; \ + const vvtype v_one = vec_splats(static_cast(1.0)); \ + vvtype ret0 = (vvtype)func(_vec0, other._vec0); \ + vvtype ret1 = (vvtype)func(_vec1, other._vec1); \ + return Vectorized{vec_and(ret0, v_one), vec_and(ret1, v_one)}; \ + } + +#define DEFINE_CLAMP_FUNCS(operand_type) \ + template <> \ + Vectorized C10_ALWAYS_INLINE clamp( \ + const Vectorized& a, \ + const Vectorized& min, \ + const Vectorized& max) { \ + return Vectorized{ \ + vec_min_nan(vec_max_nan(a.vec0(), min.vec0()), max.vec0()), \ + vec_min_nan(vec_max_nan(a.vec1(), min.vec1()), max.vec1())}; \ + } \ + template <> \ + Vectorized C10_ALWAYS_INLINE clamp_min( \ + const Vectorized& a, \ + const Vectorized& min) { \ + return Vectorized{ \ + vec_max_nan(a.vec0(), min.vec0()), vec_max_nan(a.vec1(), min.vec1())}; \ + } \ + template <> \ + Vectorized C10_ALWAYS_INLINE clamp_max( \ + const Vectorized& a, \ + const Vectorized& max) { \ + return Vectorized{ \ + vec_min_nan(a.vec0(), max.vec0()), vec_min_nan(a.vec1(), max.vec1())}; \ + } + +#define DEFINE_REINTERPRET_CAST_FUNCS( \ + first_type, cast_type, cast_inner_vector_type) \ + template <> \ + C10_ALWAYS_INLINE Vectorized cast( \ + const Vectorized& src) { \ + return Vectorized{ \ + (cast_inner_vector_type)src.vec0(), \ + (cast_inner_vector_type)src.vec1()}; \ + } + +#define DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(first_type) \ + DEFINE_REINTERPRET_CAST_FUNCS(first_type, double, vfloat64) \ + DEFINE_REINTERPRET_CAST_FUNCS(first_type, float, vfloat32) \ + DEFINE_REINTERPRET_CAST_FUNCS(first_type, int64_t, vint64) \ + DEFINE_REINTERPRET_CAST_FUNCS(first_type, int32_t, vint32) \ + DEFINE_REINTERPRET_CAST_FUNCS(first_type, int16_t, vint16) + +// it can be used to emulate blend faster +constexpr int blendChoice( + uint32_t mask, + uint32_t half1 = 0xF, + uint32_t half2 = 0xF0) { + uint32_t none = 0; + uint32_t both = half1 | half2; + // clamp it between 0 and both + mask = mask & both; + // return (a._vec0, a._vec1) + if (mask == none) + return 0; + // return (b._vec0,b._vec1) + else if (mask == both) + return 1; + // return (b._vec0,a._vec1) + else if (mask == half1) + return 2; + // return (a._vec0,b._vec1) + else if (mask == half2) + return 3; + // return (*_vec0,a._vec1) + else if (mask > 0 && mask < half1) + return 4; + // return (*_vec0,b._vec1) + else if ((mask & half2) == half2) + return 5; + // return (a._vec0,*_vec1) + else if ((mask & half1) == 0 && mask > half1) + return 6; + // return (b._vec0,*_vec1) + else if ((mask & half1) == half1 && mask > half1) + return 7; + // return (*_vec0,*_vec1) + return 8; +} + +// it can be used to emulate blend faster +constexpr int blendChoiceDbl(uint32_t mask) { + // clamp it 0 and 0xF + return blendChoice(mask, 0x3, 0xC); +} + +constexpr vbool32 VsxMask1(uint32_t mask) { + uint32_t g0 = (mask & 1) * 0xffffffff; + uint32_t g1 = ((mask & 2) >> 1) * 0xffffffff; + uint32_t g2 = ((mask & 4) >> 2) * 0xffffffff; + uint32_t g3 = ((mask & 8) >> 3) * 0xffffffff; + return (vbool32){g0, g1, g2, g3}; +} + +constexpr vbool32 VsxMask2(uint32_t mask) { + uint32_t mask2 = (mask & 0xFF) >> 4; + return VsxMask1(mask2); +} + +constexpr vbool64 VsxDblMask1(uint32_t mask) { + uint64_t g0 = (mask & 1) * 0xffffffffffffffff; + uint64_t g1 = ((mask & 2) >> 1) * 0xffffffffffffffff; + return (vbool64){g0, g1}; +} + +constexpr vbool64 VsxDblMask2(uint32_t mask) { + uint32_t mask2 = (mask & 0xF) >> 2; + return VsxDblMask1(mask2); +} + +constexpr int maskForComplex(uint32_t mask) { + mask = mask & 0xF; + int complex_mask = 0; + if (mask & 1) + complex_mask |= 3; + if (mask & 2) + complex_mask |= (3 << 2); + if (mask & 4) + complex_mask |= (3 << 4); + if (mask & 8) + complex_mask |= (3 << 6); + return complex_mask; +} + +constexpr int maskForComplexDbl(uint32_t mask) { + mask = mask & 0x3; + int complex_mask = 0; + if (mask & 1) + complex_mask |= 3; + if (mask & 2) + complex_mask |= (3 << 2); + return complex_mask; +} + +constexpr int blendChoiceComplex(uint32_t mask) { + return blendChoice(maskForComplex(mask)); +} + +constexpr int blendChoiceComplexDbl(uint32_t mask) { + return blendChoiceDbl(maskForComplexDbl(mask)); +} + +constexpr vbool32 VsxComplexMask1(uint32_t mask) { + return VsxMask1(maskForComplex(mask)); +} + +constexpr vbool32 VsxComplexMask2(uint32_t mask) { + uint32_t mask2 = (mask & 0xF) >> 2; + return VsxMask1(maskForComplex(mask2)); +} + +constexpr vbool64 VsxComplexDblMask1(uint32_t mask) { + return VsxDblMask1(mask); +} + +constexpr vbool64 VsxComplexDblMask2(uint32_t mask) { + uint32_t mask2 = (mask & 0xF) >> 2; + return VsxDblMask1(mask2); +} + +// constants +namespace at { +namespace vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { +// +constexpr int offset0 = 0; +constexpr int offset16 = 16; + +// #Constants +const vuint8 mask_zero_bits = vuint8{ + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 96, + 64, + 32, + 0}; + +const vuint8 swap_mask = + vuint8{4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11}; + +const vint32 v0x7f = vec_splats(0x7f); +const vint32 vi_0 = vec_splats((int)(0)); +const vint32 vi_1 = vec_splats((int)1); +const vint32 vi_2 = vec_splats((int)2); +const vint32 vi_4 = vec_splats((int)4); +const vint32 vi_inv1 = vec_splats((int)~1); +const vuint32 vu_29 = vec_splats(29u); +const vuint32 vu_23 = vec_splats(23u); + +const vbool32 inv_mant_mask = (vbool32)vec_splats((unsigned int)~0xff800000); +const vbool32 sign_mask = (vbool32)vec_splats((int)0x80000000); +const vbool32 real_mask = vbool32{0xFFFFFFFF, 0x0, 0xFFFFFFFF, 0x0}; +const vbool32 imag_mask = vbool32{0x0, 0xFFFFFFFF, 0x0, 0xFFFFFFFF}; +const vbool32 isign_mask = vbool32{0x0, 0x80000000, 0x0, 0x80000000}; +const vbool32 rsign_mask = vbool32{0x80000000, 0x0, 0x80000000, 0x0}; + +const vbool64 vd_sign_mask = vbool64{0x8000000000000000, 0x8000000000000000}; +const vbool64 vd_imag_mask = vbool64{0x0, 0xFFFFFFFFFFFFFFFF}; +const vbool64 vd_real_mask = vbool64{0xFFFFFFFFFFFFFFFF, 0x0}; +const vbool64 vd_isign_mask = vbool64{0x0, 0x8000000000000000}; +const vbool64 vd_rsign_mask = vbool64{0x8000000000000000, 0x0}; + +const vfloat32 zero = vec_splats(0.f); +const vfloat32 half = vec_splats(0.5f); +const vfloat32 one = vec_splats(1.f); +const vfloat32 two = vec_splats(2.0f); +const vfloat32 _4div_pi = vec_splats(1.27323954473516f); +const vfloat32 v_inf = (vfloat32)vec_splats(0x7f800000u); +const vfloat32 v_minus_inf = + vfloat32{0xff800000u, 0xff800000u, 0xff800000u, 0xff800000u}; +const vfloat32 v_nan = (vfloat32)vec_splats(0x7fffffff); +const vfloat32 log10e_inv = vec_splats(0.43429448190325176f); +const vfloat32 log2e_inv = vec_splats(1.4426950408889634f); +const vfloat32 log2eB_inv = vec_splats(1.442695036924675f); +const vfloat32 cephes_SQRTHF = vec_splats(0.707106781186547524f); +const vfloat32 coscof_p0 = vec_splats(2.443315711809948E-005f); +const vfloat32 coscof_p1 = vec_splats(-1.388731625493765E-003f); +const vfloat32 coscof_p2 = vec_splats(4.166664568298827E-002f); +const vfloat32 exp_hi = vec_splats(104.f); +const vfloat32 exp_lo = vec_splats(-104.f); +const vfloat32 exp_p0 = vec_splats(0.000198527617612853646278381f); +const vfloat32 exp_p1 = vec_splats((0.00139304355252534151077271f)); +const vfloat32 exp_p2 = vec_splats(0.00833336077630519866943359f); +const vfloat32 exp_p3 = vec_splats(0.0416664853692054748535156f); +const vfloat32 exp_p4 = vec_splats(0.166666671633720397949219f); +const vfloat32 exp_p5 = vec_splats(0.5f); +const vfloat32 log_p0 = vec_splats(7.0376836292E-2f); +const vfloat32 log_p1 = vec_splats(-1.1514610310E-1f); +const vfloat32 log_p2 = vec_splats(1.1676998740E-1f); +const vfloat32 log_p3 = vec_splats(-1.2420140846E-1f); +const vfloat32 log_p4 = vec_splats(+1.4249322787E-1f); +const vfloat32 log_p5 = vec_splats(-1.6668057665E-1f); +const vfloat32 log_p6 = vec_splats(+2.0000714765E-1f); +const vfloat32 log_p7 = vec_splats(-2.4999993993E-1f); +const vfloat32 log_p8 = vec_splats(+3.3333331174E-1f); +const vfloat32 log_q1 = vec_splats(-2.12194440e-4f); +const vfloat32 log_q2 = vec_splats(0.693359375f); +const vfloat32 max_logf = vec_splats(88.02969187150841f); +const vfloat32 max_numf = + vec_splats(1.7014117331926442990585209174225846272e38f); +const vfloat32 min_inf = (vfloat32)vec_splats(0xff800000u); +const vfloat32 min_norm_pos = (vfloat32)vec_splats(0x0800000u); +const vfloat32 minus_cephes_dp1 = vec_splats(-0.78515625f); +const vfloat32 minus_cephes_dp2 = vec_splats(-2.4187564849853515625e-4f); +const vfloat32 minus_cephes_dp3 = vec_splats(-3.77489497744594108e-8f); +const vfloat32 negln2f_hi = vec_splats(-0.693145751953125f); +const vfloat32 negln2f_lo = vec_splats(-1.428606765330187045e-06f); +const vfloat32 p0 = vec_splats(2.03721912945E-4f); +const vfloat32 p1 = vec_splats(8.33028376239E-3f); +const vfloat32 p2 = vec_splats(1.66667160211E-1f); +const vfloat32 sincof_p0 = vec_splats(-1.9515295891E-4f); +const vfloat32 sincof_p1 = vec_splats(8.3321608736E-3f); +const vfloat32 sincof_p2 = vec_splats(-1.6666654611E-1f); +const vfloat32 tanh_0p625 = vec_splats(0.625f); +const vfloat32 tanh_half_max = vec_splats(44.014845935754205f); +const vfloat32 tanh_p0 = vec_splats(-5.70498872745E-3f); +const vfloat32 tanh_p1 = vec_splats(2.06390887954E-2f); +const vfloat32 tanh_p2 = vec_splats(-5.37397155531E-2f); +const vfloat32 tanh_p3 = vec_splats(1.33314422036E-1f); +const vfloat32 tanh_p4 = vec_splats(-3.33332819422E-1f); +const vfloat32 vcheck = vec_splats((float)(1LL << 24)); +const vfloat32 imag_one = vfloat32{0.f, 1.f, 0.f, 1.f}; +const vfloat32 imag_half = vfloat32{0.f, 0.5f, 0.f, 0.5f}; +const vfloat32 sqrt2_2 = vfloat32{ + 0.70710676908493042f, + 0.70710676908493042, + 0.70710676908493042, + 0.70710676908493042}; +const vfloat32 pi_2 = vfloat32{M_PI / 2, 0.0, M_PI / 2, 0.0}; +const vfloat32 vf_89 = vfloat32{89.f, 89.f, 89.f, 89.f}; +const vfloat64 vd_one = vec_splats(1.0); +const vfloat64 vd_zero = vec_splats(0.0); +const vfloat64 vd_log10e_inv = vec_splats(0.43429448190325176); +const vfloat64 vd_log2e_inv = vec_splats(1.4426950408889634); +const vfloat64 vd_imag_one = vfloat64{0.0, 1.0}; +const vfloat64 vd_imag_half = vfloat64{0.0, 0.5}; +const vfloat64 vd_sqrt2_2 = vfloat64{0.70710678118654757, 0.70710678118654757}; +const vfloat64 vd_pi_2 = vfloat64{M_PI / 2.0, 0.0}; + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/zarch/vec256_zarch.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/zarch/vec256_zarch.h new file mode 100644 index 0000000000000000000000000000000000000000..83f9fe36cd4dee1ebd53990ea77f832579af1cc5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/zarch/vec256_zarch.h @@ -0,0 +1,2970 @@ +#include +#include +#include +#include +#include +#if defined(__clang__) +#include +#elif defined(__GNUC__) || defined(__GNUG__) +#include +#include +#endif +#include +#include +#include + +namespace at { +namespace vec { + +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +template +constexpr bool is_zarch_implemented() { + return ( + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v); +} + +template +constexpr bool is_zarch_implemented_quant() { + return ( + std::is_same_v || std::is_same_v || + std::is_same_v); +} + +template +constexpr bool is_zarch_implemented_complex() { + return std::is_same_v> || + std::is_same_v>; +} + +constexpr int offset0 = 0; +constexpr int offset16 = 16; + +template +struct VecBinaryType { + using type __attribute__((vector_size(16))) = uintmax_t; +}; + +template <> +struct VecBinaryType<8> { + using type = __attribute__((vector_size(16))) unsigned long long; +}; + +template <> +struct VecBinaryType<4> { + using type = __attribute__((vector_size(16))) unsigned int; +}; + +template <> +struct VecBinaryType<2> { + using type = __attribute__((vector_size(16))) unsigned short; +}; + +template <> +struct VecBinaryType<1> { + using type = __attribute__((vector_size(16))) unsigned char; +}; + +template +struct VecInnerType { + using Type __attribute__((vector_size(16))) = T; + using BinaryType = typename VecBinaryType::type; + using ElementType = T; + static constexpr int size = 16 / sizeof(T); +}; + +// define for int64_t properly for load +template <> +struct VecInnerType { + using Type = __attribute__((vector_size(16))) signed long long; + using ElementType = signed long long; + using BinaryType = typename VecBinaryType::type; + static constexpr int size = 16 / sizeof(signed long long); +}; + +template +using ZSimdVect = typename VecInnerType::Type; +template +using ZSimdVectBinary = typename VecInnerType::BinaryType; +template +using ZSimdVectElement = typename VecInnerType::ElementType; + +constexpr int blendChoiceInner( + const uint64_t mask, + const uint64_t half1 = 0xF, + const uint64_t half2 = 0xF0) { + uint64_t none = 0; + uint64_t both = half1 | half2; + // clamp it between 0 and both + auto res_mask = mask & both; + // return (a._vec0, a._vec1) + if (res_mask == none) + return 0; + // return (b._vec0,b._vec1) + else if (res_mask == both) + return 1; + // return (b._vec0, a._vec1) + else if (res_mask == half1) + return 2; + // return (a._vec0,b._vec1) + else if (res_mask == half2) + return 3; + // return (*_vec0,a._vec1) + else if (res_mask > 0 && res_mask < half1) + return 4; + // return (*_vec0,b._vec1) + else if ((res_mask & half2) == half2) + return 5; + // return (a._vec0,*_vec1) + else if ((res_mask & half1) == 0 && res_mask > half1) + return 6; + // return (b._vec0,*_vec1) + else if ((res_mask & half1) == half1 && res_mask > half1) + return 7; + // return (*_vec0,*_vec1) + return 8; +} + +// it can be used to emulate blend faster +template +constexpr int blendChoice(const uint64_t mask) { + static_assert(Z < 1 || Z > 8, "not implemented"); + return blendChoiceInner(mask); +} + +template <> +constexpr int blendChoice<1>(const uint64_t mask) { + return blendChoiceInner(mask, 0x0000FFFF, 0xFFFF0000); +} + +template <> +constexpr int blendChoice<2>(const uint64_t mask) { + return blendChoiceInner(mask, 0x00FF, 0xFF00); +} + +template <> +constexpr int blendChoice<4>(const uint64_t mask) { + return blendChoiceInner(mask, 0xF, 0xF0); +} + +template <> +constexpr int blendChoice<8>(const uint64_t mask) { + // clamp it 0 and 0xF + return blendChoiceInner(mask, 0x3, 0xC); +} + +template +constexpr auto GetMask1(const uint64_t mask) { + return typename VecBinaryType::type{}; +} + +template +constexpr auto GetMask2(const uint64_t mask) { + return typename VecBinaryType::type{}; +} + +template <> +constexpr auto GetMask1<1>(const uint64_t mask) { + constexpr uint8_t t = (int)0xFF; + uint8_t g0 = (mask & 1) * t; + uint8_t g1 = ((mask & 2) >> 1) * t; + uint8_t g2 = ((mask & 4) >> 2) * t; + uint8_t g3 = ((mask & 8) >> 3) * t; + uint8_t g4 = ((mask & 16) >> 4) * t; + uint8_t g5 = ((mask & 32) >> 5) * t; + uint8_t g6 = ((mask & 64) >> 6) * t; + uint8_t g7 = ((mask & 128) >> 7) * t; + uint8_t g8 = ((mask & 256) >> 8) * t; + uint8_t g9 = ((mask & 512) >> 9) * t; + uint8_t g10 = ((mask & 1024) >> 10) * t; + uint8_t g11 = ((mask & 2048) >> 11) * t; + uint8_t g12 = ((mask & 4096) >> 12) * t; + uint8_t g13 = ((mask & 8192) >> 13) * t; + uint8_t g14 = ((mask & 16384) >> 14) * t; + uint8_t g15 = ((mask & 32768) >> 15) * t; + return (typename VecBinaryType<1>::type){ + g0, g1, g2, g3, g4, g5, g6, g7, g8, g9, g10, g11, g12, g13, g14, g15}; +} + +template <> +constexpr auto GetMask2<1>(const uint64_t mask) { + uint64_t mask2 = (mask & 0xFFFFFFFF) >> 16; + return GetMask1<1>(mask2); +} + +template <> +constexpr auto GetMask1<2>(const uint64_t mask) { + constexpr uint16_t t = (int)0xFFFF; + uint16_t g0 = (mask & 1) * t; + uint16_t g1 = ((mask & 2) >> 1) * t; + uint16_t g2 = ((mask & 4) >> 2) * t; + uint16_t g3 = ((mask & 8) >> 3) * t; + uint16_t g4 = ((mask & 16) >> 4) * t; + uint16_t g5 = ((mask & 32) >> 5) * t; + uint16_t g6 = ((mask & 64) >> 6) * t; + uint16_t g7 = ((mask & 128) >> 7) * t; + return (typename VecBinaryType<2>::type){g0, g1, g2, g3, g4, g5, g6, g7}; +} + +template <> +constexpr auto GetMask2<2>(const uint64_t mask) { + uint64_t mask2 = (mask & 0xFFFF) >> 8; + return GetMask1<2>(mask2); +} + +template <> +constexpr auto GetMask1<4>(const uint64_t mask) { + uint32_t g0 = (mask & 1) * 0xffffffff; + uint32_t g1 = ((mask & 2) >> 1) * 0xffffffff; + uint32_t g2 = ((mask & 4) >> 2) * 0xffffffff; + uint32_t g3 = ((mask & 8) >> 3) * 0xffffffff; + return (typename VecBinaryType<4>::type){g0, g1, g2, g3}; +} + +template <> +constexpr auto GetMask2<4>(const uint64_t mask) { + uint64_t mask2 = (mask & 0xFF) >> 4; + return GetMask1<4>(mask2); +} + +template <> +constexpr auto GetMask1<8>(const uint64_t mask) { + uint64_t g0 = (mask & 1) * 0xffffffffffffffff; + uint64_t g1 = ((mask & 2) >> 1) * 0xffffffffffffffff; + return (typename VecBinaryType<8>::type){g0, g1}; +} + +template <> +constexpr auto GetMask2<8>(const uint64_t mask) { + uint64_t mask2 = (mask & 0xF) >> 2; + return GetMask1<8>(mask2); +} + +template +constexpr int maskForComplex(uint32_t mask) { + return 0; +} + +template <> +constexpr int maskForComplex<8>(uint32_t mask) { + mask = mask & 0xF; + int complex_mask = 0; + if (mask & 1) + complex_mask |= 3; + if (mask & 2) + complex_mask |= (3 << 2); + if (mask & 4) + complex_mask |= (3 << 4); + if (mask & 8) + complex_mask |= (3 << 6); + return complex_mask; +} + +template <> +constexpr int maskForComplex<16>(uint32_t mask) { + mask = mask & 0x3; + int complex_mask = 0; + if (mask & 1) + complex_mask |= 3; + if (mask & 2) + complex_mask |= (3 << 2); + return complex_mask; +} + +template > +constexpr int blend_choice() { + return 0xAA; +} + +template <> +constexpr int blend_choice>() { + return 0x0A; +} + +constexpr int64_t allbitset(int16_t x) { + int64_t onex = 1; + return (onex << x) - onex; +} + +namespace { /* unnamed namespace */ + +ZSimdVect vec_mergee(ZSimdVect x, ZSimdVect y) { + constexpr ZSimdVectBinary mergee_mask{ + 0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 24, 25, 26, 27}; + return vec_perm(x, y, mergee_mask); +} + +ZSimdVect vec_mergee(ZSimdVect x, ZSimdVect y) { + return vec_mergeh(x, y); +} + +ZSimdVect vec_mergeo(ZSimdVect x, ZSimdVect y) { + constexpr ZSimdVectBinary mergeo_mask{ + 4, 5, 6, 7, 20, 21, 22, 23, 12, 13, 14, 15, 28, 29, 30, 31}; + return vec_perm(x, y, mergeo_mask); +} + +ZSimdVect vec_mergeo(ZSimdVect x, ZSimdVect y) { + return vec_mergel(x, y); +} + +} /* unnamed namespace */ + +// +template +constexpr auto GetBpermZeroMask() { + return ZSimdVectBinary{ + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 96, + 64, + 32, + 0}; +} + +template <> +constexpr auto GetBpermZeroMask() { + return ZSimdVectBinary{ + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 64, + 0}; +} + +constexpr auto GetSwapMaskFloat() { + return ZSimdVectBinary{ + 4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11}; +} + +template +struct is_vec_specialized_for()>> + : std::bool_constant {}; + +template +struct Vectorized()>> { + public: + using value_type = T; + using vtype = ZSimdVect; + using vmaskType = ZSimdVectBinary; + using size_type = int; + // because of gcc inconsistency for int64_t we are obliged to use this, not + // value_type + using ElementType = ZSimdVectElement; + using vinner_data = std::pair; + + private: + vtype _vec0; + vtype _vec1; + + public: + static constexpr size_type size() { + return VECTOR_WIDTH / sizeof(ElementType); + } + Vectorized() {} + + C10_ALWAYS_INLINE Vectorized(vtype v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vectorized(const vinner_data& v) + : _vec0{v.first}, _vec1{v.second} {} + C10_ALWAYS_INLINE Vectorized(vtype v1, vtype v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vectorized(T s) + : _vec0{vec_splats((ElementType)s)}, _vec1{vec_splats((ElementType)s)} {} + + template + struct LoaduHelper { + static Vectorized C10_ALWAYS_INLINE + loadu(const U* ptr, int count = size()) { + __at_align__ ElementType tmp_values[size()] = {}; + std::memcpy( + tmp_values, ptr, std::min(count, size()) * sizeof(ElementType)); + + return { + vec_xl(offset0, &(tmp_values[0])), + vec_xl(offset16, &(tmp_values[0]))}; + } + }; + + template + struct LoaduHelper { + static Vectorized C10_ALWAYS_INLINE + loadu(const ElementType* ptr, int count = size()) { + if (count == size()) { + return {vec_xl(offset0, ptr), vec_xl(offset16, ptr)}; + } + + __at_align__ ElementType tmp_values[size()] = {}; + std::memcpy( + tmp_values, ptr, std::min(count, size()) * sizeof(ElementType)); + + return { + vec_xl(offset0, &(tmp_values[0])), + vec_xl(offset16, &(tmp_values[0]))}; + } + }; + + template + static Vectorized C10_ALWAYS_INLINE + loadu(const U* ptr, int count = size()) { + return LoaduHelper::loadu(ptr, count); + } + + template + static Vectorized C10_ALWAYS_INLINE loadu_one_fourth(const U* ptr) { + // load only first 8 bytes + // only intended to be used with uint8_t + return loadu(ptr, 8 / sizeof(ElementType)); + } + + template + struct StoreHelper { + static void C10_ALWAYS_INLINE + store(const Vectorized& vec, U* ptr, int count = size()) { + if (count > 0) { + __at_align__ ElementType tmp_values[size()]; + vec_xst(vec._vec0, offset0, &(tmp_values[0])); + vec_xst(vec._vec1, offset16, &(tmp_values[0])); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(ElementType)); + } + } + }; + + template + struct StoreHelper { + static void C10_ALWAYS_INLINE + store(const Vectorized& vec, ElementType* ptr, int count = size()) { + if (count == size()) { + vec_xst(vec._vec0, offset0, ptr); + vec_xst(vec._vec1, offset16, ptr); + } else if (count > 0) { + __at_align__ ElementType tmp_values[size()]; + vec_xst(vec._vec0, offset0, &(tmp_values[0])); + vec_xst(vec._vec1, offset16, &(tmp_values[0])); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(ElementType)); + } + } + }; + + template + void C10_ALWAYS_INLINE store(U* ptr, int count = size()) const { + return StoreHelper::store(*this, ptr, count); + } + + C10_ALWAYS_INLINE const vtype& vec0() const { + return _vec0; + } + + C10_ALWAYS_INLINE const vtype& vec1() const { + return _vec1; + } + + C10_ALWAYS_INLINE vinner_data data() const { + return std::make_pair<>(_vec0, _vec1); + } + + C10_ALWAYS_INLINE operator vinner_data() const { + return data(); + } + + C10_ALWAYS_INLINE const vmaskType vecb0() const { + return (vmaskType)_vec0; + } + C10_ALWAYS_INLINE const vmaskType vecb1() const { + return (vmaskType)_vec1; + } + + static Vectorized C10_ALWAYS_INLINE blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + return { + vec_sel(a._vec0, b._vec0, mask.vecb0()), + vec_sel(a._vec1, b._vec1, mask.vecb1())}; + } + + template = 0> + C10_ALWAYS_INLINE Vectorized(T s1, T s2, T s3, T s4) + : _vec0{s1, s2}, _vec1{s3, s4} {} + + template = 0> + C10_ALWAYS_INLINE Vectorized(T s1, T s2, T s3, T s4, T s5, T s6, T s7, T s8) + : _vec0{s1, s2, s3, s4}, _vec1{s5, s6, s7, s8} {} + + template = 0> + C10_ALWAYS_INLINE Vectorized( + T s1, + T s2, + T s3, + T s4, + T s5, + T s6, + T s7, + T s8, + T s9, + T s10, + T s11, + T s12, + T s13, + T s14, + T s15, + T s16) + : _vec0{s1, s2, s3, s4, s5, s6, s7, s8}, + _vec1{s9, s10, s11, s12, s13, s14, s15, s16} {} + + template = 0> + C10_ALWAYS_INLINE Vectorized( + T s1, + T s2, + T s3, + T s4, + T s5, + T s6, + T s7, + T s8, + T s9, + T s10, + T s11, + T s12, + T s13, + T s14, + T s15, + T s16, + T s17, + T s18, + T s19, + T s20, + T s21, + T s22, + T s23, + T s24, + T s25, + T s26, + T s27, + T s28, + T s29, + T s30, + T s31, + T s32) + : _vec0{s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14, s15, s16}, + _vec1{ + s17, + s18, + s19, + s20, + s21, + s22, + s23, + s24, + s25, + s26, + s27, + s28, + s29, + s30, + s31, + s32} {} + + template + static std::enable_if_t> arange( + T base = 0, + step_t step = static_cast(1)) { + return Vectorized(base, base + step, base + 2 * step, base + 3 * step); + } + + template + static std::enable_if_t> arange( + T base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step); + } + + template + static std::enable_if_t> arange( + T base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step, + base + 8 * step, + base + 9 * step, + base + 10 * step, + base + 11 * step, + base + 12 * step, + base + 13 * step, + base + 14 * step, + base + 15 * step); + } + + template + static std::enable_if_t> arange( + T base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step, + base + 8 * step, + base + 9 * step, + base + 10 * step, + base + 11 * step, + base + 12 * step, + base + 13 * step, + base + 14 * step, + base + 15 * step, + base + 16 * step, + base + 17 * step, + base + 18 * step, + base + 19 * step, + base + 20 * step, + base + 21 * step, + base + 22 * step, + base + 23 * step, + base + 24 * step, + base + 25 * step, + base + 26 * step, + base + 27 * step, + base + 28 * step, + base + 29 * step, + base + 30 * step, + base + 31 * step); + } + + // blend section + template + static std::enable_if_t(mask) == 0, Vectorized> + C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) { + return a; + } + + template + static std::enable_if_t(mask) == 1, Vectorized> + C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) { + return b; + } + + template + static std::enable_if_t(mask) == 2, Vectorized> + C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) { + return {b._vec0, a._vec1}; + } + + template + static std::enable_if_t(mask) == 3, Vectorized> + C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) { + return {a._vec0, b._vec1}; + } + + template + static std::enable_if_t(mask) == 4, Vectorized> + C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) { + const vmaskType mask_1st = GetMask1(mask); + return {(vtype)vec_sel(a._vec0, b._vec0, mask_1st), a._vec1}; + } + + template + static std::enable_if_t(mask) == 5, Vectorized> + C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) { + const vmaskType mask_1st = GetMask1(mask); + return {(vtype)vec_sel(a._vec0, b._vec0, mask_1st), b._vec1}; + } + + template + static std::enable_if_t(mask) == 6, Vectorized> + C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) { + const vmaskType mask_2nd = GetMask2(mask); + // generated masks + return {a._vec0, (vtype)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static std::enable_if_t(mask) == 7, Vectorized> + C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) { + const vmaskType mask_2nd = GetMask2(mask); + // generated masks + return {b._vec0, (vtype)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static std::enable_if_t(mask) == 8, Vectorized> + C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) { + const vmaskType mask_1st = GetMask1(mask); + const vmaskType mask_2nd = GetMask2(mask); + return { + (vtype)vec_sel(a._vec0, b._vec0, mask_1st), + (vtype)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static inline std::enable_if_t<(Z >= C), Vectorized> set_inner( + const Vectorized& a, + const Vectorized& b, + size_t count) { + return b; + } + + template + static inline std::enable_if_t<(Z < C), Vectorized> set_inner( + const Vectorized& a, + const Vectorized& b, + size_t count) { + if (count == Z) + return blend(a, b); + else + return set_inner(a, b, count); + } + + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + size_t count = size()) { + if (count == 0) + return a; + return set_inner<1, size()>(a, b, count); + } + + const ElementType& operator[](int idx) const = delete; + ElementType& operator[](int idx) = delete; + + Vectorized _not() const { + return {(vtype)vec_nor(vecb0(), vecb0()), (vtype)vec_nor(vecb1(), vecb1())}; + } + + Vectorized C10_ALWAYS_INLINE eq(const Vectorized& other) const { + return (*this == other) & Vectorized((T)1.0); + } + Vectorized C10_ALWAYS_INLINE ne(const Vectorized& other) const { + return (*this != other) & Vectorized((T)1.0); + } + Vectorized C10_ALWAYS_INLINE gt(const Vectorized& other) const { + return (*this > other) & Vectorized((T)1.0); + } + Vectorized C10_ALWAYS_INLINE ge(const Vectorized& other) const { + return (*this >= other) & Vectorized((T)1.0); + } + Vectorized C10_ALWAYS_INLINE lt(const Vectorized& other) const { + return (*this < other) & Vectorized((T)1.0); + } + Vectorized C10_ALWAYS_INLINE le(const Vectorized& other) const { + return (*this <= other) & Vectorized((T)1.0); + } + + template , int> = 0> + Vectorized C10_ALWAYS_INLINE abs() const { + return {vec_abs(_vec0), vec_abs(_vec1)}; + } + + template , int> = 0> + Vectorized C10_ALWAYS_INLINE abs() const { + return {_vec0, _vec1}; + } + + Vectorized C10_ALWAYS_INLINE neg() const { + return {-_vec0, -_vec1}; + } + + Vectorized isnan() const { + auto x = *this; + auto ret = (x == x); + return ret._not(); + } + + bool has_inf_nan() const { + for (const auto i : c10::irange(size() / 2)) { + if (_isnan(_vec0[i]) || _isinf(_vec0[i])) { + return true; + } + } + for (const auto i : c10::irange(size() / 2)) { + if (_isnan(_vec1[i]) || _isinf(_vec1[i])) { + return true; + } + } + return false; + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized angle() const { + auto tmp = blendv( + Vectorized(0), Vectorized(c10::pi), *this < Vectorized(0)); + return blendv(tmp, *this, isnan()); + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized angle() const { + return blendv( + Vectorized(0), Vectorized(c10::pi), *this < Vectorized(0)); + } + + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return Vectorized{0}; + } + Vectorized conj() const { + return *this; + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + int zero_mask() const { + auto cmp = (*this == Vectorized(0)); + constexpr auto mask_zero_bits = GetBpermZeroMask(); + ZSimdVectBinary result0 = + vec_bperm_u128((ZSimdVectBinary)cmp.vecb0(), mask_zero_bits); + ZSimdVectBinary result1 = + vec_bperm_u128((ZSimdVectBinary)cmp.vecb1(), mask_zero_bits); + return (result0[0] | (result1[0] << (size() / 2))); + } + + Vectorized C10_ALWAYS_INLINE floor() const { + return {vec_floor(_vec0), vec_floor(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE ceil() const { + return {vec_ceil(_vec0), vec_ceil(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE round() const { + return {vec_round(_vec0), vec_round(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE rint() const { + return {vec_rint(_vec0), vec_rint(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE trunc() const { + return {vec_trunc(_vec0), vec_trunc(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE frac() const { + return *this - trunc(); + } + + Vectorized C10_ALWAYS_INLINE sqrt() const { + return {vec_sqrt(_vec0), vec_sqrt(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE reciprocal() const { + return Vectorized((T)1) / (*this); + } + Vectorized C10_ALWAYS_INLINE rsqrt() const { + return sqrt().reciprocal(); + } + + template , int> = 0> + inline Vectorized mapOrdinary(float (*const f)(float)) const { + float a00 = f(_vec0[0]); + float a01 = f(_vec0[1]); + float a02 = f(_vec0[2]); + float a03 = f(_vec0[3]); + float a10 = f(_vec1[0]); + float a11 = f(_vec1[1]); + float a12 = f(_vec1[2]); + float a13 = f(_vec1[3]); + return Vectorized{a00, a01, a02, a03, a10, a11, a12, a13}; + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + inline Vectorized mapOrdinary(double (*const f)(double)) const { + return Vectorized(f(_vec0[0]), f(_vec0[1]), f(_vec1[0]), f(_vec1[1])); + } + + template , int> = 0> + inline Vectorized mapOrdinary( + float (*const f)(float, float), + const Vectorized& b) const { + float a00 = f(_vec0[0], b._vec0[0]); + float a01 = f(_vec0[1], b._vec0[1]); + float a02 = f(_vec0[2], b._vec0[2]); + float a03 = f(_vec0[3], b._vec0[3]); + float a10 = f(_vec1[0], b._vec1[0]); + float a11 = f(_vec1[1], b._vec1[1]); + float a12 = f(_vec1[2], b._vec1[2]); + float a13 = f(_vec1[3], b._vec1[3]); + return Vectorized{a00, a01, a02, a03, a10, a11, a12, a13}; + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + inline Vectorized mapOrdinary( + double (*const f)(double, double), + const Vectorized& b) const { + return Vectorized( + f(_vec0[0], b._vec0[0]), + f(_vec0[1], b._vec0[1]), + f(_vec1[0], b._vec1[0]), + f(_vec1[1], b._vec1[1])); + } + + template < + typename FloatOp, + typename DoubleOp, + typename U = T, + std::enable_if_t, int> = 0> + inline Vectorized mapSleef(FloatOp f, DoubleOp d) const { + vtype a0 = f(_vec0); + vtype a1 = f(_vec1); + return Vectorized{a0, a1}; + } + + template < + typename FloatOp, + typename DoubleOp, + typename U = T, + std::enable_if_t, int> = 0> + inline Vectorized mapSleef(FloatOp f, DoubleOp d) const { + return Vectorized(d(_vec0), d(_vec1)); + } + + template < + typename FloatOp, + typename DoubleOp, + typename U = T, + std::enable_if_t, int> = 0> + inline Vectorized mapSleef(FloatOp f, DoubleOp d, const Vectorized& b) + const { + vtype a0 = f(_vec0, b._vec0); + vtype a1 = f(_vec1, b._vec1); + return Vectorized{a0, a1}; + } + + template < + typename FloatOp, + typename DoubleOp, + typename U = T, + std::enable_if_t, int> = 0> + inline Vectorized mapSleef(FloatOp f, DoubleOp d, const Vectorized& b) + const { + return Vectorized(d(_vec0, b._vec0), d(_vec1, b._vec1)); + } + + Vectorized acos() const { + return mapSleef(Sleef_acosf4_u10, Sleef_acosd2_u10); + } + Vectorized asin() const { + return mapSleef(Sleef_asinf4_u10, Sleef_asind2_u10); + } + Vectorized atan() const { + return mapSleef(Sleef_atanf4_u10, Sleef_atand2_u10); + } + Vectorized atanh() const { + return mapSleef(Sleef_atanhf4_u10, Sleef_atanhd2_u10); + } + + Vectorized erf() const { + return mapSleef(Sleef_erff4_u10, Sleef_erfd2_u10); + } + Vectorized erfc() const { + return mapSleef(Sleef_erfcf4_u15, Sleef_erfcd2_u15); + } + + Vectorized exp() const { + return mapSleef(Sleef_expf4_u10, Sleef_expd2_u10); + } + Vectorized exp2() const { + return mapSleef(Sleef_exp2f4_u10, Sleef_exp2d2_u10); + } + Vectorized expm1() const { + return mapSleef(Sleef_expm1f4_u10, Sleef_expm1d2_u10); + } + Vectorized exp_u20() const { + return exp(); + } + + Vectorized log() const { + return mapSleef(Sleef_logf4_u10, Sleef_logd2_u10); + } + Vectorized log2() const { + return mapSleef(Sleef_log2f4_u10, Sleef_log2d2_u10); + } + Vectorized log10() const { + return mapSleef(Sleef_log10f4_u10, Sleef_log10d2_u10); + } + Vectorized log1p() const { + return mapSleef(Sleef_log1pf4_u10, Sleef_log1pd2_u10); + } + + Vectorized sin() const { + return mapSleef(Sleef_sinf4_u10, Sleef_sind2_u10); + } + Vectorized sinh() const { + return mapSleef(Sleef_sinhf4_u10, Sleef_sinhd2_u10); + } + Vectorized cos() const { + return mapSleef(Sleef_cosf4_u10, Sleef_cosd2_u10); + } + Vectorized cosh() const { + return mapSleef(Sleef_coshf4_u10, Sleef_coshd2_u10); + } + + Vectorized tan() const { + return mapSleef(Sleef_tanf4_u10, Sleef_tand2_u10); + } + Vectorized tanh() const { + return mapSleef(Sleef_tanhf4_u10, Sleef_tanhd2_u10); + } + + Vectorized lgamma() const { + return mapSleef(Sleef_lgammaf4_u10, Sleef_lgammad2_u10); + } + + Vectorized atan2(const Vectorized& b) const { + return mapSleef(Sleef_atan2f4_u10, Sleef_atan2d2_u10, b); + } + Vectorized copysign(const Vectorized& sign) const { + return mapSleef(Sleef_copysignf4, Sleef_copysignd2, sign); + } + Vectorized fmod(const Vectorized& q) const { + return mapSleef(Sleef_fmodf4, Sleef_fmodd2, q); + } + + Vectorized hypot(const Vectorized& b) const { + return mapSleef(Sleef_hypotf4_u05, Sleef_hypotd2_u05, b); + } + + Vectorized pow(const Vectorized& b) const { + return mapSleef(Sleef_powf4_u10, Sleef_powd2_u10, b); + } + + Vectorized nextafter(const Vectorized& b) const { + return mapSleef(Sleef_nextafterf4, Sleef_nextafterd2, b); + } + + Vectorized erfinv() const { + return mapOrdinary(calc_erfinv); + } + + Vectorized digamma() const { + return mapOrdinary(calc_digamma); + } + + Vectorized igamma(const Vectorized& x) const { + return mapOrdinary(calc_igamma, x); + } + + Vectorized igammac(const Vectorized& x) const { + return mapOrdinary(calc_igammac, x); + } + + Vectorized i0() const { + return mapOrdinary(calc_i0); + } + + Vectorized i0e() const { + return mapOrdinary(calc_i0e); + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized minimum(const Vectorized& other) const { + return {vec_min(_vec0, other._vec0), vec_min(_vec1, other._vec1)}; + } + + /* Propagates NaN if either input is a NaN. */ + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized minimum(const Vectorized& other) const { + Vectorized tmp = { + vec_min(_vec0, other._vec0), vec_min(_vec1, other._vec1)}; + tmp = blendv(tmp, *this, isnan()); + return blendv(tmp, other, other.isnan()); + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized maximum(const Vectorized& other) const { + return {vec_max(_vec0, other._vec0), vec_max(_vec1, other._vec1)}; + } + + /* Propagates NaN if either input is a NaN. */ + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized maximum(const Vectorized& other) const { + Vectorized tmp = { + vec_max(_vec0, other._vec0), vec_max(_vec1, other._vec1)}; + tmp = blendv(tmp, *this, isnan()); + return blendv(tmp, other, other.isnan()); + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized clamp_min(const Vectorized& min) const { + return {vec_max(_vec0, min._vec0), vec_max(_vec1, min._vec1)}; + } + + /* Keeps NaN if actual value is NaN */ + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized clamp_min(const Vectorized& min) const { + Vectorized tmp = {vec_max(_vec0, min._vec0), vec_max(_vec1, min._vec1)}; + return blendv(tmp, *this, isnan()); + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized clamp_max(const Vectorized& max) const { + return {vec_min(_vec0, max._vec0), vec_min(_vec1, max._vec1)}; + } + + /* Keeps NaN if actual value is NaN */ + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized clamp_max(const Vectorized& max) const { + Vectorized tmp = {vec_min(_vec0, max._vec0), vec_min(_vec1, max._vec1)}; + return blendv(tmp, *this, isnan()); + } + + template , int> = 0> + Vectorized swapped() const { + auto swap_mask = GetSwapMaskFloat(); + vtype v0 = vec_perm(_vec0, _vec0, swap_mask); + vtype v1 = vec_perm(_vec1, _vec1, swap_mask); + return {v0, v1}; + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized swapped() const { + vtype v0 = {_vec0[1], _vec0[0]}; + vtype v1 = {_vec1[1], _vec1[0]}; + return {v0, v1}; + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + static Vectorized mergee(Vectorized& first, Vectorized& second) { + return { + vec_mergee(first._vec0, second._vec0), + vec_mergee(first._vec1, second._vec1)}; + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + static Vectorized mergeo(Vectorized& first, Vectorized& second) { + return { + vec_mergeo(first._vec0, second._vec0), + vec_mergeo(first._vec1, second._vec1)}; + } + + static Vectorized horizontal_add_perm( + Vectorized& first, + Vectorized& second) { + // we will simulate it differently with 6 instructions total + // lets permute second so that we can add it getting horizontal sums + auto first_perm = first.swapped(); // 2perm + auto second_perm = second.swapped(); // 2perm + // summ + auto first_ret = first + first_perm; // 2add + auto second_ret = second + second_perm; // 2 add + // now lets choose evens + return mergee(first_ret, second_ret); // 2 mergee's + } + + static Vectorized horizontal_sub_perm( + Vectorized& first, + Vectorized& second) { + // we will simulate it differently with 6 instructions total + // lets permute second so that we can add it getting horizontal sums + auto first_perm = first.swapped(); // 2perm + auto second_perm = second.swapped(); // 2perm + // summ + auto first_ret = first - first_perm; // 2sub + auto second_ret = second - second_perm; // 2 sub + // now lets choose evens + return mergee(first_ret, second_ret); // 2 mergee's + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized mergee() const { + return {vec_mergee(_vec0, _vec0), vec_mergee(_vec1, _vec1)}; + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized mergeo() const { + return {vec_mergeo(_vec0, _vec0), vec_mergeo(_vec1, _vec1)}; + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized to_vec_float_helper() const { + int32_t values[8] = { + _vec0[0], + _vec0[1], + _vec0[2], + _vec0[3], + _vec0[4], + _vec0[5], + _vec0[6], + _vec0[7], + }; + + return Vectorized{ + values[0], + values[1], + values[2], + values[3], + values[4], + values[5], + values[6], + values[7]}; + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized to_vec_uint8_helper() const { + // helper function for float to uint8_t conversion + uint8_t values[8] = { + static_cast(_vec0[0]), + static_cast(_vec0[1]), + static_cast(_vec0[2]), + static_cast(_vec0[3]), + static_cast(_vec1[0]), + static_cast(_vec1[1]), + static_cast(_vec1[2]), + static_cast(_vec1[3]), + }; + + return Vectorized{ + values[0], values[1], values[2], values[3], values[4], values[5], + values[6], values[7], 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + 0, 0, + }; + } +}; + +#define ZVECTOR_OPERATORS(typex) \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator+( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec0() + b.vec0(), a.vec1() + b.vec1()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator-( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec0() - b.vec0(), a.vec1() - b.vec1()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator*( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec0() * b.vec0(), a.vec1() * b.vec1()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator/( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec0() / b.vec0(), a.vec1() / b.vec1()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator&( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{ \ + (Vectorized::vtype)(a.vecb0() & b.vecb0()), \ + (Vectorized::vtype)(a.vecb1() & b.vecb1())}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator|( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{ \ + (Vectorized::vtype)(a.vecb0() | b.vecb0()), \ + (Vectorized::vtype)(a.vecb1() | b.vecb1())}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator^( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{ \ + (Vectorized::vtype)(a.vecb0() ^ b.vecb0()), \ + (Vectorized::vtype)(a.vecb1() ^ b.vecb1())}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator==( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{ \ + vec_cmpeq(a.vec0(), b.vec0()), vec_cmpeq(a.vec1(), b.vec1())}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator!=( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{ \ + vec_cmpeq(a.vec0(), b.vec0()), vec_cmpeq(a.vec1(), b.vec1())} \ + ._not(); \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator>( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{ \ + vec_cmpgt(a.vec0(), b.vec0()), vec_cmpgt(a.vec1(), b.vec1())}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator>=( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{ \ + vec_cmpge(a.vec0(), b.vec0()), vec_cmpge(a.vec1(), b.vec1())}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator<( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{ \ + vec_cmplt(a.vec0(), b.vec0()), vec_cmplt(a.vec1(), b.vec1())}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator<=( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{ \ + vec_cmple(a.vec0(), b.vec0()), vec_cmple(a.vec1(), b.vec1())}; \ + } + +ZVECTOR_OPERATORS(float) +ZVECTOR_OPERATORS(double) +ZVECTOR_OPERATORS(int8_t) +ZVECTOR_OPERATORS(uint8_t) +ZVECTOR_OPERATORS(uint16_t) +ZVECTOR_OPERATORS(int16_t) +ZVECTOR_OPERATORS(int32_t) +ZVECTOR_OPERATORS(int64_t) + +#undef ZVECTOR_OPERATORS + +#define ZVECTOR_OPERATORS(typex) \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator<<( \ + const Vectorized& a, const Vectorized& b) { \ + constexpr Vectorized::ElementType max_shift = \ + sizeof(Vectorized::ElementType) * CHAR_BIT; \ + \ + Vectorized::ElementType a_array[Vectorized::size()]; \ + Vectorized::ElementType b_array[Vectorized::size()]; \ + Vectorized::ElementType c_array[Vectorized::size()]; \ + \ + a.store(a_array); \ + b.store(b_array); \ + \ + for (int i = 0; i != Vectorized::size(); i++) { \ + typex shift = b_array[i]; \ + if ((static_cast>(shift) < 0) || \ + (shift >= max_shift)) { \ + c_array[i] = 0; \ + } else { \ + c_array[i] = static_cast>(a_array[i]) \ + << shift; \ + } \ + } \ + \ + return Vectorized::loadu(c_array); \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator>>( \ + const Vectorized& a, const Vectorized& b) { \ + /* right shift value to retain sign bit for signed and no bits for \ + * unsigned */ \ + constexpr Vectorized::ElementType max_shift = \ + sizeof(typex) * CHAR_BIT - std::is_signed_v; \ + \ + Vectorized::ElementType a_array[Vectorized::size()]; \ + Vectorized::ElementType b_array[Vectorized::size()]; \ + Vectorized::ElementType c_array[Vectorized::size()]; \ + \ + a.store(a_array); \ + b.store(b_array); \ + \ + for (int i = 0; i != Vectorized::size(); i++) { \ + typex shift = b_array[i]; \ + if ((static_cast>(shift) < 0) || \ + (shift >= max_shift)) { \ + c_array[i] = a_array[i] >> max_shift; \ + } else { \ + c_array[i] = a_array[i] >> shift; \ + } \ + } \ + \ + return Vectorized::loadu(c_array); \ + } \ + \ + template <> \ + inline Vectorized operator~(const Vectorized& a) { \ + return a._not(); \ + } + +ZVECTOR_OPERATORS(int8_t) +ZVECTOR_OPERATORS(uint8_t) +ZVECTOR_OPERATORS(uint16_t) +ZVECTOR_OPERATORS(int16_t) +ZVECTOR_OPERATORS(int32_t) +ZVECTOR_OPERATORS(int64_t) + +#undef ZVECTOR_OPERATORS + +#define DEFINE_MAXMIN_FUNCS(operand_type) \ + template <> \ + Vectorized inline maximum( \ + const Vectorized& a, const Vectorized& b) { \ + return a.maximum(b); \ + } \ + template <> \ + Vectorized inline minimum( \ + const Vectorized& a, const Vectorized& b) { \ + return a.minimum(b); \ + } + +#define DEFINE_CLAMP_MAXMIN_FUNCS(typex) \ + DEFINE_MAXMIN_FUNCS(typex) \ + template <> \ + Vectorized C10_ALWAYS_INLINE clamp_min( \ + const Vectorized& a, const Vectorized& min) { \ + return a.clamp_min(min); \ + } \ + template <> \ + Vectorized C10_ALWAYS_INLINE clamp_max( \ + const Vectorized& a, const Vectorized& max) { \ + return a.clamp_max(max); \ + } \ + template <> \ + Vectorized C10_ALWAYS_INLINE clamp( \ + const Vectorized& a, \ + const Vectorized& min, \ + const Vectorized& max) { \ + return clamp_max(clamp_min(a, min), max); \ + } + +DEFINE_CLAMP_MAXMIN_FUNCS(int8_t) +DEFINE_CLAMP_MAXMIN_FUNCS(uint8_t) +DEFINE_CLAMP_MAXMIN_FUNCS(int16_t) +DEFINE_CLAMP_MAXMIN_FUNCS(int32_t) +DEFINE_CLAMP_MAXMIN_FUNCS(int64_t) +DEFINE_CLAMP_MAXMIN_FUNCS(float) +DEFINE_CLAMP_MAXMIN_FUNCS(double) + +namespace { /* unnamed namespace */ + +#if !defined(vec_float) || __ARCH__ < 13 +#warning \ + "float->int and int->float conversion is simulated. compile for z15 for improved performance" +inline ZSimdVect vec_int_flt(const ZSimdVect x) { + return ZSimdVect{float(x[0]), float(x[1]), float(x[2]), float(x[3])}; +} +inline ZSimdVect vec_flt_int(const ZSimdVect x) { + return ZSimdVect{int(x[0]), int(x[1]), int(x[2]), int(x[3])}; +} +#else +#define vec_int_flt vec_float +#define vec_flt_int vec_signed +#endif + +Vectorized zvec_convert_to_float(const Vectorized& x) { + return {vec_int_flt(x.vec0()), vec_int_flt(x.vec1())}; +} + +Vectorized zvec_convert_to_int(const Vectorized& x) { + return {vec_flt_int(x.vec0()), vec_flt_int(x.vec1())}; +} + +Vectorized zvec_convert_to_float(const Vectorized& x) { + return {vec_double(x.vec0()), vec_double(x.vec1())}; +} + +Vectorized zvec_convert_to_int(const Vectorized& x) { + return {vec_signed(x.vec0()), vec_signed(x.vec1())}; +} + +} /* unnamed namespace */ + +template +Vectorized cast_zvector(const Vectorized& x) { + using cast_type = typename Vectorized::vtype; + return Vectorized{(cast_type)x.vec0(), (cast_type)x.vec1()}; +} + +template <> +Vectorized C10_ALWAYS_INLINE fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return Vectorized{ + __builtin_s390_vfmasb(a.vec0(), b.vec0(), c.vec0()), + __builtin_s390_vfmasb(a.vec1(), b.vec1(), c.vec1())}; +} +template <> +Vectorized C10_ALWAYS_INLINE fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return Vectorized{ + __builtin_s390_vfmadb(a.vec0(), b.vec0(), c.vec0()), + __builtin_s390_vfmadb(a.vec1(), b.vec1(), c.vec1())}; +} +template <> +Vectorized C10_ALWAYS_INLINE fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return Vectorized{ + a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()}; +} +template <> +Vectorized C10_ALWAYS_INLINE fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return Vectorized{ + a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()}; +} +template <> +Vectorized C10_ALWAYS_INLINE fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return Vectorized{ + a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +convert_to_int_of_same_size(const Vectorized& src) { + return zvec_convert_to_int(src); +} + +template <> +Vectorized C10_ALWAYS_INLINE +convert_to_int_of_same_size(const Vectorized& src) { + return zvec_convert_to_int(src); +} + +template <> +inline void convert(const int32_t* src, float* dst, int64_t n) { + // int32_t and float have same size + int64_t i; + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + const int32_t* src_a = src + i; + float* dst_a = dst + i; + auto input_vec = Vectorized::loadu(src_a); + auto output_vec = zvec_convert_to_float(input_vec); + output_vec.store(dst_a); + } + + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +template <> +inline void convert(const int64_t* src, double* dst, int64_t n) { + int64_t i; + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + const int64_t* src_a = src + i; + double* dst_a = dst + i; + auto input_vec = Vectorized::loadu(src_a); + auto output_vec = zvec_convert_to_float(input_vec); + output_vec.store(dst_a); + } + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +#define DEFINE_REINTERPRET_CAST_FUNCS(Fst, Cst) \ + template <> \ + C10_ALWAYS_INLINE Vectorized cast( \ + const Vectorized& src) { \ + return cast_zvector(src); \ + } + +#define DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(Fst) \ + DEFINE_REINTERPRET_CAST_FUNCS(Fst, double) \ + DEFINE_REINTERPRET_CAST_FUNCS(Fst, float) \ + DEFINE_REINTERPRET_CAST_FUNCS(Fst, int64_t) \ + DEFINE_REINTERPRET_CAST_FUNCS(Fst, int32_t) \ + DEFINE_REINTERPRET_CAST_FUNCS(Fst, int16_t) + +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(float) +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(double) +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int64_t) +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int32_t) +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int16_t) + +#undef DEFINE_REINTERPRET_CAST_FUNCS + +template +struct unpack_type { + using type = T; +}; +template <> +struct unpack_type { + using type = int16_t; +}; +template <> +struct unpack_type { + using type = int16_t; +}; +template <> +struct unpack_type { + using type = int32_t; +}; + +template +struct pack_type { + using type = T; +}; +template <> +struct pack_type { + using type = int8_t; +}; +template <> +struct pack_type { + using type = int16_t; +}; + +namespace { /* unnamed namespace */ + +template ::type> +std::pair, Vectorized> unpack(const Vectorized& x) { + auto vec0 = vec_unpackh(x.vec0()); + auto vec1 = vec_unpackl(x.vec0()); + auto vec2 = vec_unpackh(x.vec1()); + auto vec3 = vec_unpackl(x.vec1()); + return {Vectorized{vec0, vec1}, Vectorized{vec2, vec3}}; +} + +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function") +template <> +std::pair, Vectorized> unpack( + const Vectorized& x) { + using typeX = typename Vectorized::vtype; + typeX vec0 = vec_unpackh(x.vec0()); + typeX vec1 = vec_unpackl(x.vec0()); + typeX vec2 = vec_unpackh(x.vec1()); + typeX vec3 = vec_unpackl(x.vec1()); + // auto mask = Vectorized(0xFF); + // vec0 = vec0 & mask; + // vec1 = vec1 & mask; + // vec2 = vec2 & mask; + // vec3 = vec3 & mask; + return { + cast_zvector(Vectorized{vec0, vec1}), + cast_zvector(Vectorized{vec2, vec3})}; +} +C10_DIAGNOSTIC_POP() + +template ::type> +Vectorized pack(const Vectorized& first, const Vectorized& second) { + auto vec0 = vec_packs(first.vec0(), first.vec1()); + auto vec1 = vec_packs(second.vec0(), second.vec1()); + return Vectorized{vec0, vec1}; +} + +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function") +template <> +Vectorized pack( + const Vectorized& first, + const Vectorized& second) { + auto vec0 = vec_packsu(first.vec0(), first.vec1()); + auto vec1 = vec_packsu(second.vec0(), second.vec1()); + return Vectorized{vec0, vec1}; +} +C10_DIAGNOSTIC_POP() + +} /* unnamed namespace */ + +//////////////////////////////////QUANT/////////////////////////////////////////// +template +struct is_vec_specialized_for< + T, + std::enable_if_t()>> + : std::bool_constant {}; + +template +struct Vectorized()>> { + public: + using value_type = typename T::underlying; + using vtype = ZSimdVect; + using vmaskType = ZSimdVectBinary; + using vinner_type = Vectorized; + using size_type = int; + + static constexpr size_type size() { + return VECTOR_WIDTH / sizeof(value_type); + } + + static constexpr int float_num_vecs() { + return size() / Vectorized::size(); + } + static constexpr int int_num_vecs() { + return float_num_vecs(); + } + using float_vec_return_type = std::array, float_num_vecs()>; + using int_vec_return_type = + std::array, int_num_vecs()>; + + private: + vinner_type _vec; + + public: + Vectorized() {} + + explicit C10_ALWAYS_INLINE Vectorized(vinner_type v) : _vec{v} {} + Vectorized(const T& val) : _vec(val.val_) {} + + C10_ALWAYS_INLINE const vinner_type& vec() const { + return _vec; + } + + template + static Vectorized C10_ALWAYS_INLINE + loadu(const U* ptr, int count = size()) { + return Vectorized{vinner_type::loadu(ptr, count)}; + } + + template + void C10_ALWAYS_INLINE store(U* ptr, int count = size()) const { + _vec.store(ptr, count); + } + + Vectorized relu(Vectorized zero_point) const { + return Vectorized{_vec.maximum(zero_point._vec)}; + } + + Vectorized relu6(Vectorized zero_point, Vectorized q_six) const { + auto ret_max = _vec.maximum(zero_point._vec); + auto ret_min = ret_max.minimum(q_six._vec); + return Vectorized{ret_min}; + } + + template < + typename U = T, + std::enable_if_t::float_num_vecs() == 1, int> = 0> + int_vec_return_type widening_subtract(Vectorized b) const { + return {*this - b}; + } + + template < + typename U = T, + std::enable_if_t::float_num_vecs() == 1, int> = 0> + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_zp_premul) const { + auto float_val = zvec_convert_to_float(_vec); + return {fmadd(scale, float_val, scale_zp_premul)}; + } + + template < + typename U = T, + std::enable_if_t::float_num_vecs() == 1, int> = 0> + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + auto float_val = zvec_convert_to_float(_vec); + return {(float_val - zero_point) * scale}; + } + + template < + typename U = T, + std::enable_if_t::float_num_vecs() == 1, int> = 0> + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + Vectorized vecf = rhs[0]; + vecf = vecf * Vectorized(inverse_scale); + vecf = vecf.rint() + Vectorized((float)(zero_point)); + auto veci = zvec_convert_to_int(vecf); + + return Vectorized{veci}; + } + + template < + typename U = T, + std::enable_if_t::int_num_vecs() == 1, int> = 0> + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + Vectorized vi = inp[0]; + auto vecf = zvec_convert_to_float(vi.vec()); + vecf = vecf * Vectorized(multiplier); + vecf = vecf.rint(); + auto veci = zvec_convert_to_int(vecf) + Vectorized(zero_point); + + return Vectorized{veci}; + } + + template < + typename U = T, + std::enable_if_t::int_num_vecs() == 4, int> = 0> + int_vec_return_type widening_subtract(Vectorized b) const { + auto ret16 = unpack(_vec); + auto ret16B = unpack(b.vec()); + auto ret32_0 = unpack(ret16.first); + auto ret32_1 = unpack(ret16.second); + auto ret32B_0 = unpack(ret16B.first); + auto ret32B_1 = unpack(ret16B.second); + + return { + Vectorized(ret32_0.first - ret32B_0.first), + Vectorized(ret32_0.second - ret32B_0.second), + Vectorized(ret32_1.first - ret32B_1.first), + Vectorized(ret32_1.second - ret32B_1.second)}; + } + + template < + typename U = T, + std::enable_if_t::float_num_vecs() == 4, int> = 0> + float_vec_return_type C10_ALWAYS_INLINE dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_zp_premul) const { + // unpacking unsigned as signed + auto ret16 = unpack(_vec); + auto ret32_0 = unpack(ret16.first); + auto ret32_1 = unpack(ret16.second); + + auto vecf_0 = zvec_convert_to_float(ret32_0.first); + auto vecf_1 = zvec_convert_to_float(ret32_0.second); + + auto vecf_2 = zvec_convert_to_float(ret32_1.first); + auto vecf_3 = zvec_convert_to_float(ret32_1.second); + return { + fmadd(scale, vecf_0, scale_zp_premul), + fmadd(scale, vecf_1, scale_zp_premul), + fmadd(scale, vecf_2, scale_zp_premul), + fmadd(scale, vecf_3, scale_zp_premul)}; + } + + template < + typename U = T, + std::enable_if_t::float_num_vecs() == 4, int> = 0> + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + // unpacking unsigned as signed + auto ret16 = unpack(_vec); + auto ret32_0 = unpack(ret16.first); + auto ret32_1 = unpack(ret16.second); + + auto vecf_0 = zvec_convert_to_float(ret32_0.first); + auto vecf_1 = zvec_convert_to_float(ret32_0.second); + + auto vecf_2 = zvec_convert_to_float(ret32_1.first); + auto vecf_3 = zvec_convert_to_float(ret32_1.second); + + return { + (vecf_0 - zero_point) * scale, + (vecf_1 - zero_point) * scale, + (vecf_2 - zero_point) * scale, + (vecf_3 - zero_point) * scale}; + } + + template < + typename U = T, + std::enable_if_t::float_num_vecs() == 4, int> = 0> + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + auto vec_inverse = Vectorized(inverse_scale); + auto vec_zero_point = Vectorized((float)zero_point); + + auto vecf0 = rhs[0]; + auto vecf2 = rhs[1]; + auto vecf4 = rhs[2]; + auto vecf6 = rhs[3]; + + vecf0 = vecf0 * vec_inverse; + vecf2 = vecf2 * vec_inverse; + vecf4 = vecf4 * vec_inverse; + vecf6 = vecf6 * vec_inverse; + + vecf0 = vecf0.rint() + vec_zero_point; + vecf2 = vecf2.rint() + vec_zero_point; + vecf4 = vecf4.rint() + vec_zero_point; + vecf6 = vecf6.rint() + vec_zero_point; + + auto veci0 = zvec_convert_to_int(vecf0); + auto veci2 = zvec_convert_to_int(vecf2); + auto veci4 = zvec_convert_to_int(vecf4); + auto veci6 = zvec_convert_to_int(vecf6); + + auto vecshi0 = pack(veci0, veci2); + auto vecshi2 = pack(veci4, veci6); + auto ret = pack(vecshi0, vecshi2); + + return Vectorized{ret}; + } + + template < + typename U = T, + std::enable_if_t::int_num_vecs() == 4, int> = 0> + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + Vectorized vec_multiplier = Vectorized(multiplier); + Vectorized vec_zero_point = Vectorized(zero_point); + + Vectorized vi0 = inp[0]; + Vectorized vi1 = inp[1]; + Vectorized vi2 = inp[2]; + Vectorized vi3 = inp[3]; + + auto vecf0 = zvec_convert_to_float(vi0.vec()); + auto vecf2 = zvec_convert_to_float(vi1.vec()); + + auto vecf4 = zvec_convert_to_float(vi2.vec()); + auto vecf6 = zvec_convert_to_float(vi3.vec()); + + vecf0 = vecf0 * vec_multiplier; + vecf2 = vecf2 * vec_multiplier; + + vecf4 = vecf4 * vec_multiplier; + vecf6 = vecf6 * vec_multiplier; + + vecf0 = vecf0.rint(); + vecf2 = vecf2.rint(); + vecf4 = vecf4.rint(); + vecf6 = vecf6.rint(); + + auto veci0 = zvec_convert_to_int(vecf0); + auto veci2 = zvec_convert_to_int(vecf2); + auto veci4 = zvec_convert_to_int(vecf4); + auto veci6 = zvec_convert_to_int(vecf6); + + veci0 = veci0 + vec_zero_point; + veci2 = veci2 + vec_zero_point; + + veci4 = veci4 + vec_zero_point; + veci6 = veci6 + vec_zero_point; + + auto vecshi0 = pack(veci0, veci2); + auto vecshi2 = pack(veci4, veci6); + + auto ret = pack(vecshi0, vecshi2); + + return Vectorized{ret}; + } + + Vectorized C10_ALWAYS_INLINE eq(const Vectorized& other) const { + return Vectorized{_vec.eq(other._vec)}; + } + Vectorized C10_ALWAYS_INLINE ne(const Vectorized& other) const { + return Vectorized{_vec.ne(other._vec)}; + } + Vectorized C10_ALWAYS_INLINE gt(const Vectorized& other) const { + return Vectorized{_vec.gt(other._vec)}; + } + Vectorized C10_ALWAYS_INLINE ge(const Vectorized& other) const { + return Vectorized{_vec.ge(other._vec)}; + } + Vectorized C10_ALWAYS_INLINE lt(const Vectorized& other) const { + return Vectorized{_vec.lt(other._vec)}; + } + Vectorized C10_ALWAYS_INLINE le(const Vectorized& other) const { + return Vectorized{_vec.le(other._vec)}; + } + + Vectorized clamp_min(const Vectorized& min) const { + return Vectorized{_vec.clamp_min(min._vec)}; + } + + Vectorized clamp_max(const Vectorized& max) const { + return Vectorized{_vec.clamp_max(max._vec)}; + } + + Vectorized minimum(const Vectorized& other) const { + return Vectorized{_vec.minimum(other._vec)}; + } + + Vectorized maximum(const Vectorized& other) const { + return Vectorized{_vec.maximum(other._vec)}; + } +}; + +#define ZVECTOR_OPERATORS(typex) \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator+( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() + b.vec()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator-( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() - b.vec()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator*( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() * b.vec()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator/( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() / b.vec()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator&( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() & b.vec()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator|( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() | b.vec()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator^( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() ^ b.vec()}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator==( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() == b.vec()}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator!=( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() != b.vec()}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator>( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() > b.vec()}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator>=( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() >= b.vec()}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator<( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() < b.vec()}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator<=( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() <= b.vec()}; \ + } + +ZVECTOR_OPERATORS(c10::qint32) +ZVECTOR_OPERATORS(c10::qint8) +ZVECTOR_OPERATORS(c10::quint8) + +#undef ZVECTOR_OPERATORS + +DEFINE_CLAMP_MAXMIN_FUNCS(c10::quint8) +DEFINE_CLAMP_MAXMIN_FUNCS(c10::qint8) +DEFINE_CLAMP_MAXMIN_FUNCS(c10::qint32) + +template +constexpr auto real_mask() { + return (ZSimdVect)ZSimdVectBinary{0xFFFFFFFF, 0, 0xFFFFFFFF, 0}; +} + +template <> +constexpr auto real_mask() { + return (ZSimdVect)ZSimdVectBinary{0xFFFFFFFFFFFFFFFF, 0}; +} + +template +constexpr auto image_mask() { + return (ZSimdVect)ZSimdVectBinary{0, 0xFFFFFFFF, 0, 0xFFFFFFFF}; +} + +template <> +constexpr auto image_mask() { + return (ZSimdVect)ZSimdVectBinary{0, 0xFFFFFFFFFFFFFFFF}; +} + +template +constexpr auto rsign_mask() { + return ZSimdVect{-0.f, 0.f, -0.f, 0.f}; +} + +template <> +constexpr auto rsign_mask() { + return ZSimdVect{-0.0, 0.f}; +} + +template +constexpr auto isign_mask() { + return ZSimdVect{0.0, -0.f, 0.0, -0.f}; +} + +template <> +constexpr auto isign_mask() { + return ZSimdVect{0.0, -0.0}; +} + +template +constexpr auto image_one() { + return ZSimdVect{0, 1.f, 0, 1.f}; +} + +template <> +constexpr auto image_one() { + return ZSimdVect{0.0, 1.0}; +} + +template +constexpr auto pi_half() { + return ZSimdVect{(float)(M_PI / 2.0), 0.f, (float)(M_PI / 2.0), 0.f}; +} + +template <> +constexpr auto pi_half() { + return ZSimdVect{M_PI / 2.0, 0.0}; +} + +template +constexpr auto image_half() { + return ZSimdVect{0, 0.5f, 0, 0.5f}; +} + +template <> +constexpr auto image_half() { + return ZSimdVect{0.0, 0.5}; +} + +template +constexpr U log2e_inv() { + return static_cast(1.4426950408889634); +} + +template +constexpr U log10e_inv() { + return static_cast(0.43429448190325176); +} + +template +struct is_vec_specialized_for< + T, + std::enable_if_t()>> + : std::bool_constant {}; + +template +struct Vectorized()>> { + public: + using underline_type = decltype(std::declval().imag()); + using value_type = T; + using vtype = ZSimdVect; + using vmaskType = ZSimdVectBinary; + using vinner_type = Vectorized; + using size_type = int; + using vinner_data = typename Vectorized::vinner_data; + + static constexpr size_type size() { + return VECTOR_WIDTH / sizeof(value_type); + } + + private: + vinner_type _vec; + + public: + Vectorized() {} + + C10_ALWAYS_INLINE Vectorized(const vinner_data& v) + : _vec{v.first, v.second} {} + + template = 0> + C10_ALWAYS_INLINE Vectorized(T s1, T s2) + : _vec{s1.real(), s1.imag(), s2.real(), s2.imag()} {} + + template = 0> + C10_ALWAYS_INLINE Vectorized(T s1, T s2, T s3, T s4) + : _vec{ + s1.real(), + s1.imag(), + s2.real(), + s2.imag(), + s3.real(), + s3.imag(), + s4.real(), + s4.imag()} {} + + template = 0> + C10_ALWAYS_INLINE Vectorized(T s) : Vectorized(s, s) {} + + template = 0> + C10_ALWAYS_INLINE Vectorized(T s) : Vectorized(s, s, s, s) {} + + C10_ALWAYS_INLINE operator vinner_type() const { + return _vec; + } + + C10_ALWAYS_INLINE const vinner_type& vec() const { + return _vec; + } + + C10_ALWAYS_INLINE operator vinner_data() const { + return _vec.data(); + } + + C10_ALWAYS_INLINE vinner_data data() const { + return _vec.data(); + } + + template + static Vectorized C10_ALWAYS_INLINE + loadu(const U* ptr, int count = size()) { + return Vectorized{vinner_type::loadu(ptr, 2 * count)}; + } + + template + void C10_ALWAYS_INLINE store(U* ptr, int count = size()) const { + return _vec.store(ptr, 2 * count); + } + + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + // convert std::complex index mask to V index mask: xy -> xxyy + vinner_type vmask = mask.vec(); + auto mask_complex = vinner_type( + vec_mergeh(vmask.vec0(), vmask.vec0()), + vec_mergeh(vmask.vec1(), vmask.vec1())); + return Vectorized{vinner_type::blendv(a.vec(), b.vec(), mask_complex)}; + } + + template + static auto C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + constexpr int mask_complex = maskForComplex(mask); + return Vectorized{ + vinner_type::template blend(a.vec(), b.vec())}; + } + + template + static std::enable_if_t> arange( + T base = 0, + step_t step = static_cast(1)) { + return Vectorized(base, base + step); + } + + template + static std::enable_if_t> arange( + T base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + value_type(2) * step, + base + value_type(3) * step); + } + + template + static inline std::enable_if_t<(Z >= C), Vectorized> set_inner( + const Vectorized& a, + const Vectorized& b, + size_t count) { + return b; + } + + template + static inline std::enable_if_t<(Z < C), Vectorized> set_inner( + const Vectorized& a, + const Vectorized& b, + size_t count) { + if (count == Z) + return blend(a, b); + else + return set_inner(a, b, count); + } + + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + size_t count = size()) { + if (count == 0) + return a; + return set_inner<1, size()>(a, b, count); + } + + const T& operator[](int idx) const = delete; + T& operator[](int idx) = delete; + + template < + typename U = T, + std::enable_if_t>::value, int> = 0> + Vectorized mapOrdinary(T (*const f)(const T&)) const { + auto v0 = _vec.vec0(); + auto v1 = _vec.vec1(); + return Vectorized{ + f(T(v0[0], v0[1])), + f(T(v0[2], v0[3])), + f(T(v1[0], v1[1])), + f(T(v1[2], v1[3]))}; + } + + template < + typename U = T, + std::enable_if_t>::value, int> = 0> + Vectorized mapOrdinary(T (*const f)(const T&)) const { + auto v0 = _vec.vec0(); + auto v1 = _vec.vec1(); + return Vectorized{f(T(v0[0], v0[1])), f(T(v1[0], v1[1]))}; + } + + template < + typename U = T, + std::enable_if_t>::value, int> = 0> + Vectorized mapOrdinary(T (*const f)(T)) const { + auto v0 = _vec.vec0(); + auto v1 = _vec.vec1(); + return Vectorized{ + f(T(v0[0], v0[1])), + f(T(v0[2], v0[3])), + f(T(v1[0], v1[1])), + f(T(v1[2], v1[3]))}; + } + + template < + typename U = T, + std::enable_if_t>::value, int> = 0> + Vectorized mapOrdinary(T (*const f)(T)) const { + auto v0 = _vec.vec0(); + auto v1 = _vec.vec1(); + return Vectorized{f(T(v0[0], v0[1])), f(T(v1[0], v1[1]))}; + } + + template < + typename U = T, + std::enable_if_t>::value, int> = 0> + inline Vectorized mapOrdinary( + T (*const f)(const T&, const T&), + const Vectorized& b) const { + auto v0 = _vec.vec0(); + auto v1 = _vec.vec1(); + auto bvec = b.vec(); + auto b0 = bvec.vec0(); + auto b1 = bvec.vec1(); + T a00 = f(T(v0[0], v0[1]), T(b0[0], b0[1])); + T a01 = f(T(v0[2], v0[3]), T(b0[2], b0[3])); + T a02 = f(T(v1[0], v1[1]), T(b1[0], b1[1])); + T a03 = f(T(v1[2], v1[3]), T(b1[2], b1[3])); + return Vectorized{a00, a01, a02, a03}; + } + + template < + typename U = T, + std::enable_if_t>::value, int> = 0> + inline Vectorized mapOrdinary( + T (*const f)(const T&, const T&), + const Vectorized& b) const { + auto v0 = _vec.vec0(); + auto v1 = _vec.vec1(); + auto bvec = b.vec(); + auto b0 = bvec.vec0(); + auto b1 = bvec.vec1(); + U a00 = f(U(v0[0], v0[1]), U(b0[0], b0[1])); + U a01 = f(U(v1[0], v1[1]), U(b1[0], b1[1])); + return Vectorized{a00, a01}; + } + + template < + typename U = T, + std::enable_if_t>::value, int> = 0> + static typename Vectorized::vinner_type real_neg( + const typename Vectorized::vinner_type& a) { + const auto swap_mask = ZSimdVectBinary{ + 0, 1, 2, 3, 20, 21, 22, 23, 8, 9, 10, 11, 28, 29, 30, 31}; + + auto a_neg = a.neg(); + vtype v0 = vec_perm(a_neg.vec0(), a.vec0(), swap_mask); + vtype v1 = vec_perm(a_neg.vec1(), a.vec1(), swap_mask); + return {v0, v1}; + } + + template < + typename U = T, + std::enable_if_t>::value, int> = 0> + static typename Vectorized::vinner_type real_neg( + const typename Vectorized::vinner_type& a) { + auto a_neg = a.neg(); + vtype v0 = {a_neg.vec0()[0], a.vec0()[1]}; + vtype v1 = {a_neg.vec1()[0], a.vec1()[1]}; + return {v0, v1}; + } + + Vectorized angle2_() const { + auto b_a = _vec.swapped(); // b a + return Vectorized{_vec.atan2(b_a).swapped()}; + } + + Vectorized angle() const { + return angle2_().real(); + } + + Vectorized atan() const { + // atan(x) = i/2 * ln((i + z)/(i - z)) + auto ione = Vectorized{vinner_type(image_one())}; + auto sum = ione + *this; + auto sub = ione - *this; + auto ln = (sum / sub).log(); // ln((i + z)/(i - z)) + return ln * + Vectorized{vinner_type(image_half())}; // i/2*ln() + } + + Vectorized atanh() const { + return mapOrdinary(std::atanh); + } + + Vectorized asin() const { + // asin(x) + // = -i*ln(iz + sqrt(1 -z^2)) + // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi))) + // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi)) +#if 1 + vinner_type cnj = conj().vec(); + vinner_type b_a = cnj.swapped(); + vinner_type ab = cnj * b_a; + vinner_type im = ab + ab; + vinner_type val_2 = _vec * _vec; + vinner_type val_2_swapped = val_2.swapped(); + vinner_type re = vinner_type::horizontal_sub_perm(val_2, val_2_swapped); + re = vinner_type(static_cast(1)) - re; + constexpr int blend_mask = + blend_choice(); // 0x0A for complex , 0xAA for complex + vinner_type blendx = vinner_type::template blend(re, im); + auto root = Vectorized(blendx).sqrt(); + auto ln = Vectorized(Vectorized(b_a) + root).log(); + return Vectorized(ln.vec().swapped()).conj(); +#else + return mapOrdinary(std::asin); +#endif + } + + Vectorized acos() const { + // acos(x) = pi/2 - asin(x) + return Vectorized(vinner_type(pi_half())) - asin(); + } + + Vectorized sin() const { + return mapOrdinary(std::sin); + } + Vectorized sinh() const { + return mapOrdinary(std::sinh); + } + Vectorized cos() const { + return mapOrdinary(std::cos); + } + Vectorized cosh() const { + return mapOrdinary(std::cosh); + } + Vectorized ceil() const { + return Vectorized{_vec.ceil()}; + } + Vectorized floor() const { + return Vectorized{_vec.floor()}; + } + Vectorized neg() const { + return Vectorized(_vec.neg()); + } + Vectorized round() const { + return Vectorized{_vec.round()}; + } + Vectorized tan() const { + return mapOrdinary(std::tan); + } + Vectorized tanh() const { + return mapOrdinary(std::tanh); + } + Vectorized trunc() const { + return Vectorized{_vec.trunc()}; + } + + Vectorized C10_ALWAYS_INLINE eq(const Vectorized& other) const { + auto eq = _vec.eq(other._vec); // compares real and imag individually + // If both real numbers and imag numbers are equal, then the complex numbers + // are equal + auto real = eq & vinner_type(real_mask()); + auto imag = (eq & vinner_type(image_mask())).swapped(); + return Vectorized{real & imag}; + } + Vectorized C10_ALWAYS_INLINE ne(const Vectorized& other) const { + auto ne = _vec.ne(other._vec); // compares real and imag individually + // If either real numbers or imag numbers are not equal, then the complex + // numbers are not equal + auto real = ne & vinner_type(real_mask()); + auto imag = (ne & vinner_type(image_mask())).swapped(); + return Vectorized{real | imag}; + } + + Vectorized real() const { + return Vectorized(_vec & vinner_type(real_mask())); + } + Vectorized imag_() const { + return Vectorized(_vec & vinner_type(image_mask())); + } + Vectorized imag() const { + return Vectorized{ + (_vec & vinner_type(image_mask())).swapped()}; + } + + Vectorized conj() const { + return Vectorized(_vec ^ vinner_type(isign_mask())); + } + + vinner_data abs_2_() const { + auto a = _vec * _vec; + a = a + a.swapped(); + return a.mergee().data(); + } + + static T abs_helper(const T& value) { + return T(std::abs(value)); + } + + Vectorized abs() const { + return mapOrdinary(abs_helper); + } + + Vectorized exp() const { + return mapOrdinary(std::exp); + } + + Vectorized exp2() const { + return mapOrdinary(exp2_impl); + } + + Vectorized expm1() const { + return mapOrdinary(std::expm1); + } + + Vectorized log() const { + return mapOrdinary(std::log); + } + + Vectorized log2() const { + // log2eB_inv + auto ret = log(); + return Vectorized{ret._vec * vinner_type(log2e_inv())}; + } + + Vectorized log10() const { + auto ret = log(); + return Vectorized{ret._vec * vinner_type(log10e_inv())}; + } + + Vectorized log1p() const { + return mapOrdinary(std::log1p); + } + + Vectorized sgn() const { + return mapOrdinary(at::native::sgn_impl); + } + + Vectorized pow(const Vectorized& exp) const { + return mapOrdinary(std::pow, exp); + } + + Vectorized sqrt() const { + return mapOrdinary(std::sqrt); + } + + Vectorized reciprocal() const { + // re + im*i = (a + bi) / (c + di) + // re = (ac + bd)/abs_2() = c/abs_2() + // im = (bc - ad)/abs_2() = d/abs_2() + vinner_type c_d = _vec ^ vinner_type(isign_mask()); + vinner_type abs = abs_2_(); + return Vectorized{c_d / abs}; + } + + Vectorized rsqrt() const { + return sqrt().reciprocal(); + } + + Vectorized lt(const Vectorized& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized le(const Vectorized& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized gt(const Vectorized& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized ge(const Vectorized& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } +}; + +#define ZVECTOR_OPERATORS(typex) \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator+( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() + b.vec()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator-( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() - b.vec()}; \ + } \ + \ + template <> \ + Vectorized inline operator*( \ + const Vectorized& a, const Vectorized& b) { \ + /* (a + bi) * (c + di) = (ac - bd) + (ad + bc)i */ \ + Vectorized::vinner_type bv = b.vec(); \ + \ + /* this is more z arch friendly than simulating horizontal from x86 */ \ + Vectorized::vinner_type vi = bv.mergeo(); \ + Vectorized::vinner_type vr = bv.mergee(); \ + vi = vi ^ \ + Vectorized::vinner_type( \ + rsign_mask::underline_type>()); \ + Vectorized::vinner_type ret = a.vec() * vr; \ + Vectorized::vinner_type vx_swapped = a.vec().swapped(); \ + ret = fmadd(vx_swapped, vi, ret); \ + \ + return Vectorized{ret}; \ + } \ + \ + template <> \ + Vectorized inline operator/( \ + const Vectorized& a, const Vectorized& b) { \ + /* Unfortunately, this breaks some tests */ \ + /* Implement it like it's done for avx2 */ \ + auto fabs_cd = b.vec().abs(); /* |c| |d| */ \ + auto fabs_dc = fabs_cd.swapped(); /* |d| |c| */ \ + auto scale = Vectorized::vinner_type{1.0} / \ + maximum(fabs_cd, fabs_dc); /* 1/sc 1/sc */ \ + auto a2 = a.vec() * scale; /* a/sc b/sc */ \ + auto b2 = b.vec() * scale; /* c/sc d/sc */ \ + auto acbd2 = a2 * b2; /* ac/sc^2 bd/sc^2 */ \ + \ + auto dc2 = b2.swapped(); /* d/sc c/sc */ \ + dc2 = Vectorized::real_neg(dc2); /* -d/|c,d| c/sc */ \ + auto adbc2 = a2 * dc2; /* -ad/sc^2 bc/sc^2 */ \ + auto sum1 = acbd2 + acbd2.swapped(); /* (ac+bd)/sc^2 (ac+bd)/sc^2 */ \ + auto sum2 = adbc2 + adbc2.swapped(); /* (bc-ad)/sc^2 (bc-ad)/sc^2 */ \ + auto res2 = Vectorized::vinner_type::mergee( \ + sum1, sum2); /* (ac+bd)/sc^2 (bc-ad)/sc^2 */ \ + \ + /* get the denominator */ \ + Vectorized::vinner_type denom2 = \ + Vectorized{b2}.abs_2_(); /* (c^2+d^2)/sc^2 (c^2+d^2)/sc^2 */ \ + res2 = res2 / denom2; \ + return Vectorized{res2}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator&( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() & b.vec()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator|( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() | b.vec()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator^( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() ^ b.vec()}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator==( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() == b.vec()}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator!=( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() != b.vec()}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator<( \ + const Vectorized& a, const Vectorized& b) { \ + TORCH_CHECK(false, "not supported for complex numbers"); \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator<=( \ + const Vectorized& a, const Vectorized& b) { \ + TORCH_CHECK(false, "not supported for complex numbers"); \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator>( \ + const Vectorized& a, const Vectorized& b) { \ + TORCH_CHECK(false, "not supported for complex numbers"); \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator>=( \ + const Vectorized& a, const Vectorized& b) { \ + TORCH_CHECK(false, "not supported for complex numbers"); \ + } + +ZVECTOR_OPERATORS(c10::complex) +ZVECTOR_OPERATORS(c10::complex) + +#undef ZVECTOR_OPERATORS + +template = 0> +std::pair, Vectorized> inline inner_interleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, a1, a2, a3} + // b = {b0, b1, b2, b3} + using vtype = typename Vectorized::vtype; + vtype ab00 = {a.vec0()[0], b.vec0()[0]}; + vtype ab11 = {a.vec0()[1], b.vec0()[1]}; + vtype ab2_00 = {a.vec1()[0], b.vec1()[0]}; + vtype ab2_11 = {a.vec1()[1], b.vec1()[1]}; + // return {a0, b0, a1, b1} + // {a2, b2, a3, b3} + return std::make_pair( + Vectorized{ab00, ab11}, Vectorized{ab2_00, ab2_11}); +} + +template = 0> +std::pair, Vectorized> inline inner_deinterleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1} + // b = {a2, b2, a3, b3} + using vtype = typename Vectorized::vtype; + vtype aa01 = {a.vec0()[0], a.vec1()[0]}; + vtype aa23 = {b.vec0()[0], b.vec1()[0]}; + + vtype bb_01 = {a.vec0()[1], a.vec1()[1]}; + vtype bb_23 = {b.vec0()[1], b.vec1()[1]}; + + // swap lanes: + // return {a0, a1, a2, a3} + // {b0, b1, b2, b3} + return std::make_pair(Vectorized{aa01, aa23}, Vectorized{bb_01, bb_23}); +} + +template = 0> +std::pair, Vectorized> inline inner_interleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, a1, a2, a3,, a4, a5, a6, a7} + // b = {b0, b1, b2, b3,, b4, b5, b6, b7} + using vtype = typename Vectorized::vtype; + vtype ab0011 = vec_mergeh(a.vec0(), b.vec0()); + vtype ab2233 = vec_mergel(a.vec0(), b.vec0()); + + vtype ab2_0011 = vec_mergeh(a.vec1(), b.vec1()); + vtype ab2_2233 = vec_mergel(a.vec1(), b.vec1()); + // group cols crossing lanes: + // return {a0, b0, a1, b1,, a2, b2, a3, b3} + // {a4, b4, a5, b5,, a6, b6, a7, b7} + + return std::make_pair( + Vectorized{ab0011, ab2233}, Vectorized{ab2_0011, ab2_2233}); +} + +template = 0> +std::pair, Vectorized> inline inner_deinterleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1,, a2, b2, a3, b3} + // b = {a4, b4, a5, b5,, a6, b6, a7, b7} + using vtype = typename Vectorized::vtype; + // {a0,a2,b0,b2} {a1,a3,b1,b3} + vtype a0a2b0b2 = vec_mergeh(a.vec0(), a.vec1()); + vtype a1a3b1b3 = vec_mergel(a.vec0(), a.vec1()); + + vtype aa0123 = vec_mergeh(a0a2b0b2, a1a3b1b3); + vtype bb0123 = vec_mergel(a0a2b0b2, a1a3b1b3); + + vtype a0a2b0b2_2 = vec_mergeh(b.vec0(), b.vec1()); + vtype a1a3b1b3_2 = vec_mergel(b.vec0(), b.vec1()); + + vtype aa0123_2 = vec_mergeh(a0a2b0b2_2, a1a3b1b3_2); + vtype bb0123_2 = vec_mergel(a0a2b0b2_2, a1a3b1b3_2); + + // it could be done with vec_perm ,too + // swap lanes: + // return {a0, a1, a2, a3,, a4, a5, a6, a7} + // {b0, b1, b2, b3,, b4, b5, b6, b7} + + return std::make_pair( + Vectorized{aa0123, aa0123_2}, Vectorized{bb0123, bb0123_2}); +} + +template <> +std::pair, Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + return inner_interleave2(a, b); +} + +template <> +std::pair, Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + return inner_interleave2(a, b); +} + +template <> +std::pair, Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + return inner_interleave2(a, b); +} + +template <> +std::pair, Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + return inner_interleave2(a, b); +} + +template <> +std::pair, Vectorized> inline deinterleave2( + const Vectorized& a, + const Vectorized& b) { + return inner_deinterleave2(a, b); +} + +template <> +std::pair, Vectorized> inline deinterleave2< + int32_t>(const Vectorized& a, const Vectorized& b) { + return inner_deinterleave2(a, b); +} + +template <> +std::pair, Vectorized> inline deinterleave2( + const Vectorized& a, + const Vectorized& b) { + return inner_deinterleave2(a, b); +} + +template <> +std::pair, Vectorized> inline deinterleave2< + int64_t>(const Vectorized& a, const Vectorized& b) { + return inner_deinterleave2(a, b); +} + +template +std::enable_if_t< + std::is_same_v, + at::vec::Vectorized< + float>> inline convert_int8_to_float(const Vectorized& src) { + // Note: this function only convert inputs number of elements equal to + // at::vec::Vectorized.size() Only handle first 64 bits + auto vec_int = src.to_vec_float_helper(); + + return zvec_convert_to_float(vec_int); +} + +template +std::enable_if_t< + std::is_same_v, + at::vec::Vectorized< + T>> inline convert_float_to_int8(const Vectorized& src) { + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + + auto vec_int = clamp( + zvec_convert_to_int(src), + Vectorized(min_val), + Vectorized(max_val)); + + return vec_int.to_vec_uint8_helper(); +} + +#undef DEFINE_CLAMP_MAXMIN_FUNCS +#undef DEFINE_MAXMIN_FUNCS +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512.h new file mode 100644 index 0000000000000000000000000000000000000000..75511251f86156b166c70f4d48ad81dde6207054 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512.h @@ -0,0 +1,409 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include + +// clang-format off +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +// clang-format on + +#include +#include +#include +#include +#include + +namespace at { +namespace vec { + +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +inline std::ostream& operator<<(std::ostream& stream, const c10::qint32& val) { + stream << val.val_; + return stream; +} +inline std::ostream& operator<<(std::ostream& stream, const c10::qint8& val) { + stream << static_cast(val.val_); + return stream; +} +inline std::ostream& operator<<(std::ostream& stream, const c10::quint8& val) { + stream << static_cast(val.val_); + return stream; +} + +template +std::ostream& operator<<(std::ostream& stream, const Vectorized& vec) { + T buf[Vectorized::size()]; + vec.store(buf); + stream << "vec["; + for (int i = 0; i != Vectorized::size(); i++) { + if (i != 0) { + stream << ", "; + } + stream << buf[i]; + } + stream << "]"; + return stream; +} + +#if defined(CPU_CAPABILITY_AVX512) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX512) +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +inline Vectorized cast(const Vectorized& src) { + return _mm512_castpd_ps(src); +} + +template <> +inline Vectorized cast(const Vectorized& src) { + return _mm512_castps_pd(src); +} + +template <> +inline Vectorized cast(const Vectorized& src) { + return _mm512_castsi512_ps(src); +} + +template <> +inline Vectorized cast( + const Vectorized& src) { + return _mm512_castsi512_pd(src); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +#ifndef _MSC_VER +// MSVC is not working well on complex function overload. +template +std::enable_if_t< + scale == 1 || scale == 2 || scale == 4 || scale == 8, + Vectorized< + double>> inline gather(const double* base_addr, const Vectorized& vindex) { + return _mm512_i64gather_pd(vindex, base_addr, scale); +} + +template +std::enable_if_t< + scale == 1 || scale == 2 || scale == 4 || scale == 8, + Vectorized< + float>> inline gather(const float* base_addr, const Vectorized& vindex) { + return _mm512_i32gather_ps(vindex, base_addr, scale); +} +#endif +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +#ifndef _MSC_VER +// MSVC is not working well on complex function overload. +template +std:: + enable_if_t> inline mask_gather( + const Vectorized& src, + const double* base_addr, + const Vectorized& vindex, + Vectorized& mask) { + auto all_ones = _mm512_castsi512_pd(_mm512_set1_epi64(0xFFFFFFFFFFFFFFFF)); + auto mask_ = _mm512_cmp_pd_mask(all_ones, mask.values, _CMP_EQ_OQ); + return _mm512_mask_i64gather_pd(src, mask_, vindex, base_addr, scale); +} + +template +std:: + enable_if_t> inline mask_gather( + const Vectorized& src, + const float* base_addr, + const Vectorized& vindex, + Vectorized& mask) { + auto all_ones = _mm512_castsi512_ps(_mm512_set1_epi32(0xFFFFFFFF)); + auto mask_ = _mm512_cmp_ps_mask(all_ones, mask.values, _CMP_EQ_OQ); + return _mm512_mask_i32gather_ps(src, mask_, vindex, base_addr, scale); +} +#endif +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +Vectorized inline convert_to_int_of_same_size( + const Vectorized& src) { + return _mm512_cvtpd_epi64(src); +} + +template <> +Vectorized inline convert_to_int_of_same_size( + const Vectorized& src) { + return _mm512_cvttps_epi32(src); +} + +template <> +Vectorized inline convert_to_fp_of_same_size( + const Vectorized& src) { + return _mm512_cvtepi64_pd(src); +} + +template <> +Vectorized inline convert_to_fp_of_same_size( + const Vectorized& src) { + return _mm512_cvtepi32_ps(src); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +std::pair, Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, a1, a3, a3, a4, a5, a6, a7} + // b = {b0, b1, b2, b3, b4, b5, b6, b7} + // group cols crossing lanes: + // return {a0, b0, a1, b1, a2, b2, a3, b3} + // {a4, b4, a5, b5, a6, b6, a7, b7} + __m512i idx1 = _mm512_set_epi64(11, 3, 10, 2, 9, 1, 8, 0); + __m512i idx2 = _mm512_set_epi64(15, 7, 14, 6, 13, 5, 12, 4); + return std::make_pair( + _mm512_mask_permutex2var_pd(a, 0xff, idx1, b), + _mm512_mask_permutex2var_pd(a, 0xff, idx2, b)); +} + +template <> +std::pair, Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, + // a15} b = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, + // b14, b15} + // + // return: + // {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7} + // {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, + // b15} + __m512i idx1 = + _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); + __m512i idx2 = _mm512_set_epi32( + 31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8); + return std::make_pair( + _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b), + _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b)); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DEINTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +std::pair, Vectorized> inline deinterleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1, a2, b2, a3, b3} + // b = {a4, b4, a5, b5, a6, b6, a7, b7} + // output: + // return {a0, a1, a2, a3, a4, a5, a6, a7} + // {b0, b1, b2, b3, b4, b5, b6, b7} + // The members of indices have been written in binary format for better + // understandability + __m512i idx1 = _mm512_set_epi64(14, 12, 10, 8, 6, 4, 2, 0); + __m512i idx2 = _mm512_set_epi64(15, 13, 11, 9, 7, 5, 3, 1); + + return std::make_pair( + _mm512_mask_permutex2var_pd(a, 0xff, idx1, b), + _mm512_mask_permutex2var_pd(a, 0xff, idx2, b)); +} + +template <> +std::pair, Vectorized> inline deinterleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7} + // b = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, + // a15, b15} + // output: + // return {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, + // a15} + // {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, + // b15} + __m512i idx1 = _mm512_set_epi32( + 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0); + __m512i idx2 = _mm512_set_epi32( + 31, 29, 27, 25, 23, 21, 19, 17, 15, 13, 11, 9, 7, 5, 3, 1); + + return std::make_pair( + _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b), + _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b)); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FLIP ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +inline Vectorized flip(const Vectorized& v) { + const __m512i mask = + _mm512_set_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + return _mm512_permutexvar_ps(mask, v); +} + +template <> +inline Vectorized flip(const Vectorized& v) { + const __m512i mask = _mm512_set_epi64(0, 1, 2, 3, 4, 5, 6, 7); + return _mm512_permutexvar_pd(mask, v); +} + +template <> +inline Vectorized flip(const Vectorized& v) { + const __m512i mask = _mm512_set_epi64(0, 1, 2, 3, 4, 5, 6, 7); + return _mm512_permutexvar_epi64(mask, v); +} + +template <> +inline Vectorized flip(const Vectorized& v) { + const __m512i mask = + _mm512_set_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + return _mm512_permutexvar_epi32(mask, v); +} + +template <> +inline Vectorized flip(const Vectorized& v) { + const __m512i mask = _mm512_set_epi16( + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31); + return _mm512_permutexvar_epi16(mask, v); +} + +inline __m512i flip8(const __m512i& v) { + const __m512i mask1 = _mm512_set_epi8( + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15); + const __m512i mask2 = _mm512_set_epi64(1, 0, 3, 2, 5, 4, 7, 6); + auto reversed_vec = _mm512_shuffle_epi8(v, mask1); + return _mm512_permutexvar_epi64(mask2, reversed_vec); +} + +template <> +inline Vectorized flip(const Vectorized& v) { + return flip8(v); +} + +template <> +inline Vectorized flip(const Vectorized& v) { + return flip8(v); +} + +inline Vectorized operator&&( + const Vectorized& self, + const Vectorized& other) { + const __m512i* self_ = reinterpret_cast(self.as_bytes()); + const __m512i* other_ = reinterpret_cast(other.as_bytes()); + __m512i out = _mm512_and_si512(*self_, *other_); + Vectorized ret; + // We do not have a constructer that takes __m512i, so we need to memcpy + std::memcpy(ret, &out, ret.size() * sizeof(bool)); + return ret; +} + +#endif // defined(CPU_CAPABILITY_AVX512) + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_bfloat16.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_bfloat16.h new file mode 100644 index 0000000000000000000000000000000000000000..e23ac2fa7bb4531cdc873ef42ea4de12c4aba41e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_bfloat16.h @@ -0,0 +1,1937 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include + +#if defined(CPU_CAPABILITY_AVX512) +#define SLEEF_STATIC_LIBS +#include +#endif + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX512) + +#ifndef SLEEF_CONST +#if (defined(__GNUC__) || defined(__CLANG__)) && !defined(__INTEL_COMPILER) +#define SLEEF_CONST const +#else +#define SLEEF_CONST +#endif +#define SLEEF_CONST_OLD SLEEF_CONST +#else +#define SLEEF_CONST_OLD +#endif + +// bfloat16 conversion +static inline void cvtbf16_fp32(const __m256i& a, __m512& o) { + o = _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16)); +} + +static inline void cvtbf16_fp32(const __m512i& a, __m512& o1, __m512& o2) { + __m256i lo = _mm512_extracti32x8_epi32(a, 0); + __m256i hi = _mm512_extracti32x8_epi32(a, 1); + cvtbf16_fp32(lo, o1); + cvtbf16_fp32(hi, o2); +} + +static inline __m256i cvtfp32_bf16(const __m512& src) { + __m512i value = _mm512_castps_si512(src); + __m512i nan = _mm512_set1_epi32(0xffff); + auto mask_value = _mm512_cmp_ps_mask(src, src, _CMP_ORD_Q); + __m512i ones = _mm512_set1_epi32(0x1); + __m512i vec_bias = _mm512_set1_epi32(0x7fff); + // uint32_t lsb = (input >> 16) & 1; + auto t_value = _mm512_and_si512(_mm512_srli_epi32(value, 16), ones); + // uint32_t rounding_bias = 0x7fff + lsb; + t_value = _mm512_add_epi32(t_value, vec_bias); + // input += rounding_bias; + t_value = _mm512_add_epi32(t_value, value); + // input = input >> 16; + t_value = _mm512_srli_epi32(t_value, 16); + // Check NaN before converting back to bf16 + t_value = _mm512_mask_blend_epi32(mask_value, nan, t_value); + return _mm512_cvtusepi32_epi16(t_value); +} + +static inline __m512i cvtfp32_bf16(const __m512& a, const __m512& b) { + __m512i lo = _mm512_castps_si512(a); + __m512i hi = _mm512_castps_si512(b); + __m512i nan = _mm512_set1_epi32(0xffff); + auto mask_lo = _mm512_cmp_ps_mask(a, a, _CMP_ORD_Q); + auto mask_hi = _mm512_cmp_ps_mask(b, b, _CMP_ORD_Q); + __m512i ones = _mm512_set1_epi32(0x1); + __m512i vec_bias = _mm512_set1_epi32(0x7fff); + // uint32_t lsb = (input >> 16) & 1; + auto t_lo = _mm512_and_si512(_mm512_srli_epi32(lo, 16), ones); + auto t_hi = _mm512_and_si512(_mm512_srli_epi32(hi, 16), ones); + // uint32_t rounding_bias = 0x7fff + lsb; + t_lo = _mm512_add_epi32(t_lo, vec_bias); + t_hi = _mm512_add_epi32(t_hi, vec_bias); + // input += rounding_bias; + t_lo = _mm512_add_epi32(t_lo, lo); + t_hi = _mm512_add_epi32(t_hi, hi); + // input = input >> 16; + t_lo = _mm512_srli_epi32(t_lo, 16); + t_hi = _mm512_srli_epi32(t_hi, 16); + // Check NaN before converting back to bf16 + t_lo = _mm512_mask_blend_epi32(mask_lo, nan, t_lo); + t_hi = _mm512_mask_blend_epi32(mask_hi, nan, t_hi); + + t_lo = _mm512_packus_epi32( + t_lo, t_hi); // t_hi[4-7] t_lo[4-7] t_hi[0-4] t_lo[0-4] + __m512i idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); + return _mm512_permutexvar_epi64(idx, t_lo); +} + +static inline __m512i merge_compare_result(const __m512& a, const __m512& b) { + __m512i lo = _mm512_castps_si512(a); + __m512i hi = _mm512_castps_si512(b); + lo = _mm512_srli_epi32(lo, 16); + hi = _mm512_srli_epi32(hi, 16); + auto out = _mm512_packus_epi32(lo, hi); + __m512i idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); + return _mm512_permutexvar_epi64(idx, out); +} + +// float16 conversion +static inline void cvtfp16_fp32(const __m256i& a, __m512& o) { + o = _mm512_cvtph_ps(a); +} + +static inline void cvtfp16_fp32(const __m512i& a, __m512& o1, __m512& o2) { + __m256i lo = _mm512_extracti32x8_epi32(a, 0); + __m256i hi = _mm512_extracti32x8_epi32(a, 1); + cvtfp16_fp32(lo, o1); + cvtfp16_fp32(hi, o2); +} + +static inline __m256i cvtfp32_fp16(const __m512& src) { + return _mm512_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); +} + +static inline __m512i cvtfp32_fp16(const __m512& a, const __m512& b) { + __m256i lo = + _mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i hi = + _mm512_cvtps_ph(b, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m512 t_lo = _mm512_castsi512_ps(_mm512_castsi256_si512(lo)); + __m256 t_hi = _mm256_castsi256_ps(hi); + return _mm512_castps_si512(_mm512_insertf32x8(t_lo, t_hi, 1)); +} + +// dtype conversion between float16/bfloat16 and float32 +template < + typename T, + typename std::enable_if_t, int> = 0> +inline void cvt_to_fp32(const __m256i& a, __m512& o); +template <> +inline void cvt_to_fp32(const __m256i& a, __m512& o) { + cvtbf16_fp32(a, o); +} +template <> +inline void cvt_to_fp32(const __m256i& a, __m512& o) { + cvtfp16_fp32(a, o); +} + +template < + typename T, + typename std::enable_if_t, int> = 0> +inline void cvt_to_fp32(const __m512i& a, __m512& o1, __m512& o2); +template <> +inline void cvt_to_fp32(const __m512i& a, __m512& o1, __m512& o2) { + cvtbf16_fp32(a, o1, o2); +} +template <> +inline void cvt_to_fp32(const __m512i& a, __m512& o1, __m512& o2) { + cvtfp16_fp32(a, o1, o2); +} + +template < + typename T, + bool is_compare_op = false, + typename std::enable_if_t, int> = 0> +inline __m512i cvt_from_fp32(const __m512& a, const __m512& b); +template <> +inline __m512i cvt_from_fp32( + const __m512& a, + const __m512& b) { + return cvtfp32_bf16(a, b); +} +template <> +inline __m512i cvt_from_fp32(const __m512& a, const __m512& b) { + return merge_compare_result(a, b); +} +template <> +inline __m512i cvt_from_fp32(const __m512& a, const __m512& b) { + return cvtfp32_fp16(a, b); +} +template <> +inline __m512i cvt_from_fp32(const __m512& a, const __m512& b) { + return cvtfp32_fp16(a, b); +} + +template +class Vectorized16 { + static_assert( + is_reduced_floating_point_v, + "Support only float16 and bfloat16."); + + private: + __m512i values; + + public: + using value_type = uint16_t; + using size_type = int; + static constexpr size_type size() { + return 32; + } + Vectorized16() {} + Vectorized16(__m512i v) : values(v) {} + Vectorized16(T val) { + value_type uw = val.x; + values = _mm512_set1_epi16(uw); + } + Vectorized16( + T val1, + T val2, + T val3, + T val4, + T val5, + T val6, + T val7, + T val8, + T val9, + T val10, + T val11, + T val12, + T val13, + T val14, + T val15, + T val16, + T val17, + T val18, + T val19, + T val20, + T val21, + T val22, + T val23, + T val24, + T val25, + T val26, + T val27, + T val28, + T val29, + T val30, + T val31, + T val32) { + values = _mm512_set_epi16( + val32.x, + val31.x, + val30.x, + val29.x, + val28.x, + val27.x, + val26.x, + val25.x, + val24.x, + val23.x, + val22.x, + val21.x, + val20.x, + val19.x, + val18.x, + val17.x, + val16.x, + val15.x, + val14.x, + val13.x, + val12.x, + val11.x, + val10.x, + val9.x, + val8.x, + val7.x, + val6.x, + val5.x, + val4.x, + val3.x, + val2.x, + val1.x); + } + operator __m512i() const { + return values; + } + T& operator[](int idx) = delete; + const T& operator[](int idx) const = delete; + int zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit + // and others are translated to 0-bit + return _mm512_cmpeq_epi16_mask(values, _mm512_set1_epi16(0)); + } + static Vectorized loadu(const void* ptr, int16_t count = size()) { + if (count == size()) + return _mm512_loadu_si512(reinterpret_cast(ptr)); + + __mmask32 mask = (1ULL << count) - 1; + return _mm512_maskz_loadu_epi16(mask, ptr); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values); + } else if (count > 0) { + __mmask32 mask = (1ULL << count) - 1; + _mm512_mask_storeu_epi16(ptr, mask, values); + } + } + template + static Vectorized blend(const Vectorized& a, const Vectorized& b) { + return _mm512_mask_blend_epi16(mask, a.values, b.values); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + auto all_ones = _mm512_set1_epi16(0xFFFF); + auto mask_ = _mm512_cmp_epi16_mask(mask, all_ones, _MM_CMPINT_EQ); + return _mm512_mask_blend_epi16(mask_, a.values, b.values); + } + template + static Vectorized arange( + T base = 0.f, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step, + base + 8 * step, + base + 9 * step, + base + 10 * step, + base + 11 * step, + base + 12 * step, + base + 13 * step, + base + 14 * step, + base + 15 * step, + base + 16 * step, + base + 17 * step, + base + 18 * step, + base + 19 * step, + base + 20 * step, + base + 21 * step, + base + 22 * step, + base + 23 * step, + base + 24 * step, + base + 25 * step, + base + 26 * step, + base + 27 * step, + base + 28 * step, + base + 29 * step, + base + 30 * step, + base + 31 * step); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + case 8: + return blend<255>(a, b); + case 9: + return blend<511>(a, b); + case 10: + return blend<1023>(a, b); + case 11: + return blend<2047>(a, b); + case 12: + return blend<4095>(a, b); + case 13: + return blend<8191>(a, b); + case 14: + return blend<16383>(a, b); + case 15: + return blend<32767>(a, b); + case 16: + return blend<65535>(a, b); + case 17: + return blend<131071>(a, b); + case 18: + return blend<262143>(a, b); + case 19: + return blend<524287>(a, b); + case 20: + return blend<1048575>(a, b); + case 21: + return blend<2097151>(a, b); + case 22: + return blend<4194303>(a, b); + case 23: + return blend<8388607>(a, b); + case 24: + return blend<16777215>(a, b); + case 25: + return blend<33554431>(a, b); + case 26: + return blend<67108863>(a, b); + case 27: + return blend<134217727>(a, b); + case 28: + return blend<268435455>(a, b); + case 29: + return blend<536870911>(a, b); + case 30: + return blend<1073741823>(a, b); + case 31: + return blend<2147483647>(a, b); + } + return b; + } +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wignored-qualifiers" + + Vectorized map(SLEEF_CONST __m512 (*SLEEF_CONST_OLD vop)(__m512)) const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + const auto o1 = vop(lo); + const auto o2 = vop(hi); + return cvt_from_fp32(o1, o2); + } + Vectorized isnan() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + __mmask16 lo_mask, hi_mask; + __m512 zero = _mm512_set1_ps(0.0); + __m512i zeroi = _mm512_castps_si512(zero); + lo_mask = _mm512_cmp_ps_mask(lo, zero, _CMP_UNORD_Q); + lo = _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zeroi, lo_mask, 0xFFFF'FFFF)); + hi_mask = _mm512_cmp_ps_mask(hi, zero, _CMP_UNORD_Q); + hi = _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zeroi, hi_mask, 0xFFFF'FFFF)); + return merge_compare_result(lo, hi); + } +#pragma clang diagnostic pop + Vectorized abs() const { + return _mm512_andnot_si512(_mm512_set1_epi16(0x8000), values); + } + Vectorized angle() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + auto angle_lambda = [](__m512 values) { + const auto zero_vec = _mm512_set1_ps(0.f); + const auto nan_vec = _mm512_set1_ps(NAN); + const auto not_nan_mask = _mm512_cmp_ps_mask(values, values, _CMP_EQ_OQ); + const auto non_nan_mask_vec = _mm512_mask_set1_epi32( + _mm512_castps_si512(zero_vec), not_nan_mask, 0xFFFFFFFF); + const auto nan_mask = _mm512_cmp_ps_mask( + _mm512_castsi512_ps(non_nan_mask_vec), zero_vec, _CMP_EQ_OQ); + const auto pi = _mm512_set1_ps(c10::pi); + + const auto neg_mask = _mm512_cmp_ps_mask(values, zero_vec, _CMP_LT_OQ); + auto angle = _mm512_mask_blend_ps(neg_mask, zero_vec, pi); + angle = _mm512_mask_blend_ps(nan_mask, angle, nan_vec); + return angle; + }; + auto o1 = angle_lambda(lo); + auto o2 = angle_lambda(hi); + return cvt_from_fp32(o1, o2); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm512_set1_epi16(0); + } + Vectorized conj() const { + return *this; + } + Vectorized acos() const { + return map(Sleef_acosf16_u10); + } + Vectorized acosh() const { + return map(Sleef_acoshf16_u10); + } + Vectorized asin() const { + return map(Sleef_asinf16_u10); + } + Vectorized asinh() const { + return map(Sleef_asinhf16_u10); + } + Vectorized atan() const { + return map(Sleef_atanf16_u10); + } + Vectorized atanh() const { + return map(Sleef_atanhf16_u10); + } + Vectorized atan2(const Vectorized& b) const { + __m512 lo, hi; + __m512 b1, b2; + cvt_to_fp32(values, lo, hi); + cvt_to_fp32(b.values, b1, b2); + auto o1 = Sleef_atan2f16_u10(lo, b1); + auto o2 = Sleef_atan2f16_u10(hi, b2); + return cvt_from_fp32(o1, o2); + } + Vectorized copysign(const Vectorized& sign) const { + // copy sign bit (0x8000) from sign and remaining bits from values + __m512i mask_value = _mm512_set1_epi32(~0x80008000); + __m512i mask_signbit = _mm512_set1_epi32(0x80008000); + return Vectorized(_mm512_or_si512( + _mm512_and_si512(values, mask_value), + _mm512_and_si512(sign, mask_signbit))); + } + Vectorized erf() const { + return map(Sleef_erff16_u10); + } + Vectorized erfc() const { + return map(Sleef_erfcf16_u15); + } + Vectorized erfinv() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm512_storeu_ps(reinterpret_cast(tmp1), lo); + _mm512_storeu_ps(reinterpret_cast(tmp2), hi); + for (int64_t i = 0; i < size() / 2; i++) { + tmp1[i] = calc_erfinv(tmp1[i]); + tmp2[i] = calc_erfinv(tmp2[i]); + } + auto o1 = _mm512_loadu_ps(tmp1); + auto o2 = _mm512_loadu_ps(tmp2); + return cvt_from_fp32(o1, o2); + } + Vectorized exp() const { + return map(Sleef_expf16_u10); + } + Vectorized exp2() const { + return map(Sleef_exp2f16_u10); + } + Vectorized expm1() const { + return map(Sleef_expm1f16_u10); + } + Vectorized exp_u20() const { + return exp(); + } + Vectorized fmod(const Vectorized& q) const { + __m512 x_lo, x_hi; + cvt_to_fp32(values, x_lo, x_hi); + __m512 q_lo, q_hi; + cvtbf16_fp32(q.values, q_lo, q_hi); + auto o1 = Sleef_fmodf16(x_lo, q_lo); + auto o2 = Sleef_fmodf16(x_hi, q_hi); + return cvt_from_fp32(o1, o2); + } + Vectorized hypot(const Vectorized& b) const { + __m512 lo, hi; + __m512 b1, b2; + cvt_to_fp32(values, lo, hi); + cvt_to_fp32(b.values, b1, b2); + auto o1 = Sleef_hypotf16_u05(lo, b1); + auto o2 = Sleef_hypotf16_u05(hi, b2); + return cvt_from_fp32(o1, o2); + } + Vectorized i0() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm512_storeu_ps(reinterpret_cast(tmp1), lo); + _mm512_storeu_ps(reinterpret_cast(tmp2), hi); + for (int64_t i = 0; i < size() / 2; i++) { + tmp1[i] = calc_i0(tmp1[i]); + tmp2[i] = calc_i0(tmp2[i]); + } + auto o1 = _mm512_loadu_ps(tmp1); + auto o2 = _mm512_loadu_ps(tmp2); + return cvt_from_fp32(o1, o2); + } + Vectorized i0e() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + constexpr auto sz = size(); + __at_align__ float tmp1[sz / 2], tmp2[sz / 2]; + _mm512_storeu_ps(reinterpret_cast(tmp1), lo); + _mm512_storeu_ps(reinterpret_cast(tmp2), hi); + + for (auto i = decltype(sz){0}; i < sz / 2; i++) { + tmp1[i] = calc_i0e(tmp1[i]); + tmp2[i] = calc_i0e(tmp2[i]); + } + const auto o1 = _mm512_loadu_ps(tmp1); + const auto o2 = _mm512_loadu_ps(tmp2); + return cvt_from_fp32(o1, o2); + } + Vectorized digamma() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + constexpr auto sz = size(); + __at_align__ float tmp1[sz / 2], tmp2[sz / 2]; + _mm512_storeu_ps(reinterpret_cast(tmp1), lo); + _mm512_storeu_ps(reinterpret_cast(tmp2), hi); + + for (auto i = decltype(sz){0}; i < sz / 2; i++) { + tmp1[i] = calc_digamma(tmp1[i]); + tmp2[i] = calc_digamma(tmp2[i]); + } + const auto o1 = _mm512_loadu_ps(tmp1); + const auto o2 = _mm512_loadu_ps(tmp2); + return cvt_from_fp32(o1, o2); + } + Vectorized igamma(const Vectorized& x) const { + __m512 lo, hi; + __m512 xlo, xhi; + cvt_to_fp32(values, lo, hi); + cvt_to_fp32(x.values, xlo, xhi); + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm512_storeu_ps(reinterpret_cast(tmp1), lo); + _mm512_storeu_ps(reinterpret_cast(tmp2), hi); + __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2]; + _mm512_storeu_ps(reinterpret_cast(tmpx1), xlo); + _mm512_storeu_ps(reinterpret_cast(tmpx2), xhi); + for (int64_t i = 0; i < size() / 2; ++i) { + tmp1[i] = calc_igamma(tmp1[i], tmpx1[i]); + tmp2[i] = calc_igamma(tmp2[i], tmpx2[i]); + } + auto o1 = _mm512_loadu_ps(tmp1); + auto o2 = _mm512_loadu_ps(tmp2); + return cvt_from_fp32(o1, o2); + } + + Vectorized igammac(const Vectorized& x) const { + __m512 lo, hi; + __m512 xlo, xhi; + cvt_to_fp32(values, lo, hi); + cvt_to_fp32(x.values, xlo, xhi); + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm512_storeu_ps(reinterpret_cast(tmp1), lo); + _mm512_storeu_ps(reinterpret_cast(tmp2), hi); + __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2]; + _mm512_storeu_ps(reinterpret_cast(tmpx1), xlo); + _mm512_storeu_ps(reinterpret_cast(tmpx2), xhi); + for (int64_t i = 0; i < size() / 2; ++i) { + tmp1[i] = calc_igammac(tmp1[i], tmpx1[i]); + tmp2[i] = calc_igammac(tmp2[i], tmpx2[i]); + } + auto o1 = _mm512_loadu_ps(tmp1); + auto o2 = _mm512_loadu_ps(tmp2); + return cvt_from_fp32(o1, o2); + } + Vectorized log() const { + return map(Sleef_logf16_u10); + } + Vectorized log2() const { + return map(Sleef_log2f16_u10); + } + Vectorized log10() const { + return map(Sleef_log10f16_u10); + } + Vectorized log1p() const { + return map(Sleef_log1pf16_u10); + } + Vectorized sin() const { + return map(Sleef_sinf16_u10); + } + Vectorized sinh() const { + return map(Sleef_sinhf16_u10); + } + Vectorized cos() const { + return map(Sleef_cosf16_u10); + } + Vectorized cosh() const { + return map(Sleef_coshf16_u10); + } + Vectorized ceil() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + auto o1 = _mm512_ceil_ps(lo); + auto o2 = _mm512_ceil_ps(hi); + return cvt_from_fp32(o1, o2); + } + Vectorized floor() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + auto o1 = _mm512_floor_ps(lo); + auto o2 = _mm512_floor_ps(hi); + return cvt_from_fp32(o1, o2); + } + Vectorized neg() const { + return _mm512_xor_si512(values, _mm512_set1_epi16(0x8000)); + } + Vectorized round() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + auto o1 = _mm512_roundscale_ps( + lo, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + auto o2 = _mm512_roundscale_ps( + hi, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + return cvt_from_fp32(o1, o2); + } + Vectorized tan() const { + return map(Sleef_tanf16_u10); + } + Vectorized tanh() const { + return map(Sleef_tanhf16_u10); + } + Vectorized trunc() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + auto o1 = + _mm512_roundscale_ps(lo, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + auto o2 = + _mm512_roundscale_ps(hi, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + return cvt_from_fp32(o1, o2); + } + Vectorized lgamma() const { + return map(Sleef_lgammaf16_u10); + } + Vectorized sqrt() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + auto o1 = _mm512_sqrt_ps(lo); + auto o2 = _mm512_sqrt_ps(hi); + return cvt_from_fp32(o1, o2); + } + Vectorized reciprocal() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + auto ones = _mm512_set1_ps(1); + auto o1 = _mm512_div_ps(ones, lo); + auto o2 = _mm512_div_ps(ones, hi); + return cvt_from_fp32(o1, o2); + } + Vectorized rsqrt() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + auto ones = _mm512_set1_ps(1); + auto o1 = _mm512_div_ps(ones, _mm512_sqrt_ps(lo)); + auto o2 = _mm512_div_ps(ones, _mm512_sqrt_ps(hi)); + return cvt_from_fp32(o1, o2); + } + Vectorized pow(const Vectorized& b) const { + __m512 lo, hi; + __m512 b1, b2; + cvt_to_fp32(values, lo, hi); + cvt_to_fp32(b.values, b1, b2); + auto o1 = Sleef_powf16_u10(lo, b1); + auto o2 = Sleef_powf16_u10(hi, b2); + return cvt_from_fp32(o1, o2); + } + + private: + template + Vectorized inline binary_compare(const VectorizedType& b, Op op) const { + __m512 a_lo, a_hi; + __m512 b_lo, b_hi; + cvt_to_fp32(values, a_lo, a_hi); + cvt_to_fp32(b.values, b_lo, b_hi); + auto o1 = op(a_lo, b_lo); + auto o2 = op(a_hi, b_hi); + return cvt_from_fp32(o1, o2); + } + + public: + Vectorized inline operator>(const Vectorized& other) const { + return binary_compare(other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GT_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); + } + Vectorized inline operator<(const Vectorized& other) const { + return binary_compare(other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LT_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); + } + Vectorized inline operator>=(const Vectorized& other) const { + return binary_compare(other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); + } + Vectorized inline operator<=(const Vectorized& other) const { + return binary_compare(other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LE_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); + } + Vectorized inline operator==(const Vectorized16& other) const { + return binary_compare(other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_EQ_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); + } + Vectorized inline operator!=(const Vectorized16& other) const { + return binary_compare(other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_NEQ_UQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); + } +}; + +template +static inline Vectorized binary_op_as_fp32( + const Vectorized& a, + const Vectorized& b, + Op op) { + __m512 a_lo, a_hi; + __m512 b_lo, b_hi; + cvt_to_fp32(__m512i(a), a_lo, a_hi); + cvt_to_fp32(__m512i(b), b_lo, b_hi); + auto o1 = op(a_lo, b_lo); + auto o2 = op(a_hi, b_hi); + return cvt_from_fp32(o1, o2); +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorized16 { + public: + using Vectorized16::Vectorized16; + + using value_type = BFloat16; + + Vectorized frac() const; + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_add_ps(x, y); + }); +} +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_sub_ps(x, y); + }); +} +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_mul_ps(x, y); + }); +} +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_div_ps(x, y); + }); +} +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return _mm512_and_si512(a, b); +} +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return _mm512_or_si512(a, b); +} +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return _mm512_xor_si512(a, b); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +// frac. Implement this here so we can use subtraction +inline Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + __m512 a_lo, a_hi; + __m512 b_lo, b_hi; + cvtbf16_fp32(__m512i(a), a_lo, a_hi); + cvtbf16_fp32(__m512i(b), b_lo, b_hi); + auto max_lo = _mm512_max_ps(a_lo, b_lo); + auto max_hi = _mm512_max_ps(a_hi, b_hi); + auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q); + auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q); + auto nan_lo = _mm512_castsi512_ps(_mm512_set1_epi32(nan_lo_mask)); + auto nan_hi = _mm512_castsi512_ps(_mm512_set1_epi32(nan_hi_mask)); + // Exploit the fact that all-ones is a NaN. + auto o1 = _mm512_or_ps(max_lo, nan_lo); + auto o2 = _mm512_or_ps(max_hi, nan_hi); + return cvtfp32_bf16(o1, o2); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + __m512 a_lo, a_hi; + __m512 b_lo, b_hi; + __m512i zero_vec = _mm512_set1_epi32(0); + cvtbf16_fp32(__m512i(a), a_lo, a_hi); + cvtbf16_fp32(__m512i(b), b_lo, b_hi); + auto min_lo = _mm512_min_ps(a_lo, b_lo); + auto min_hi = _mm512_min_ps(a_hi, b_hi); + auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q); + auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q); + auto nan_lo = _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, nan_lo_mask, 0xFFFFFFFF)); + auto nan_hi = _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, nan_hi_mask, 0xFFFFFFFF)); + // Exploit the fact that all-ones is a NaN. + auto o1 = _mm512_or_ps(min_lo, nan_lo); + auto o2 = _mm512_or_ps(min_hi, nan_hi); + return cvtfp32_bf16(o1, o2); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + __m512 a_lo, a_hi; + __m512 min_lo, min_hi; + __m512 max_lo, max_hi; + cvtbf16_fp32(__m512i(a), a_lo, a_hi); + cvtbf16_fp32(__m512i(min), min_lo, min_hi); + cvtbf16_fp32(__m512i(max), max_lo, max_hi); + auto o1 = _mm512_min_ps(max_lo, _mm512_max_ps(min_lo, a_lo)); + auto o2 = _mm512_min_ps(max_hi, _mm512_max_ps(min_hi, a_hi)); + return cvtfp32_bf16(o1, o2); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + __m512 a_lo, a_hi; + __m512 max_lo, max_hi; + cvtbf16_fp32(__m512i(a), a_lo, a_hi); + cvtbf16_fp32(__m512i(max), max_lo, max_hi); + auto o1 = _mm512_min_ps(max_lo, a_lo); + auto o2 = _mm512_min_ps(max_hi, a_hi); + return cvtfp32_bf16(o1, o2); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + __m512 a_lo, a_hi; + __m512 min_lo, min_hi; + cvtbf16_fp32(__m512i(a), a_lo, a_hi); + cvtbf16_fp32(__m512i(min), min_lo, min_hi); + auto o1 = _mm512_max_ps(min_lo, a_lo); + auto o2 = _mm512_max_ps(min_hi, a_hi); + return cvtfp32_bf16(o1, o2); +} + +template <> +inline void convert(const BFloat16* src, BFloat16* dst, int64_t n) { + int64_t i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + auto vsrc = + _mm512_loadu_si512(reinterpret_cast<__m512i*>((void*)(src + i))); + _mm512_storeu_si512(reinterpret_cast<__m512i*>((void*)(dst + i)), vsrc); + } +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = src[i]; + } +} + +template <> +inline void convert(const float* src, BFloat16* dst, int64_t n) { + int64_t i; + for (i = 0; i + Vectorized::size() <= n; + i += Vectorized::size()) { + __m512 a = _mm512_loadu_ps(&src[i]); + __m512 b = _mm512_loadu_ps(&src[i + 16]); + + __m512i bf = cvtfp32_bf16(a, b); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(&dst[i]), bf); + } + for (; i < n; i++) { + dst[i] = c10::convert(src[i]); + } +} + +template <> +inline void convert(const double* src, BFloat16* dst, int64_t n) { + auto load_float = [](const double* src) -> __m512 { + // Load one float vector from an array of doubles + __m256 a = _mm512_cvtpd_ps(_mm512_loadu_pd(src)); + __m256 b = _mm512_cvtpd_ps(_mm512_loadu_pd(src + 8)); + return _mm512_insertf32x8(_mm512_castps256_ps512(a), b, 1); + }; + + int64_t i; + for (i = 0; i + Vectorized::size() <= n; + i += Vectorized::size()) { + __m512 a = load_float(&src[i]); + __m512 b = load_float(&src[i + 16]); + + __m512i bf = cvtfp32_bf16(a, b); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(&dst[i]), bf); + } + for (; i < n; i++) { + dst[i] = c10::convert(src[i]); + } +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + __m512 a_lo, a_hi; + __m512 b_lo, b_hi; + __m512 c_lo, c_hi; + cvtbf16_fp32(__m512i(a), a_lo, a_hi); + cvtbf16_fp32(__m512i(b), b_lo, b_hi); + cvtbf16_fp32(__m512i(c), c_lo, c_hi); + auto o1 = _mm512_fmadd_ps(a_lo, b_lo, c_lo); + auto o2 = _mm512_fmadd_ps(a_hi, b_hi, c_hi); + return cvtfp32_bf16(o1, o2); +} + +static inline void _transpose_mxn_half_16_16(__m256i t[], __m512i u[]) { + __m512i r[8]; + // a0a1 a2a3 a4a5 a6a7 a8a9 a10a11 a12a13 a14a15 e0e1 e2e3 e4e5 e6e7 e8e9 + // e10e11 e12e13 e14e15 b0-b15 f0-f15 c0-c15 g0-g15 d0-d15 h0-h15 i0-i15 + // m0-m15 j0-j15 n0-n15 k0-k15 o0-o15 l0-l15 p0-p15 +#ifndef __msvc_cl__ +#pragma unroll(4) +#endif + for (int i = 0; i < 4; i++) { + r[i] = _mm512_inserti64x4(_mm512_castsi256_si512(t[i]), t[i + 4], 0x01); + r[i + 4] = + _mm512_inserti64x4(_mm512_castsi256_si512(t[i + 8]), t[i + 12], 0x01); + } + + // u0: a0a1 b0b1 a2a3 b2b3 a8a9 b8b9 a10a11 b10b11 e0e1 f0f1 e2e3 f2f3 e8e9 + // f8f9 e10e11 f10f11 u1: a4a5 b4b5 a6a7 b6b7 a12a13 b12b13 a14a15 b14b15 e4e5 + // f4f5 e6e7 f6f7 e12e13 f12f13 e14e15 f14f15 u2: c0c1 d0d1 c2c3 d2d3 c8c9 + // d8d9 c10c11 d10d11 g0g1 h0h1 g2g3 h2h3 g8g9 h8h9 g10g11 h10h11 u3: c4c5 + // d4b5 c6c7 d6b7 c12c13 d12d13 c14c15 d14d15 g4g5 h4h5 g6g7 h6h7 g12g13 + // h12h13 g14g15 h14h15 i j m n k l o p +#ifndef __msvc_cl__ +#pragma unroll(4) +#endif + for (int i = 0; i < 8; i += 2) { + u[i] = _mm512_unpacklo_epi32(r[i], r[i + 1]); + u[i + 1] = _mm512_unpackhi_epi32(r[i], r[i + 1]); + } + + // r0: a0a1 b0b1 c0c1 d0d1 a8a9 b8b9 c8c9 d8d9 e0e1 f0f1 g0g1 h0h1 e8e9 f8f9 + // g8g9 h8h9 r1: a2a3 b2b3 c2c3 d2d3 a10a11 b10b11 c10c11 d10d11 e2e3 f2f3 + // g2g3 h2h3 e10e11 f10f11 g10g11 h10h11 r2: a4a5 b4b5 c4c5 d4b5 a12a13 b12b13 + // c12c13 d12d13 r3: a6a7 b6b7 c6c7 d6b7 a14a15 b14b15 c14c15 d14d15 r4: i j k + // l m n o p + r[0] = _mm512_unpacklo_epi64(u[0], u[2]); + r[1] = _mm512_unpackhi_epi64(u[0], u[2]); + r[2] = _mm512_unpacklo_epi64(u[1], u[3]); + r[3] = _mm512_unpackhi_epi64(u[1], u[3]); + r[4] = _mm512_unpacklo_epi64(u[4], u[6]); + r[5] = _mm512_unpackhi_epi64(u[4], u[6]); + r[6] = _mm512_unpacklo_epi64(u[5], u[7]); + r[7] = _mm512_unpackhi_epi64(u[5], u[7]); + + __m512i const1 = _mm512_set_epi32( + 0x00370035, + 0x00330031, + 0x00270025, + 0x00230021, + 0x00170015, + 0x00130011, + 0x00070005, + 0x00030001, + 0x00360034, + 0x00320030, + 0x00260024, + 0x00220020, + 0x00160014, + 0x00120010, + 0x00060004, + 0x00020000); + __m512i const2 = _mm512_set_epi32( + 0x003f003d, + 0x003b0039, + 0x002f002d, + 0x002b0029, + 0x001f001d, + 0x001b0019, + 0x000f000d, + 0x000b0009, + 0x003e003c, + 0x003a0038, + 0x002e002c, + 0x002a0028, + 0x001e001c, + 0x001a0018, + 0x000e000c, + 0x000a0008); + // merge values from two regs + // 0-- 1-- + // 8-- 9-- + // 2-- 3-- + // 10-- 11-- + // 4-- 5-- + // 12-- 13-- + // 6-- 7-- + // 14-- 15-- +#ifndef __msvc_cl__ +#pragma unroll(4) +#endif + for (int i = 0; i < 4; i++) { + u[i] = _mm512_permutex2var_epi16(r[i], const1, r[i + 4]); + u[i + 4] = _mm512_permutex2var_epi16(r[i], const2, r[i + 4]); + } +} + +// TODO(Leslie): Add the AVX2 Version of transpose_mxn for BFloat16 and Float16 +// Code referred to FBGEMM: +// https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#L1483-L1607 +template <> +inline void transpose_mxn( + const BFloat16* src, + int64_t ld_src, + BFloat16* dst, + int64_t ld_dst) { + __m256i t[16]; + // load from src to registers + // a: a0 a1 a2 a3 a4 a5 a6 a7 a8 a9 a10 a11 a12 a13 a14 a15 + // b: b0 b1 b2 b3 b4 b5 b6 b7 b8 b9 b10 b11 b12 b13 b14 b15 + // c: c0 c1 c2 c3 c4 c5 c6 c7 c8 c9 c10 c11 c12 c13 c14 c15 + // d: d0 d1 d2 d3 d4 d5 d6 d7 d8 d9 d10 d11 d12 d13 d14 d15 + // e: e0 e1 e2 e3 e4 e5 e6 e7 e8 e9 e10 e11 e12 e13 e14 e15 + // f: f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 f10 f11 f12 f13 f14 f15 + // g: g0 g1 g2 g3 g4 g5 g6 g7 g8 g9 g10 g11 g12 g13 g14 g15 + // h: h0 h1 h2 h3 h4 h5 h6 h7 h8 h9 h10 h11 h12 h13 h14 h15 + // i: i0 i1 i2 i3 i4 i5 i6 i7 i8 i9 i10 i11 i12 i13 i14 i15 + // j: j0 j1 j2 j3 j4 j5 j6 j7 j8 j9 j10 j11 j12 j13 j14 j15 + // k: k0 k1 k2 k3 k4 k5 k6 k7 k8 k9 k10 k11 k12 k13 k14 k15 + // l: l0 l1 l2 l3 l4 l5 l6 l7 l8 l9 l10 l11 l12 l13 l14 l15 + // m: m0 m1 m2 m3 m4 m5 m6 m7 m8 m9 m10 m11 m12 m13 m14 m15 + // n: n0 n1 n2 n3 n4 n5 n6 n7 n8 n9 n10 n11 n12 n13 n14 n15 + // o: o0 o1 o2 o3 o4 o5 o6 o7 o8 o9 o10 o11 o12 o13 o14 o15 + // p: p0 p1 p2 p3 p4 p5 p6 p7 p8 p9 p10 p11 p12 p13 p14 p15 +#ifndef __msvc_cl__ +#pragma unroll(16) +#endif + for (int i = 0; i < 16; i++) { + t[i] = + _mm256_loadu_si256(reinterpret_cast(src + i * ld_src)); + } + + __m512i u[8]; + _transpose_mxn_half_16_16(t, u); + +#ifndef __msvc_cl__ +#pragma unroll(8) +#endif + for (int i = 0; i < 8; i++) { + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(dst + (i * 2) * ld_dst), + _mm512_extracti32x8_epi32(u[i], 0x0)); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(dst + (i * 2 + 1) * ld_dst), + _mm512_extracti32x8_epi32(u[i], 0x01)); + } +} + +// Code referred to FBGEMM: +// https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#L1483-L1607 +template <> +inline void transpose_mxn( + const Half* src, + int64_t ld_src, + Half* dst, + int64_t ld_dst) { + __m256i t[16]; + // load from src to registers + // Same matrix indices as above transpose_mxn +#ifndef __msvc_cl__ +#pragma unroll(16) +#endif + for (int i = 0; i < 16; i++) { + t[i] = + _mm256_loadu_si256(reinterpret_cast(src + i * ld_src)); + } + + __m512i u[8]; + _transpose_mxn_half_16_16(t, u); + +#ifndef __msvc_cl__ +#pragma unroll(8) +#endif + for (int i = 0; i < 8; i++) { + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(dst + (i * 2) * ld_dst), + _mm512_extracti32x8_epi32(u[i], 0x0)); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(dst + (i * 2 + 1) * ld_dst), + _mm512_extracti32x8_epi32(u[i], 0x01)); + } +} + +static inline void _transpose_mxn_half_32_32(__m512i r[], __m512i d[]) { + // t[0]: 0 32 1 33 2 34 3 35 8 40 9 41 10 42 11 43 16 ... 59 + // t[1]: 4 36 5 37 6 38 7 39 12 44 13 45 14 46 15 47 20 ... 63 + // t[2]: 64 96 65 97 66 98 67 99 72 104 73 105 74 106 75 ... 123 + // t[3]: 68 100 69 101 70 102 71 103 76 108 77 109 78 110 79 111 84 ... 127 + // t[4]: 128 160 129 161 130 162 131 163 136 168 137 169 138 170 139 171 144 + // ... 187 t[5]: 132 164 133 165 134 166 135 167 140 172 141 173 142 174 143 + // 175 148 ... 191 t[6]: 192 224 193 225 194 226 195 227 200 232 201 233 202 + // 234 203 235 208 ... 251 t[7]: 196 228 197 229 198 230 199 231 204 236 205 + // 237 206 238 207 239 212 ... 255 t[8]: 256 288 257 289 258 290 259 291 264 + // 296 265 297 266 298 267 299 272 ... 315 t[9]: 260 292 261 293 262 294 263 + // 295 268 300 269 301 270 302 271 303 276 ... 319 t[10]: 320 352 321 353 322 + // 354 323 355 328 360 329 361 330 362 331 363 336 ... 379 t[11]: 324 356 325 + // 357 326 358 327 359 332 364 333 365 334 366 335 367 340 ... 383 t[12]: 384 + // 416 385 417 386 418 387 419 392 424 393 425 394 426 395 427 400 ... 443 + // t[13]: 388 420 389 421 390 422 391 423 396 428 397 429 398 430 399 431 404 + // ... 447 t[14]: 448 480 449 481 450 482 451 483 456 488 457 489 458 490 459 + // 491 464 ... 507 t[15]: 452 484 453 485 454 486 455 487 460 492 461 493 462 + // 494 463 495 468 ... 511 t[16]: 512 544 513 545 514 546 515 547 520 552 521 + // 553 522 554 523 555 528 ... 571 + // ... + // t[31]: 964 996 965 997 966 998 967 999 972 1004 973 1005 974 1006 975 1007 + // 980 ... 1023 +#ifndef __msvc_cl__ +#pragma unroll(16) +#endif + for (int i = 0; i < 16; ++i) { + d[i * 2] = _mm512_unpacklo_epi16(r[i * 2], r[i * 2 + 1]); + d[i * 2 + 1] = _mm512_unpackhi_epi16(r[i * 2], r[i * 2 + 1]); + } + + // t[0]: 0 32 64 96 1 33 65 97 8 40 72 104 9 41 73 105 16 ... 121 + // t[1]: 2 34 66 98 3 35 67 99 10 42 74 106 11 43 75 107 18 ... 123 + // t[2]: 4 36 68 100 5 37 69 101 12 44 76 108 13 45 77 109 20 ... 125 + // t[3]: 6 38 70 102 7 39 71 103 14 46 78 110 15 47 79 111 22 ... 127 + // t[4]: 128 160 192 224 129 161 193 225 136 168 200 232 137 169 201 233 144 + // ... 249 t[5]: 130 162 194 226 131 163 195 227 138 170 202 234 139 171 203 + // 235 146 ... 251 t[6]: 132 164 196 228 133 165 197 229 140 172 204 236 141 + // 173 205 237 148 ... 253 t[7]: 134 166 198 230 135 167 199 231 142 174 206 + // 238 143 175 207 239 150 ... 255 t[8]: 256 288 320 352 257 289 321 353 264 + // 296 328 360 265 297 329 361 272 ... 377 t[9]: 258 290 322 354 259 291 323 + // 355 266 298 330 362 267 299 331 363 274 ... 379 t[10]: 260 292 324 356 261 + // 293 325 357 268 300 332 364 269 301 333 365 276 ... 381 t[11]: 262 294 326 + // 358 263 295 327 359 270 302 334 366 271 303 335 367 278 ... 383 t[12]: 384 + // 416 448 480 385 417 449 481 392 424 456 488 393 425 457 489 400 ... 505 + // t[13]: 386 418 450 482 387 419 451 483 394 426 458 490 395 427 459 491 402 + // ... 507 t[14]: 388 420 452 484 389 421 453 485 396 428 460 492 397 429 461 + // 493 404 ... 509 t[15]: 390 422 454 486 391 423 455 487 398 430 462 494 399 + // 431 463 495 406 ... 511 t[16]: 512 544 576 608 513 545 577 609 520 552 584 + // 616 521 553 585 617 528 ... 633 + // ... + // t[31]: 902 934 966 998 903 935 967 999 910 942 974 1006 911 943 975 1007 + // 918 ... 1023 +#ifndef __msvc_cl__ +#pragma unroll(8) +#endif + for (int i = 0; i < 8; ++i) { + r[i * 4] = _mm512_unpacklo_epi32(d[i * 4], d[i * 4 + 2]); + r[i * 4 + 1] = _mm512_unpackhi_epi32(d[i * 4], d[i * 4 + 2]); + r[i * 4 + 2] = _mm512_unpacklo_epi32(d[i * 4 + 1], d[i * 4 + 3]); + r[i * 4 + 3] = _mm512_unpackhi_epi32(d[i * 4 + 1], d[i * 4 + 3]); + } + + // t[0]: 0 32 64 96 128 160 192 224 8 40 72 104 136 168 200 232 16 ... 248 + // t[1]: 1 33 65 97 129 161 193 225 9 41 73 105 137 169 201 233 17 ... 249 + // t[2]: 2 34 66 98 130 162 194 226 10 42 74 106 138 170 202 234 18 ... 250 + // t[3]: 3 35 67 99 131 163 195 227 11 43 75 107 139 171 203 235 19 ... 251 + // t[4]: 4 36 68 100 132 164 196 228 12 44 76 108 140 172 204 236 20 ... 252 + // t[5]: 5 37 69 101 133 165 197 229 13 45 77 109 141 173 205 237 21 ... 253 + // t[6]: 6 38 70 102 134 166 198 230 14 46 78 110 142 174 206 238 22 ... 254 + // t[7]: 7 39 71 103 135 167 199 231 15 47 79 111 143 175 207 239 23 ... 255 + // t[8]: 256 288 320 352 384 416 448 480 264 296 328 360 392 424 456 488 272 + // ... 504 t[9]: 257 289 321 353 385 417 449 481 265 297 329 361 393 425 457 + // 489 273 ... 505 t[10]: 258 290 322 354 386 418 450 482 266 298 330 362 394 + // 426 458 490 274 ... 506 t[11]: 259 291 323 355 387 419 451 483 267 299 331 + // 363 395 427 459 491 275 ... 507 t[12]: 260 292 324 356 388 420 452 484 268 + // 300 332 364 396 428 460 492 276 ... 508 t[13]: 261 293 325 357 389 421 453 + // 485 269 301 333 365 397 429 461 493 277 ... 509 t[14]: 262 294 326 358 390 + // 422 454 486 270 302 334 366 398 430 462 494 278 ... 510 t[15]: 263 295 327 + // 359 391 423 455 487 271 303 335 367 399 431 463 495 279 ... 511 t[16]: 512 + // 544 576 608 640 672 704 736 520 552 584 616 648 680 712 744 528 ... 760 + // ... + // t[31]: 775 807 839 871 903 935 967 999 783 815 847 879 911 943 975 1007 791 + // ... 1023 +#ifndef __msvc_cl__ +#pragma unroll(4) +#endif + for (int i = 0; i < 4; ++i) { + d[i * 8] = _mm512_unpacklo_epi64(r[i * 8], r[i * 8 + 4]); + d[i * 8 + 1] = _mm512_unpackhi_epi64(r[i * 8], r[i * 8 + 4]); + d[i * 8 + 2] = _mm512_unpacklo_epi64(r[i * 8 + 1], r[i * 8 + 5]); + d[i * 8 + 3] = _mm512_unpackhi_epi64(r[i * 8 + 1], r[i * 8 + 5]); + d[i * 8 + 4] = _mm512_unpacklo_epi64(r[i * 8 + 2], r[i * 8 + 6]); + d[i * 8 + 5] = _mm512_unpackhi_epi64(r[i * 8 + 2], r[i * 8 + 6]); + d[i * 8 + 6] = _mm512_unpacklo_epi64(r[i * 8 + 3], r[i * 8 + 7]); + d[i * 8 + 7] = _mm512_unpackhi_epi64(r[i * 8 + 3], r[i * 8 + 7]); + } + + // t[0]: 0 32 64 96 128 160 192 224 256 288 320 352 384 416 448 480 16 ... 496 + // t[1]: 1 33 65 97 129 161 193 225 257 289 321 353 385 417 449 481 17 ... 497 + // t[2]: 2 34 66 98 130 162 194 226 258 290 322 354 386 418 450 482 18 ... 498 + // t[3]: 3 35 67 99 131 163 195 227 259 291 323 355 387 419 451 483 19 ... 499 + // t[4]: 4 36 68 100 132 164 196 228 260 292 324 356 388 420 452 484 20 ... + // 500 t[5]: 5 37 69 101 133 165 197 229 261 293 325 357 389 421 453 485 21 + // ... 501 t[6]: 6 38 70 102 134 166 198 230 262 294 326 358 390 422 454 486 + // 22 ... 502 t[7]: 7 39 71 103 135 167 199 231 263 295 327 359 391 423 455 + // 487 23 ... 503 t[8]: 8 40 72 104 136 168 200 232 264 296 328 360 392 424 + // 456 488 24 ... 504 t[9]: 9 41 73 105 137 169 201 233 265 297 329 361 393 + // 425 457 489 25 ... 505 t[10]: 10 42 74 106 138 170 202 234 266 298 330 362 + // 394 426 458 490 26 ... 506 t[11]: 11 43 75 107 139 171 203 235 267 299 331 + // 363 395 427 459 491 27 ... 507 t[12]: 12 44 76 108 140 172 204 236 268 300 + // 332 364 396 428 460 492 28 ... 508 t[13]: 13 45 77 109 141 173 205 237 269 + // 301 333 365 397 429 461 493 29 ... 509 t[14]: 14 46 78 110 142 174 206 238 + // 270 302 334 366 398 430 462 494 30 ... 510 t[15]: 15 47 79 111 143 175 207 + // 239 271 303 335 367 399 431 463 495 31 ... 511 t[16]: 512 544 576 608 640 + // 672 704 736 768 800 832 864 896 928 960 992 528 ... 1008 + // ... + // t[31]: 527 559 591 623 655 687 719 751 783 815 847 879 911 943 975 1007 543 + // ... 1023 + __m512i const1 = _mm512_set_epi64( + 0x000000000000000d, + 0x000000000000000c, + 0x0000000000000005, + 0x0000000000000004, + 0x0000000000000009, + 0x0000000000000008, + 0x0000000000000001, + 0x0000000000000000); + __m512i const2 = _mm512_set_epi64( + 0x000000000000000f, + 0x000000000000000e, + 0x0000000000000007, + 0x0000000000000006, + 0x000000000000000b, + 0x000000000000000a, + 0x0000000000000003, + 0x0000000000000002); +#ifndef __msvc_cl__ +#pragma unroll(8) +#endif + for (int i = 0; i < 8; ++i) { + r[i] = _mm512_permutex2var_epi64(d[i], /*idx*/ const1, d[i + 8]); + r[i + 8] = _mm512_permutex2var_epi64(d[i], /*idx*/ const2, d[i + 8]); + r[i + 16] = _mm512_permutex2var_epi64(d[i + 16], /*idx*/ const1, d[i + 24]); + r[i + 24] = _mm512_permutex2var_epi64(d[i + 16], /*idx*/ const2, d[i + 24]); + } + + // t[0]: 0 32 64 96 128 160 192 224 256 288 320 352 384 416 448 480 512 544 + // ... 992 t[1]: 1 33 65 97 129 161 193 225 257 289 321 353 385 417 449 481 + // 513 545 ... 993 t[2]: 2 34 66 98 130 162 194 226 258 290 322 354 386 418 + // 450 482 514 546 ... 994 t[3]: 3 35 67 99 131 163 195 227 259 291 323 355 + // 387 419 451 483 515 547 ... 995 t[4]: 4 36 68 100 132 164 196 228 260 292 + // 324 356 388 420 452 484 516 548 ... 996 t[5]: 5 37 69 101 133 165 197 229 + // 261 293 325 357 389 421 453 485 517 549 ... 997 t[6]: 6 38 70 102 134 166 + // 198 230 262 294 326 358 390 422 454 486 518 550 ... 998 t[7]: 7 39 71 103 + // 135 167 199 231 263 295 327 359 391 423 455 487 519 551 ... 999 t[8]: 8 40 + // 72 104 136 168 200 232 264 296 328 360 392 424 456 488 520 552 ... 1000 + // t[9]: 9 41 73 105 137 169 201 233 265 297 329 361 393 425 457 489 521 553 + // ... 1001 t[10]: 10 42 74 106 138 170 202 234 266 298 330 362 394 426 458 + // 490 522 554 ... 1002 t[11]: 11 43 75 107 139 171 203 235 267 299 331 363 + // 395 427 459 491 523 555 ... 1003 t[12]: 12 44 76 108 140 172 204 236 268 + // 300 332 364 396 428 460 492 524 556 ... 1004 t[13]: 13 45 77 109 141 173 + // 205 237 269 301 333 365 397 429 461 493 525 557 ... 1005 t[14]: 14 46 78 + // 110 142 174 206 238 270 302 334 366 398 430 462 494 526 558 ... 1006 t[15]: + // 15 47 79 111 143 175 207 239 271 303 335 367 399 431 463 495 527 559 ... + // 1007 t[16]: 16 48 80 112 144 176 208 240 272 304 336 368 400 432 464 496 + // 528 560 ... 1008 + // ... + // t[31]: 31 63 95 127 159 191 223 255 287 319 351 383 415 447 479 511 543 575 + // ... 1023 + __m512i const3 = _mm512_set_epi64( + 0x000000000000000b, + 0x000000000000000a, + 0x0000000000000009, + 0x0000000000000008, + 0x0000000000000003, + 0x0000000000000002, + 0x0000000000000001, + 0x0000000000000000); + __m512i const4 = _mm512_set_epi64( + 0x000000000000000f, + 0x000000000000000e, + 0x000000000000000d, + 0x000000000000000c, + 0x0000000000000007, + 0x0000000000000006, + 0x0000000000000005, + 0x0000000000000004); +#ifndef __msvc_cl__ +#pragma unroll(16) +#endif + for (int i = 0; i < 16; ++i) { + d[i] = _mm512_permutex2var_epi64(r[i], /*idx*/ const3, r[i + 16]); + d[i + 16] = _mm512_permutex2var_epi64(r[i], /*idx*/ const4, r[i + 16]); + } +} + +// Code referred to FBGEMM: +// https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#LL19C6-L19C6 +template <> +inline void transpose_mxn( + const BFloat16* src, + int64_t ld_src, + BFloat16* dst, + int64_t ld_dst, + int M, + int N) { + // load from src + TORCH_CHECK( + M <= 32 && N <= 32, "transpose_mxn expects M, N <= 32."); + __m512i r[32]; + int i; + if (N == 32) { + for (i = 0; i < M; ++i) { + r[i] = _mm512_loadu_si512(&src[i * ld_src]); + } + } else { + __mmask32 src_mask = (1 << N) - 1; + for (i = 0; i < M; ++i) { + r[i] = _mm512_maskz_loadu_epi16(src_mask, &src[i * ld_src]); + } + } + for (; i < 32; ++i) { + r[i] = _mm512_setzero_si512(); + } + + __m512i d[32]; + _transpose_mxn_half_32_32(r, d); + + // store to dst + if (M == 32) { + for (i = 0; i < N; ++i) { + _mm512_storeu_si512(&dst[i * ld_dst], d[i]); + } + } else { + __mmask32 dst_mask = (1 << M) - 1; + for (i = 0; i < N; ++i) { + _mm512_mask_storeu_epi16(&dst[i * ld_dst], dst_mask, d[i]); + } + } +} + +template < + typename T, + int M, + int N, + typename std::enable_if_t< + std::is_same_v && + ((M <= 32 && M != 16) || (N <= 32 && N != 16)), + int> = 0> +inline void transpose_mxn( + const BFloat16* src, + int64_t ld_src, + BFloat16* dst, + int64_t ld_dst) { + transpose_mxn(src, ld_src, dst, ld_dst, M, N); +} + +template <> +inline void transpose_mxn( + const Half* src, + int64_t ld_src, + Half* dst, + int64_t ld_dst, + int M, + int N) { + TORCH_CHECK(M <= 32 && N <= 32, "transpose_mxn expects M, N <= 32."); + // load from src + __m512i r[32]; + int i; + if (N == 32) { + for (i = 0; i < M; ++i) { + r[i] = _mm512_loadu_si512(&src[i * ld_src]); + } + } else { + __mmask32 src_mask = (1 << N) - 1; + for (i = 0; i < M; ++i) { + r[i] = _mm512_maskz_loadu_epi16(src_mask, &src[i * ld_src]); + } + } + for (; i < 32; ++i) { + r[i] = _mm512_setzero_si512(); + } + + __m512i d[32]; + _transpose_mxn_half_32_32(r, d); + + // store to dst + if (M == 32) { + for (i = 0; i < N; ++i) { + _mm512_storeu_si512(&dst[i * ld_dst], d[i]); + } + } else { + __mmask32 dst_mask = (1 << M) - 1; + for (i = 0; i < N; ++i) { + _mm512_mask_storeu_epi16(&dst[i * ld_dst], dst_mask, d[i]); + } + } +} + +template < + typename T, + int M, + int N, + typename std::enable_if_t< + std::is_same_v && + ((M <= 32 && M != 16) || (N <= 32 && N != 16)), + int> = 0> +inline void transpose_mxn( + const Half* src, + int64_t ld_src, + Half* dst, + int64_t ld_dst) { + transpose_mxn(src, ld_src, dst, ld_dst, M, N); +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorized16 { + public: + using Vectorized16::Vectorized16; + + using value_type = Half; + + Vectorized frac() const; + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_add_ps(x, y); + }); +} +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_sub_ps(x, y); + }); +} +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_mul_ps(x, y); + }); +} +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_div_ps(x, y); + }); +} + +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return _mm512_and_si512(a, b); +} +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return _mm512_or_si512(a, b); +} +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return _mm512_xor_si512(a, b); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +// frac. Implement this here so we can use subtraction +inline Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + __m512 a_lo, a_hi; + __m512 b_lo, b_hi; + cvtfp16_fp32(__m512i(a), a_lo, a_hi); + cvtfp16_fp32(__m512i(b), b_lo, b_hi); + auto max_lo = _mm512_max_ps(a_lo, b_lo); + auto max_hi = _mm512_max_ps(a_hi, b_hi); + auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q); + auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q); + auto nan_lo = _mm512_castsi512_ps(_mm512_set1_epi32(nan_lo_mask)); + auto nan_hi = _mm512_castsi512_ps(_mm512_set1_epi32(nan_hi_mask)); + // Exploit the fact that all-ones is a NaN. + auto o1 = _mm512_or_ps(max_lo, nan_lo); + auto o2 = _mm512_or_ps(max_hi, nan_hi); + return cvtfp32_fp16(o1, o2); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + __m512 a_lo, a_hi; + __m512 b_lo, b_hi; + __m512i zero_vec = _mm512_set1_epi32(0); + cvtfp16_fp32(__m512i(a), a_lo, a_hi); + cvtfp16_fp32(__m512i(b), b_lo, b_hi); + auto min_lo = _mm512_min_ps(a_lo, b_lo); + auto min_hi = _mm512_min_ps(a_hi, b_hi); + auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q); + auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q); + auto nan_lo = _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, nan_lo_mask, 0xFFFFFFFF)); + auto nan_hi = _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, nan_hi_mask, 0xFFFFFFFF)); + // Exploit the fact that all-ones is a NaN. + auto o1 = _mm512_or_ps(min_lo, nan_lo); + auto o2 = _mm512_or_ps(min_hi, nan_hi); + return cvtfp32_fp16(o1, o2); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + __m512 a_lo, a_hi; + __m512 min_lo, min_hi; + __m512 max_lo, max_hi; + cvtfp16_fp32(__m512i(a), a_lo, a_hi); + cvtfp16_fp32(__m512i(min), min_lo, min_hi); + cvtfp16_fp32(__m512i(max), max_lo, max_hi); + auto o1 = _mm512_min_ps(max_lo, _mm512_max_ps(min_lo, a_lo)); + auto o2 = _mm512_min_ps(max_hi, _mm512_max_ps(min_hi, a_hi)); + return cvtfp32_fp16(o1, o2); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + __m512 a_lo, a_hi; + __m512 max_lo, max_hi; + cvtfp16_fp32(__m512i(a), a_lo, a_hi); + cvtfp16_fp32(__m512i(max), max_lo, max_hi); + auto o1 = _mm512_min_ps(max_lo, a_lo); + auto o2 = _mm512_min_ps(max_hi, a_hi); + return cvtfp32_fp16(o1, o2); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + __m512 a_lo, a_hi; + __m512 min_lo, min_hi; + cvtfp16_fp32(__m512i(a), a_lo, a_hi); + cvtfp16_fp32(__m512i(min), min_lo, min_hi); + auto o1 = _mm512_max_ps(min_lo, a_lo); + auto o2 = _mm512_max_ps(min_hi, a_hi); + return cvtfp32_fp16(o1, o2); +} + +template <> +inline void convert(const Half* src, Half* dst, int64_t n) { + int64_t i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + auto vsrc = + _mm512_loadu_si512(reinterpret_cast<__m512i*>((void*)(src + i))); + _mm512_storeu_si512(reinterpret_cast<__m512i*>((void*)(dst + i)), vsrc); + } +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = src[i]; + } +} + +template <> +inline void convert(const float* src, Half* dst, int64_t n) { + int64_t i; + for (i = 0; i + Vectorized::size() <= n; + i += Vectorized::size()) { + __m512 a = _mm512_loadu_ps(&src[i]); + __m512 b = _mm512_loadu_ps(&src[i + 16]); + + __m512i bf = cvtfp32_fp16(a, b); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(&dst[i]), bf); + } + for (; i < n; i++) { + dst[i] = c10::convert(src[i]); + } +} + +template <> +inline void convert(const double* src, Half* dst, int64_t n) { + auto load_float = [](const double* src) -> __m512 { + // Load one float vector from an array of doubles + __m256 a = _mm512_cvtpd_ps(_mm512_loadu_pd(src)); + __m256 b = _mm512_cvtpd_ps(_mm512_loadu_pd(src + 8)); + return _mm512_insertf32x8(_mm512_castps256_ps512(a), b, 1); + }; + + int64_t i; + for (i = 0; i + Vectorized::size() <= n; + i += Vectorized::size()) { + __m512 a = load_float(&src[i]); + __m512 b = load_float(&src[i + 16]); + + __m512i bf = cvtfp32_fp16(a, b); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(&dst[i]), bf); + } + for (; i < n; i++) { + dst[i] = c10::convert(src[i]); + } +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + __m512 a_lo, a_hi; + __m512 b_lo, b_hi; + __m512 c_lo, c_hi; + cvtfp16_fp32(__m512i(a), a_lo, a_hi); + cvtfp16_fp32(__m512i(b), b_lo, b_hi); + cvtfp16_fp32(__m512i(c), c_lo, c_hi); + auto o1 = _mm512_fmadd_ps(a_lo, b_lo, c_lo); + auto o2 = _mm512_fmadd_ps(a_hi, b_hi, c_hi); + return cvtfp32_fp16(o1, o2); +} + +#define CONVERT_VECTORIZED_INIT(type, name) \ + inline std::tuple, Vectorized> \ + convert_##name##_float(const Vectorized& a) { \ + __m512 o1, o2; \ + cvt_to_fp32(__m512i(a), o1, o2); \ + return std::make_tuple(o1, o2); \ + } \ + \ + inline Vectorized convert_float_##name( \ + const Vectorized& a, const Vectorized& b) { \ + return cvt_from_fp32(__m512(a), __m512(b)); \ + } +CONVERT_VECTORIZED_INIT(BFloat16, bfloat16) +CONVERT_VECTORIZED_INIT(Half, half) + +#else // defined(CPU_CAPABILITY_AVX512) + +#define CONVERT_NON_VECTORIZED_INIT(type, name) \ + inline std::tuple, Vectorized> \ + convert_##name##_float(const Vectorized& a) { \ + constexpr int64_t K = Vectorized::size(); \ + __at_align__ float arr[K]; \ + __at_align__ type arr2[K]; \ + a.store(arr2); \ + for (const auto k : c10::irange(K)) { \ + arr[k] = c10::convert(arr2[k]); \ + } \ + return std::make_tuple( \ + Vectorized::loadu(arr), \ + Vectorized::loadu(arr + Vectorized::size())); \ + } \ + \ + inline Vectorized convert_float_##name( \ + const Vectorized& a, const Vectorized& b) { \ + constexpr int64_t K = Vectorized::size(); \ + __at_align__ float arr[K]; \ + __at_align__ type arr2[K]; \ + a.store(arr); \ + b.store(arr + Vectorized::size()); \ + for (const auto k : c10::irange(K)) { \ + arr2[k] = c10::convert(arr[k]); \ + } \ + return Vectorized::loadu(arr2); \ + } +CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16) +CONVERT_NON_VECTORIZED_INIT(Half, half) + +#endif // defined(CPU_CAPABILITY_AVX512) + +#if defined(CPU_CAPABILITY_AVX512) +#define LOAD_FP32_VECTORIZED_INIT(type, name) \ + inline void load_fp32_from_##name( \ + const type* data, Vectorized& out) { \ + auto values = _mm256_loadu_si256(reinterpret_cast(data)); \ + __m512 out_values; \ + cvt_to_fp32(values, out_values); \ + out = out_values; \ + } \ + \ + inline void load_fp32_from_##name( \ + const type* data, Vectorized& out1, Vectorized& out2) { \ + auto vec = Vectorized::loadu(data); \ + __m512 out1_values, out2_values; \ + cvt_to_fp32(vec, out1_values, out2_values); \ + out1 = out1_values; \ + out2 = out2_values; \ + } +LOAD_FP32_VECTORIZED_INIT(BFloat16, bf16) +LOAD_FP32_VECTORIZED_INIT(Half, fp16) + +#else // defined(CPU_CAPABILITY_AVX512) +#define LOAD_FP32_NON_VECTORIZED_INIT(type, name) \ + inline void load_fp32_from_##name( \ + const type* data, Vectorized& out) { \ + __at_align__ float values[Vectorized::size()]; \ + for (const auto k : c10::irange(Vectorized::size())) { \ + values[k] = data[k]; \ + } \ + out = Vectorized::loadu(values); \ + } \ + \ + inline void load_fp32_from_##name( \ + const type* data, Vectorized& out1, Vectorized& out2) { \ + load_fp32_from_##name(data, out1); \ + data += Vectorized::size(); \ + load_fp32_from_##name(data, out2); \ + } +LOAD_FP32_NON_VECTORIZED_INIT(BFloat16, bf16) +LOAD_FP32_NON_VECTORIZED_INIT(Half, fp16) + +#endif +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_complex_double.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_complex_double.h new file mode 100644 index 0000000000000000000000000000000000000000..755b5887f2372d2af1cd645593a65f759cda0660 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_complex_double.h @@ -0,0 +1,654 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#include +#if defined(CPU_CAPABILITY_AVX512) +#define SLEEF_STATIC_LIBS +#include +#endif + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX512) + +template <> +struct is_vec_specialized_for> : std::bool_constant { +}; + +template <> +class Vectorized> { + private: + __m512d values; + static constexpr __m512i zero_vector{0, 0, 0, 0, 0, 0, 0, 0}; + + public: + using value_type = c10::complex; + using size_type = int; + static constexpr size_type size() { + return 4; + } + Vectorized() {} + Vectorized(__m512d v) : values(v) {} + Vectorized(c10::complex val) { + double real_value = val.real(); + double imag_value = val.imag(); + values = _mm512_setr_pd( + real_value, + imag_value, + real_value, + imag_value, + real_value, + imag_value, + real_value, + imag_value); + } + Vectorized( + c10::complex val1, + c10::complex val2, + c10::complex val3, + c10::complex val4) { + values = _mm512_setr_pd( + val1.real(), + val1.imag(), + val2.real(), + val2.imag(), + val3.real(), + val3.imag(), + val4.real(), + val4.imag()); + } + operator __m512d() const { + return values; + } + template + static Vectorized> blend( + const Vectorized>& a, + const Vectorized>& b) { + // convert c10::complex index mask to V index mask: xy -> xxyy + // NOLINTNEXTLINE(clang-diagnostic-warning) + switch (mask) { + case 0: + return a; + case 1: + return _mm512_mask_blend_pd( + 0x03, a.values, b.values); // b0000 0001 = b0000 0011 + case 2: + return _mm512_mask_blend_pd( + 0x0C, a.values, b.values); // b0000 0010 = b0000 1100 + case 3: + return _mm512_mask_blend_pd( + 0x0F, a.values, b.values); // b0000 0011 = b0000 1111 + case 4: + return _mm512_mask_blend_pd( + 0x30, a.values, b.values); // b0000 0100 = b0011 0000 + case 5: + return _mm512_mask_blend_pd( + 0x33, a.values, b.values); // b0000 0101 = b0011 0011 + case 6: + return _mm512_mask_blend_pd( + 0x3C, a.values, b.values); // b0000 0110 = b0011 1100 + case 7: + return _mm512_mask_blend_pd( + 0x3F, a.values, b.values); // b0000 0111 = b0011 1111 + case 8: + return _mm512_mask_blend_pd( + 0xC0, a.values, b.values); // b0000 1000 = b1100 0000 + case 9: + return _mm512_mask_blend_pd( + 0xC3, a.values, b.values); // b0000 1001 = b1100 0011 + case 10: + return _mm512_mask_blend_pd( + 0xCC, a.values, b.values); // b0000 1010 = b1100 1100 + case 11: + return _mm512_mask_blend_pd( + 0xCF, a.values, b.values); // b0000 1011 = b1100 1111 + case 12: + return _mm512_mask_blend_pd( + 0xF0, a.values, b.values); // b0000 1100 = b1111 0000 + case 13: + return _mm512_mask_blend_pd( + 0xF3, a.values, b.values); // b0000 1101 = b1111 0011 + case 14: + return _mm512_mask_blend_pd( + 0xFC, a.values, b.values); // b0000 1110 = b1111 1100 + case 15: + return _mm512_mask_blend_pd( + 0xFF, a.values, b.values); // b0000 1111 = b1111 1111 + } + return b; + } + static Vectorized> blendv( + const Vectorized>& a, + const Vectorized>& b, + const Vectorized>& mask) { + // convert c10::complex index mask to V index mask: xy -> xxyy + auto mask_ = _mm512_unpacklo_pd(mask.values, mask.values); + auto all_ones = _mm512_set1_epi64(0xFFFFFFFFFFFFFFFF); + auto mmask = _mm512_cmp_epi64_mask( + _mm512_castpd_si512(mask_), all_ones, _MM_CMPINT_EQ); + return _mm512_mask_blend_pd(mmask, a.values, b.values); + } + template + static Vectorized> arange( + c10::complex base = 0., + step_t step = static_cast(1)) { + return Vectorized>( + base, + base + c10::complex(1) * step, + base + c10::complex(2) * step, + base + c10::complex(3) * step); + } + static Vectorized> set( + const Vectorized>& a, + const Vectorized>& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + } + return b; + } + static Vectorized> loadu( + const void* ptr, + int64_t count = size()) { + if (count == size()) + return _mm512_loadu_pd(reinterpret_cast(ptr)); + + __at_align__ double tmp_values[2 * size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(2 * size())) { + tmp_values[i] = 0.0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(c10::complex)); + return _mm512_load_pd(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm512_storeu_pd(reinterpret_cast(ptr), values); + } else if (count > 0) { + double tmp_values[2 * size()]; + _mm512_storeu_pd(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(c10::complex)); + } + } + const c10::complex& operator[](int idx) const = delete; + c10::complex& operator[](int idx) = delete; + Vectorized> map( + c10::complex (*const f)(const c10::complex&)) const { + __at_align__ c10::complex tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + // AVX512 doesn't have horizontal add & horizontal sub instructions. + // TODO: hadd_pd() & hsub_pd() may have scope for improvement. + static inline __m512d hadd_pd(__m512d a, __m512d b) { + __m512i idx1 = _mm512_set_epi64(14, 6, 12, 4, 10, 2, 8, 0); + __m512i idx2 = _mm512_set_epi64(15, 7, 13, 5, 11, 3, 9, 1); + return _mm512_add_pd( + _mm512_mask_permutex2var_pd(a, 0xff, idx1, b), + _mm512_mask_permutex2var_pd(a, 0xff, idx2, b)); + } + static inline __m512d hsub_pd(__m512d a, __m512d b) { + __m512i idx1 = _mm512_set_epi64(14, 6, 12, 4, 10, 2, 8, 0); + __m512i idx2 = _mm512_set_epi64(15, 7, 13, 5, 11, 3, 9, 1); + return _mm512_sub_pd( + _mm512_mask_permutex2var_pd(a, 0xff, idx1, b), + _mm512_mask_permutex2var_pd(a, 0xff, idx2, b)); + } + __m512d abs_2_() const { + auto val_2 = _mm512_mul_pd(values, values); // a*a b*b + return hadd_pd(val_2, val_2); // a*a+b*b a*a+b*b + } + __m512d abs_() const { + auto real = _mm512_movedup_pd(values); // real real + // movehdup_pd does not exist... + auto imag = _mm512_permute_pd(values, 0xff); // imag imag + return Sleef_hypotd8_u05(real, imag); // abs abs + } + Vectorized> abs() const { + const __m512d real_mask = _mm512_castsi512_pd(_mm512_setr_epi64( + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000)); + return _mm512_and_pd(abs_(), real_mask); // abs 0 + } + __m512d angle_() const { + // angle = atan2(b/a) + auto b_a = _mm512_permute_pd(values, 0x55); // b a + return Sleef_atan2d8_u10(values, b_a); // 90-angle angle + } + Vectorized> angle() const { + const __m512d real_mask = _mm512_castsi512_pd(_mm512_setr_epi64( + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000)); + auto angle = _mm512_permute_pd(angle_(), 0x55); // angle 90-angle + return _mm512_and_pd(angle, real_mask); // angle 0 + } + Vectorized> sgn() const { + auto abs = abs_(); + auto zero = _mm512_setzero_pd(); + auto mask = _mm512_cmp_pd_mask(abs, zero, _CMP_EQ_OQ); + auto div = _mm512_div_pd(values, abs); + return _mm512_mask_blend_pd(mask, div, zero); + } + __m512d real_() const { + const __m512d real_mask = _mm512_castsi512_pd(_mm512_setr_epi64( + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000)); + return _mm512_and_pd(values, real_mask); + } + Vectorized> real() const { + return real_(); + } + __m512d imag_() const { + const __m512d imag_mask = _mm512_castsi512_pd(_mm512_setr_epi64( + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF)); + return _mm512_and_pd(values, imag_mask); + } + Vectorized> imag() const { + return _mm512_permute_pd(imag_(), 0x55); // b a + } + __m512d conj_() const { + const __m512d sign_mask = + _mm512_setr_pd(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); + return _mm512_xor_pd(values, sign_mask); // a -b + } + Vectorized> conj() const { + return conj_(); + } + Vectorized> log() const { + // Most trigonomic ops use the log() op to improve complex number + // performance. + return map(std::log); + } + Vectorized> log2() const { + const __m512d log2_ = _mm512_set1_pd(std::log(2)); + return _mm512_div_pd(log(), log2_); + } + Vectorized> log10() const { + const __m512d log10_ = _mm512_set1_pd(std::log(10)); + return _mm512_div_pd(log(), log10_); + } + Vectorized> log1p() const { + return map(std::log1p); + } + Vectorized> asin() const { + // TODO: The vectorized implementation requires special handling for the + // case where real number/imag number is 0/Inf/NaN. + // // asin(x) + // // = -i*ln(iz + sqrt(1 -z^2)) + // // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi))) + // // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi)) + // const __m512d one = _mm512_set1_pd(1); + + // auto conj = conj_(); + // auto b_a = _mm512_permute_pd(conj, 0x55); //-b a + // auto ab = _mm512_mul_pd(conj, b_a); //-ab + // -ab auto im = _mm512_add_pd(ab, ab); //-2ab -2ab + + // auto val_2 = _mm512_mul_pd(values, values); // a*a + // b*b auto re = hsub_pd(val_2, _mm512_permute_pd(val_2, 0x55)); // a*a-b*b + // b*b-a*a re = _mm512_sub_pd(one, re); + + // auto root = Vectorized(_mm512_mask_blend_pd(0xAA, re, im)).sqrt(); + // //sqrt(re + i*im) auto ln = Vectorized(_mm512_add_pd(b_a, root)).log(); + // //ln(iz + sqrt()) return Vectorized(_mm512_permute_pd(ln.values, + // 0x55)).conj(); //-i*ln() + return map(std::asin); + } + Vectorized> acos() const { + // acos(x) = pi/2 - asin(x) + constexpr auto pi_2d = c10::pi / 2; + const __m512d pi_2 = + _mm512_setr_pd(pi_2d, 0.0, pi_2d, 0.0, pi_2d, 0.0, pi_2d, 0.0); + return _mm512_sub_pd(pi_2, asin()); + } + Vectorized> atan() const; + Vectorized> atanh() const { + return map(std::atanh); + } + Vectorized> exp() const { + // TODO: The vectorized implementation requires special handling for the + // case where real number/imag number is 0/Inf/NaN. + // //exp(a + bi) + // // = exp(a)*(cos(b) + sin(b)i) + // auto exp = Sleef_expd8_u10(values); //exp(a) exp(b) exp = + // _mm512_mask_blend_pd(0xAA, exp, _mm512_permute_pd(exp, 0x55)); //exp(a) + // exp(a) + + // auto sin_cos = Sleef_sincosd8_u10(values); //[sin(a), cos(a)] [sin(b), + // cos(b)] auto cos_sin = _mm512_mask_blend_pd(0xAA, + // _mm512_permute_pd(sin_cos.y, 0x55), + // sin_cos.x); //cos(b) + // sin(b) + // return _mm512_mul_pd(exp, cos_sin); + return map(std::exp); + } + Vectorized> exp2() const { + // Use identity 2**x = exp(log(2) * x) + const __m512d ln_2 = _mm512_set1_pd(c10::ln_2); + Vectorized> scaled_values = + _mm512_mul_pd(values, ln_2); + return scaled_values.exp(); + } + Vectorized> expm1() const { + return map(std::expm1); + } + Vectorized> sin() const { + return map(std::sin); + } + Vectorized> sinh() const { + return map(std::sinh); + } + Vectorized> cos() const { + return map(std::cos); + } + Vectorized> cosh() const { + return map(std::cosh); + } + Vectorized> ceil() const { + return _mm512_ceil_pd(values); + } + Vectorized> floor() const { + return _mm512_floor_pd(values); + } + Vectorized> neg() const { + auto zero = _mm512_setzero_pd(); + return _mm512_sub_pd(zero, values); + } + Vectorized> round() const { + return _mm512_roundscale_pd( + values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + Vectorized> tan() const { + return map(std::tan); + } + Vectorized> tanh() const { + return map(std::tanh); + } + Vectorized> trunc() const { + return _mm512_roundscale_pd( + values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + } + Vectorized> sqrt() const { + return map(std::sqrt); + } + Vectorized> reciprocal() const; + Vectorized> rsqrt() const { + return sqrt().reciprocal(); + } + Vectorized> pow( + const Vectorized>& exp) const { + __at_align__ c10::complex x_tmp[size()]; + __at_align__ c10::complex y_tmp[size()]; + store(x_tmp); + exp.store(y_tmp); + for (const auto i : c10::irange(size())) { + x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]); + } + return loadu(x_tmp); + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized> operator==( + const Vectorized>& other) const { + auto mask = _mm512_cmp_pd_mask(values, other.values, _CMP_EQ_OQ); + return _mm512_castsi512_pd( + _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF)); + } + Vectorized> operator!=( + const Vectorized>& other) const { + auto mask = _mm512_cmp_pd_mask(values, other.values, _CMP_NEQ_UQ); + return _mm512_castsi512_pd( + _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF)); + } + Vectorized> operator<( + const Vectorized>& other [[maybe_unused]]) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator<=( + const Vectorized>& other [[maybe_unused]]) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator>( + const Vectorized>& other [[maybe_unused]]) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator>=( + const Vectorized>& other [[maybe_unused]]) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized> eq( + const Vectorized>& other) const; + Vectorized> ne( + const Vectorized>& other) const; +}; + +template <> +Vectorized> inline operator+( + const Vectorized>& a, + const Vectorized>& b) { + return _mm512_add_pd(a, b); +} + +template <> +Vectorized> inline operator-( + const Vectorized>& a, + const Vectorized>& b) { + return _mm512_sub_pd(a, b); +} + +template <> +Vectorized> inline operator*( + const Vectorized>& a, + const Vectorized>& b) { + //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i + const __m512d sign_mask = + _mm512_setr_pd(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); + auto ac_bd = _mm512_mul_pd(a, b); // ac bd + + auto d_c = _mm512_permute_pd(b, 0x55); // d c + d_c = _mm512_xor_pd(sign_mask, d_c); // d -c + auto ad_bc = _mm512_mul_pd(a, d_c); // ad -bc + + auto ret = Vectorized>::hsub_pd( + ac_bd, ad_bc); // ac - bd ad + bc + return ret; +} + +template <> +Vectorized> inline operator/( + const Vectorized>& a, + const Vectorized>& b) { + // TODO: The vectorized implementation requires special handling for the case + // where real number/imag number is 0/Inf/NaN. + // //re + im*i = (a + bi) / (c + di) + // auto mask = _mm512_set1_pd(-0.f); + // auto fabs_cd = _mm512_andnot_pd(mask, b); // |c| |d| + // auto fabs_dc = _mm512_permute_pd(fabs_cd, 0x55); // |d| |c| + // auto scale = _mm512_rcp14_pd(_mm512_max_pd(fabs_cd, fabs_dc)); // 1/sc + // 1/sc auto a2 = _mm512_mul_pd(a, scale); // a/sc b/sc auto b2 = + // _mm512_mul_pd(b, scale); // c/sc d/sc auto acbd2 = + // _mm512_mul_pd(a2, b2); + + // const __m512d sign_mask = _mm512_setr_pd(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, + // -0.0, 0.0); auto dc2 = _mm512_permute_pd(b2, 0x55); // d/sc c/sc + // dc2 = _mm512_xor_pd(sign_mask, dc2); // -d/|c,d| c/sc + // auto adbc2 = _mm512_mul_pd(a2, dc2); //-ad/sc^2 bc/sc^2 + // auto res2 = Vectorized>::hadd_pd(acbd2, adbc2); + // //(ac+bd)/sc^2 (bc-ad)/sc^2 + + // // get the denominator + // auto denom2 = Vectorized>(b2).abs_2_(); // + // (c^2+d^2)/sc^2 (c^2+d^2)/sc^2 res2 = _mm512_div_pd(res2, denom2); return + // res2; + __at_align__ c10::complex + tmp1[Vectorized>::size()]; + __at_align__ c10::complex + tmp2[Vectorized>::size()]; + __at_align__ c10::complex + out[Vectorized>::size()]; + a.store(tmp1); + b.store(tmp2); + for (const auto i : c10::irange(Vectorized>::size())) { + out[i] = tmp1[i] / tmp2[i]; + } + return _mm512_loadu_pd(reinterpret_cast(out)); +} + +// reciprocal. Implement this here so we can use multiplication. +inline Vectorized> Vectorized< + c10::complex>::reciprocal() const { + // TODO: The vectorized implementation requires special handling for the case + // where real number/imag number is 0/Inf/NaN. + // //re + im*i = (a + bi) / (c + di) + // //re = (ac + bd)/abs_2() = c/abs_2() + // //im = (bc - ad)/abs_2() = d/abs_2() + // const __m512d sign_mask = _mm512_setr_pd(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, + // 0.0, -0.0); auto c_d = _mm512_xor_pd(sign_mask, values); //c -d + // return _mm512_div_pd(c_d, abs_2_()); + __at_align__ c10::complex tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = c10::complex(1) / tmp[i]; + } + return loadu(tmp); +} + +inline Vectorized> Vectorized>::atan() + const { + // TODO: The vectorized implementation requires special handling for the case + // where real number/imag number is 0/Inf/NaN. + // // atan(x) = i/2 * ln((i + z)/(i - z)) + // const __m512d i = _mm512_setr_pd(0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0); + // const Vectorized i_half = _mm512_setr_pd(0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.0, + // 0.5); + + // auto sum = Vectorized(_mm512_add_pd(i, values)); // a + // 1+b auto sub = Vectorized(_mm512_sub_pd(i, values)); // -a 1-b auto + // ln = (sum/sub).log(); // ln((i + + // z)/(i - z)) return i_half*ln; // i/2*ln() + return map(std::atan); +} + +template <> +Vectorized> inline maximum( + const Vectorized>& a, + const Vectorized>& b) { + auto zero_vec = _mm512_set1_epi64(0); + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + auto mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_LT_OQ); + auto max = _mm512_mask_blend_pd(mask, a, b); + // Exploit the fact that all-ones is a NaN. + auto isnan_mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_UNORD_Q); + auto isnan = _mm512_mask_set1_epi64(zero_vec, isnan_mask, 0xFFFFFFFFFFFFFFFF); + return _mm512_or_pd(max, _mm512_castsi512_pd(isnan)); +} + +template <> +Vectorized> inline minimum( + const Vectorized>& a, + const Vectorized>& b) { + auto zero_vec = _mm512_set1_epi64(0); + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + auto mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_GT_OQ); + auto min = _mm512_mask_blend_pd(mask, a, b); + // Exploit the fact that all-ones is a NaN. + auto isnan_mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_UNORD_Q); + auto isnan = _mm512_mask_set1_epi64(zero_vec, isnan_mask, 0xFFFFFFFFFFFFFFFF); + return _mm512_or_pd(min, _mm512_castsi512_pd(isnan)); +} + +template <> +Vectorized> inline operator&( + const Vectorized>& a, + const Vectorized>& b) { + return _mm512_and_pd(a, b); +} + +template <> +Vectorized> inline operator|( + const Vectorized>& a, + const Vectorized>& b) { + return _mm512_or_pd(a, b); +} + +template <> +Vectorized> inline operator^( + const Vectorized>& a, + const Vectorized>& b) { + return _mm512_xor_pd(a, b); +} + +inline Vectorized> Vectorized>::eq( + const Vectorized>& other) const { + auto eq = (*this == other); // compares real and imag individually + // If both real numbers and imag numbers are equal, then the complex numbers + // are equal + return (eq.real() & eq.imag()) & + Vectorized>(_mm512_set1_pd(1.0)); +} + +inline Vectorized> Vectorized>::ne( + const Vectorized>& other) const { + auto ne = (*this != other); // compares real and imag individually + // If either real numbers or imag numbers are not equal, then the complex + // numbers are not equal + return (ne.real() | ne.imag()) & + Vectorized>(_mm512_set1_pd(1.0)); +} + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_complex_float.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_complex_float.h new file mode 100644 index 0000000000000000000000000000000000000000..3c80fe924695652034ec5744856fd88d7f4d8872 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_complex_float.h @@ -0,0 +1,1222 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#include +#if defined(CPU_CAPABILITY_AVX512) +#define SLEEF_STATIC_LIBS +#include +#endif + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX512) + +template <> +struct is_vec_specialized_for> : std::bool_constant { +}; + +template <> +class Vectorized> { + private: + __m512 values; + static constexpr __m512i zero_vector{0, 0, 0, 0, 0, 0, 0, 0}; + + public: + using value_type = c10::complex; + using size_type = int; + static constexpr size_type size() { + return 8; + } + Vectorized() {} + Vectorized(__m512 v) : values(v) {} + Vectorized(c10::complex val) { + float real_value = val.real(); + float imag_value = val.imag(); + values = _mm512_setr_ps( + real_value, + imag_value, + real_value, + imag_value, + real_value, + imag_value, + real_value, + imag_value, + real_value, + imag_value, + real_value, + imag_value, + real_value, + imag_value, + real_value, + imag_value); + } + Vectorized( + c10::complex val1, + c10::complex val2, + c10::complex val3, + c10::complex val4, + c10::complex val5, + c10::complex val6, + c10::complex val7, + c10::complex val8) { + values = _mm512_setr_ps( + val1.real(), + val1.imag(), + val2.real(), + val2.imag(), + val3.real(), + val3.imag(), + val4.real(), + val4.imag(), + val5.real(), + val5.imag(), + val6.real(), + val6.imag(), + val7.real(), + val7.imag(), + val8.real(), + val8.imag()); + } + operator __m512() const { + return values; + } + template + static Vectorized> blend( + const Vectorized>& a, + const Vectorized>& b) { + // convert c10::complex index mask to V index mask: xy -> xxyy + static_assert(mask > -1 && mask < 256, "Unexpected mask value"); + // The compiler would hopefully convert this switch condition + // into a jump table + switch (mask) { + case 0: + return a; + case 1: + return _mm512_mask_blend_ps(0x03, a.values, b.values); + case 2: + return _mm512_mask_blend_ps(0x0C, a.values, b.values); + case 3: + return _mm512_mask_blend_ps(0x0F, a.values, b.values); + case 4: + return _mm512_mask_blend_ps(0x30, a.values, b.values); + case 5: + return _mm512_mask_blend_ps(0x33, a.values, b.values); + case 6: + return _mm512_mask_blend_ps(0x3C, a.values, b.values); + case 7: + return _mm512_mask_blend_ps(0x3F, a.values, b.values); + case 8: + return _mm512_mask_blend_ps(0xC0, a.values, b.values); + case 9: + return _mm512_mask_blend_ps(0xC3, a.values, b.values); + case 10: + return _mm512_mask_blend_ps(0xCC, a.values, b.values); + case 11: + return _mm512_mask_blend_ps(0xCF, a.values, b.values); + case 12: + return _mm512_mask_blend_ps(0xF0, a.values, b.values); + case 13: + return _mm512_mask_blend_ps(0xF3, a.values, b.values); + case 14: + return _mm512_mask_blend_ps(0xFC, a.values, b.values); + case 15: + return _mm512_mask_blend_ps(0xFF, a.values, b.values); + case 16: + return _mm512_mask_blend_ps(0x300, a.values, b.values); + case 17: + return _mm512_mask_blend_ps(0x303, a.values, b.values); + case 18: + return _mm512_mask_blend_ps(0x30C, a.values, b.values); + case 19: + return _mm512_mask_blend_ps(0x30F, a.values, b.values); + case 20: + return _mm512_mask_blend_ps(0x330, a.values, b.values); + case 21: + return _mm512_mask_blend_ps(0x333, a.values, b.values); + case 22: + return _mm512_mask_blend_ps(0x33C, a.values, b.values); + case 23: + return _mm512_mask_blend_ps(0x33F, a.values, b.values); + case 24: + return _mm512_mask_blend_ps(0x3C0, a.values, b.values); + case 25: + return _mm512_mask_blend_ps(0x3C3, a.values, b.values); + case 26: + return _mm512_mask_blend_ps(0x3CC, a.values, b.values); + case 27: + return _mm512_mask_blend_ps(0x3CF, a.values, b.values); + case 28: + return _mm512_mask_blend_ps(0x3F0, a.values, b.values); + case 29: + return _mm512_mask_blend_ps(0x3F3, a.values, b.values); + case 30: + return _mm512_mask_blend_ps(0x3FC, a.values, b.values); + case 31: + return _mm512_mask_blend_ps(0x3FF, a.values, b.values); + case 32: + return _mm512_mask_blend_ps(0xC00, a.values, b.values); + case 33: + return _mm512_mask_blend_ps(0xC03, a.values, b.values); + case 34: + return _mm512_mask_blend_ps(0xC0C, a.values, b.values); + case 35: + return _mm512_mask_blend_ps(0xC0F, a.values, b.values); + case 36: + return _mm512_mask_blend_ps(0xC30, a.values, b.values); + case 37: + return _mm512_mask_blend_ps(0xC33, a.values, b.values); + case 38: + return _mm512_mask_blend_ps(0xC3C, a.values, b.values); + case 39: + return _mm512_mask_blend_ps(0xC3F, a.values, b.values); + case 40: + return _mm512_mask_blend_ps(0xCC0, a.values, b.values); + case 41: + return _mm512_mask_blend_ps(0xCC3, a.values, b.values); + case 42: + return _mm512_mask_blend_ps(0xCCC, a.values, b.values); + case 43: + return _mm512_mask_blend_ps(0xCCF, a.values, b.values); + case 44: + return _mm512_mask_blend_ps(0xCF0, a.values, b.values); + case 45: + return _mm512_mask_blend_ps(0xCF3, a.values, b.values); + case 46: + return _mm512_mask_blend_ps(0xCFC, a.values, b.values); + case 47: + return _mm512_mask_blend_ps(0xCFF, a.values, b.values); + case 48: + return _mm512_mask_blend_ps(0xF00, a.values, b.values); + case 49: + return _mm512_mask_blend_ps(0xF03, a.values, b.values); + case 50: + return _mm512_mask_blend_ps(0xF0C, a.values, b.values); + case 51: + return _mm512_mask_blend_ps(0xF0F, a.values, b.values); + case 52: + return _mm512_mask_blend_ps(0xF30, a.values, b.values); + case 53: + return _mm512_mask_blend_ps(0xF33, a.values, b.values); + case 54: + return _mm512_mask_blend_ps(0xF3C, a.values, b.values); + case 55: + return _mm512_mask_blend_ps(0xF3F, a.values, b.values); + case 56: + return _mm512_mask_blend_ps(0xFC0, a.values, b.values); + case 57: + return _mm512_mask_blend_ps(0xFC3, a.values, b.values); + case 58: + return _mm512_mask_blend_ps(0xFCC, a.values, b.values); + case 59: + return _mm512_mask_blend_ps(0xFCF, a.values, b.values); + case 60: + return _mm512_mask_blend_ps(0xFF0, a.values, b.values); + case 61: + return _mm512_mask_blend_ps(0xFF3, a.values, b.values); + case 62: + return _mm512_mask_blend_ps(0xFFC, a.values, b.values); + case 63: + return _mm512_mask_blend_ps(0xFFF, a.values, b.values); + case 64: + return _mm512_mask_blend_ps(0x3000, a.values, b.values); + case 65: + return _mm512_mask_blend_ps(0x3003, a.values, b.values); + case 66: + return _mm512_mask_blend_ps(0x300C, a.values, b.values); + case 67: + return _mm512_mask_blend_ps(0x300F, a.values, b.values); + case 68: + return _mm512_mask_blend_ps(0x3030, a.values, b.values); + case 69: + return _mm512_mask_blend_ps(0x3033, a.values, b.values); + case 70: + return _mm512_mask_blend_ps(0x303C, a.values, b.values); + case 71: + return _mm512_mask_blend_ps(0x303F, a.values, b.values); + case 72: + return _mm512_mask_blend_ps(0x30C0, a.values, b.values); + case 73: + return _mm512_mask_blend_ps(0X30C3, a.values, b.values); + case 74: + return _mm512_mask_blend_ps(0x30CC, a.values, b.values); + case 75: + return _mm512_mask_blend_ps(0x30CF, a.values, b.values); + case 76: + return _mm512_mask_blend_ps(0x30F0, a.values, b.values); + case 77: + return _mm512_mask_blend_ps(0x30F3, a.values, b.values); + case 78: + return _mm512_mask_blend_ps(0x30FC, a.values, b.values); + case 79: + return _mm512_mask_blend_ps(0x30FF, a.values, b.values); + case 80: + return _mm512_mask_blend_ps(0x3300, a.values, b.values); + case 81: + return _mm512_mask_blend_ps(0X3303, a.values, b.values); + case 82: + return _mm512_mask_blend_ps(0x330C, a.values, b.values); + case 83: + return _mm512_mask_blend_ps(0x330F, a.values, b.values); + case 84: + return _mm512_mask_blend_ps(0x3330, a.values, b.values); + case 85: + return _mm512_mask_blend_ps(0x3333, a.values, b.values); + case 86: + return _mm512_mask_blend_ps(0x333C, a.values, b.values); + case 87: + return _mm512_mask_blend_ps(0X333F, a.values, b.values); + case 88: + return _mm512_mask_blend_ps(0x33C0, a.values, b.values); + case 89: + return _mm512_mask_blend_ps(0x33C3, a.values, b.values); + case 90: + return _mm512_mask_blend_ps(0x33CC, a.values, b.values); + case 91: + return _mm512_mask_blend_ps(0x33CF, a.values, b.values); + case 92: + return _mm512_mask_blend_ps(0x33F0, a.values, b.values); + case 93: + return _mm512_mask_blend_ps(0x33F3, a.values, b.values); + case 94: + return _mm512_mask_blend_ps(0x33FC, a.values, b.values); + case 95: + return _mm512_mask_blend_ps(0x33FF, a.values, b.values); + case 96: + return _mm512_mask_blend_ps(0X3C00, a.values, b.values); + case 97: + return _mm512_mask_blend_ps(0x3C03, a.values, b.values); + case 98: + return _mm512_mask_blend_ps(0x3C0C, a.values, b.values); + case 99: + return _mm512_mask_blend_ps(0x3C0F, a.values, b.values); + case 100: + return _mm512_mask_blend_ps(0x3C30, a.values, b.values); + case 101: + return _mm512_mask_blend_ps(0x3C33, a.values, b.values); + case 102: + return _mm512_mask_blend_ps(0x3C3C, a.values, b.values); + case 103: + return _mm512_mask_blend_ps(0x3C3F, a.values, b.values); + case 104: + return _mm512_mask_blend_ps(0x3CC0, a.values, b.values); + case 105: + return _mm512_mask_blend_ps(0x3CC3, a.values, b.values); + case 106: + return _mm512_mask_blend_ps(0x3CCC, a.values, b.values); + case 107: + return _mm512_mask_blend_ps(0x3CCF, a.values, b.values); + case 108: + return _mm512_mask_blend_ps(0x3CF0, a.values, b.values); + case 109: + return _mm512_mask_blend_ps(0x3CF3, a.values, b.values); + case 110: + return _mm512_mask_blend_ps(0x3CFC, a.values, b.values); + case 111: + return _mm512_mask_blend_ps(0x3CFF, a.values, b.values); + case 112: + return _mm512_mask_blend_ps(0x3F00, a.values, b.values); + case 113: + return _mm512_mask_blend_ps(0x3F03, a.values, b.values); + case 114: + return _mm512_mask_blend_ps(0x3F0C, a.values, b.values); + case 115: + return _mm512_mask_blend_ps(0x3F0F, a.values, b.values); + case 116: + return _mm512_mask_blend_ps(0x3F30, a.values, b.values); + case 117: + return _mm512_mask_blend_ps(0x3F33, a.values, b.values); + case 118: + return _mm512_mask_blend_ps(0x3F3C, a.values, b.values); + case 119: + return _mm512_mask_blend_ps(0x3F3F, a.values, b.values); + case 120: + return _mm512_mask_blend_ps(0x3FC0, a.values, b.values); + case 121: + return _mm512_mask_blend_ps(0x3FC3, a.values, b.values); + case 122: + return _mm512_mask_blend_ps(0x3FCC, a.values, b.values); + case 123: + return _mm512_mask_blend_ps(0x3FCF, a.values, b.values); + case 124: + return _mm512_mask_blend_ps(0x3FF0, a.values, b.values); + case 125: + return _mm512_mask_blend_ps(0x3FF3, a.values, b.values); + case 126: + return _mm512_mask_blend_ps(0x3FFC, a.values, b.values); + case 127: + return _mm512_mask_blend_ps(0x3FFF, a.values, b.values); + case 128: + return _mm512_mask_blend_ps(0xC000, a.values, b.values); + case 129: + return _mm512_mask_blend_ps(0xC003, a.values, b.values); + case 130: + return _mm512_mask_blend_ps(0xC00C, a.values, b.values); + case 131: + return _mm512_mask_blend_ps(0xC00F, a.values, b.values); + case 132: + return _mm512_mask_blend_ps(0xC030, a.values, b.values); + case 133: + return _mm512_mask_blend_ps(0xC033, a.values, b.values); + case 134: + return _mm512_mask_blend_ps(0xC03C, a.values, b.values); + case 135: + return _mm512_mask_blend_ps(0xC03F, a.values, b.values); + case 136: + return _mm512_mask_blend_ps(0xC0C0, a.values, b.values); + case 137: + return _mm512_mask_blend_ps(0xC0C3, a.values, b.values); + case 138: + return _mm512_mask_blend_ps(0xC0CC, a.values, b.values); + case 139: + return _mm512_mask_blend_ps(0xC0CF, a.values, b.values); + case 140: + return _mm512_mask_blend_ps(0xC0F0, a.values, b.values); + case 141: + return _mm512_mask_blend_ps(0xC0F3, a.values, b.values); + case 142: + return _mm512_mask_blend_ps(0xC0FC, a.values, b.values); + case 143: + return _mm512_mask_blend_ps(0xC0FF, a.values, b.values); + case 144: + return _mm512_mask_blend_ps(0xC300, a.values, b.values); + case 145: + return _mm512_mask_blend_ps(0xC303, a.values, b.values); + case 146: + return _mm512_mask_blend_ps(0xC30C, a.values, b.values); + case 147: + return _mm512_mask_blend_ps(0xC30F, a.values, b.values); + case 148: + return _mm512_mask_blend_ps(0xC330, a.values, b.values); + case 149: + return _mm512_mask_blend_ps(0xC333, a.values, b.values); + case 150: + return _mm512_mask_blend_ps(0xC33C, a.values, b.values); + case 151: + return _mm512_mask_blend_ps(0xC33F, a.values, b.values); + case 152: + return _mm512_mask_blend_ps(0xC3C0, a.values, b.values); + case 153: + return _mm512_mask_blend_ps(0xC3C3, a.values, b.values); + case 154: + return _mm512_mask_blend_ps(0xC3CC, a.values, b.values); + case 155: + return _mm512_mask_blend_ps(0xC3CF, a.values, b.values); + case 156: + return _mm512_mask_blend_ps(0xC3F0, a.values, b.values); + case 157: + return _mm512_mask_blend_ps(0xC3F3, a.values, b.values); + case 158: + return _mm512_mask_blend_ps(0xC3FC, a.values, b.values); + case 159: + return _mm512_mask_blend_ps(0xC3FF, a.values, b.values); + case 160: + return _mm512_mask_blend_ps(0xCC00, a.values, b.values); + case 161: + return _mm512_mask_blend_ps(0xCC03, a.values, b.values); + case 162: + return _mm512_mask_blend_ps(0xCC0C, a.values, b.values); + case 163: + return _mm512_mask_blend_ps(0xCC0F, a.values, b.values); + case 164: + return _mm512_mask_blend_ps(0xCC30, a.values, b.values); + case 165: + return _mm512_mask_blend_ps(0xCC33, a.values, b.values); + case 166: + return _mm512_mask_blend_ps(0xCC3C, a.values, b.values); + case 167: + return _mm512_mask_blend_ps(0xCC3F, a.values, b.values); + case 168: + return _mm512_mask_blend_ps(0xCCC0, a.values, b.values); + case 169: + return _mm512_mask_blend_ps(0xCCC3, a.values, b.values); + case 170: + return _mm512_mask_blend_ps(0xCCCC, a.values, b.values); + case 171: + return _mm512_mask_blend_ps(0xCCCF, a.values, b.values); + case 172: + return _mm512_mask_blend_ps(0xCCF0, a.values, b.values); + case 173: + return _mm512_mask_blend_ps(0xCCF3, a.values, b.values); + case 174: + return _mm512_mask_blend_ps(0xCCFC, a.values, b.values); + case 175: + return _mm512_mask_blend_ps(0xCCFF, a.values, b.values); + case 176: + return _mm512_mask_blend_ps(0xCF00, a.values, b.values); + case 177: + return _mm512_mask_blend_ps(0xCF03, a.values, b.values); + case 178: + return _mm512_mask_blend_ps(0xCF0C, a.values, b.values); + case 179: + return _mm512_mask_blend_ps(0xCF0F, a.values, b.values); + case 180: + return _mm512_mask_blend_ps(0xCF30, a.values, b.values); + case 181: + return _mm512_mask_blend_ps(0xCF33, a.values, b.values); + case 182: + return _mm512_mask_blend_ps(0xCF3C, a.values, b.values); + case 183: + return _mm512_mask_blend_ps(0xCF3F, a.values, b.values); + case 184: + return _mm512_mask_blend_ps(0xCFC0, a.values, b.values); + case 185: + return _mm512_mask_blend_ps(0xCFC3, a.values, b.values); + case 186: + return _mm512_mask_blend_ps(0xCFCC, a.values, b.values); + case 187: + return _mm512_mask_blend_ps(0xCFCF, a.values, b.values); + case 188: + return _mm512_mask_blend_ps(0xCFF0, a.values, b.values); + case 189: + return _mm512_mask_blend_ps(0xCFF3, a.values, b.values); + case 190: + return _mm512_mask_blend_ps(0xCFFC, a.values, b.values); + case 191: + return _mm512_mask_blend_ps(0xCFFF, a.values, b.values); + case 192: + return _mm512_mask_blend_ps(0xF000, a.values, b.values); + case 193: + return _mm512_mask_blend_ps(0xF003, a.values, b.values); + case 194: + return _mm512_mask_blend_ps(0xF00C, a.values, b.values); + case 195: + return _mm512_mask_blend_ps(0xF00F, a.values, b.values); + case 196: + return _mm512_mask_blend_ps(0xF030, a.values, b.values); + case 197: + return _mm512_mask_blend_ps(0xF033, a.values, b.values); + case 198: + return _mm512_mask_blend_ps(0xF03C, a.values, b.values); + case 199: + return _mm512_mask_blend_ps(0xF03F, a.values, b.values); + case 200: + return _mm512_mask_blend_ps(0XF0C0, a.values, b.values); + case 201: + return _mm512_mask_blend_ps(0xF0C3, a.values, b.values); + case 202: + return _mm512_mask_blend_ps(0xF0CC, a.values, b.values); + case 203: + return _mm512_mask_blend_ps(0xF0CF, a.values, b.values); + case 204: + return _mm512_mask_blend_ps(0xF0F0, a.values, b.values); + case 205: + return _mm512_mask_blend_ps(0xF0F3, a.values, b.values); + case 206: + return _mm512_mask_blend_ps(0xF0FC, a.values, b.values); + case 207: + return _mm512_mask_blend_ps(0xF0FF, a.values, b.values); + case 208: + return _mm512_mask_blend_ps(0XF300, a.values, b.values); + case 209: + return _mm512_mask_blend_ps(0xF303, a.values, b.values); + case 210: + return _mm512_mask_blend_ps(0xF30C, a.values, b.values); + case 211: + return _mm512_mask_blend_ps(0xF30F, a.values, b.values); + case 212: + return _mm512_mask_blend_ps(0xF330, a.values, b.values); + case 213: + return _mm512_mask_blend_ps(0xF333, a.values, b.values); + case 214: + return _mm512_mask_blend_ps(0XF33C, a.values, b.values); + case 215: + return _mm512_mask_blend_ps(0xF33F, a.values, b.values); + case 216: + return _mm512_mask_blend_ps(0xF3C0, a.values, b.values); + case 217: + return _mm512_mask_blend_ps(0xF3C3, a.values, b.values); + case 218: + return _mm512_mask_blend_ps(0xF3CC, a.values, b.values); + case 219: + return _mm512_mask_blend_ps(0xF3CF, a.values, b.values); + case 220: + return _mm512_mask_blend_ps(0xF3F0, a.values, b.values); + case 221: + return _mm512_mask_blend_ps(0xF3F3, a.values, b.values); + case 222: + return _mm512_mask_blend_ps(0xF3FC, a.values, b.values); + case 223: + return _mm512_mask_blend_ps(0XF3FF, a.values, b.values); + case 224: + return _mm512_mask_blend_ps(0xFC00, a.values, b.values); + case 225: + return _mm512_mask_blend_ps(0xFC03, a.values, b.values); + case 226: + return _mm512_mask_blend_ps(0xFC0C, a.values, b.values); + case 227: + return _mm512_mask_blend_ps(0xFC0F, a.values, b.values); + case 228: + return _mm512_mask_blend_ps(0xFC30, a.values, b.values); + case 229: + return _mm512_mask_blend_ps(0xFC33, a.values, b.values); + case 230: + return _mm512_mask_blend_ps(0xFC3C, a.values, b.values); + case 231: + return _mm512_mask_blend_ps(0xFC3F, a.values, b.values); + case 232: + return _mm512_mask_blend_ps(0xFCC0, a.values, b.values); + case 233: + return _mm512_mask_blend_ps(0xFCC3, a.values, b.values); + case 234: + return _mm512_mask_blend_ps(0xFCCC, a.values, b.values); + case 235: + return _mm512_mask_blend_ps(0xFCCF, a.values, b.values); + case 236: + return _mm512_mask_blend_ps(0xFCF0, a.values, b.values); + case 237: + return _mm512_mask_blend_ps(0xFCF3, a.values, b.values); + case 238: + return _mm512_mask_blend_ps(0xFCFC, a.values, b.values); + case 239: + return _mm512_mask_blend_ps(0xFCFF, a.values, b.values); + case 240: + return _mm512_mask_blend_ps(0xFF00, a.values, b.values); + case 241: + return _mm512_mask_blend_ps(0xFF03, a.values, b.values); + case 242: + return _mm512_mask_blend_ps(0xFF0C, a.values, b.values); + case 243: + return _mm512_mask_blend_ps(0xFF0F, a.values, b.values); + case 244: + return _mm512_mask_blend_ps(0xFF30, a.values, b.values); + case 245: + return _mm512_mask_blend_ps(0xFF33, a.values, b.values); + case 246: + return _mm512_mask_blend_ps(0xFF3C, a.values, b.values); + case 247: + return _mm512_mask_blend_ps(0xFF3F, a.values, b.values); + case 248: + return _mm512_mask_blend_ps(0xFFC0, a.values, b.values); + case 249: + return _mm512_mask_blend_ps(0xFFC3, a.values, b.values); + case 250: + return _mm512_mask_blend_ps(0xFFCC, a.values, b.values); + case 251: + return _mm512_mask_blend_ps(0xFFCF, a.values, b.values); + case 252: + return _mm512_mask_blend_ps(0xFFF0, a.values, b.values); + case 253: + return _mm512_mask_blend_ps(0xFFF3, a.values, b.values); + case 254: + return _mm512_mask_blend_ps(0xFFFC, a.values, b.values); + default: + break; + } + return b; + } + static Vectorized> blendv( + const Vectorized>& a, + const Vectorized>& b, + const Vectorized>& mask) { + // convert c10::complex index mask to V index mask: xy -> xxyy + auto mask_ = _mm512_unpacklo_ps(mask.values, mask.values); + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + auto mmask = _mm512_cmp_epi32_mask( + _mm512_castps_si512(mask_), all_ones, _MM_CMPINT_EQ); + return _mm512_mask_blend_ps(mmask, a.values, b.values); + } + template + static Vectorized> arange( + c10::complex base = 0., + step_t step = static_cast(1)) { + return Vectorized>( + base, + base + step, + base + c10::complex(2) * step, + base + c10::complex(3) * step, + base + c10::complex(4) * step, + base + c10::complex(5) * step, + base + c10::complex(6) * step, + base + c10::complex(7) * step); + } + static Vectorized> set( + const Vectorized>& a, + const Vectorized>& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + } + return b; + } + static Vectorized> loadu( + const void* ptr, + int64_t count = size()) { + if (count == size()) + return _mm512_loadu_ps(reinterpret_cast(ptr)); + + __at_align__ float tmp_values[2 * size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(2 * size())) { + tmp_values[i] = 0.0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(c10::complex)); + return _mm512_load_ps(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm512_storeu_ps(reinterpret_cast(ptr), values); + } else if (count > 0) { + float tmp_values[2 * size()]; + _mm512_storeu_ps(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(c10::complex)); + } + } + // AVX512 doesn't have horizontal add & horizontal sub instructions. + // TODO: hadd_pd() & hsub_pd() may have scope for improvement. + static inline __m512 hadd_ps(__m512 a, __m512 b) { + __m512i idx1 = _mm512_set_epi32( + 30, 14, 28, 12, 26, 10, 24, 8, 22, 6, 20, 4, 18, 2, 16, 0); + __m512i idx2 = _mm512_set_epi32( + 31, 15, 29, 13, 27, 11, 25, 9, 23, 7, 21, 5, 19, 3, 17, 1); + return _mm512_add_ps( + _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b), + _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b)); + } + static inline __m512 hsub_ps(__m512 a, __m512 b) { + __m512i idx1 = _mm512_set_epi32( + 30, 14, 28, 12, 26, 10, 24, 8, 22, 6, 20, 4, 18, 2, 16, 0); + __m512i idx2 = _mm512_set_epi32( + 31, 15, 29, 13, 27, 11, 25, 9, 23, 7, 21, 5, 19, 3, 17, 1); + return _mm512_sub_ps( + _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b), + _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b)); + } + const c10::complex& operator[](int idx) const = delete; + c10::complex& operator[](int idx) = delete; + Vectorized> map( + c10::complex (*const f)(const c10::complex&)) const { + __at_align__ c10::complex tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + __m512 abs_2_() const { + auto val_2 = _mm512_mul_ps(values, values); // a*a b*b + auto ret = hadd_ps(val_2, val_2); // a*a+b*b a*a+b*b + return ret; + } + __m512 abs_() const { + auto real = _mm512_moveldup_ps(values); // real real + auto imag = _mm512_movehdup_ps(values); // imag imag + return Sleef_hypotf16_u05(real, imag); // abs abs + } + Vectorized> abs() const { + const __m512 real_mask = _mm512_castsi512_ps(_mm512_setr_epi32( + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000)); + return _mm512_and_ps(abs_(), real_mask); // abs 0 + } + __m512 angle_() const { + // angle = atan2(b/a) + auto b_a = _mm512_permute_ps(values, 0xB1); // b a + return Sleef_atan2f16_u10(values, b_a); // 90-angle angle + } + Vectorized> angle() const { + const __m512 real_mask = _mm512_castsi512_ps(_mm512_setr_epi32( + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000)); + auto angle = _mm512_permute_ps(angle_(), 0xB1); // angle 90-angle + return _mm512_and_ps(angle, real_mask); // angle 0 + } + Vectorized> sgn() const { + auto abs = abs_(); + auto zero = _mm512_setzero_ps(); + auto mask = _mm512_cmp_ps_mask(abs, zero, _CMP_EQ_OQ); + auto div = _mm512_div_ps(values, abs); + return _mm512_mask_blend_ps(mask, div, zero); + } + __m512 real_() const { + const __m512 real_mask = _mm512_castsi512_ps(_mm512_setr_epi32( + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000)); + return _mm512_and_ps(values, real_mask); + } + Vectorized> real() const { + return real_(); + } + __m512 imag_() const { + const __m512 imag_mask = _mm512_castsi512_ps(_mm512_setr_epi32( + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF)); + return _mm512_and_ps(values, imag_mask); + } + Vectorized> imag() const { + return _mm512_permute_ps(imag_(), 0xB1); // b a + } + __m512 conj_() const { + const __m512 sign_mask = _mm512_setr_ps( + 0.0, + -0.0, + 0.0, + -0.0, + 0.0, + -0.0, + 0.0, + -0.0, + 0.0, + -0.0, + 0.0, + -0.0, + 0.0, + -0.0, + 0.0, + -0.0); + return _mm512_xor_ps(values, sign_mask); // a -b + } + Vectorized> conj() const { + return conj_(); + } + Vectorized> log() const { + // Most trigonomic ops use the log() op to improve complex number + // performance. + return map(std::log); + } + Vectorized> log2() const { + const __m512 log2_ = _mm512_set1_ps(std::log(2)); + return _mm512_div_ps(log(), log2_); + } + Vectorized> log10() const { + const __m512 log10_ = _mm512_set1_ps(std::log(10)); + return _mm512_div_ps(log(), log10_); + } + Vectorized> log1p() const { + return map(std::log1p); + } + Vectorized> asin() const { + // TODO: The vectorized implementation requires special handling for the + // case where real number/imag number is 0/Inf/NaN. + // // asin(x) + // // = -i*ln(iz + sqrt(1 -z^2)) + // // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi))) + // // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi)) + // const __m512 one = _mm512_set1_ps(1); + + // auto conj = conj_(); + // auto b_a = _mm512_permute_ps(conj, 0xB1); //-b a + // auto ab = _mm512_mul_ps(conj, b_a); //-ab + // -ab auto im = _mm512_add_ps(ab, ab); //-2ab -2ab + + // auto val_2 = _mm512_mul_ps(values, values); // a*a + // b*b auto re = hsub_ps(val_2, _mm512_permute_ps(val_2, 0xB1)); // a*a-b*b + // b*b-a*a re = _mm512_sub_ps(one, re); + + // auto root = Vectorized(_mm512_mask_blend_ps(0xAAAA, re, im)).sqrt(); + // //sqrt(re + i*im) auto ln = Vectorized(_mm512_add_ps(b_a, root)).log(); + // //ln(iz + sqrt()) return Vectorized(_mm512_permute_ps(ln.values, + // 0xB1)).conj(); //-i*ln() + return map(std::asin); + } + Vectorized> acos() const { + return map(std::acos); + } + Vectorized> atan() const; + Vectorized> atanh() const { + return map(std::atanh); + } + Vectorized> exp() const { + // TODO: The vectorized implementation requires special handling for the + // case where real number/imag number is 0/Inf/NaN. + // //exp(a + bi) + // // = exp(a)*(cos(b) + sin(b)i) + // auto exp = Sleef_expf16_u10(values); //exp(a) exp(b) exp = + // _mm512_mask_blend_ps(0xAAAA, exp, _mm512_permute_ps(exp, 0xB1)); //exp(a) + // exp(a) + + // auto sin_cos = Sleef_sincosf16_u10(values); //[sin(a), cos(a)] [sin(b), + // cos(b)] auto cos_sin = _mm512_mask_blend_ps(0xAAAA, + // _mm512_permute_ps(sin_cos.y, 0xB1), + // sin_cos.x); //cos(b) + // sin(b) + // return _mm512_mul_ps(exp, cos_sin); + return map(std::exp); + } + Vectorized> exp2() const { + // Use identity 2**x = exp(log(2) * x) + const __m512 ln_2 = _mm512_set1_ps(c10::ln_2); + Vectorized> scaled_values = _mm512_mul_ps(values, ln_2); + return scaled_values.exp(); + } + Vectorized> expm1() const { + return map(std::expm1); + } + Vectorized> sin() const { + return map(std::sin); + } + Vectorized> sinh() const { + return map(std::sinh); + } + Vectorized> cos() const { + return map(std::cos); + } + Vectorized> cosh() const { + return map(std::cosh); + } + Vectorized> ceil() const { + return _mm512_ceil_ps(values); + } + Vectorized> floor() const { + return _mm512_floor_ps(values); + } + Vectorized> neg() const { + auto zero = _mm512_setzero_ps(); + return _mm512_sub_ps(zero, values); + } + Vectorized> round() const { + return _mm512_roundscale_ps( + values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + Vectorized> tan() const { + return map(std::tan); + } + Vectorized> tanh() const { + return map(std::tanh); + } + Vectorized> trunc() const { + return _mm512_roundscale_ps( + values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + } + Vectorized> sqrt() const { + return map(std::sqrt); + } + Vectorized> reciprocal() const; + Vectorized> rsqrt() const { + return sqrt().reciprocal(); + } + Vectorized> pow( + const Vectorized>& exp) const { + __at_align__ c10::complex x_tmp[size()]; + __at_align__ c10::complex y_tmp[size()]; + store(x_tmp); + exp.store(y_tmp); + for (const auto i : c10::irange(size())) { + x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]); + } + return loadu(x_tmp); + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized> operator==( + const Vectorized>& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_EQ_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF)); + } + Vectorized> operator!=( + const Vectorized>& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_NEQ_UQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF)); + } + Vectorized> operator<( + const Vectorized>& other [[maybe_unused]]) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator<=( + const Vectorized>& other [[maybe_unused]]) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator>( + const Vectorized>& other [[maybe_unused]]) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator>=( + const Vectorized>& other [[maybe_unused]]) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized> eq( + const Vectorized>& other) const; + Vectorized> ne( + const Vectorized>& other) const; +}; + +template <> +Vectorized> inline operator+( + const Vectorized>& a, + const Vectorized>& b) { + return _mm512_add_ps(a, b); +} + +template <> +Vectorized> inline operator-( + const Vectorized>& a, + const Vectorized>& b) { + return _mm512_sub_ps(a, b); +} + +template <> +Vectorized> inline operator*( + const Vectorized>& a, + const Vectorized>& b) { + //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i + const __m512 sign_mask = _mm512_setr_ps( + 0.0, + -0.0, + 0.0, + -0.0, + 0.0, + -0.0, + 0.0, + -0.0, + 0.0, + -0.0, + 0.0, + -0.0, + 0.0, + -0.0, + 0.0, + -0.0); + auto ac_bd = _mm512_mul_ps(a, b); // ac bd + + auto d_c = _mm512_permute_ps(b, 0xB1); // d c + d_c = _mm512_xor_ps(sign_mask, d_c); // d -c + auto ad_bc = _mm512_mul_ps(a, d_c); // ad -bc + + auto ret = Vectorized>::hsub_ps( + ac_bd, ad_bc); // ac - bd ad + bc + return ret; +} + +template <> +Vectorized> inline operator/( + const Vectorized>& a, + const Vectorized>& b) { + // TODO: The vectorized implementation requires special handling for the case + // where real number/imag number is 0/Inf/NaN. + // //re + im*i = (a + bi) / (c + di) + // auto mask = _mm512_set1_ps(-0.f); + // auto fabs_cd = _mm512_andnot_ps(mask, b); // |c| |d| + // auto fabs_dc = _mm512_permute_ps(fabs_cd, 0xB1); // |d| |c| + // auto scale = _mm512_rcp14_ps(_mm512_max_ps(fabs_cd, fabs_dc)); // 1/sc + // 1/sc auto a2 = _mm512_mul_ps(a, scale); // a/sc b/sc auto b2 = + // _mm512_mul_ps(b, scale); // c/sc d/sc auto acbd2 = + // _mm512_mul_ps(a2, b2); + + // const __m512 sign_mask = _mm512_setr_ps(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, + // -0.0, 0.0, + // -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, + // -0.0, 0.0); + // auto dc2 = _mm512_permute_ps(b2, 0xB1); // d/sc c/sc + // dc2 = _mm512_xor_ps(sign_mask, dc2); // -d/|c,d| c/sc + // auto adbc2 = _mm512_mul_ps(a2, dc2); //-ad/sc^2 bc/sc^2 + // auto res2 = Vectorized>::hadd_ps(acbd2, adbc2); + // //(ac+bd)/sc^2 (bc-ad)/sc^2 + + // // get the denominator + // auto denom2 = Vectorized>(b2).abs_2_(); // + // (c^2+d^2)/sc^2 (c^2+d^2)/sc^2 res2 = _mm512_div_ps(res2, denom2); return + // res2; + __at_align__ c10::complex + tmp1[Vectorized>::size()]; + __at_align__ c10::complex + tmp2[Vectorized>::size()]; + __at_align__ c10::complex out[Vectorized>::size()]; + a.store(tmp1); + b.store(tmp2); + for (const auto i : c10::irange(Vectorized>::size())) { + out[i] = tmp1[i] / tmp2[i]; + } + return _mm512_loadu_ps(reinterpret_cast(out)); +} + +// reciprocal. Implement this here so we can use multiplication. +inline Vectorized> Vectorized< + c10::complex>::reciprocal() const { + // TODO: The vectorized implementation requires special handling for the case + // where real number/imag number is 0/Inf/NaN. + // //re + im*i = (a + bi) / (c + di) + // //re = (ac + bd)/abs_2() = c/abs_2() + // //im = (bc - ad)/abs_2() = d/abs_2() + // const __m512 sign_mask = _mm512_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, + // 0.0, -0.0, + // 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, + // 0.0, -0.0); + // auto c_d = _mm512_xor_ps(sign_mask, values); //c -d + // return _mm512_div_ps(c_d, abs_2_()); + __at_align__ c10::complex tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = c10::complex(1) / tmp[i]; + } + return loadu(tmp); +} + +inline Vectorized> Vectorized>::atan() + const { + // TODO: The vectorized implementation requires special handling for the case + // where real number/imag number is 0/Inf/NaN. + // // atan(x) = i/2 * ln((i + z)/(i - z)) + // const __m512 i = _mm512_setr_ps(0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, + // 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0); + // const Vectorized i_half = _mm512_setr_ps(0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.0, + // 0.5, + // 0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.0, + // 0.5); + + // auto sum = Vectorized(_mm512_add_ps(i, values)); // a + // 1+b auto sub = Vectorized(_mm512_sub_ps(i, values)); // -a 1-b auto + // ln = (sum/sub).log(); // ln((i + + // z)/(i - z)) return i_half*ln; // i/2*ln() + return map(std::atan); +} + +template <> +Vectorized> inline maximum( + const Vectorized>& a, + const Vectorized>& b) { + auto zero_vector = _mm512_set1_epi32(0); + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + auto mask = _mm512_cmp_ps_mask(abs_a, abs_b, _CMP_LT_OQ); + auto max = _mm512_mask_blend_ps(mask, a, b); + // Exploit the fact that all-ones is a NaN. + auto isnan_mask = _mm512_cmp_ps_mask(abs_a, abs_b, _CMP_UNORD_Q); + auto isnan = _mm512_mask_set1_epi32(zero_vector, isnan_mask, 0xFFFFFFFF); + return _mm512_or_ps(max, _mm512_castsi512_ps(isnan)); +} + +template <> +Vectorized> inline minimum( + const Vectorized>& a, + const Vectorized>& b) { + auto zero_vector = _mm512_set1_epi32(0); + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + auto mask = _mm512_cmp_ps_mask(abs_a, abs_b, _CMP_GT_OQ); + auto min = _mm512_mask_blend_ps(mask, a, b); + // Exploit the fact that all-ones is a NaN. + auto isnan_mask = _mm512_cmp_ps_mask(abs_a, abs_b, _CMP_UNORD_Q); + auto isnan = _mm512_mask_set1_epi32(zero_vector, isnan_mask, 0xFFFFFFFF); + return _mm512_or_ps(min, _mm512_castsi512_ps(isnan)); +} + +template <> +Vectorized> inline operator&( + const Vectorized>& a, + const Vectorized>& b) { + return _mm512_and_ps(a, b); +} + +template <> +Vectorized> inline operator|( + const Vectorized>& a, + const Vectorized>& b) { + return _mm512_or_ps(a, b); +} + +template <> +Vectorized> inline operator^( + const Vectorized>& a, + const Vectorized>& b) { + return _mm512_xor_ps(a, b); +} + +inline Vectorized> Vectorized>::eq( + const Vectorized>& other) const { + auto eq = (*this == other); // compares real and imag individually + // If both real numbers and imag numbers are equal, then the complex numbers + // are equal + return (eq.real() & eq.imag()) & + Vectorized>(_mm512_set1_ps(1.0f)); +} + +inline Vectorized> Vectorized>::ne( + const Vectorized>& other) const { + auto ne = (*this != other); // compares real and imag individually + // If either real numbers or imag numbers are not equal, then the complex + // numbers are not equal + return (ne.real() | ne.imag()) & + Vectorized>(_mm512_set1_ps(1.0f)); +} + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_convert.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_convert.h new file mode 100644 index 0000000000000000000000000000000000000000..8c03183982436a51155d8975b1c7641dc9eb7b2a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_convert.h @@ -0,0 +1,340 @@ +#pragma once + +#include +#include +#include +#include + +namespace at::vec { +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + __m512 value; + cvtbf16_fp32(_mm512_castsi512_si256(src[0]), value); + result[0] = value; + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN result; + __m512 value; + cvtfp16_fp32(_mm512_castsi512_si256(src[0]), value); + result[0] = value; + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + result[0] = _mm512_castsi256_si512(cvtfp32_bf16(src[0])); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + result[0] = convert_float_bfloat16(src[0], src[1]); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + std::tie(result[0], result[1]) = convert_bfloat16_float(src[0]); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN result; + result[0] = _mm512_castsi256_si512(cvtfp32_fp16(src[0])); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN result; + result[0] = convert_float_half(src[0], src[1]); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN result; + std::tie(result[0], result[1]) = convert_half_float(src[0]); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + auto low = _mm512_cvtepi64_ps(src[0]); + auto high = _mm512_cvtepi64_ps(src[1]); + return Vectorized( + _mm512_insertf32x8(_mm512_castps256_ps512(low), high, 1)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + at::vec::VectorizedN result; + result[0] = _mm512_cvt_roundps_epi64( + _mm512_castps512_ps256(src[0]), _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC); + result[1] = _mm512_cvt_roundps_epi64( + _mm512_extractf32x8_ps(src[0], 1), + _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + auto low = _mm512_cvtepi64_epi32(src[0]); + auto high = _mm512_cvtepi64_epi32(src[1]); + return Vectorized( + _mm512_inserti32x8(_mm512_castsi256_si512(low), high, 1)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + at::vec::VectorizedN result; + result[0] = _mm512_cvtepi32_epi64(_mm512_castsi512_si256(src[0])); + result[1] = _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(src[0], 1)); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + auto src128 = _mm512_castsi512_si128(src[0]); + return Vectorized(_mm512_cvtepi8_epi32(src128)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + auto src128 = _mm512_castsi512_si128(src[0]); + return Vectorized(_mm512_cvtepu8_epi32(src128)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + return Vectorized(_mm512_cvttps_epi32(src[0])); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + return Vectorized(_mm512_cvtepi32_ps(src[0])); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + auto src256 = _mm512_castsi512_si256(src[0]); + return Vectorized(_mm512_cvtepu8_epi16(src256)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + auto src128 = _mm512_cvtepi32_epi8(src[0]); + return Vectorized(_mm512_castsi128_si512(src128)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + auto src256 = _mm512_cvtepi16_epi8(src[0]); + return Vectorized(_mm512_castsi256_si512(src256)); + } +}; + +template +struct VecConvert< + dst_t, + 1, + src_t, + 1, + typename std::enable_if_t< + (is_reduced_floating_point_v && is_8bit_integer_v) || + (is_reduced_floating_point_v && is_8bit_integer_v), + void>> { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN tmp_fp32 = VecConvert::apply(src); + return VecConvert::apply(tmp_fp32); + } +}; + +template +struct VecConvert< + dst_t, + 1, + float, + 2, + typename std::enable_if_t, void>> { + static inline VectorizedN apply(const VectorizedN& src) { + at::vec::Vectorized vec1 = convert_float_to_int8(src[0]); + at::vec::Vectorized vec2 = convert_float_to_int8(src[1]); + __m128 lane2 = _mm512_castps512_ps128(_mm512_castsi512_ps(vec2)); + __m512 result = _mm512_insertf32x4( + _mm512_castsi512_ps(vec1), + lane2, + 1); // Insert lane2 into the second 128-bit lane + return at::vec::Vectorized(_mm512_castps_si512(result)); + } +}; + +template +struct VecConvert< + dst_t, + 1, + float, + 1, + typename std::enable_if_t, void>> { + static inline VectorizedN apply(const VectorizedN& src) { + return convert_float_to_int8(src[0]); + } +}; + +template +struct VecConvert< + float, + 2, + src_t, + 1, + typename std::enable_if_t, void>> { + static inline VectorizedN apply(const VectorizedN& src) { + __m512i src2 = + _mm512_castsi128_si512(_mm_castps_si128(_mm512_extractf32x4_ps( + _mm512_castsi512_ps(src[0]), 1) // Extract the second 128-bit lane + )); + return VectorizedN( + convert_int8_to_float(src[0]), + convert_int8_to_float(src2)); + } +}; + +template +struct VecConvert< + float, + 1, + src_t, + 1, + typename std::enable_if_t, void>> { + static inline VectorizedN apply(const VectorizedN& src) { + return convert_int8_to_float(src[0]); + } +}; + +template +struct VecConvert< + dst_t, + 1, + int64_t, + 2, + std::enable_if_t< + std::is_same_v || std::is_same_v>> { + static inline VectorizedN apply( + const VectorizedN& src) { + return VecConvert::apply( + VecConvert::apply(src)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src_n) { + at::vec::Vectorized src = src_n[0]; + __m128i res128 = cvtfp32_fp8e4m3(src); + return at::vec::Vectorized(_mm512_castsi128_si512(res128)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src_n) { + // cvt first 16x8 bits from Float8_e4m3fn to float + at::vec::Vectorized src = src_n[0]; + __m512 result; + cvtfp8e4m3_fp32(_mm512_castsi512_si128(src), result); + return at::vec::Vectorized(result); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src_n) { + at::vec::Vectorized src = src_n[0]; + __m128i res128 = cvtfp32_fp8e5m2(src); + return at::vec::Vectorized(_mm512_castsi128_si512(res128)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src_n) { + // cvt first 16x8 bits from Float8_e5m2 to float + at::vec::Vectorized src = src_n[0]; + __m512 result; + cvtfp8e5m2_fp32(_mm512_castsi512_si128(src), result); + return at::vec::Vectorized(result); + } +}; + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_double.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_double.h new file mode 100644 index 0000000000000000000000000000000000000000..6f6b558918cb9e8a227a44536cc81c9a58b724cc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_double.h @@ -0,0 +1,545 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#if (defined(CPU_CAPABILITY_AVX512)) +#define SLEEF_STATIC_LIBS +#include +#endif + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX512) + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + static constexpr __m512i zero_vector{0, 0, 0, 0, 0, 0, 0, 0}; + + public: + // values needs to be public for compilation with clang + // as vec512.h uses it + __m512d values; + using value_type = double; + using size_type = int; + static constexpr size_type size() { + return 8; + } + Vectorized() {} + Vectorized(__m512d v) : values(v) {} + Vectorized(double val) { + values = _mm512_set1_pd(val); + } + Vectorized( + double val1, + double val2, + double val3, + double val4, + double val5, + double val6, + double val7, + double val8) { + values = _mm512_setr_pd(val1, val2, val3, val4, val5, val6, val7, val8); + } + operator __m512d() const { + return values; + } + template + static Vectorized blend( + const Vectorized& a, + const Vectorized& b) { + return _mm512_mask_blend_pd(mask, a.values, b.values); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + auto all_ones = _mm512_set1_epi64(0xFFFFFFFFFFFFFFFF); + auto mmask = _mm512_cmp_epi64_mask( + _mm512_castpd_si512(mask.values), all_ones, _MM_CMPINT_EQ); + return _mm512_mask_blend_pd(mmask, a.values, b.values); + } + template + static Vectorized arange( + double base = 0., + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return _mm512_loadu_pd(reinterpret_cast(ptr)); + + __mmask8 mask = (1ULL << count) - 1; + return _mm512_maskz_loadu_pd(mask, ptr); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm512_storeu_pd(reinterpret_cast(ptr), values); + } else if (count > 0) { + __mmask8 mask = (1ULL << count) - 1; + _mm512_mask_storeu_pd(reinterpret_cast(ptr), mask, values); + } + } + const double& operator[](int idx) const = delete; + double& operator[](int idx) = delete; + int zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit + // and others are translated to 0-bit + __mmask8 cmp = _mm512_cmp_pd_mask(values, _mm512_set1_pd(0.0), _CMP_EQ_OQ); + return static_cast(cmp); + } + Vectorized isnan() const { + auto cmp_mask = + _mm512_cmp_pd_mask(values, _mm512_set1_pd(0.0), _CMP_UNORD_Q); + return _mm512_castsi512_pd( + _mm512_mask_set1_epi64(zero_vector, cmp_mask, 0xFFFFFFFFFFFFFFFF)); + } + bool has_inf_nan() const { + __m512d self_sub = _mm512_sub_pd(values, values); + return (_mm512_movepi8_mask(_mm512_castpd_si512(self_sub)) & + 0x7777777777777777) != 0; + } + Vectorized map(double (*const f)(double)) const { + __at_align__ double tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + Vectorized abs() const { + auto mask = _mm512_set1_pd(-0.f); + return _mm512_andnot_pd(mask, values); + } + Vectorized angle() const { + const auto zero_vec = _mm512_castsi512_pd(zero_vector); + const auto nan_vec = _mm512_set1_pd(NAN); + const auto not_nan_mask = _mm512_cmp_pd_mask(values, values, _CMP_EQ_OQ); + const auto not_nan = + _mm512_mask_set1_epi64(zero_vector, not_nan_mask, 0xFFFFFFFFFFFFFFFF); + const auto nan_mask = + _mm512_cmp_pd_mask(_mm512_castsi512_pd(not_nan), zero_vec, _CMP_EQ_OQ); + const auto pi = _mm512_set1_pd(c10::pi); + + const auto neg_mask = _mm512_cmp_pd_mask(values, zero_vec, _CMP_LT_OQ); + auto angle = _mm512_mask_blend_pd(neg_mask, zero_vec, pi); + angle = _mm512_mask_blend_pd(nan_mask, angle, nan_vec); + return angle; + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm512_set1_pd(0); + } + Vectorized conj() const { + return *this; + } + Vectorized acos() const { + return Vectorized(Sleef_acosd8_u10(values)); + } + Vectorized acosh() const { + return Vectorized(Sleef_acoshd8_u10(values)); + } + Vectorized asin() const { + return Vectorized(Sleef_asind8_u10(values)); + } + Vectorized asinh() const { + return Vectorized(Sleef_asinhd8_u10(values)); + } + Vectorized atan() const { + return Vectorized(Sleef_atand8_u10(values)); + } + Vectorized atanh() const { + return Vectorized(Sleef_atanhd8_u10(values)); + } + Vectorized atan2(const Vectorized& b) const { + return Vectorized(Sleef_atan2d8_u10(values, b)); + } + Vectorized copysign(const Vectorized& sign) const { + return Vectorized(Sleef_copysignd8(values, sign)); + } + Vectorized erf() const { + return Vectorized(Sleef_erfd8_u10(values)); + } + Vectorized erfc() const { + return Vectorized(Sleef_erfcd8_u15(values)); + } + Vectorized erfinv() const { + return map(calc_erfinv); + } + Vectorized exp() const { + return Vectorized(Sleef_expd8_u10(values)); + } + Vectorized exp2() const { + return Vectorized(Sleef_exp2d8_u10(values)); + } + Vectorized expm1() const { + return Vectorized(Sleef_expm1d8_u10(values)); + } + Vectorized exp_u20() const { + return exp(); + } + Vectorized fmod(const Vectorized& q) const { + return Vectorized(Sleef_fmodd8(values, q)); + } + Vectorized hypot(const Vectorized& b) const { + return Vectorized(Sleef_hypotd8_u05(values, b)); + } + Vectorized i0() const { + return map(calc_i0); + } + Vectorized i0e() const { + return map(calc_i0e); + } + Vectorized digamma() const { + return map(calc_digamma); + } + Vectorized igamma(const Vectorized& x) const { + __at_align__ double tmp[size()]; + __at_align__ double tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (const auto i : c10::irange(size())) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized igammac(const Vectorized& x) const { + __at_align__ double tmp[size()]; + __at_align__ double tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (const auto i : c10::irange(size())) { + tmp[i] = calc_igammac(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized log() const { + return Vectorized(Sleef_logd8_u10(values)); + } + Vectorized log2() const { + return Vectorized(Sleef_log2d8_u10(values)); + } + Vectorized log10() const { + return Vectorized(Sleef_log10d8_u10(values)); + } + Vectorized log1p() const { + return Vectorized(Sleef_log1pd8_u10(values)); + } + Vectorized sin() const { + return Vectorized(Sleef_sind8_u10(values)); + } + Vectorized sinh() const { + return Vectorized(Sleef_sinhd8_u10(values)); + } + Vectorized cos() const { + return Vectorized(Sleef_cosd8_u10(values)); + } + Vectorized cosh() const { + return Vectorized(Sleef_coshd8_u10(values)); + } + Vectorized ceil() const { + return _mm512_ceil_pd(values); + } + Vectorized floor() const { + return _mm512_floor_pd(values); + } + Vectorized frac() const; + Vectorized neg() const { + return _mm512_xor_pd(_mm512_set1_pd(-0.), values); + } + Vectorized nextafter(const Vectorized& b) const { + return Vectorized(Sleef_nextafterd8(values, b)); + } + Vectorized round() const { + return _mm512_roundscale_pd( + values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + Vectorized tan() const { + return Vectorized(Sleef_tand8_u10(values)); + } + Vectorized tanh() const { + return Vectorized(Sleef_tanhd8_u10(values)); + } + Vectorized trunc() const { + return _mm512_roundscale_pd( + values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + } + Vectorized lgamma() const { + return Vectorized(Sleef_lgammad8_u10(values)); + } + Vectorized sqrt() const { + return _mm512_sqrt_pd(values); + } + Vectorized reciprocal() const { + return _mm512_div_pd(_mm512_set1_pd(1), values); + } + Vectorized rsqrt() const { + return _mm512_div_pd(_mm512_set1_pd(1), _mm512_sqrt_pd(values)); + } + Vectorized pow(const Vectorized& b) const { + return Vectorized(Sleef_powd8_u10(values, b)); + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized operator==(const Vectorized& other) const { + auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_EQ_OQ); + return _mm512_castsi512_pd( + _mm512_mask_set1_epi64(zero_vector, cmp_mask, 0xFFFFFFFFFFFFFFFF)); + } + + Vectorized operator!=(const Vectorized& other) const { + auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_NEQ_UQ); + return _mm512_castsi512_pd( + _mm512_mask_set1_epi64(zero_vector, cmp_mask, 0xFFFFFFFFFFFFFFFF)); + } + + Vectorized operator<(const Vectorized& other) const { + auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_LT_OQ); + return _mm512_castsi512_pd( + _mm512_mask_set1_epi64(zero_vector, cmp_mask, 0xFFFFFFFFFFFFFFFF)); + } + + Vectorized operator<=(const Vectorized& other) const { + auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_LE_OQ); + return _mm512_castsi512_pd( + _mm512_mask_set1_epi64(zero_vector, cmp_mask, 0xFFFFFFFFFFFFFFFF)); + } + + Vectorized operator>(const Vectorized& other) const { + auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_GT_OQ); + return _mm512_castsi512_pd( + _mm512_mask_set1_epi64(zero_vector, cmp_mask, 0xFFFFFFFFFFFFFFFF)); + } + + Vectorized operator>=(const Vectorized& other) const { + auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_GE_OQ); + return _mm512_castsi512_pd( + _mm512_mask_set1_epi64(zero_vector, cmp_mask, 0xFFFFFFFFFFFFFFFF)); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm512_add_pd(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm512_sub_pd(a, b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return _mm512_mul_pd(a, b); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return _mm512_div_pd(a, b); +} + +// frac. Implement this here so we can use subtraction. +inline Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + auto zero_vec = _mm512_set1_epi64(0); + Vectorized max = _mm512_max_pd(a, b); + auto isnan_mask = _mm512_cmp_pd_mask(a, b, _CMP_UNORD_Q); + auto isnan = _mm512_castsi512_pd( + _mm512_mask_set1_epi64(zero_vec, isnan_mask, 0xFFFFFFFFFFFFFFFF)); + // Exploit the fact that all-ones is a NaN. + return _mm512_or_pd(max, isnan); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + auto zero_vec = _mm512_set1_epi64(0); + Vectorized min = _mm512_min_pd(a, b); + auto isnan_mask = _mm512_cmp_pd_mask(a, b, _CMP_UNORD_Q); + auto isnan = _mm512_castsi512_pd( + _mm512_mask_set1_epi64(zero_vec, isnan_mask, 0xFFFFFFFFFFFFFFFF)); + // Exploit the fact that all-ones is a NaN. + return _mm512_or_pd(min, isnan); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + return _mm512_min_pd(max, _mm512_max_pd(min, a)); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + return _mm512_max_pd(min, a); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + return _mm512_min_pd(max, a); +} + +template <> +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return _mm512_and_pd(a, b); +} + +template <> +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return _mm512_or_pd(a, b); +} + +template <> +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return _mm512_xor_pd(a, b); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0); +} + +template <> +inline void convert(const double* src, double* dst, int64_t n) { + int64_t i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + _mm512_storeu_pd(dst + i, _mm512_loadu_pd(src + i)); + } +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = src[i]; + } +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return _mm512_fmadd_pd(a, b, c); +} + +template <> +Vectorized inline fmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return _mm512_fmsub_pd(a, b, c); +} + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_float.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_float.h new file mode 100644 index 0000000000000000000000000000000000000000..184bc4db4aaf3bee9387eb7d188e48aa8d8d56cc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_float.h @@ -0,0 +1,868 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#if defined(CPU_CAPABILITY_AVX512) +#define SLEEF_STATIC_LIBS +#include +#endif + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX512) + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + static constexpr __m512i zero_vec{0, 0, 0, 0, 0, 0, 0, 0}; + + public: + __m512 values; + using value_type = float; + using size_type = int; + static constexpr size_type size() { + return 16; + } + Vectorized() {} + Vectorized(__m512 v) : values(v) {} + Vectorized(float val) { + values = _mm512_set1_ps(val); + } + Vectorized( + float val1, + float val2, + float val3, + float val4, + float val5, + float val6, + float val7, + float val8, + float val9, + float val10, + float val11, + float val12, + float val13, + float val14, + float val15, + float val16) { + values = _mm512_setr_ps( + val1, + val2, + val3, + val4, + val5, + val6, + val7, + val8, + val9, + val10, + val11, + val12, + val13, + val14, + val15, + val16); + } + Vectorized(const float (&arr)[16]) + : Vectorized( + arr[0], + arr[1], + arr[2], + arr[3], + arr[4], + arr[5], + arr[6], + arr[7], + arr[8], + arr[9], + arr[10], + arr[11], + arr[12], + arr[13], + arr[14], + arr[15]) {} + operator __m512() const { + return values; + } + template + static Vectorized blend( + const Vectorized& a, + const Vectorized& b) { + return _mm512_mask_blend_ps(mask, a.values, b.values); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + auto mmask = _mm512_cmp_epi32_mask( + _mm512_castps_si512(mask.values), all_ones, _MM_CMPINT_EQ); + return _mm512_mask_blend_ps(mmask, a.values, b.values); + } + template + static Vectorized arange( + float base = 0.f, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step, + base + 8 * step, + base + 9 * step, + base + 10 * step, + base + 11 * step, + base + 12 * step, + base + 13 * step, + base + 14 * step, + base + 15 * step); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + case 8: + return blend<255>(a, b); + case 9: + return blend<511>(a, b); + case 10: + return blend<1023>(a, b); + case 11: + return blend<2047>(a, b); + case 12: + return blend<4095>(a, b); + case 13: + return blend<8191>(a, b); + case 14: + return blend<16383>(a, b); + case 15: + return blend<32767>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return _mm512_loadu_ps(reinterpret_cast(ptr)); + + __mmask16 mask = (1ULL << count) - 1; + return _mm512_maskz_loadu_ps(mask, ptr); + } + void store(void* ptr, int64_t count = size()) const { + if (count == size()) { + _mm512_storeu_ps(reinterpret_cast(ptr), values); + } else if (count > 0) { + __mmask16 mask = (1ULL << count) - 1; + _mm512_mask_storeu_ps(reinterpret_cast(ptr), mask, values); + } + } + const float& operator[](int idx) const = delete; + float& operator[](int idx) = delete; + int zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit + // and others are translated to 0-bit + __mmask16 cmp = _mm512_cmp_ps_mask(values, _mm512_set1_ps(0.0), _CMP_EQ_OQ); + return static_cast(cmp); + } + Vectorized isnan() const { + auto mask = _mm512_cmp_ps_mask(values, _mm512_set1_ps(0.0), _CMP_UNORD_Q); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, mask, 0xFFFFFFFF)); + } + bool has_inf_nan() const { + __m512 self_sub = _mm512_sub_ps(values, values); + return (_mm512_movepi8_mask(_mm512_castps_si512(self_sub)) & + 0x7777777777777777) != 0; + } + Vectorized map(float (*const f)(float)) const { + __at_align__ float tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + Vectorized abs() const { + auto mask = _mm512_set1_ps(-0.f); + return _mm512_andnot_ps(mask, values); + } + Vectorized angle() const { + __m512 zero_vec = _mm512_set1_ps(0.f); + const auto nan_vec = _mm512_set1_ps(NAN); + const auto not_nan_mask = _mm512_cmp_ps_mask(values, values, _CMP_EQ_OQ); + const auto not_nan_vec = _mm512_mask_set1_epi32( + _mm512_castps_si512(zero_vec), not_nan_mask, 0xFFFFFFFF); + const auto nan_mask = _mm512_cmp_ps_mask( + _mm512_castsi512_ps(not_nan_vec), zero_vec, _CMP_EQ_OQ); + const auto pi = _mm512_set1_ps(c10::pi); + + const auto neg_mask = _mm512_cmp_ps_mask(values, zero_vec, _CMP_LT_OQ); + auto angle = _mm512_mask_blend_ps(neg_mask, zero_vec, pi); + angle = _mm512_mask_blend_ps(nan_mask, angle, nan_vec); + return angle; + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm512_set1_ps(0); + } + Vectorized conj() const { + return *this; + } + Vectorized acos() const { + return Vectorized(Sleef_acosf16_u10(values)); + } + Vectorized acosh() const { + return Vectorized(Sleef_acoshf16_u10(values)); + } + Vectorized asin() const { + return Vectorized(Sleef_asinf16_u10(values)); + } + Vectorized asinh() const { + return Vectorized(Sleef_asinhf16_u10(values)); + } + Vectorized atan() const { + return Vectorized(Sleef_atanf16_u10(values)); + } + Vectorized atanh() const { + return Vectorized(Sleef_atanhf16_u10(values)); + } + Vectorized atan2(const Vectorized& b) const { + return Vectorized(Sleef_atan2f16_u10(values, b)); + } + Vectorized copysign(const Vectorized& sign) const { + return Vectorized(Sleef_copysignf16(values, sign)); + } + Vectorized erf() const { + // constants + const auto neg_zero_vec = _mm512_set1_ps(-0.f); + const auto one_vec = _mm512_set1_ps(1.0f); + const auto p = _mm512_set1_ps(0.3275911f); + const auto p1 = _mm512_set1_ps(0.254829592f); + const auto p2 = _mm512_set1_ps(-0.284496736f); + const auto p3 = _mm512_set1_ps(1.421413741f); + const auto p4 = _mm512_set1_ps(-1.453152027f); + const auto p5 = _mm512_set1_ps(1.061405429f); + // sign(x) + auto sign_mask = _mm512_and_ps(neg_zero_vec, values); + auto abs_vec = _mm512_abs_ps(values); + // t = 1 / (p * abs(x) + 1) + auto tmp0 = _mm512_fmadd_ps(p, abs_vec, one_vec); + auto t = _mm512_div_ps(one_vec, tmp0); + // r = p5 * t ^ 4 + p4 * t ^ 3 + p3 * t ^ 2 + p2 * t + p1 + auto tmp1 = _mm512_fmadd_ps(p5, t, p4); + auto tmp2 = _mm512_fmadd_ps(tmp1, t, p3); + auto tmp3 = _mm512_fmadd_ps(tmp2, t, p2); + auto r = _mm512_fmadd_ps(tmp3, t, p1); + // - exp(- x * x) + auto pow_2 = _mm512_mul_ps(values, values); + auto neg_pow_2 = _mm512_xor_ps(neg_zero_vec, pow_2); + // auto tmp4 = exp(neg_pow_2); + auto tmp4 = Vectorized(Sleef_expf16_u10(neg_pow_2)); + auto tmp5 = _mm512_xor_ps(neg_zero_vec, tmp4); + // erf(x) = sign(x) * (1 - r * t * exp(- x * x)) + auto tmp6 = _mm512_mul_ps(tmp5, t); + auto tmp7 = _mm512_fmadd_ps(tmp6, r, one_vec); + return _mm512_xor_ps(sign_mask, tmp7); + } + Vectorized erfc() const { + return Vectorized(Sleef_erfcf16_u15(values)); + } + Vectorized erfinv() const { + return map(calc_erfinv); + } + Vectorized exp() const { + return Vectorized(Sleef_expf16_u10(values)); + } + Vectorized exp2() const { + return Vectorized(Sleef_exp2f16_u10(values)); + } + Vectorized expm1() const { + return Vectorized(Sleef_expm1f16_u10(values)); + } + Vectorized exp_u20() const { + // A faster version of exp with ULP=20 + const __m512 vec_factorial_1 = + _mm512_set1_ps(0.999999701f); // 1/factorial(1) + const __m512 vec_factorial_2 = + _mm512_set1_ps(0.499991506f); // 1/factorial(2) + const __m512 vec_factorial_3 = + _mm512_set1_ps(0.166676521f); // 1/factorial(3) + const __m512 vec_factorial_4 = + _mm512_set1_ps(0.0418978221f); // 1/factorial(4) + const __m512 vec_factorial_5 = + _mm512_set1_ps(0.00828929059f); // 1/factorial(5) + const __m512 vec_exp_log2ef = + _mm512_castsi512_ps(_mm512_set1_epi32(0x3fb8aa3b)); // log2(e) + const __m512 vec_half = _mm512_set1_ps(0.5f); + const __m512 vec_one = _mm512_set1_ps(1.f); + const __m512 vec_zero = _mm512_set1_ps(0.f); + const __m512 vec_two = _mm512_set1_ps(2.f); + const __m512 vec_ln2f = + _mm512_castsi512_ps(_mm512_set1_epi32(0x3f317218)); // ln(2) + const __m512 vec_ln_flt_min = + _mm512_castsi512_ps(_mm512_set1_epi32(0xc2aeac50)); + const __m512 vec_ln_flt_max = + _mm512_castsi512_ps(_mm512_set1_epi32(0x42b17218)); + const __m512i vec_127 = _mm512_set1_epi32(0x0000007f); + const int n_mantissa_bits = 23; + + // exp(x) = + // = exp(n * ln(2) + r) // divide x by ln(2) and get quot and rem + // = 2^n * exp(r) // simplify the exp(n*ln(2)) expression + + auto less_ln_flt_min_mask = + _mm512_cmp_ps_mask(values, vec_ln_flt_min, 1 /*_CMP_LT_OS*/); + auto vec_src = _mm512_min_ps(values, vec_ln_flt_max); + vec_src = _mm512_max_ps(vec_src, vec_ln_flt_min); + + // fx = floorf(x * log2ef + 0.5) + auto vec_fx = _mm512_fmadd_ps(vec_src, vec_exp_log2ef, vec_half); + auto vec_fx_i = _mm512_cvt_roundps_epi32( + vec_fx, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC); + vec_fx = _mm512_cvtepi32_ps(vec_fx_i); + + // x = x - fx * ln2 + auto vec_exp_poly = _mm512_fnmadd_ps(vec_fx, vec_ln2f, vec_src); + + // compute polynomial + auto vec_res = + _mm512_fmadd_ps(vec_exp_poly, vec_factorial_5, vec_factorial_4); + vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_3); + vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_2); + vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_1); + vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_one); + + // compute 2^(n-1) + auto vec_exp_number = _mm512_sub_ps(vec_fx, vec_one); + auto vec_exp_number_i = _mm512_cvtps_epi32(vec_exp_number); + auto vec_two_pow_n_i = _mm512_add_epi32(vec_exp_number_i, vec_127); + vec_two_pow_n_i = _mm512_slli_epi32(vec_two_pow_n_i, n_mantissa_bits); + auto vec_two_pow_n = _mm512_castsi512_ps(vec_two_pow_n_i); + vec_two_pow_n = + _mm512_mask_blend_ps(less_ln_flt_min_mask, vec_two_pow_n, vec_zero); + + // y = y * 2^n + vec_res = _mm512_mul_ps(vec_res, vec_two_pow_n); + vec_res = _mm512_mul_ps(vec_res, vec_two); + return vec_res; + } + Vectorized fmod(const Vectorized& q) const { + return Vectorized(Sleef_fmodf16(values, q)); + } + Vectorized log() const { + return Vectorized(Sleef_logf16_u10(values)); + } + Vectorized log2() const { + return Vectorized(Sleef_log2f16_u10(values)); + } + Vectorized log10() const { + return Vectorized(Sleef_log10f16_u10(values)); + } + Vectorized log1p() const { + return Vectorized(Sleef_log1pf16_u10(values)); + } + Vectorized frac() const; + Vectorized sin() const { + return Vectorized(Sleef_sinf16_u35(values)); + } + Vectorized sinh() const { + return Vectorized(Sleef_sinhf16_u10(values)); + } + Vectorized cos() const { + return Vectorized(Sleef_cosf16_u35(values)); + } + Vectorized cosh() const { + return Vectorized(Sleef_coshf16_u10(values)); + } + Vectorized ceil() const { + return _mm512_ceil_ps(values); + } + Vectorized floor() const { + return _mm512_floor_ps(values); + } + Vectorized hypot(const Vectorized& b) const { + return Vectorized(Sleef_hypotf16_u05(values, b)); + } + Vectorized i0() const { + return map(calc_i0); + } + Vectorized i0e() const { + return map(calc_i0e); + } + Vectorized digamma() const { + return map(calc_digamma); + } + Vectorized igamma(const Vectorized& x) const { + __at_align__ float tmp[size()]; + __at_align__ float tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (const auto i : c10::irange(size())) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized igammac(const Vectorized& x) const { + __at_align__ float tmp[size()]; + __at_align__ float tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (const auto i : c10::irange(size())) { + tmp[i] = calc_igammac(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized neg() const { + return _mm512_xor_ps(_mm512_set1_ps(-0.f), values); + } + Vectorized nextafter(const Vectorized& b) const { + return Vectorized(Sleef_nextafterf16(values, b)); + } + Vectorized round() const { + return _mm512_roundscale_ps( + values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + Vectorized tan() const { + return Vectorized(Sleef_tanf16_u10(values)); + } + Vectorized tanh() const { + return Vectorized(Sleef_tanhf16_u10(values)); + } + Vectorized trunc() const { + return _mm512_roundscale_ps( + values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + } + Vectorized lgamma() const { + return Vectorized(Sleef_lgammaf16_u10(values)); + } + Vectorized sqrt() const { + return _mm512_sqrt_ps(values); + } + Vectorized reciprocal() const { + return _mm512_div_ps(_mm512_set1_ps(1), values); + } + Vectorized rsqrt() const { + return _mm512_div_ps(_mm512_set1_ps(1), _mm512_sqrt_ps(values)); + } + Vectorized pow(const Vectorized& b) const { + return Vectorized(Sleef_powf16_u10(values, b)); + } + float reduce_add() const { + return _mm512_reduce_add_ps(values); + } + float reduce_max() const { + return _mm512_reduce_max_ps(values); + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized operator==(const Vectorized& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_EQ_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, mask, 0xFFFFFFFF)); + } + + Vectorized operator!=(const Vectorized& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_NEQ_UQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, mask, 0xFFFFFFFF)); + } + + Vectorized operator<(const Vectorized& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_LT_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, mask, 0xFFFFFFFF)); + } + + Vectorized operator<=(const Vectorized& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_LE_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, mask, 0xFFFFFFFF)); + } + + Vectorized operator>(const Vectorized& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_GT_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, mask, 0xFFFFFFFF)); + } + + Vectorized operator>=(const Vectorized& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_GE_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, mask, 0xFFFFFFFF)); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm512_add_ps(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm512_sub_ps(a, b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return _mm512_mul_ps(a, b); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return _mm512_div_ps(a, b); +} + +// frac. Implement this here so we can use subtraction +inline Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + auto zero_vec = _mm512_set1_epi32(0); + auto max = _mm512_max_ps(a, b); + auto isnan_mask = _mm512_cmp_ps_mask(a, b, _CMP_UNORD_Q); + auto isnan = _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, isnan_mask, 0xFFFFFFFF)); + // Exploit the fact that all-ones is a NaN. + return _mm512_or_ps(max, isnan); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + auto zero_vec = _mm512_set1_epi32(0); + auto min = _mm512_min_ps(a, b); + auto isnan_mask = _mm512_cmp_ps_mask(a, b, _CMP_UNORD_Q); + auto isnan = _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, isnan_mask, 0xFFFFFFFF)); + // Exploit the fact that all-ones is a NaN. + return _mm512_or_ps(min, isnan); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + return _mm512_min_ps(max, _mm512_max_ps(min, a)); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + return _mm512_min_ps(max, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + return _mm512_max_ps(min, a); +} + +template <> +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return _mm512_and_ps(a, b); +} + +template <> +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return _mm512_or_ps(a, b); +} + +template <> +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return _mm512_xor_ps(a, b); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +template <> +inline void convert(const float* src, float* dst, int64_t n) { + int64_t i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + _mm512_storeu_ps(dst + i, _mm512_loadu_ps(src + i)); + } +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = src[i]; + } +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return _mm512_fmadd_ps(a, b, c); +} + +template <> +Vectorized inline fmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return _mm512_fmsub_ps(a, b, c); +} + +// TODO: rewrite with ATEN vectorized (need to add unpack and shuffle) +// Used by Inductor CPP codegen for micro gemm +// Code referred to FBGEMM: +// https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#L230-L304 +// kernel for transposing mxn where m, n <= 16 +// (M + 1) / 2 * 2 + (M + 3) / 4 * 4 + (M + 7) / 8 * 8 + N instructions +inline void transpose_block( + at::vec::VectorizedN& input, + int M = 16, + int N = 16) { + TORCH_CHECK(M <= 16 && N <= 16, "transpose_block expects M, N <= 16."); + // unpacking and interleaving 32-bit elements + __m512 temp[16]; + int i; + for (i = 0; i < (M + 1) / 2; ++i) { + temp[2 * i] = _mm512_unpacklo_ps(input[2 * i], input[2 * i + 1]); + temp[2 * i + 1] = _mm512_unpackhi_ps(input[2 * i], input[2 * i + 1]); + } + for (i = i * 2; i < 16; ++i) { + temp[i] = _mm512_setzero_ps(); + } + + // unpacking and interleaving 64-bit elements + for (i = 0; i < (M + 3) / 4; ++i) { + input[4 * i] = _mm512_castpd_ps(_mm512_unpacklo_pd( + _mm512_castps_pd(temp[4 * i]), _mm512_castps_pd(temp[4 * i + 2]))); + input[4 * i + 1] = _mm512_castpd_ps(_mm512_unpackhi_pd( + _mm512_castps_pd(temp[4 * i]), _mm512_castps_pd(temp[4 * i + 2]))); + input[4 * i + 2] = _mm512_castpd_ps(_mm512_unpacklo_pd( + _mm512_castps_pd(temp[4 * i + 1]), _mm512_castps_pd(temp[4 * i + 3]))); + input[4 * i + 3] = _mm512_castpd_ps(_mm512_unpackhi_pd( + _mm512_castps_pd(temp[4 * i + 1]), _mm512_castps_pd(temp[4 * i + 3]))); + } + + // shuffle 128-bits (composed of 4 32-bit elements) + for (i = 0; i < (M + 7) / 8; ++i) { + temp[8 * i] = _mm512_shuffle_f32x4(input[8 * i], input[8 * i + 4], 0x88); + temp[8 * i + 1] = + _mm512_shuffle_f32x4(input[8 * i + 1], input[8 * i + 5], 0x88); + temp[8 * i + 2] = + _mm512_shuffle_f32x4(input[8 * i + 2], input[8 * i + 6], 0x88); + temp[8 * i + 3] = + _mm512_shuffle_f32x4(input[8 * i + 3], input[8 * i + 7], 0x88); + temp[8 * i + 4] = + _mm512_shuffle_f32x4(input[8 * i], input[8 * i + 4], 0xdd); + temp[8 * i + 5] = + _mm512_shuffle_f32x4(input[8 * i + 1], input[8 * i + 5], 0xdd); + temp[8 * i + 6] = + _mm512_shuffle_f32x4(input[8 * i + 2], input[8 * i + 6], 0xdd); + temp[8 * i + 7] = + _mm512_shuffle_f32x4(input[8 * i + 3], input[8 * i + 7], 0xdd); + } + + for (i = 0; i < N; ++i) { + if (i < 8) { + input[i] = _mm512_shuffle_f32x4(temp[i], temp[8 + i], 0x88); + } else { + input[i] = _mm512_shuffle_f32x4(temp[i - 8], temp[i], 0xdd); + } + } +} + +// TODO(jgong5): rewrite with ATEN vectorized (need to add unpack and shuffle) +// Used by Inductor CPP codegen +// Code referred to FBGEMM: +// https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#L230-L304 +// kernel for transposing mxn where m, n <= 16 +// M + (M + 1) / 2 * 2 + (M + 3) / 4 * 4 + (M + 7) / 8 * 8 + 2 * N instructions +inline void transpose_mxn_16x16( + const float* src, + int64_t ld_src, + float* dst, + int64_t ld_dst, + int M, + int N) { + TORCH_CHECK(M <= 16 && N <= 16, "transpose_mxn expects M, N <= 16."); + // load from src to registers + at::vec::VectorizedN input; + int i; + if (N == 16) { + for (i = 0; i < M; ++i) { + input[i] = _mm512_loadu_ps(&src[i * ld_src]); + } + } else { + __mmask16 src_mask = (1 << N) - 1; + for (i = 0; i < M; ++i) { + input[i] = _mm512_maskz_loadu_ps(src_mask, &src[i * ld_src]); + } + } + for (; i < 16; ++i) { + // Not really needed but to avoid uninitialized variable warning. + // Shouldn't be much overhead because xor can be executed in parallel with + // other instructions. + input[i] = _mm512_setzero_ps(); + } + + transpose_block(input, M, N); + + // store from registers to dst + if (M == 16) { + for (i = 0; i < N; ++i) { + _mm512_storeu_ps(&dst[i * ld_dst], input[i]); + } + } else { + __mmask16 dst_mask = (1 << M) - 1; + for (i = 0; i < N; ++i) { + _mm512_mask_storeu_ps(&dst[i * ld_dst], dst_mask, input[i]); + } + } +} + +template <> +inline void transpose_mxn( + const float* src, + int64_t ld_src, + float* dst, + int64_t ld_dst, + int M, + int N) { + int64_t i = 0; + for (; i < M / 16 * 16; i += 16) { + int64_t j = 0; + for (; j < N / 16 * 16; j += 16) { + transpose_mxn_16x16( + src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, 16, 16); + } + // handle remainder j + int nrem = N - j; + if (nrem > 0) { + transpose_mxn_16x16( + src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, 16, nrem); + } + } + // handle remainder i + int mrem = M - i; + if (mrem > 0) { + int j = 0; + for (; j < N / 16 * 16; j += 16) { + transpose_mxn_16x16( + src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, mrem, 16); + } + // handle remainder j + int nrem = N - j; + transpose_mxn_16x16( + src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, mrem, nrem); + } +} + +template < + typename T, + int M, + int N, + typename std::enable_if_t, int> = 0> +inline void transpose_mxn( + const float* src, + int64_t ld_src, + float* dst, + int64_t ld_dst) { + transpose_mxn(src, ld_src, dst, ld_dst, M, N); +} + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_float8.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_float8.h new file mode 100644 index 0000000000000000000000000000000000000000..f247ae4574580e8e87d72bfcc41780856c924d22 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_float8.h @@ -0,0 +1,661 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#if (defined(CPU_CAPABILITY_AVX512)) +#define SLEEF_STATIC_LIBS +#include +#endif + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) + +static inline void cvtfp8e4m3_fp32(const __m128i& a, __m512& o) { + // Zero Extend + __m512i x = _mm512_cvtepu8_epi32(a); + __m512i val = _mm512_and_epi32( + _mm512_slli_epi32(x, 24), _mm512_set1_epi32(0x7FFFFFFF)); // nonsign_val + __m512i mant = + _mm512_and_si512(x, _mm512_set1_epi32(0x07)); // mantissa = x & 0x07 + __m512i exp = _mm512_and_si512( + _mm512_srli_epi32(x, 3), + _mm512_set1_epi32(0x0F)); // exp = (x >> 3) & 0x0F + __m512i sign = + _mm512_and_si512(x, _mm512_set1_epi32(0x80)); // sign = x & 0x80 + __m512i _zeros = _mm512_setzero_si512(); + + // --- Step 1: Calculate the renorm_shift + __m512i renorm_shift = _zeros; + // Denorm case (exp == 0 && mant != 0) --- + __mmask16 denormal_mask = _mm512_cmpeq_epi32_mask(exp, _zeros) & + _mm512_cmpneq_epi32_mask(mant, _zeros); + if (denormal_mask) { + // An alternative solution is as what scalar did in + // pytorch/c10/util/Float8_e4m3fn.h To count the num of leading zeros, since + // here we know the unsigned denorm value has zero sign and exp which is 5 + // leading zeros, we need to count the leading zero of mant (3bit) which may + // done through table lookup for example: const uint8_t lz_table[8] = {3, 2, + // 1, 1, 0, 0, 0, 0}; num_leading_zero = lz_table[mant] + 5; + + __m512i _ones = _mm512_set1_epi32(1); + __m512i _twos = _mm512_set1_epi32(2); + __m512i _threes = _mm512_set1_epi32(3); + + // Default leading zero number for denorm value is 1 = 5 - 4 + __m512i denorm_renorm_shift = _ones; + // For mant 001, leading zero number is 3 = 7 -4 + __mmask16 leading_Zero_mask = _mm512_cmpeq_epi32_mask(mant, _ones); + denorm_renorm_shift = + _mm512_mask_mov_epi32(denorm_renorm_shift, leading_Zero_mask, _threes); + // For mant 010 and 011, leading zero number is 2 = 6 -4 + leading_Zero_mask = _mm512_cmpeq_epi32_mask(mant, _twos); + denorm_renorm_shift = + _mm512_mask_mov_epi32(denorm_renorm_shift, leading_Zero_mask, _twos); + leading_Zero_mask = _mm512_cmpeq_epi32_mask(mant, _threes); + denorm_renorm_shift = + _mm512_mask_mov_epi32(denorm_renorm_shift, leading_Zero_mask, _twos); + + renorm_shift = + _mm512_mask_mov_epi32(renorm_shift, denormal_mask, denorm_renorm_shift); + } + + // --- Step 2: calculate norm and denorm --- + __m512i norm_shifted = + _mm512_srli_epi32(_mm512_sllv_epi32(val, renorm_shift), 4); + // exponent bias adjustment: (0x78 - renorm_shift) << 23 + __m512i exp_bias = _mm512_slli_epi32( + _mm512_sub_epi32(_mm512_set1_epi32(0x78), renorm_shift), 23); + val = _mm512_add_epi32(norm_shifted, exp_bias); + + // --- Step 3: Nan case (exp == 0xF && mant == 0x07) --- + __mmask16 nan_mask = _mm512_cmpeq_epi32_mask(exp, _mm512_set1_epi32(0xF)) & + _mm512_cmpeq_epi32_mask(mant, _mm512_set1_epi32(0x07)); + if (nan_mask) { + const __m512i nan_values = _mm512_set1_epi32(0x7FC00000); + val = _mm512_mask_mov_epi32(val, nan_mask, nan_values); + } + + // --- Step 4: Zero case (exp == 0x00 && mant == 0x00) --- + __mmask16 zero_mask = _mm512_cmpeq_epi32_mask(exp, _zeros) & + _mm512_cmpeq_epi32_mask(mant, _zeros); + if (zero_mask) { + val = _mm512_mask_mov_epi32(val, zero_mask, _zeros); + } + + // --- Step 5: OR with sign (sign bit << 24 to get to bit 31) --- + val = _mm512_or_si512(val, _mm512_slli_epi32(sign, 24)); + + o = _mm512_castsi512_ps(val); +} + +static inline __m128i cvtfp32_fp8e4m3(const __m512& src) { + // cvt 16x32 from fp32 to fp8 e4m3 + const __m512i sign_mask = _mm512_set1_epi32(0x80000000); + const __m512i fp8_max = _mm512_set1_epi32(UINT32_C(1087) << 20); + const __m512i denorm_thresh = _mm512_set1_epi32(UINT32_C(121) << 23); + const __m512i denorm_mask = _mm512_set1_epi32(UINT32_C(141) << 23); + const __m512i bias_part1 = _mm512_set1_epi32((uint32_t)(7 - 127) << 23); + const __m512i rounding_bias = _mm512_set1_epi32(0x7FFFF); + __m512i f_bits = _mm512_castps_si512(src); + // Extract and save sign + __m512i sign = _mm512_and_epi32(f_bits, sign_mask); + f_bits = _mm512_xor_epi32(f_bits, sign); + + // Prepare result containers + __m512i result = _mm512_setzero_si512(); + + // Step 1: Handle case of overflow + // (f_bits >= fp8_max): set result = 0x7f + __mmask16 overflow_mask = _mm512_cmpge_epu32_mask(f_bits, fp8_max); + if (overflow_mask) { + result = _mm512_mask_set1_epi32(result, overflow_mask, 0x7f); + } + + // Step 2: Handle small numbers (denormals) + // Small numbers (f_bits < denorm_thresh) + __mmask16 denorm_thresh_mask = _mm512_cmplt_epu32_mask(f_bits, denorm_thresh); + + if (denorm_thresh_mask) { + __m512 small_input = _mm512_castsi512_ps(f_bits); + __m512 small_denorm = + _mm512_add_ps(small_input, _mm512_castsi512_ps(denorm_mask)); + __m512i small_denorm_bits = _mm512_castps_si512(small_denorm); + __m512i small_result = _mm512_sub_epi32(small_denorm_bits, denorm_mask); + result = _mm512_mask_mov_epi32(result, denorm_thresh_mask, small_result); + } + + // Step 3: Handle normal numbers + __mmask16 normal_mask = ~(overflow_mask | denorm_thresh_mask); + + if (normal_mask) { + // mant_odd = (f_bits >> 20) & 1 + __m512i mant_odd = + _mm512_and_epi32(_mm512_srli_epi32(f_bits, 20), _mm512_set1_epi32(1)); + // f_bits += bias_part1 + rounding_bias + __m512i rounded = _mm512_add_epi32(f_bits, bias_part1); + rounded = _mm512_add_epi32(rounded, rounding_bias); + // Add mant_odd + rounded = _mm512_add_epi32(rounded, mant_odd); + // Shift right by 20 bits + __m512i normal_result = _mm512_srli_epi32(rounded, 20); + result = _mm512_mask_mov_epi32(result, normal_mask, normal_result); + } + + // Merge back the sign + __m512i sign_shifted = _mm512_srli_epi32(sign, 24); + result = _mm512_or_epi32(result, sign_shifted); + + // Now result is 16 x 32-bit integers, but we only need 8-bit for each + __m512i packed = _mm512_and_si512(result, _mm512_set1_epi32(0xFF)); + + // Narrow 32-bit integers to 8-bit + return _mm512_cvtepi32_epi8(packed); +} + +static inline float fp8e4m3_to_fp32_scalar(uint8_t val) { + __m512i v = _mm512_set1_epi8(val); + __m128i v_128 = _mm512_castsi512_si128(v); + __m512 o; + cvtfp8e4m3_fp32(v_128, o); + return _mm512_cvtss_f32(o); +} + +static inline uint8_t fp32_to_fp8e4m3_scalar(float val) { + __m512 v = _mm512_set1_ps(val); + __m128i o = cvtfp32_fp8e4m3(v); + return static_cast(_mm_cvtsi128_si32(o)); +} + +static inline void cvtfp8e5m2_fp32(const __m128i& a, __m512& o) { + __m256i a_256 = _mm256_castsi128_si256(a); + __m512i a_512 = _mm512_cvtepu8_epi16(a_256); + a_512 = _mm512_slli_epi16(a_512, 8); + a_256 = _mm512_castsi512_si256(a_512); + cvtfp16_fp32(a_256, o); +} + +static inline __m128i cvtfp32_fp8e5m2(const __m512& src) { + constexpr uint32_t fp32_inf = UINT32_C(255) << 23; + constexpr uint32_t fp8_max = UINT32_C(143) << 23; + constexpr uint32_t denorm_mask = UINT32_C(134) << 23; + + // Cvt to bits + __m512i input_bits = _mm512_castps_si512(src); + __m512i result = _mm512_setzero_si512(); + + // Get the sign + __m512i sign = _mm512_and_si512(input_bits, _mm512_set1_epi32(0x80000000)); + + // Get the unsigned input + input_bits = _mm512_xor_si512(input_bits, sign); + + // Calculate the mask for inf, nan and denorm + __mmask16 greater_than_fp8_max = + _mm512_cmpge_epi32_mask(input_bits, _mm512_set1_epi32(fp8_max)); + __mmask16 greater_than_fp32_inf = + _mm512_cmpgt_epi32_mask(input_bits, _mm512_set1_epi32(fp32_inf)); + __mmask16 less_than_normal = _mm512_cmpgt_epi32_mask( + _mm512_set1_epi32((UINT32_C(113) << 23)), input_bits); + __m512i temp_bits_for_denorm = _mm512_setzero_si512(); + if (less_than_normal) { + __m512i denorm_mask_512i = _mm512_set1_epi32(denorm_mask); + temp_bits_for_denorm = _mm512_castps_si512(_mm512_add_ps( + _mm512_castsi512_ps(input_bits), + _mm512_castsi512_ps(denorm_mask_512i))); + temp_bits_for_denorm = + _mm512_sub_epi32(temp_bits_for_denorm, denorm_mask_512i); + } + + // Step 1: Norm Val + __m512i mant_odd_mask = + _mm512_and_epi32(_mm512_srli_epi32(input_bits, 21), _mm512_set1_epi32(1)); + input_bits = _mm512_add_epi32( + input_bits, _mm512_set1_epi32(((uint32_t)(15 - 127) << 23) + 0xFFFFF)); + input_bits = _mm512_add_epi32(input_bits, mant_odd_mask); + result = _mm512_srli_epi32(input_bits, 21); + + // Step 2: INF and NAN + if (greater_than_fp8_max) { + result = _mm512_mask_mov_epi32( + result, greater_than_fp8_max, _mm512_set1_epi8(0x7C)); + if (greater_than_fp32_inf) { + result = _mm512_mask_mov_epi32( + result, greater_than_fp32_inf, _mm512_set1_epi8(0x7F)); + } + } + + // Step 3: Denorm val + if (less_than_normal) { + result = + _mm512_mask_mov_epi32(result, less_than_normal, temp_bits_for_denorm); + } + + // Step 4: restore sign + result = _mm512_or_si512(result, _mm512_srli_epi32(sign, 24)); + + return _mm512_cvtepi32_epi8(result); +} + +static inline float fp8e5m2_to_fp32_scalar(uint8_t val) { + __m512i v = _mm512_set1_epi8(val); + __m128i v_128 = _mm512_castsi512_si128(v); + __m512 o; + cvtfp8e5m2_fp32(v_128, o); + return _mm512_cvtss_f32(o); +} + +static inline uint8_t fp32_to_fp8e5m2_scalar(float val) { + __m512 v = _mm512_set1_ps(val); + __m128i o = cvtfp32_fp8e5m2(v); + return static_cast(_mm_cvtsi128_si32(o)); +} + +template +class Vectorizedf8 { + static_assert( + std::integral_constant < bool, + std::is_same_v || std::is_same_v < T, + at::Float8_e5m2 >> ::value, + "Support only float8 e4m3."); + + private: + __m512i values; + template + Vectorized inline binary_compare(const VectorizedType& b, Op op) const { + __m512 a0, a1, a2, a3; + __m512 b0, b1, b2, b3; + __m512 o0, o1, o2, o3; + if constexpr (std::is_same_v) { + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(values, 0), a0); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(b.values, 0), b0); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(values, 1), a1); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(b.values, 1), b1); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(values, 2), a2); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(b.values, 2), b2); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(values, 3), a3); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(b.values, 3), b3); + } else { + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(values, 0), a0); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(b.values, 0), b0); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(values, 1), a1); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(b.values, 1), b1); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(values, 2), a2); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(b.values, 2), b2); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(values, 3), a3); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(b.values, 3), b3); + } + + o0 = op(a0, b0); + o1 = op(a1, b1); + o2 = op(a2, b2); + o3 = op(a3, b3); + __m128i o128_0, o128_1, o128_2, o128_3; + if constexpr (std::is_same_v) { + o128_0 = cvtfp32_fp8e4m3(o0); + o128_1 = cvtfp32_fp8e4m3(o1); + o128_2 = cvtfp32_fp8e4m3(o2); + o128_3 = cvtfp32_fp8e4m3(o3); + } else { + o128_0 = cvtfp32_fp8e5m2(o0); + o128_1 = cvtfp32_fp8e5m2(o1); + o128_2 = cvtfp32_fp8e5m2(o2); + o128_3 = cvtfp32_fp8e5m2(o3); + } + + __m512i result = _mm512_setzero_si512(); + result = _mm512_inserti32x4(result, o128_0, 0); + result = _mm512_inserti32x4(result, o128_1, 1); + result = _mm512_inserti32x4(result, o128_2, 2); + result = _mm512_inserti32x4(result, o128_3, 3); + + return result; + } + + public: + using value_type = uint8_t; + using size_type = int; + static constexpr size_type size() { + return 64; + } + Vectorizedf8() {} + Vectorizedf8(__m512i v) : values(v) {} + Vectorizedf8(T val) { + value_type uw = val.x; + values = _mm512_set1_epi8(uw); + } + operator __m512i() const { + return values; + } + T& operator[](int idx) = delete; + const T& operator[](int idx) const = delete; + static Vectorized loadu(const void* ptr, int16_t count = size()) { + if (count == size()) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } else if (count == 16) { + // Fast path if only load element number of 16 + __m128i input_128 = + _mm_loadu_si128(reinterpret_cast(ptr)); + return _mm512_castsi128_si512(input_128); + } else { + __mmask64 mask = (1ULL << count) - 1; + return _mm512_maskz_loadu_epi8(mask, ptr); + } + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values); + } else if (count > 0) { + if (count == 16) { + // Fast path if only store element number of 16 + _mm_storeu_si128( + reinterpret_cast<__m128i*>(ptr), _mm512_castsi512_si128(values)); + } else { + __mmask64 mask = (1ULL << count) - 1; + _mm512_mask_storeu_epi8(ptr, mask, values); + } + } + } + + Vectorized abs() const { + return _mm512_andnot_si512(_mm512_set1_epi8(0x80), values); + } + + Vectorized inline operator==(const Vectorizedf8& other) const { + return binary_compare(other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_EQ_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); + } + + Vectorized inline operator!=(const Vectorizedf8& other) const { + return binary_compare(other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_NEQ_UQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); + } + + Vectorized inline operator>(const Vectorizedf8& other) const { + return binary_compare(other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GT_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); + } + + Vectorized inline operator>=(const Vectorizedf8& other) const { + return binary_compare(other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); + } + + Vectorized inline operator<(const Vectorizedf8& other) const { + return binary_compare(other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LT_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); + } + + Vectorized inline operator<=(const Vectorizedf8& other) const { + return binary_compare(other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LE_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); + } +}; + +template <> +class Vectorized : public Vectorizedf8 { + public: + using Vectorizedf8::Vectorizedf8; + + using value_type = Float8_e4m3fn; + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template < + typename T, + typename Op, + std::enable_if_t< + std::is_same_v || + std::is_same_v, + int> = 0> +static inline Vectorized binary_fp8_op_as_fp32( + const Vectorized& a, + const Vectorized& b, + Op op) { + __m512 a0, a1, a2, a3; + __m512 b0, b1, b2, b3; + __m512 o0, o1, o2, o3; + if constexpr (std::is_same_v) { + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(a, 0), a0); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(b, 0), b0); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(a, 1), a1); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(b, 1), b1); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(a, 2), a2); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(b, 2), b2); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(a, 3), a3); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(b, 3), b3); + } else { + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(a, 0), a0); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(b, 0), b0); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(a, 1), a1); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(b, 1), b1); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(a, 2), a2); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(b, 2), b2); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(a, 3), a3); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(b, 3), b3); + } + o0 = op(a0, b0); + o1 = op(a1, b1); + o2 = op(a2, b2); + o3 = op(a3, b3); + + __m128i o128_0, o128_1, o128_2, o128_3; + if constexpr (std::is_same_v) { + o128_0 = cvtfp32_fp8e4m3(o0); + o128_1 = cvtfp32_fp8e4m3(o1); + o128_2 = cvtfp32_fp8e4m3(o2); + o128_3 = cvtfp32_fp8e4m3(o3); + } else { + o128_0 = cvtfp32_fp8e5m2(o0); + o128_1 = cvtfp32_fp8e5m2(o1); + o128_2 = cvtfp32_fp8e5m2(o2); + o128_3 = cvtfp32_fp8e5m2(o3); + } + + __m512i result = _mm512_setzero_si512(); + result = _mm512_inserti32x4(result, o128_0, 0); + result = _mm512_inserti32x4(result, o128_1, 1); + result = _mm512_inserti32x4(result, o128_2, 2); + result = _mm512_inserti32x4(result, o128_3, 3); + + return result; +} + +// Refer to +// https://github.com/pytorch/pytorch/pull/153364#discussion_r2086509353 FP8 +, +// -, *, /, planed to be deleted in the future and here is just to make compiler +// happy +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return binary_fp8_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_add_ps(x, y); + }); +} + +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return binary_fp8_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_sub_ps(x, y); + }); +} + +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return binary_fp8_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_mul_ps(x, y); + }); +} + +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return binary_fp8_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_div_ps(x, y); + }); +} + +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return _mm512_and_si512(a, b); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +template <> +class Vectorized : public Vectorizedf8 { + public: + using Vectorizedf8::Vectorizedf8; + + using value_type = Float8_e5m2; + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +// Refer to +// https://github.com/pytorch/pytorch/pull/153364#discussion_r2086509353 FP8 +, +// -, *, /, planed to be deleted in the future and here is just to make compiler +// happy +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return binary_fp8_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_add_ps(x, y); + }); +} + +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return binary_fp8_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_sub_ps(x, y); + }); +} + +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return binary_fp8_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_mul_ps(x, y); + }); +} + +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return binary_fp8_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_div_ps(x, y); + }); +} + +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return _mm512_and_si512(a, b); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_int.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_int.h new file mode 100644 index 0000000000000000000000000000000000000000..5a4328b1be536af41481f61714408ae304b87371 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_int.h @@ -0,0 +1,2115 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#include + +namespace at::vec { +inline namespace CPU_CAPABILITY { + +#ifdef CPU_CAPABILITY_AVX512 + +struct Vectorizedi { + protected: + __m512i values; + static constexpr __m512i zero_vector{0, 0, 0, 0, 0, 0, 0, 0}; + static inline __m512i invert(const __m512i& v) { + const auto ones = _mm512_set1_epi64(-1); + return _mm512_xor_si512(ones, v); + } + + public: + Vectorizedi() {} + Vectorizedi(__m512i v) : values(v) {} + operator __m512i() const { + return values; + } +}; + +#else + +struct Vectorizedi {}; // dummy definition to make Vectorizedi always defined + +#endif // CPU_CAPABILITY_AVX512 + +#ifdef CPU_CAPABILITY_AVX512 + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorizedi { + private: + static const Vectorized ones; + + public: + using value_type = int64_t; + using size_type = int; + static constexpr size_type size() { + return 8; + } + using Vectorizedi::Vectorizedi; + Vectorized() {} + Vectorized(int64_t v) { + values = _mm512_set1_epi64(v); + } + Vectorized( + int64_t val1, + int64_t val2, + int64_t val3, + int64_t val4, + int64_t val5, + int64_t val6, + int64_t val7, + int64_t val8) { + values = _mm512_setr_epi64(val1, val2, val3, val4, val5, val6, val7, val8); + } + template + static Vectorized blend( + Vectorized a, + Vectorized b) { + return _mm512_mask_blend_epi64(mask, a.values, b.values); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + auto msb_one = _mm512_set1_epi64(0xFFFFFFFFFFFFFFFF); + auto mask_ = _mm512_cmp_epi64_mask(mask, msb_one, _MM_CMPINT_EQ); + return _mm512_mask_blend_epi64(mask_, a.values, b.values); + } + template + static Vectorized arange( + int64_t base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step); + } + static Vectorized set( + Vectorized a, + Vectorized b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } + static Vectorized loadu(const void* ptr, int64_t count) { + if (count == size()) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } else { + __mmask8 mask = (1ULL << count) - 1; + return _mm512_maskz_loadu_epi64(mask, ptr); + } + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + // ptr need not to be aligned here. See + // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm512-storeu-si512.html + _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values); + } else if (count > 0) { + __mmask8 mask = (1ULL << count) - 1; + _mm512_mask_storeu_epi64(ptr, mask, values); + } + } + const int64_t& operator[](int idx) const = delete; + int64_t& operator[](int idx) = delete; + Vectorized abs() const { + auto is_larger_mask = _mm512_cmpgt_epi64_mask(zero_vector, values); + auto is_larger = + _mm512_mask_set1_epi64(zero_vector, is_larger_mask, 0xFFFFFFFFFFFFFFFF); + auto inverse = _mm512_xor_si512(values, is_larger); + return _mm512_sub_epi64(inverse, is_larger); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm512_set1_epi64(0); + } + Vectorized conj() const { + return *this; + } + Vectorized neg() const; + Vectorized operator==(const Vectorized& other) const { + auto mask = _mm512_cmpeq_epi64_mask(values, other.values); + return _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF); + } + Vectorized operator!=(const Vectorized& other) const { + auto mask = _mm512_cmpneq_epi64_mask(values, other.values); + return _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF); + } + Vectorized operator<(const Vectorized& other) const { + auto mask = _mm512_cmplt_epi64_mask(values, other.values); + return _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF); + } + Vectorized operator<=(const Vectorized& other) const { + auto mask = _mm512_cmple_epi64_mask(values, other.values); + return _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF); + } + Vectorized operator>(const Vectorized& other) const { + auto mask = _mm512_cmpgt_epi64_mask(values, other.values); + return _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF); + } + Vectorized operator>=(const Vectorized& other) const { + auto mask = _mm512_cmpge_epi64_mask(values, other.values); + return _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; +template <> +class Vectorized : public Vectorizedi { + private: + static constexpr __m512i zero_vector{0, 0, 0, 0, 0, 0, 0, 0}; + static const Vectorized ones; + + public: + using value_type = int32_t; + static constexpr int size() { + return 16; + } + using Vectorizedi::Vectorizedi; + Vectorized() {} + Vectorized(int32_t v) { + values = _mm512_set1_epi32(v); + } + Vectorized( + int32_t val1, + int32_t val2, + int32_t val3, + int32_t val4, + int32_t val5, + int32_t val6, + int32_t val7, + int32_t val8, + int32_t val9, + int32_t val10, + int32_t val11, + int32_t val12, + int32_t val13, + int32_t val14, + int32_t val15, + int32_t val16) { + values = _mm512_setr_epi32( + val1, + val2, + val3, + val4, + val5, + val6, + val7, + val8, + val9, + val10, + val11, + val12, + val13, + val14, + val15, + val16); + } + template + static Vectorized blend( + Vectorized a, + Vectorized b) { + return _mm512_mask_blend_epi32(mask, a.values, b.values); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + auto msb_one = _mm512_set1_epi32(0xFFFFFFFF); + auto mask_ = _mm512_cmp_epi32_mask(mask, msb_one, _MM_CMPINT_EQ); + return _mm512_mask_blend_epi32(mask_, a.values, b.values); + } + template + static Vectorized arange( + int32_t base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step, + base + 8 * step, + base + 9 * step, + base + 10 * step, + base + 11 * step, + base + 12 * step, + base + 13 * step, + base + 14 * step, + base + 15 * step); + } + static Vectorized set( + Vectorized a, + Vectorized b, + int32_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + case 8: + return blend<255>(a, b); + case 9: + return blend<511>(a, b); + case 10: + return blend<1023>(a, b); + case 11: + return blend<2047>(a, b); + case 12: + return blend<4095>(a, b); + case 13: + return blend<8191>(a, b); + case 14: + return blend<16383>(a, b); + case 15: + return blend<32767>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } + static Vectorized loadu(const void* ptr, int32_t count) { + if (count == size()) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } else { + __mmask16 mask = (1ULL << count) - 1; + return _mm512_maskz_loadu_epi32(mask, ptr); + } + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + // ptr need not to be aligned here. See + // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm512-storeu-si512.html + _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values); + } else if (count > 0) { + __mmask16 mask = (1ULL << count) - 1; + _mm512_mask_storeu_epi32(ptr, mask, values); + } + } + const int32_t& operator[](int idx) const = delete; + int32_t& operator[](int idx) = delete; + Vectorized abs() const { + return _mm512_abs_epi32(values); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm512_set1_epi32(0); + } + Vectorized conj() const { + return *this; + } + Vectorized neg() const; + int32_t reduce_add() const { + return _mm512_reduce_add_epi32(values); + } + int32_t reduce_max() const { + return _mm512_reduce_max_epi32(values); + } + Vectorized operator==(const Vectorized& other) const { + auto mask = _mm512_cmpeq_epi32_mask(values, other.values); + return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF); + } + Vectorized operator!=(const Vectorized& other) const { + auto mask = _mm512_cmpneq_epi32_mask(values, other.values); + return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF); + } + Vectorized operator<(const Vectorized& other) const { + auto mask = _mm512_cmplt_epi32_mask(values, other.values); + return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF); + } + Vectorized operator<=(const Vectorized& other) const { + auto mask = _mm512_cmple_epi32_mask(values, other.values); + return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF); + } + Vectorized operator>(const Vectorized& other) const { + auto mask = _mm512_cmpgt_epi32_mask(values, other.values); + return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF); + } + Vectorized operator>=(const Vectorized& other) const { + auto mask = _mm512_cmpge_epi32_mask(values, other.values); + return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF); + } + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +inline void convert(const int32_t* src, float* dst, int64_t n) { + int64_t i; + // int32_t and float have same size +#ifndef _MSC_VER +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + auto input_vec = + _mm512_loadu_si512(reinterpret_cast(src + i)); + auto output_vec = _mm512_cvtepi32_ps(input_vec); + _mm512_storeu_ps(reinterpret_cast(dst + i), output_vec); + } +#ifndef _MSC_VER +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +template <> +inline void convert(const int32_t* src, double* dst, int64_t n) { + int64_t i; + // int32_t has half the size of double +#ifndef _MSC_VER +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + auto input_256_vec = + _mm256_loadu_si256(reinterpret_cast(src + i)); + auto output_vec = _mm512_cvtepi32_pd(input_256_vec); + _mm512_storeu_pd(reinterpret_cast(dst + i), output_vec); + } +#ifndef _MSC_VER +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorizedi { + private: + static const Vectorized ones; + static constexpr __m512i zero_vector{0, 0, 0, 0, 0, 0, 0, 0}; + + public: + using value_type = int16_t; + static constexpr int size() { + return 32; + } + using Vectorizedi::Vectorizedi; + Vectorized() {} + Vectorized(int16_t v) { + values = _mm512_set1_epi16(v); + } + Vectorized( + int16_t val1, + int16_t val2, + int16_t val3, + int16_t val4, + int16_t val5, + int16_t val6, + int16_t val7, + int16_t val8, + int16_t val9, + int16_t val10, + int16_t val11, + int16_t val12, + int16_t val13, + int16_t val14, + int16_t val15, + int16_t val16, + int16_t val17, + int16_t val18, + int16_t val19, + int16_t val20, + int16_t val21, + int16_t val22, + int16_t val23, + int16_t val24, + int16_t val25, + int16_t val26, + int16_t val27, + int16_t val28, + int16_t val29, + int16_t val30, + int16_t val31, + int16_t val32) { + values = _mm512_set_epi16( + val32, + val31, + val30, + val29, + val28, + val27, + val26, + val25, + val24, + val23, + val22, + val21, + val20, + val19, + val18, + val17, + val16, + val15, + val14, + val13, + val12, + val11, + val10, + val9, + val8, + val7, + val6, + val5, + val4, + val3, + val2, + val1); + } + template + static Vectorized blend( + Vectorized a, + Vectorized b) { + return _mm512_mask_blend_epi16(mask, a.values, b.values); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + auto msb_one = _mm512_set1_epi16(0xFFFF); + auto mask_ = _mm512_cmp_epi16_mask(mask, msb_one, _MM_CMPINT_EQ); + return _mm512_mask_blend_epi16(mask_, a.values, b.values); + } + template + static Vectorized arange( + int16_t base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step, + base + 8 * step, + base + 9 * step, + base + 10 * step, + base + 11 * step, + base + 12 * step, + base + 13 * step, + base + 14 * step, + base + 15 * step, + base + 16 * step, + base + 17 * step, + base + 18 * step, + base + 19 * step, + base + 20 * step, + base + 21 * step, + base + 22 * step, + base + 23 * step, + base + 24 * step, + base + 25 * step, + base + 26 * step, + base + 27 * step, + base + 28 * step, + base + 29 * step, + base + 30 * step, + base + 31 * step); + } + static Vectorized set( + Vectorized a, + Vectorized b, + int16_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<0x1>(a, b); + case 2: + return blend<0x3>(a, b); + case 3: + return blend<0x7>(a, b); + case 4: + return blend<0xF>(a, b); + case 5: + return blend<0x1F>(a, b); + case 6: + return blend<0x3F>(a, b); + case 7: + return blend<0x7F>(a, b); + case 8: + return blend<0xFF>(a, b); + case 9: + return blend<0x1FF>(a, b); + case 10: + return blend<0x3FF>(a, b); + case 11: + return blend<0x7FF>(a, b); + case 12: + return blend<0xFFF>(a, b); + case 13: + return blend<0x1FFF>(a, b); + case 14: + return blend<0x3FFF>(a, b); + case 15: + return blend<0x7FFF>(a, b); + case 16: + return blend<0xFFFF>(a, b); + case 17: + return blend<0x1FFFF>(a, b); + case 18: + return blend<0x3FFFF>(a, b); + case 19: + return blend<0x7FFFF>(a, b); + case 20: + return blend<0xFFFFF>(a, b); + case 21: + return blend<0x1FFFFF>(a, b); + case 22: + return blend<0x3FFFFF>(a, b); + case 23: + return blend<0x7FFFFF>(a, b); + case 24: + return blend<0xFFFFFF>(a, b); + case 25: + return blend<0x1FFFFFF>(a, b); + case 26: + return blend<0x3FFFFFF>(a, b); + case 27: + return blend<0x7FFFFFF>(a, b); + case 28: + return blend<0xFFFFFFF>(a, b); + case 29: + return blend<0x1FFFFFFF>(a, b); + case 30: + return blend<0x3FFFFFFF>(a, b); + case 31: + return blend<0x7FFFFFFF>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } + static Vectorized loadu(const void* ptr, int16_t count) { + if (count == size()) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } else { + __mmask32 mask = (1ULL << count) - 1; + return _mm512_maskz_loadu_epi16(mask, ptr); + } + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + // ptr need not to be aligned here. See + // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm512-storeu-si512.html + _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values); + } else if (count > 0) { + __mmask32 mask = (1ULL << count) - 1; + _mm512_mask_storeu_epi16(ptr, mask, values); + } + } + const int16_t& operator[](int idx) const = delete; + int16_t& operator[](int idx) = delete; + Vectorized abs() const { + return _mm512_abs_epi16(values); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm512_set1_epi16(0); + } + Vectorized conj() const { + return *this; + } + Vectorized neg() const; + Vectorized operator==(const Vectorized& other) const { + auto mask = _mm512_cmpeq_epi16_mask(values, other.values); + return _mm512_mask_set1_epi16(zero_vector, mask, 0xFFFF); + } + Vectorized operator!=(const Vectorized& other) const { + auto mask = _mm512_cmpneq_epi16_mask(values, other.values); + return _mm512_mask_set1_epi16(zero_vector, mask, 0xFFFF); + } + Vectorized operator<(const Vectorized& other) const { + auto mask = _mm512_cmplt_epi16_mask(values, other.values); + return _mm512_mask_set1_epi16(zero_vector, mask, 0xFFFF); + } + Vectorized operator<=(const Vectorized& other) const { + auto mask = _mm512_cmple_epi16_mask(values, other.values); + return _mm512_mask_set1_epi16(zero_vector, mask, 0xFFFF); + } + Vectorized operator>(const Vectorized& other) const { + auto mask = _mm512_cmpgt_epi16_mask(values, other.values); + return _mm512_mask_set1_epi16(zero_vector, mask, 0xFFFF); + } + Vectorized operator>=(const Vectorized& other) const { + auto mask = _mm512_cmpge_epi16_mask(values, other.values); + return _mm512_mask_set1_epi16(zero_vector, mask, 0xFFFF); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template +class Vectorized8 : public Vectorizedi { + static_assert( + std::is_same_v || std::is_same_v, + "Only int8_t/uint8_t are supported"); + + protected: + static constexpr __m512i zero_vector{0, 0, 0, 0, 0, 0, 0, 0}; + static const Vectorized ones; + + public: + using value_type = T; + static constexpr int size() { + return 64; + } + using Vectorizedi::Vectorizedi; + Vectorized8() {} + Vectorized8(T v) { + values = _mm512_set1_epi8(v); + } + Vectorized8( + T val1, + T val2, + T val3, + T val4, + T val5, + T val6, + T val7, + T val8, + T val9, + T val10, + T val11, + T val12, + T val13, + T val14, + T val15, + T val16, + T val17, + T val18, + T val19, + T val20, + T val21, + T val22, + T val23, + T val24, + T val25, + T val26, + T val27, + T val28, + T val29, + T val30, + T val31, + T val32, + T val33, + T val34, + T val35, + T val36, + T val37, + T val38, + T val39, + T val40, + T val41, + T val42, + T val43, + T val44, + T val45, + T val46, + T val47, + T val48, + T val49, + T val50, + T val51, + T val52, + T val53, + T val54, + T val55, + T val56, + T val57, + T val58, + T val59, + T val60, + T val61, + T val62, + T val63, + T val64) { + values = _mm512_set_epi8( + val64, + val63, + val62, + val61, + val60, + val59, + val58, + val57, + val56, + val55, + val54, + val53, + val52, + val51, + val50, + val49, + val48, + val47, + val46, + val45, + val44, + val43, + val42, + val41, + val40, + val39, + val38, + val37, + val36, + val35, + val34, + val33, + val32, + val31, + val30, + val29, + val28, + val27, + val26, + val25, + val24, + val23, + val22, + val21, + val20, + val19, + val18, + val17, + val16, + val15, + val14, + val13, + val12, + val11, + val10, + val9, + val8, + val7, + val6, + val5, + val4, + val3, + val2, + val1); + } + template + static Vectorized blend(Vectorized a, Vectorized b) { + return _mm512_mask_blend_epi8(mask, a.values, b.values); + } + template + static Vectorized arange( + T base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step, + base + 8 * step, + base + 9 * step, + base + 10 * step, + base + 11 * step, + base + 12 * step, + base + 13 * step, + base + 14 * step, + base + 15 * step, + base + 16 * step, + base + 17 * step, + base + 18 * step, + base + 19 * step, + base + 20 * step, + base + 21 * step, + base + 22 * step, + base + 23 * step, + base + 24 * step, + base + 25 * step, + base + 26 * step, + base + 27 * step, + base + 28 * step, + base + 29 * step, + base + 30 * step, + base + 31 * step, + base + 32 * step, + base + 33 * step, + base + 34 * step, + base + 35 * step, + base + 36 * step, + base + 37 * step, + base + 38 * step, + base + 39 * step, + base + 40 * step, + base + 41 * step, + base + 42 * step, + base + 43 * step, + base + 44 * step, + base + 45 * step, + base + 46 * step, + base + 47 * step, + base + 48 * step, + base + 49 * step, + base + 50 * step, + base + 51 * step, + base + 52 * step, + base + 53 * step, + base + 54 * step, + base + 55 * step, + base + 56 * step, + base + 57 * step, + base + 58 * step, + base + 59 * step, + base + 60 * step, + base + 61 * step, + base + 62 * step, + base + 63 * step); + } + static Vectorized set(Vectorized a, Vectorized b, T count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<0x1>(a, b); + case 2: + return blend<0x3>(a, b); + case 3: + return blend<0x7>(a, b); + case 4: + return blend<0xF>(a, b); + case 5: + return blend<0x1F>(a, b); + case 6: + return blend<0x3F>(a, b); + case 7: + return blend<0x7F>(a, b); + case 8: + return blend<0xFF>(a, b); + case 9: + return blend<0x1FF>(a, b); + case 10: + return blend<0x3FF>(a, b); + case 11: + return blend<0x7FF>(a, b); + case 12: + return blend<0xFFF>(a, b); + case 13: + return blend<0x1FFF>(a, b); + case 14: + return blend<0x3FFF>(a, b); + case 15: + return blend<0x7FFF>(a, b); + case 16: + return blend<0xFFFF>(a, b); + case 17: + return blend<0x1FFFF>(a, b); + case 18: + return blend<0x3FFFF>(a, b); + case 19: + return blend<0x7FFFF>(a, b); + case 20: + return blend<0xFFFFF>(a, b); + case 21: + return blend<0x1FFFFF>(a, b); + case 22: + return blend<0x3FFFFF>(a, b); + case 23: + return blend<0x7FFFFF>(a, b); + case 24: + return blend<0xFFFFFF>(a, b); + case 25: + return blend<0x1FFFFFF>(a, b); + case 26: + return blend<0x3FFFFFF>(a, b); + case 27: + return blend<0x7FFFFFF>(a, b); + case 28: + return blend<0xFFFFFFF>(a, b); + case 29: + return blend<0x1FFFFFFF>(a, b); + case 30: + return blend<0x3FFFFFFF>(a, b); + case 31: + return blend<0x7FFFFFFF>(a, b); + case 32: + return blend<0xFFFFFFFF>(a, b); + case 33: + return blend<0x1FFFFFFFF>(a, b); + case 34: + return blend<0x3FFFFFFFF>(a, b); + case 35: + return blend<0x7FFFFFFFF>(a, b); + case 36: + return blend<0xFFFFFFFFF>(a, b); + case 37: + return blend<0x1FFFFFFFFF>(a, b); + case 38: + return blend<0x3FFFFFFFFF>(a, b); + case 39: + return blend<0x7FFFFFFFFF>(a, b); + case 40: + return blend<0xFFFFFFFFFF>(a, b); + case 41: + return blend<0x1FFFFFFFFFF>(a, b); + case 42: + return blend<0x3FFFFFFFFFF>(a, b); + case 43: + return blend<0x7FFFFFFFFFF>(a, b); + case 44: + return blend<0xFFFFFFFFFFF>(a, b); + case 45: + return blend<0x1FFFFFFFFFFF>(a, b); + case 46: + return blend<0x3FFFFFFFFFFF>(a, b); + case 47: + return blend<0x7FFFFFFFFFFF>(a, b); + case 48: + return blend<0xFFFFFFFFFFFF>(a, b); + case 49: + return blend<0x1FFFFFFFFFFFF>(a, b); + case 50: + return blend<0x3FFFFFFFFFFFF>(a, b); + case 51: + return blend<0x7FFFFFFFFFFFF>(a, b); + case 52: + return blend<0xFFFFFFFFFFFFF>(a, b); + case 53: + return blend<0x1FFFFFFFFFFFFF>(a, b); + case 54: + return blend<0x3FFFFFFFFFFFFF>(a, b); + case 55: + return blend<0x7FFFFFFFFFFFFF>(a, b); + case 56: + return blend<0xFFFFFFFFFFFFFF>(a, b); + case 57: + return blend<0x1FFFFFFFFFFFFFF>(a, b); + case 58: + return blend<0x3FFFFFFFFFFFFFF>(a, b); + case 59: + return blend<0x7FFFFFFFFFFFFFF>(a, b); + case 60: + return blend<0xFFFFFFFFFFFFFFF>(a, b); + case 61: + return blend<0x1FFFFFFFFFFFFFFF>(a, b); + case 62: + return blend<0x3FFFFFFFFFFFFFFF>(a, b); + case 63: + return blend<0x7FFFFFFFFFFFFFFF>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } + static Vectorized loadu_one_fourth(const void* ptr) { + // Fast path if only load element number of 16. + // Note: We didn't merge it as fast path of loadu(const void* ptr, T count), + // Because loadu(const void* ptr, T count) requires zero initialization for + // upper 384 bits. However, by using _mm512_castsi128_si512, the upper 384 + // bits of the result are undefined. + // TODO We can use _mm512_zextsi128_si512 in the furture, + // since gcc 9.3 doesn't support it now. + __m128i input_128 = _mm_loadu_si128(reinterpret_cast(ptr)); + return _mm512_castsi128_si512(input_128); + } + static Vectorized loadu(const void* ptr, T count) { + if (count == size()) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } else if (count == 16) { + // Fast path if only load element number of 16 + return loadu_one_fourth(ptr); + } else { + __mmask64 mask = (1ULL << count) - 1; + return _mm512_maskz_loadu_epi8(mask, ptr); + } + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + // ptr need not to be aligned here. See + // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm512-storeu-si512.html + _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values); + } else if (count > 0) { + if (count == 16) { + // Fast path if only store element number of 16 + _mm_storeu_si128( + reinterpret_cast<__m128i*>(ptr), _mm512_castsi512_si128(values)); + } else { + __mmask64 mask = (1ULL << count) - 1; + _mm512_mask_storeu_epi8(ptr, mask, values); + } + } + } + const T& operator[](int idx) const = delete; + T& operator[](int idx) = delete; + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm512_set1_epi8(0); + } + Vectorized conj() const { + return *this; + } +}; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorized8 { + public: + using Vectorized8::Vectorized8; + + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + auto msb_one = _mm512_set1_epi8(0xFF); + auto mask_ = _mm512_cmp_epi8_mask(mask, msb_one, _MM_CMPINT_EQ); + return _mm512_mask_blend_epi8(mask_, a.values, b.values); + } + + Vectorized neg() const; + + Vectorized abs() const { + return _mm512_abs_epi8(values); + } + + Vectorized operator==(const Vectorized& other) const { + auto mask = _mm512_cmpeq_epi8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator!=(const Vectorized& other) const { + auto mask = _mm512_cmpneq_epi8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator<(const Vectorized& other) const { + auto mask = _mm512_cmplt_epi8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator<=(const Vectorized& other) const { + auto mask = _mm512_cmple_epi8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator>(const Vectorized& other) const { + return other < *this; + } + Vectorized operator>=(const Vectorized& other) const { + return other <= *this; + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorized8 { + public: + using Vectorized8::Vectorized8; + + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + auto msb_one = _mm512_set1_epi8(0xFF); + auto mask_ = _mm512_cmp_epu8_mask(mask, msb_one, _MM_CMPINT_EQ); + return _mm512_mask_blend_epi8(mask_, a.values, b.values); + } + + Vectorized neg() const; + + Vectorized abs() const { + return *this; + } + + Vectorized operator==(const Vectorized& other) const { + auto mask = _mm512_cmpeq_epu8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator!=(const Vectorized& other) const { + auto mask = _mm512_cmpneq_epu8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator<(const Vectorized& other) const { + auto mask = _mm512_cmplt_epu8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator<=(const Vectorized& other) const { + auto mask = _mm512_cmple_epu8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator>(const Vectorized& other) const { + return other < *this; + } + Vectorized operator>=(const Vectorized& other) const { + return other <= *this; + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm512_add_epi64(a, b); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm512_add_epi32(a, b); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm512_add_epi16(a, b); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm512_add_epi8(a, b); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm512_add_epi8(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm512_sub_epi64(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm512_sub_epi32(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm512_sub_epi16(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm512_sub_epi8(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm512_sub_epi8(a, b); +} + +// Negation. Defined here so we can utilize operator- +inline Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +inline Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +inline Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +inline Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +inline Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return _mm512_mullo_epi64(a, b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return _mm512_mullo_epi32(a, b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return _mm512_mullo_epi16(a, b); +} + +template +Vectorized inline int_elementwise_binary_512( + const Vectorized& a, + const Vectorized& b, + Op op) { + T values_a[Vectorized::size()]; + T values_b[Vectorized::size()]; + a.store(values_a); + b.store(values_b); + for (int i = 0; i != Vectorized::size(); i++) { + values_a[i] = op(values_a[i], values_b[i]); + } + return Vectorized::loadu(values_a); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + // We don't have an instruction for multiplying int8_t +#ifndef CPU_CAPABILITY_AVX512 + return int_elementwise_binary_512(a, b, std::multiplies()); +#else + __m512i mask00FF = _mm512_set1_epi16(0x00FF); + __m512i a_lo = _mm512_srai_epi16(_mm512_slli_epi16(a, 8), 8); + __m512i b_lo = _mm512_srai_epi16(_mm512_slli_epi16(b, 8), 8); + __m512i a_hi = _mm512_srai_epi16(a, 8); + __m512i b_hi = _mm512_srai_epi16(b, 8); + __m512i res_lo = _mm512_and_si512(_mm512_mullo_epi16(a_lo, b_lo), mask00FF); + __m512i res_hi = _mm512_slli_epi16(_mm512_mullo_epi16(a_hi, b_hi), 8); + __m512i res = _mm512_or_si512(res_hi, res_lo); + return res; +#endif +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + // We don't have an instruction for multiplying uint8_t +#ifndef CPU_CAPABILITY_AVX512 + return int_elementwise_binary_512(a, b, std::multiplies()); +#else + __m512i mask00FF = _mm512_set1_epi16(0x00FF); + __m512i a_lo = _mm512_and_si512(a, mask00FF); + __m512i b_lo = _mm512_and_si512(b, mask00FF); + __m512i a_hi = _mm512_srli_epi16(a, 8); + __m512i b_hi = _mm512_srli_epi16(b, 8); + __m512i res_lo = _mm512_and_si512(_mm512_mullo_epi16(a_lo, b_lo), mask00FF); + __m512i res_hi = _mm512_slli_epi16(_mm512_mullo_epi16(a_hi, b_hi), 8); + __m512i res = _mm512_or_si512(res_hi, res_lo); + return res; +#endif +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return _mm512_min_epi64(a, b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return _mm512_min_epi32(a, b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return _mm512_min_epi16(a, b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return _mm512_min_epi8(a, b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return _mm512_min_epu8(a, b); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return _mm512_max_epi64(a, b); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return _mm512_max_epi32(a, b); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return _mm512_max_epi16(a, b); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return _mm512_max_epi8(a, b); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return _mm512_max_epu8(a, b); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min_val, + const Vectorized& max_val) { + return _mm512_min_epi64(max_val, _mm512_max_epi64(a, min_val)); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min_val, + const Vectorized& max_val) { + return _mm512_min_epi32(max_val, _mm512_max_epi32(a, min_val)); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min_val, + const Vectorized& max_val) { + return _mm512_min_epi16(max_val, _mm512_max_epi16(a, min_val)); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min_val, + const Vectorized& max_val) { + return _mm512_min_epi8(max_val, _mm512_max_epi8(a, min_val)); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min_val, + const Vectorized& max_val) { + return _mm512_min_epu8(max_val, _mm512_max_epu8(a, min_val)); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max_val) { + return _mm512_min_epi64(max_val, a); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max_val) { + return _mm512_min_epi32(max_val, a); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max_val) { + return _mm512_min_epi16(max_val, a); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max_val) { + return _mm512_min_epi8(max_val, a); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max_val) { + return _mm512_min_epu8(max_val, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min_val) { + return _mm512_max_epi64(min_val, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min_val) { + return _mm512_max_epi32(min_val, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min_val) { + return _mm512_max_epi16(min_val, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min_val) { + return _mm512_max_epi8(min_val, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min_val) { + return _mm512_max_epu8(min_val, a); +} + +template +std::enable_if_t< + !(std::is_same_v || std::is_same_v), + Vectorized< + int32_t>> inline convert_to_int32(const T* ptr, int count = Vectorized::size()) { + return Vectorized::loadu(ptr, count); +} + +template +std:: + enable_if_t, Vectorized> inline convert_to_int32( + const int8_t* ptr, + int count = Vectorized::size()) { + if (count == Vectorized::size()) { + return _mm512_cvtepi8_epi32( + _mm_loadu_si128(reinterpret_cast(ptr))); + } else { + auto a = Vectorized::loadu(ptr, count); + return _mm512_cvtepi8_epi32(_mm512_castsi512_si128(a)); + } +} + +template +std:: + enable_if_t, Vectorized> inline convert_to_int32( + const uint8_t* ptr, + int count = Vectorized::size()) { + if (count == Vectorized::size()) { + return _mm512_cvtepu8_epi32( + _mm_loadu_si128(reinterpret_cast(ptr))); + } else { + auto a = Vectorized::loadu(ptr, count); + return _mm512_cvtepu8_epi32(_mm512_castsi512_si128(a)); + } +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return int_elementwise_binary_512(a, b, std::divides()); +} +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return int_elementwise_binary_512(a, b, std::divides()); +} +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return int_elementwise_binary_512(a, b, std::divides()); +} +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return int_elementwise_binary_512(a, b, std::divides()); +} +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return int_elementwise_binary_512(a, b, std::divides()); +} + +template < + class T, + typename std::enable_if_t< + std::is_base_of>::value, + int> = 0> +inline Vectorized operator&(const Vectorized& a, const Vectorized& b) { + return _mm512_and_si512(a, b); +} +template < + class T, + typename std::enable_if_t< + std::is_base_of>::value, + int> = 0> +inline Vectorized operator|(const Vectorized& a, const Vectorized& b) { + return _mm512_or_si512(a, b); +} +template < + class T, + typename std::enable_if_t< + std::is_base_of>::value, + int> = 0> +inline Vectorized operator^(const Vectorized& a, const Vectorized& b) { + return _mm512_xor_si512(a, b); +} +template < + class T, + typename std::enable_if_t< + std::is_base_of>::value, + int> = 0> +inline Vectorized operator~(const Vectorized& a) { + return _mm512_xor_si512(a, _mm512_set1_epi32(-1)); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +template < + bool left_shift, + typename T, + typename std::enable_if_t< + std::is_same_v || std::is_same_v, + int> = 0> +Vectorized inline shift_512_8( + const Vectorized& a, + const Vectorized& b) { + // No vector instruction for shifting int8_t/uint8_t, so emulating + // it instead. + + // Control masks for shuffle operation, treating 512 bits as an + // array of 8-bit elements, and considering pairs of neighboring + // elements. Specifially, a mask named "ctl_M_N" (M,N in [0,1], and + // M!=N) is set so that shuffle will move element with index M from + // input pair into element with index N in output pair, and element + // with index M in output pair will be set to all 0s. + __m512i ctl_0_1 = _mm512_set_epi8( + 62, + 0x80, + 60, + 0x80, + 58, + 0x80, + 56, + 0x80, + 54, + 0x80, + 52, + 0x80, + 50, + 0x80, + 48, + 0x80, + 46, + 0x80, + 44, + 0x80, + 42, + 0x80, + 40, + 0x80, + 38, + 0x80, + 36, + 0x80, + 34, + 0x80, + 32, + 0x80, + 30, + 0x80, + 28, + 0x80, + 26, + 0x80, + 24, + 0x80, + 22, + 0x80, + 20, + 0x80, + 18, + 0x80, + 16, + 0x80, + 14, + 0x80, + 12, + 0x80, + 10, + 0x80, + 8, + 0x80, + 6, + 0x80, + 4, + 0x80, + 2, + 0x80, + 0, + 0x80); + __m512i ctl_1_0 = _mm512_set_epi8( + 0x80, + 63, + 0x80, + 61, + 0x80, + 59, + 0x80, + 57, + 0x80, + 55, + 0x80, + 53, + 0x80, + 51, + 0x80, + 49, + 0x80, + 47, + 0x80, + 45, + 0x80, + 43, + 0x80, + 41, + 0x80, + 39, + 0x80, + 37, + 0x80, + 35, + 0x80, + 33, + 0x80, + 31, + 0x80, + 29, + 0x80, + 27, + 0x80, + 25, + 0x80, + 23, + 0x80, + 21, + 0x80, + 19, + 0x80, + 17, + 0x80, + 15, + 0x80, + 13, + 0x80, + 11, + 0x80, + 9, + 0x80, + 7, + 0x80, + 5, + 0x80, + 3, + 0x80, + 1); + + // Masks for bitwise and operation, treating 512 bits as an array of + // 8-bit elements, and considering them in pairs of neighboring + // elements. A mask named "keep_M" (M in [0,1]) is set so that + // bitwise and will copy element with index M from input pair into + // element with the same index in output pair, while the other + // element in output pair will be set to all 0s. + __m512i keep_0 = _mm512_set1_epi16(0xFF); + __m512i keep_1 = _mm512_set1_epi16(0xFF00); + + // Take each 8-bit element with idx%2==0 from input array to be + // shifted and extend it to 16 bits so that 0s are added to the + // right. Then, perform shifting on this 16-bit number. Upper 8 + // bits will be proper result of shifting original 8-bit number, so + // write them to result array, into the same position from which + // corresponding input element is taken. Also, make sure that + // result array elements with idx%2!=0 are set to all 0s. + // + // Note that number of bits to shift for is extended to 16 bits by + // adding 0s to the left. That means this number is not properly + // sign-extended for negative values. However, number of bits to + // shift is treated as an unsigned integer by respective shift + // intrinsics anyway so if negative then either with or without + // proper sign extension, it will be interpreted as a number greater + // than 32, and the shifting result will be the same. + __m512i a0 = _mm512_shuffle_epi8(a, ctl_0_1); + __m512i b0 = _mm512_and_si512(b, keep_0); + __m512i c0; + if (left_shift) + c0 = _mm512_sllv_epi16(a0, b0); + else if constexpr (std::is_same_v) + c0 = _mm512_srav_epi16(a0, b0); + else + c0 = _mm512_srlv_epi16(a0, b0); + c0 = _mm512_shuffle_epi8(c0, ctl_1_0); + + // Peform shifting the same way for input array elements with + // idx%2==1. + __m512i a1 = _mm512_and_si512(a, keep_1); + __m512i b1 = _mm512_shuffle_epi8(b, ctl_1_0); + __m512i c1; + if (left_shift) + c1 = _mm512_sllv_epi16(a1, b1); + else if constexpr (std::is_same_v) + c1 = _mm512_srav_epi16(a1, b1); + else + c1 = _mm512_srlv_epi16(a1, b1); + c1 = _mm512_and_si512(c1, keep_1); + + // Merge partial results into the final result. + __m512i c = _mm512_or_si512(c0, c1); + + return c; +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return _mm512_sllv_epi64(a, b); +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return _mm512_sllv_epi32(a, b); +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return _mm512_sllv_epi16(a, b); +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return shift_512_8(a, b); +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return shift_512_8(a, b); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return _mm512_srav_epi64(a, b); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return _mm512_srav_epi32(a, b); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return _mm512_srav_epi16(a, b); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return shift_512_8(a, b); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return shift_512_8(a, b); +} + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_mask.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_mask.h new file mode 100644 index 0000000000000000000000000000000000000000..2ce1d895329fd7922f8b93e3b9cd164b93917f94 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_mask.h @@ -0,0 +1,390 @@ +#pragma once + +#include +#include +#include + +namespace at::vec { +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) + +template +struct VecMaskLoad< + T, + dst_n, + mask_t, + mask_n, + typename std::enable_if_t< + (mask_n == dst_n * 2 && dst_n >= 1) && + (std::is_same_v || std::is_same_v), + void>> { + static inline VectorizedN apply( + const T* ptr, + const VecMask& vec_mask) { + at::vec::Vectorized zero_vec(0); + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + VectorizedN tmp_vec; + VectorizedN result; + for (int i = 0; i < dst_n; i++) { + tmp_vec[0] = vec_mask[2 * i]; + tmp_vec[1] = vec_mask[2 * i + 1]; + auto int64_mask = VecMask(tmp_vec).template cast(); + auto int_mask = int64_mask.template cast()[0]; + auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ); + if constexpr (std::is_same_v) { + result[i] = Vectorized(_mm512_mask_loadu_ps( + zero_vec, mmask, ptr + i * Vectorized::size())); + } else { + result[i] = Vectorized(_mm512_mask_loadu_epi32( + zero_vec, mmask, ptr + i * Vectorized::size())); + } + } + return result; + } +}; + +template +struct VecMaskLoad< + T, + dst_n, + mask_t, + dst_n, + typename std::enable_if_t< + std::is_same_v || std::is_same_v, + void>> { + static inline VectorizedN apply( + const T* ptr, + const VecMask& vec_mask) { + at::vec::Vectorized zero_vec(0); + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < dst_n; i++) { + auto tmp_mask = VecMask(vec_mask[i]); + auto int_mask = tmp_mask.template cast()[0]; + auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ); + if constexpr (std::is_same_v) { + result[i] = Vectorized(_mm512_mask_loadu_ps( + zero_vec, mmask, ptr + i * Vectorized::size())); + } else { + result[i] = Vectorized(_mm512_mask_loadu_epi32( + zero_vec, mmask, ptr + i * Vectorized::size())); + } + } + return result; + } +}; + +template +struct VecMaskLoad< + data_t, + dst_n, + mask_t, + dst_n, + std::enable_if_t< + std::is_same_v || std::is_same_v>> { + static inline VectorizedN apply( + const data_t* ptr, + const VecMask& vec_mask) { + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < dst_n; i++) { + auto tmp_mask = VecMask(vec_mask[i]); + auto int_mask = tmp_mask.template cast(); + auto mmask0 = _mm512_cmp_epi32_mask(int_mask[0], all_ones, _MM_CMPINT_EQ); + auto mmask1 = _mm512_cmp_epi32_mask(int_mask[1], all_ones, _MM_CMPINT_EQ); + auto zero = _mm256_set1_epi16(0); + auto temp0 = _mm256_mask_loadu_epi16( + zero, mmask0, ptr + (2 * i) * Vectorized::size()); + auto temp1 = _mm256_mask_loadu_epi16( + zero, mmask1, ptr + (2 * i + 1) * Vectorized::size()); + result[i] = Vectorized( + _mm512_inserti32x8(_mm512_castsi256_si512(temp0), temp1, 1)); + } + return result; + } +}; + +template +struct VecMaskLoad< + data_t, + dst_n, + mask_t, + mask_n, + typename std::enable_if_t< + (mask_n == 2 * dst_n && dst_n >= 1) && + (std::is_same_v || std::is_same_v)>> { + static inline VectorizedN apply( + const data_t* ptr, + const VecMask& vec_mask) { + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + VectorizedN result; + VectorizedN tmp_vec; + for (int i = 0; i < dst_n; i++) { + tmp_vec[0] = vec_mask[2 * i]; + tmp_vec[1] = vec_mask[2 * i + 1]; + auto int_mask = VecMask(tmp_vec).template cast(); + auto mmask0 = _mm512_cmp_epi32_mask(int_mask[0], all_ones, _MM_CMPINT_EQ); + auto mmask1 = _mm512_cmp_epi32_mask(int_mask[1], all_ones, _MM_CMPINT_EQ); + auto zero = _mm256_set1_epi16(0); + auto temp0 = _mm256_mask_loadu_epi16( + zero, mmask0, ptr + (2 * i) * Vectorized::size()); + auto temp1 = _mm256_mask_loadu_epi16( + zero, mmask1, ptr + (2 * i + 1) * Vectorized::size()); + result[i] = Vectorized( + _mm512_inserti32x8(_mm512_castsi256_si512(temp0), temp1, 1)); + } + return result; + } +}; + +template +struct VecMaskLoad< + data_t, + 1, + mask_t, + 1, + std::enable_if_t< + std::is_same_v || std::is_same_v>> { + static inline VectorizedN apply( + const data_t* ptr, + const VecMask& vec_mask) { + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + auto int_mask = vec_mask.template cast()[0]; + auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ); + auto zero = _mm_set1_epi8(0); + auto temp = _mm_mask_loadu_epi8(zero, mmask, ptr); + return Vectorized( + _mm512_inserti64x2(_mm512_set1_epi32(0), temp, 0)); + } +}; + +template +struct VecMaskLoad< + data_t, + 2, + mask_t, + 1, + std::enable_if_t< + std::is_same_v || std::is_same_v>> { + static inline VectorizedN apply( + const data_t* ptr, + const VecMask& vec_mask) { + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + at::vec::Vectorized zero_vec(0); + auto int_mask = vec_mask.template cast()[0]; + auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ); + at::vec::VectorizedN result; + if constexpr (std::is_same_v) { + result[0] = _mm512_mask_loadu_pd(zero_vec, (__mmask8)mmask, ptr); + result[1] = + _mm512_mask_loadu_pd(zero_vec, (__mmask8)(mmask >> 8), ptr + 8); + } else { + result[0] = _mm512_mask_loadu_epi64(zero_vec, (__mmask8)mmask, ptr); + result[1] = + _mm512_mask_loadu_epi64(zero_vec, (__mmask8)(mmask >> 8), ptr + 8); + } + return result; + } +}; + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm512_castsi512_ps(vec_mask[i]); + } + return result; + } +}; + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm512_castps_si512(vec_mask[i]); + } + return result; + } +}; + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm512_castpd_si512(vec_mask[i]); + } + return result; + } +}; + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm512_castsi512_pd(vec_mask[i]); + } + return result; + } +}; + +template +struct VecMaskCast< + int64_t, + dst_n, + mask_t, + mask_n, + typename std::enable_if_t< + (dst_n == 2 * mask_n) && + (std::is_same_v || std::is_same_v), + void>> { + static inline VecMask apply( + const VecMask& vec_mask) { + VectorizedN result; + auto int_mask = vec_mask.template cast(); +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < mask_n; ++i) { + auto int64_vec = + convert(VectorizedN(int_mask[i])); + result[2 * i] = int64_vec[0]; + result[2 * i + 1] = int64_vec[1]; + } + return VecMask(result); + } +}; + +template +struct VecMaskCast< + dst_t, + dst_n, + int64_t, + mask_n, + typename std::enable_if_t< + (mask_n == 2 * dst_n) && + (std::is_same_v || std::is_same_v), + void>> { + static inline VecMask apply( + const VecMask& vec_mask) { + VectorizedN result; + VectorizedN int64_vec; + for (int i = 0; i < dst_n; ++i) { + int64_vec[0] = vec_mask[2 * i]; + int64_vec[1] = vec_mask[2 * i + 1]; + result[i] = convert(int64_vec); + } + return VecMask(result).template cast(); + } +}; + +template <> +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + auto int64_mask = VecMaskCast::apply(vec_mask); + return VecMaskCast::apply(int64_mask); + } +}; + +template <> +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + auto int64_mask = VecMaskCast::apply(vec_mask); + return VecMaskCast::apply(int64_mask); + } +}; + +template <> +inline bool VecMask::all_zero() const { + __mmask16 mask = _mm512_test_epi32_mask(mask_[0], mask_[0]); + return mask == 0; +} + +template <> +inline bool VecMask::is_masked(int i) const { + return _mm512_movepi32_mask(mask_[0]) & (1 << i); +} + +template <> +inline bool VecMask::all_masked() const { + __mmask16 mask = _mm512_movepi32_mask(mask_[0]); + return mask == 0xffff; +} + +template +struct VecMaskCheck { + static inline bool all_zero(const VectorizedN& vec_mask) { + bool all_zero = true; + for (int i = 0; i < N; ++i) { + all_zero = + all_zero && (_mm512_test_epi64_mask(vec_mask[i], vec_mask[i]) == 0); + if (!all_zero) { + return all_zero; + } + } + return all_zero; + } + + static inline bool is_masked(const VectorizedN& vec_mask, int i) { + for (int j = 0; j < N; ++j) { + if (i < (j + 1) * 8) { + return _mm512_movepi64_mask(vec_mask[j]) & (1 << (i - j * 8)); + } + } + return false; + } + + static inline bool all_masked(const VectorizedN& vec_mask) { + bool all_masked = true; + for (int i = 0; i < N; ++i) { + all_masked = all_masked && (_mm512_movepi64_mask(vec_mask[i]) == 0xff); + if (!all_masked) { + return all_masked; + } + } + return all_masked; + } +}; + +#define VEC_MASK_METHOD_WITH_CAST_TO_INT( \ + T, N, return_type, method, args_def, args) \ + template <> \ + inline return_type VecMask::method args_def const { \ + return cast().method args; \ + } + +VEC_MASK_METHOD_WITH_CAST_TO_INT(float, 1, bool, all_zero, (), ()) +VEC_MASK_METHOD_WITH_CAST_TO_INT(int64_t, 2, bool, all_zero, (), ()) +VEC_MASK_METHOD_WITH_CAST_TO_INT(float, 1, bool, is_masked, (int i), (i)) +VEC_MASK_METHOD_WITH_CAST_TO_INT(int64_t, 2, bool, is_masked, (int i), (i)) +VEC_MASK_METHOD_WITH_CAST_TO_INT(float, 1, bool, all_masked, (), ()) +VEC_MASK_METHOD_WITH_CAST_TO_INT(int64_t, 2, bool, all_masked, (), ()) + +#undef VEC_MASK_DEFINE_METHOD_WITH_CAST_TO_INT + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_qint.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_qint.h new file mode 100644 index 0000000000000000000000000000000000000000..427d2ca0e600fb5156571cc18d514b59d52be910 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_qint.h @@ -0,0 +1,1528 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include + +#include +#include +#include +#include + +#include +#include + +// This file defines Vectorized<> for the quantized types. +// +// +// Currently, we simply use these classes as efficient converters between +// the quantized types and Vectorized, usually in bandwidth-bound cases +// where doing the arithmetic in full-precision is acceptable (e.g. +// elementwise operators). +// +// +// Conversions are as follows: +// Vectorized -> 4x Vectorized +// Vectorized -> 4x Vectorized +// Vectorized -> 1x Vectorized +// +// The size of the returned float vector is specified by the special +// constexpr function float_num_vecs. The type of the value returned +// from dequantize (and expected as an argument to quantize) is +// specified by float_vec_return_type. +// +// When writing kernels with these vectors, it is expected that floating- +// point operations will be carried out in a loop over +// Vectorized::float_num_vecs iterations. + +namespace at { +namespace vec { +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX512) + +#ifdef _MSC_VER +__declspec(align(64)) struct Vectorizedqi { + protected: + __m512i vals; +#else +struct Vectorizedqi { + protected: + __m512i vals __attribute__((aligned(64))); +#endif + + public: + Vectorizedqi() {} + Vectorizedqi(__m512i v) : vals(v) {} + operator __m512i() const { + return vals; + } +}; + +template +__m512i pack_saturate_and_clamp( + __m512i first, + __m512i second, + T min_val, + T max_val); + +template <> +inline __m512i pack_saturate_and_clamp( + __m512i first [[maybe_unused]], + __m512i second [[maybe_unused]], + int32_t min_val [[maybe_unused]], + int32_t max_val [[maybe_unused]]) { + // This function is for linkage only, will not be used + TORCH_CHECK(false, "pack_saturate_and_clamp is not supported"); + return __m512i{}; +} + +template <> +inline __m512i pack_saturate_and_clamp( + __m512i first, + __m512i second, + int8_t min_val, + int8_t max_val) { + __m512i packed_and_sat = _mm512_packs_epi16(first, second); + return _mm512_max_epi8( + _mm512_set1_epi8(min_val), + _mm512_min_epi8(packed_and_sat, _mm512_set1_epi8(max_val))); +} + +template <> +inline __m512i pack_saturate_and_clamp( + __m512i first, + __m512i second, + uint8_t min_val, + uint8_t max_val) { + __m512i packed_and_sat = _mm512_packus_epi16(first, second); + return _mm512_max_epu8( + _mm512_set1_epi8(min_val), + _mm512_min_epu8(packed_and_sat, _mm512_set1_epi8(max_val))); +} + +template +typename std::enable_if_t< + std::is_same_v || std::is_same_v, + at::vec::Vectorized< + float>> inline convert_int8_to_float(at::vec::Vectorized src) { + // Note: this function only convert inputs number of elements equal to + // at::vec::Vectorized.size() Only handle first 16*8 bits + __m128i input_128 = _mm512_castsi512_si128(src); + // Convert from 16*uint8/int8 to 16*int32 + __m512i input_512_extended; + if constexpr (std::is_same_v) + input_512_extended = _mm512_cvtepu8_epi32(input_128); + else + input_512_extended = _mm512_cvtepi8_epi32(input_128); + // Convert from 16*int32 to 16*float32 + return _mm512_cvtepi32_ps(input_512_extended); +} + +template +typename std::enable_if_t< + std::is_same_v || std::is_same_v, + at::vec::Vectorized< + T>> inline convert_float_to_int8(at::vec::Vectorized src) { + // Convert from float32 to int32 with truncation + __m512i x_values_int32 = _mm512_cvttps_epi32(src); + + // Convert from int32 to int16 using signed saturation + __m512i xy_packed_v = _mm512_packs_epi32(x_values_int32, x_values_int32); + + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + + // Convert from int16 to uint8/int8 using unsigned saturation + __m512i xyzw_clamped_v = + pack_saturate_and_clamp(xy_packed_v, xy_packed_v, min_val, max_val); + __m512i permute_mask_v = _mm512_set_epi32( + 0x0f, + 0x0b, + 0x07, + 0x03, + 0x0e, + 0x0a, + 0x06, + 0x02, + 0x0d, + 0x09, + 0x05, + 0x01, + 0x0c, + 0x08, + 0x04, + 0x00); + return _mm512_permutexvar_epi32(permute_mask_v, xyzw_clamped_v); +} + +template +__FORCE_INLINE void QuantizeAvx512( + const float* src, + T* dst, + int len, + float inverse_scale, + int64_t zero_point) { + constexpr int VLEN = 16; + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + const __m512i min_v = _mm512_set1_epi32(min_val); + const __m512i max_v = _mm512_set1_epi32(max_val); + // This is the largest int32 value < int32_max exactly representable in float + constexpr int32_t int32_float_max_val = + std::numeric_limits::max() - 127; + int i = 0; + __m512 inverse_scale_v = _mm512_set1_ps(inverse_scale); + // clang-format off + static const __m512i shuffle_mask_v = _mm512_set_epi8( + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0x0c, 0x08, 0x04, 0x00, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0x0c, 0x08, 0x04, 0x00, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0x0c, 0x08, 0x04, 0x00, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0x0c, 0x08, 0x04, 0x00); + // clang-format on + __m512i permute_mask_v = _mm512_set_epi32( + 0x0f, + 0x0b, + 0x07, + 0x03, + 0x0e, + 0x0a, + 0x06, + 0x02, + 0x0d, + 0x09, + 0x05, + 0x01, + 0x0c, + 0x08, + 0x04, + 0x00); + __m512i permute_mask_l8_v = _mm512_set_epi32( + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x0c, + 0x08, + 0x04, + 0x00); + int len_aligned = len / (VLEN * 4) * (VLEN * 4); + for (; i < len_aligned; i += 4 * VLEN) { + // x + __m512 x_vals = _mm512_load_ps(src + i); + __m512 x_transformed_v = _mm512_mul_ps(x_vals, inverse_scale_v); + // If the floating point value is greater than int32_max, + // _mm512_cvtps_epi32 converts them to -ve. Clip at int32_float_max_val to + // Clip at int32_float_max_val to avoid this. + x_transformed_v = + _mm512_min_ps(x_transformed_v, _mm512_set1_ps(int32_float_max_val)); + // y + __m512 y_vals = _mm512_load_ps(src + i + VLEN); + __m512 y_transformed_v = _mm512_mul_ps(y_vals, inverse_scale_v); + y_transformed_v = + _mm512_min_ps(y_transformed_v, _mm512_set1_ps(int32_float_max_val)); + // z + __m512 z_vals = _mm512_load_ps(src + i + 2 * VLEN); + __m512 z_transformed_v = _mm512_mul_ps(z_vals, inverse_scale_v); + z_transformed_v = + _mm512_min_ps(z_transformed_v, _mm512_set1_ps(int32_float_max_val)); + // w + __m512 w_vals = _mm512_load_ps(src + i + 3 * VLEN); + __m512 w_transformed_v = _mm512_mul_ps(w_vals, inverse_scale_v); + w_transformed_v = + _mm512_min_ps(w_transformed_v, _mm512_set1_ps(int32_float_max_val)); + + __m512i x_rounded_v = _mm512_cvtps_epi32(x_transformed_v); + __m512i y_rounded_v = _mm512_cvtps_epi32(y_transformed_v); + __m512i z_rounded_v = _mm512_cvtps_epi32(z_transformed_v); + __m512i w_rounded_v = _mm512_cvtps_epi32(w_transformed_v); + + // add zero point + x_rounded_v = _mm512_add_epi32(x_rounded_v, _mm512_set1_epi32(zero_point)); + y_rounded_v = _mm512_add_epi32(y_rounded_v, _mm512_set1_epi32(zero_point)); + z_rounded_v = _mm512_add_epi32(z_rounded_v, _mm512_set1_epi32(zero_point)); + w_rounded_v = _mm512_add_epi32(w_rounded_v, _mm512_set1_epi32(zero_point)); + + __m512i xy_packed_v = _mm512_packs_epi32(x_rounded_v, y_rounded_v); + __m512i zw_packed_v = _mm512_packs_epi32(z_rounded_v, w_rounded_v); + __m512i xyzw_clamped_v = + pack_saturate_and_clamp(xy_packed_v, zw_packed_v, min_val, max_val); + + xyzw_clamped_v = _mm512_permutexvar_epi32(permute_mask_v, xyzw_clamped_v); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + i), xyzw_clamped_v); + } + + // Additional 8-lane AVX512 version to take advantage when len is smaller + // based on fbgemm::QuantizeAvx2 (https://github.com/pytorch/FBGEMM) + for (; i < len / VLEN * VLEN; i += VLEN) { + __m512 x_vals = _mm512_load_ps(src + i); + __m512 x_transformed_v = _mm512_mul_ps(x_vals, inverse_scale_v); + x_transformed_v = + _mm512_min_ps(x_transformed_v, _mm512_set1_ps(int32_float_max_val)); + __m512i x_rounded_v = _mm512_cvtps_epi32(x_transformed_v); + x_rounded_v = _mm512_add_epi32(x_rounded_v, _mm512_set1_epi32(zero_point)); + __m512i x_clipped_v = + _mm512_max_epi32(min_v, _mm512_min_epi32(max_v, x_rounded_v)); + + x_clipped_v = _mm512_shuffle_epi8(x_clipped_v, shuffle_mask_v); + x_clipped_v = _mm512_permutexvar_epi32(permute_mask_l8_v, x_clipped_v); + _mm_storeu_si128( + reinterpret_cast<__m128i*>(dst + i), + _mm512_castsi512_si128(x_clipped_v)); + } + + for (; i < len; ++i) { + float transformed = src[i] * inverse_scale; + + // Not exactly the same behavior as the vectorized code. + // The vectorized code above always rounds to even in halfway cases + // (https://software.intel.com/en-us/node/523819), but std::nearbyint + // does the same only when the current rounding mode is FE_TONEAREST. + // However, in practice, this should not be a problem because most cases + // use the default rounding mode FE_TONEAREST. + // Note that we cannot implement the same behavior as the vectorized code + // using std::round because it does rounding away from zero in halfway + // cases. + transformed = zero_point + std::nearbyint(transformed); + float clipped = + std::min(std::max(transformed, float(min_val)), float(max_val)); + dst[i] = clipped; + } +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public Vectorizedqi { + using size_type = int; + static constexpr size_type size() { + return 16; + } + + static constexpr int float_num_vecs() { + return 1; + } + + static constexpr int int_num_vecs() { + return 1; + } + + using float_vec_return_type = std::array, 1>; + using int_vec_return_type = std::array, 1>; + using value_type = c10::qint32::underlying; + + public: + using Vectorizedqi::Vectorizedqi; + Vectorized() {} + + Vectorized(__m512i vals_) { + vals = vals_; + } + + // Broadcast constructor + Vectorized(const c10::qint32& val) { + value_type uw = val.val_; + vals = _mm512_set1_epi32(uw); + } + + void store(void* ptr, int count = size()) const { + if (count != size()) { + memcpy(ptr, &vals, count * sizeof(value_type)); + } else { + _mm512_storeu_si512((__m512i*)ptr, vals); + } + } + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return loadu(tmp_values); + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_zp_premul) const { + __m512 float_vals = _mm512_cvtepi32_ps(vals); + return {vec::fmadd(scale, Vectorized(float_vals), scale_zp_premul)}; + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + __m512 float_vals = _mm512_cvtepi32_ps(vals); + return {(Vectorized(float_vals) - zero_point) * scale}; + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale [[maybe_unused]]) { + Vectorized retval; + auto rhs_data = (__m512)rhs[0]; + at::native::quantize_vec( + scale, zero_point, (float*)&rhs_data, (c10::qint32*)&retval.vals, 16); + return retval; + } + + Vectorized maximum(Vectorized b) const { + return _mm512_max_epi32(vals, b.vals); + } + + Vectorized minimum(Vectorized b) const { + return _mm512_min_epi32(vals, b.vals); + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + return _mm512_min_epi32( + _mm512_max_epi32(vals, zero_point.vals), q_six.vals); + } + + int_vec_return_type widening_subtract(Vectorized b) const { + return {_mm512_sub_epi32(vals, b)}; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + __m512 multiplier_v = _mm512_set1_ps(multiplier); + __m512i zero_point_v = _mm512_set1_epi32(zero_point); + + __m512 scaled = _mm512_mul_ps(_mm512_cvtepi32_ps(inp[0]), multiplier_v); + __m512i rounded = _mm512_cvtps_epi32(scaled); + return _mm512_add_epi32(rounded, zero_point_v); + } + + private: + // Load from memory constructor + Vectorized(const void* ptr) { + vals = _mm512_loadu_si512((const __m512i*)ptr); + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return _mm512_mullo_epi32(a, b); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm512_add_epi32(a, b); +} + +/* + * Convert values from int32 back to int8/uint8 + */ +template +__m512i RequantizeAvx512( + const std::array, 4>& inp, + __m512 multiplier, + __m512i zp) { + static_assert( + std::is_same_v || std::is_same_v, + "Only int8_t/uint8_t are supported"); + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + __m512i permute_mask_v = _mm512_set_epi32( + 0x0f, + 0x0b, + 0x07, + 0x03, + 0x0e, + 0x0a, + 0x06, + 0x02, + 0x0d, + 0x09, + 0x05, + 0x01, + 0x0c, + 0x08, + 0x04, + 0x00); + __m512 x_scaled_v = _mm512_mul_ps(_mm512_cvtepi32_ps(inp[0]), multiplier); + __m512 y_scaled_v = _mm512_mul_ps(_mm512_cvtepi32_ps(inp[1]), multiplier); + __m512 z_scaled_v = _mm512_mul_ps(_mm512_cvtepi32_ps(inp[2]), multiplier); + __m512 w_scaled_v = _mm512_mul_ps(_mm512_cvtepi32_ps(inp[3]), multiplier); + + __m512i x_rounded_v = _mm512_cvtps_epi32(x_scaled_v); + __m512i y_rounded_v = _mm512_cvtps_epi32(y_scaled_v); + __m512i z_rounded_v = _mm512_cvtps_epi32(z_scaled_v); + __m512i w_rounded_v = _mm512_cvtps_epi32(w_scaled_v); + + /* Add zero point */ + __m512i x_v = _mm512_add_epi32(x_rounded_v, zp); + __m512i y_v = _mm512_add_epi32(y_rounded_v, zp); + __m512i z_v = _mm512_add_epi32(z_rounded_v, zp); + __m512i w_v = _mm512_add_epi32(w_rounded_v, zp); + + /* Pack to int16_t and saturate */ + __m512i xy_packed_v = _mm512_packs_epi32(x_v, y_v); + __m512i zw_packed_v = _mm512_packs_epi32(z_v, w_v); + + __m512i xyzw_clamped_v = + pack_saturate_and_clamp(xy_packed_v, zw_packed_v, min_val, max_val); + + /* + * xyzw_clamped_v has results in the following layout so we need to + * permute: x0-3 y0-3 z0-3 w0-3 x4-7 y4-7 z4-7 w4-7 x8-11 y8-11 z8-11 w8-11 + * x12-15 y12-15 z12-15 w12-15 + */ + xyzw_clamped_v = _mm512_permutexvar_epi32(permute_mask_v, xyzw_clamped_v); + return xyzw_clamped_v; +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public Vectorizedqi { + static constexpr int size() { + return 64; + } + + static constexpr int float_num_vecs() { + return 4; + } + + static constexpr int int_num_vecs() { + return 4; + } + + using float_vec_return_type = std::array, 4>; + using int_vec_return_type = std::array, 4>; + using value_type = typename c10::qint8::underlying; + + public: + using Vectorizedqi::Vectorizedqi; + + Vectorized() {} + Vectorized(__m512i vals_) { + vals = vals_; + } + + // Broadcast constructor + Vectorized(const c10::qint8& val) { + value_type uw = val.val_; + vals = _mm512_set1_epi8(uw); + } + + // This is needed because the compiler emits awful code for the default + // constructor for moving the enum + Vectorized(const Vectorized& other) : Vectorizedqi(other.vals) {} + + // This is added to avoid error: definition of implicit copy assignment + // operator for 'Vectorized' is deprecated because it has a + // user-declared copy constructor [-Werror,-Wdeprecated-copy] + Vectorized& operator=(const Vectorized&) = default; + + void store(void* ptr, int count = size()) const { + if (count != size()) { + memcpy(ptr, &vals, count * sizeof(value_type)); + } else { + _mm512_storeu_si512((__m512i*)ptr, vals); + } + } + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return loadu(tmp_values); + } + + private: + __m512i cvtepi8_epi32(__m128i epi8_vals) const { + return _mm512_cvtepi8_epi32(epi8_vals); + } + + public: + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_neg_zp_premul) const { +#if defined(_MSC_VER) && !defined(__clang__) + __m128i int_val0 = _mm_set_epi64x(vals.m512i_u64[1], vals.m512i_u64[0]); + __m128i int_val1 = _mm_set_epi64x(vals.m512i_u64[3], vals.m512i_u64[2]); + __m128i int_val2 = _mm_set_epi64x(vals.m512i_u64[5], vals.m512i_u64[4]); + __m128i int_val3 = _mm_set_epi64x(vals.m512i_u64[7], vals.m512i_u64[6]); +#else + __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]); + __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]); + __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]); + __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]); +#endif + + __m512 float_val0 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val0)); + __m512 float_val1 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val1)); + __m512 float_val2 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val2)); + __m512 float_val3 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val3)); + + auto val0 = + vec::fmadd(scale, Vectorized(float_val0), scale_neg_zp_premul); + auto val1 = + vec::fmadd(scale, Vectorized(float_val1), scale_neg_zp_premul); + auto val2 = + vec::fmadd(scale, Vectorized(float_val2), scale_neg_zp_premul); + auto val3 = + vec::fmadd(scale, Vectorized(float_val3), scale_neg_zp_premul); + return {val0, val1, val2, val3}; + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { +#if defined(_MSC_VER) && !defined(__clang__) + __m128i int_val0 = _mm_set_epi64x(vals.m512i_u64[1], vals.m512i_u64[0]); + __m128i int_val1 = _mm_set_epi64x(vals.m512i_u64[3], vals.m512i_u64[2]); + __m128i int_val2 = _mm_set_epi64x(vals.m512i_u64[5], vals.m512i_u64[4]); + __m128i int_val3 = _mm_set_epi64x(vals.m512i_u64[7], vals.m512i_u64[6]); +#else + __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]); + __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]); + __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]); + __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]); +#endif + + __m512 float_val0 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val0)); + __m512 float_val1 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val1)); + __m512 float_val2 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val2)); + __m512 float_val3 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val3)); + + auto val0 = (Vectorized(float_val0) - zero_point) * scale; + auto val1 = (Vectorized(float_val1) - zero_point) * scale; + auto val2 = (Vectorized(float_val2) - zero_point) * scale; + auto val3 = (Vectorized(float_val3) - zero_point) * scale; + return {val0, val1, val2, val3}; + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + auto* rhs_data = (float*)rhs.data(); + int8_t quantized_values[64]; + QuantizeAvx512( + rhs_data, quantized_values, 64, inverse_scale, zero_point); + return Vectorized::loadu(quantized_values); + } + + Vectorized maximum(Vectorized b) const { + return _mm512_max_epi8(vals, b.vals); + } + + Vectorized minimum(Vectorized b) const { + return _mm512_min_epi8(vals, b.vals); + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + return _mm512_min_epi8(_mm512_max_epi8(vals, zero_point.vals), q_six.vals); + } + + int_vec_return_type widening_subtract(Vectorized b) const { +#if defined(_MSC_VER) && !defined(__clang__) + __m128i int_val0 = _mm_set_epi64x(vals.m512i_u64[1], vals.m512i_u64[0]); + __m128i int_val1 = _mm_set_epi64x(vals.m512i_u64[3], vals.m512i_u64[2]); + __m128i int_val2 = _mm_set_epi64x(vals.m512i_u64[5], vals.m512i_u64[4]); + __m128i int_val3 = _mm_set_epi64x(vals.m512i_u64[7], vals.m512i_u64[6]); +#else + __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]); + __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]); + __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]); + __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]); +#endif + + __m512i int32_val0 = cvtepi8_epi32(int_val0); + __m512i int32_val1 = cvtepi8_epi32(int_val1); + __m512i int32_val2 = cvtepi8_epi32(int_val2); + __m512i int32_val3 = cvtepi8_epi32(int_val3); + +#if defined(_MSC_VER) && !defined(__clang__) + __m128i int_b0 = _mm_set_epi64x(b.vals.m512i_u64[1], b.vals.m512i_u64[0]); + __m128i int_b1 = _mm_set_epi64x(b.vals.m512i_u64[3], b.vals.m512i_u64[2]); + __m128i int_b2 = _mm_set_epi64x(b.vals.m512i_u64[5], b.vals.m512i_u64[4]); + __m128i int_b3 = _mm_set_epi64x(b.vals.m512i_u64[7], b.vals.m512i_u64[6]); +#else + __m128i int_b0 = _mm_set_epi64x(b.vals[1], b.vals[0]); + __m128i int_b1 = _mm_set_epi64x(b.vals[3], b.vals[2]); + __m128i int_b2 = _mm_set_epi64x(b.vals[5], b.vals[4]); + __m128i int_b3 = _mm_set_epi64x(b.vals[7], b.vals[6]); +#endif + + __m512i int32_b0 = cvtepi8_epi32(int_b0); + __m512i int32_b1 = cvtepi8_epi32(int_b1); + __m512i int32_b2 = cvtepi8_epi32(int_b2); + __m512i int32_b3 = cvtepi8_epi32(int_b3); + + __m512i res_0 = _mm512_sub_epi32(int32_val0, int32_b0); + __m512i res_1 = _mm512_sub_epi32(int32_val1, int32_b1); + __m512i res_2 = _mm512_sub_epi32(int32_val2, int32_b2); + __m512i res_3 = _mm512_sub_epi32(int32_val3, int32_b3); + + return { + Vectorized(res_0), + Vectorized(res_1), + Vectorized(res_2), + Vectorized(res_3)}; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + __m512 multiplier_v = _mm512_set1_ps(multiplier); + __m512i zero_point_v = _mm512_set1_epi32(zero_point); + return RequantizeAvx512(inp, multiplier_v, zero_point_v); + } + + private: + // Load from memory constructor + Vectorized(const void* ptr) { + vals = _mm512_loadu_si512((const __m512i*)ptr); + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public Vectorizedqi { + static constexpr int size() { + return 64; + } + + static constexpr int float_num_vecs() { + return 4; + } + + static constexpr int int_num_vecs() { + return 4; + } + + using float_vec_return_type = std::array, 4>; + using int_vec_return_type = std::array, 4>; + using value_type = typename c10::quint8::underlying; + + public: + using Vectorizedqi::Vectorizedqi; + Vectorized() {} + + Vectorized(__m512i vals_) { + vals = vals_; + } + + // Broadcast constructor + Vectorized(const c10::quint8& val) { + value_type uw = val.val_; + vals = _mm512_set1_epi8(uw); + } + + Vectorized(const Vectorized& other) : Vectorizedqi(other.vals) {} + + // This is added to avoid error: definition of implicit copy assignment + // operator for 'Vectorized' is deprecated because it has a + // user-declared copy constructor [-Werror,-Wdeprecated-copy] + Vectorized& operator=(const Vectorized&) = default; + + void store(void* ptr, int count = size()) const { + if (count != size()) { + memcpy(ptr, &vals, count * sizeof(value_type)); + } else { + _mm512_storeu_si512((__m512i*)ptr, vals); + } + } + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return loadu(tmp_values); + } + + private: + __m512i cvtepu8_epi32(__m128i epu8_vals) const { + return _mm512_cvtepu8_epi32(epu8_vals); + } + + public: + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_zp_premul) const { +#if defined(_MSC_VER) && !defined(__clang__) + __m128i int_val0 = _mm_set_epi64x(vals.m512i_u64[1], vals.m512i_u64[0]); + __m128i int_val1 = _mm_set_epi64x(vals.m512i_u64[3], vals.m512i_u64[2]); + __m128i int_val2 = _mm_set_epi64x(vals.m512i_u64[5], vals.m512i_u64[4]); + __m128i int_val3 = _mm_set_epi64x(vals.m512i_u64[7], vals.m512i_u64[6]); +#else + __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]); + __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]); + __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]); + __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]); +#endif + + __m512 float_val0 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val0)); + __m512 float_val1 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val1)); + __m512 float_val2 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val2)); + __m512 float_val3 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val3)); + + auto val0 = + vec::fmadd(scale, Vectorized(float_val0), scale_zp_premul); + auto val1 = + vec::fmadd(scale, Vectorized(float_val1), scale_zp_premul); + auto val2 = + vec::fmadd(scale, Vectorized(float_val2), scale_zp_premul); + auto val3 = + vec::fmadd(scale, Vectorized(float_val3), scale_zp_premul); + + return {val0, val1, val2, val3}; + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { +#if defined(_MSC_VER) && !defined(__clang__) + __m128i int_val0 = _mm_set_epi64x(vals.m512i_u64[1], vals.m512i_u64[0]); + __m128i int_val1 = _mm_set_epi64x(vals.m512i_u64[3], vals.m512i_u64[2]); + __m128i int_val2 = _mm_set_epi64x(vals.m512i_u64[5], vals.m512i_u64[4]); + __m128i int_val3 = _mm_set_epi64x(vals.m512i_u64[7], vals.m512i_u64[6]); +#else + __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]); + __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]); + __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]); + __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]); +#endif + + __m512 float_val0 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val0)); + __m512 float_val1 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val1)); + __m512 float_val2 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val2)); + __m512 float_val3 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val3)); + + auto val0 = (Vectorized(float_val0) - zero_point) * scale; + auto val1 = (Vectorized(float_val1) - zero_point) * scale; + auto val2 = (Vectorized(float_val2) - zero_point) * scale; + auto val3 = (Vectorized(float_val3) - zero_point) * scale; + + return {val0, val1, val2, val3}; + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + auto* rhs_data = (float*)rhs.data(); + uint8_t quantized_values[64]; + QuantizeAvx512( + rhs_data, quantized_values, 64, inverse_scale, zero_point); + return Vectorized::loadu(quantized_values); + } + + Vectorized maximum(Vectorized b) const { + return _mm512_max_epu8(vals, b.vals); + } + + Vectorized minimum(Vectorized b) const { + return _mm512_min_epu8(vals, b.vals); + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + return _mm512_min_epu8(_mm512_max_epu8(vals, zero_point.vals), q_six.vals); + } + + int_vec_return_type widening_subtract(Vectorized b) const { +#if defined(_MSC_VER) && !defined(__clang__) + __m128i int_val0 = _mm_set_epi64x(vals.m512i_u64[1], vals.m512i_u64[0]); + __m128i int_val1 = _mm_set_epi64x(vals.m512i_u64[3], vals.m512i_u64[2]); + __m128i int_val2 = _mm_set_epi64x(vals.m512i_u64[5], vals.m512i_u64[4]); + __m128i int_val3 = _mm_set_epi64x(vals.m512i_u64[7], vals.m512i_u64[6]); +#else + __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]); + __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]); + __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]); + __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]); +#endif + + __m512i int32_val0 = cvtepu8_epi32(int_val0); + __m512i int32_val1 = cvtepu8_epi32(int_val1); + __m512i int32_val2 = cvtepu8_epi32(int_val2); + __m512i int32_val3 = cvtepu8_epi32(int_val3); + +#if defined(_MSC_VER) && !defined(__clang__) + __m128i int_b0 = _mm_set_epi64x(b.vals.m512i_u64[1], b.vals.m512i_u64[0]); + __m128i int_b1 = _mm_set_epi64x(b.vals.m512i_u64[3], b.vals.m512i_u64[2]); + __m128i int_b2 = _mm_set_epi64x(b.vals.m512i_u64[5], b.vals.m512i_u64[4]); + __m128i int_b3 = _mm_set_epi64x(b.vals.m512i_u64[7], b.vals.m512i_u64[6]); +#else + __m128i int_b0 = _mm_set_epi64x(b.vals[1], b.vals[0]); + __m128i int_b1 = _mm_set_epi64x(b.vals[3], b.vals[2]); + __m128i int_b2 = _mm_set_epi64x(b.vals[5], b.vals[4]); + __m128i int_b3 = _mm_set_epi64x(b.vals[7], b.vals[6]); +#endif + + __m512i int32_b0 = cvtepu8_epi32(int_b0); + __m512i int32_b1 = cvtepu8_epi32(int_b1); + __m512i int32_b2 = cvtepu8_epi32(int_b2); + __m512i int32_b3 = cvtepu8_epi32(int_b3); + + __m512i res_0 = _mm512_sub_epi32(int32_val0, int32_b0); + __m512i res_1 = _mm512_sub_epi32(int32_val1, int32_b1); + __m512i res_2 = _mm512_sub_epi32(int32_val2, int32_b2); + __m512i res_3 = _mm512_sub_epi32(int32_val3, int32_b3); + return { + Vectorized(res_0), + Vectorized(res_1), + Vectorized(res_2), + Vectorized(res_3)}; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + __m512 multiplier_v = _mm512_set1_ps(multiplier); + __m512i zero_point_v = _mm512_set1_epi32(zero_point); + return RequantizeAvx512(inp, multiplier_v, zero_point_v); + } + + private: + // Load from memory constructor + Vectorized(const void* ptr) { + vals = _mm512_loadu_si512((const __m512i*)ptr); + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +#else + +// NOTE: These are low-performance implementations that we fall back on. + +template < + typename T, + typename float_vec_return_type_, + typename int_vec_return_type_, + int size_> +struct VectorizedQuantizedConverter { + static constexpr int size() { + return size_; + } + + static constexpr int float_num_vecs() { + return size() / 8; + } + + static constexpr int int_num_vecs() { + return size() / 8; + } + + using float_vec_return_type = float_vec_return_type_; + using int_vec_return_type = int_vec_return_type_; + + using value_type = typename T::underlying; + std::array vals; + + VectorizedQuantizedConverter(T val) { + for (const auto i : c10::irange(size())) { + vals[i] = val.val_; + } + } + + VectorizedQuantizedConverter(const void* ptr) { + memcpy(vals.data(), ptr, sizeof(value_type) * size()); + } + + void store(void* ptr, int count = size()) const { + memcpy(ptr, vals.data(), count * sizeof(value_type)); + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_zp_premul [[maybe_unused]]) const { + float_vec_return_type rv; + for (const auto i : c10::irange(float_num_vecs())) { + float tmp_vals[16]; + for (const auto j : c10::irange(16)) { + tmp_vals[j] = at::native::dequantize_val( + scale[j], zero_point[j], T(vals[16 * i + j])); + } + rv[i] = Vectorized( + tmp_vals[0], + tmp_vals[1], + tmp_vals[2], + tmp_vals[3], + tmp_vals[4], + tmp_vals[5], + tmp_vals[6], + tmp_vals[7], + tmp_vals[8], + tmp_vals[9], + tmp_vals[10], + tmp_vals[11], + tmp_vals[12], + tmp_vals[13], + tmp_vals[14], + tmp_vals[15]); + } + return rv; + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + Vectorized scale_zp_premul; + return dequantize(scale, zero_point, scale_zp_premul); + } + + protected: + VectorizedQuantizedConverter() {} +}; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + 16> { + Vectorized() + : VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + 16>() {} + Vectorized(c10::qint32 val) + : VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + 16>(val) {} + Vectorized(const void* ptr) + : VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + 16>(ptr) {} + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return loadu(tmp_values); + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale [[maybe_unused]]) { + std::array qvals; + std::array float_vals; + + for (const auto i : c10::irange(float_num_vecs())) { + rhs[i].store(&float_vals[i * 16], 16); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::qint32*)qvals.data(), + 16 * float_num_vecs()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + for (const auto i : c10::irange(size())) { + retval[0].vals[i] = vals[i] - b.vals[i]; + } + return retval; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = + std::nearbyint(static_cast(inp[0].vals[i]) * multiplier) + + zero_point; + } + return retval; + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + Vectorized retval; + for (const auto i : c10::irange(std::decay_t::size())) { + retval.vals[i] = a.vals[i] * b.vals[i]; + } + return retval; +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + Vectorized retval; + for (const auto i : c10::irange(std::decay_t::size())) { + retval.vals[i] = a.vals[i] + b.vals[i]; + } + return retval; +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + 64> { + Vectorized() + : VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + 64>() {} + Vectorized(c10::qint8 val) + : VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + 64>(val) {} + Vectorized(const void* ptr) + : VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + 64>(ptr) {} + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return loadu(tmp_values); + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale [[maybe_unused]]) { + std::array qvals; + std::array float_vals; + + for (const auto i : c10::irange(float_num_vecs())) { + rhs[i].store(&float_vals[i * 16], 16); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::qint8*)qvals.data(), + 16 * float_num_vecs()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + constexpr int elem_per_int_vec = size() / int_num_vecs(); + for (const auto i : c10::irange(int_num_vecs())) { + for (const auto j : c10::irange(elem_per_int_vec)) { + retval[i].vals[j] = + static_cast(vals[i * elem_per_int_vec + j]) - + static_cast(b.vals[i * elem_per_int_vec + j]); + } + } + return retval; + } + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + constexpr int elem_per_int_vec = size() / int_num_vecs(); + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + Vectorized retval; + for (const auto i : c10::irange(int_num_vecs())) { + for (const auto j : c10::irange(elem_per_int_vec)) { + int32_t rounded = + std::nearbyint(static_cast(inp[i].vals[j]) * multiplier) + + zero_point; + retval.vals[i * elem_per_int_vec + j] = + std::min(std::max(rounded, min_val), max_val); + } + } + return retval; + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + 64> { + Vectorized() + : VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + 64>() {} + Vectorized(c10::quint8 val) + : VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + 64>(val) {} + Vectorized(const void* ptr) + : VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + 64>(ptr) {} + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return loadu(tmp_values); + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale [[maybe_unused]]) { + std::array qvals; + std::array float_vals; + + for (const auto i : c10::irange(float_num_vecs())) { + rhs[i].store(&float_vals[i * 16], 16); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::quint8*)qvals.data(), + 16 * float_num_vecs()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + constexpr int elem_per_int_vec = size() / int_num_vecs(); + for (const auto i : c10::irange(int_num_vecs())) { + for (const auto j : c10::irange(elem_per_int_vec)) { + retval[i].vals[j] = + static_cast(vals[i * elem_per_int_vec + j]) - + static_cast(b.vals[i * elem_per_int_vec + j]); + } + } + return retval; + } + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + constexpr int elem_per_int_vec = size() / int_num_vecs(); + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + Vectorized retval; + for (const auto i : c10::irange(int_num_vecs())) { + for (const auto j : c10::irange(elem_per_int_vec)) { + int32_t rounded = + std::nearbyint(static_cast(inp[i].vals[j]) * multiplier) + + zero_point; + retval.vals[i * elem_per_int_vec + j] = + std::min(std::max(rounded, min_val), max_val); + } + } + return retval; + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +#endif // defined(CPU_CAPABILITY_AVX512) && !defined(MSVC) + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec_base.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec_base.h new file mode 100644 index 0000000000000000000000000000000000000000..42ee90d6b7f49b94036b37e435fe6910e187f010 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec_base.h @@ -0,0 +1,1512 @@ +#pragma once +#if defined(__GNUC__) && __GNUC__ == 10 && __GNUC_MINOR__ <= 2 && \ + defined(__ARM_FEATURE_SVE) +// Workaround for https: //gcc.gnu.org/bugzilla/show_bug.cgi?id=117161 +#pragma GCC optimize("no-tree-vectorize") +#endif + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] +// +// Note [Do not compile initializers with AVX] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// If you define a static initializer in this file, the initialization will use +// AVX instructions because these object files are compiled with AVX enabled. +// We need to avoid non-trivial global data in these architecture specific files +// because there's no way to guard the global initializers with CPU capability +// detection. +// +// See https://github.com/pytorch/pytorch/issues/37577 for an instance +// of this bug in the past. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(__GNUC__) +#define __FORCE_INLINE __attribute__((always_inline)) inline +#elif defined(_MSC_VER) +#define __FORCE_INLINE __forceinline +#endif + +#if defined(_MSC_FULL_VER) +/* +https://learn.microsoft.com/en-us/cpp/overview/compiler-versions?view=msvc-170 +Use _MSC_FULL_VER to identify current compiler is msvc, +Windows llvm will not have this definition. +*/ +#define __msvc_cl__ +#endif + +// These macros helped us unify vec_base.h +#ifdef CPU_CAPABILITY_AVX512 +#if defined(__GNUC__) +#define __at_align__ __attribute__((aligned(64))) +#elif defined(_WIN32) +#define __at_align__ __declspec(align(64)) +#else +#define __at_align__ +#endif +#define VECTOR_WIDTH 64 +#define int_vector __m512i +#elif defined(__aarch64__) && \ + !defined(CPU_CAPABILITY_SVE) // CPU_CAPABILITY_AVX512 +// SVE code expects 256-vectors; leave that set for SVE? +#if defined(__GNUC__) +#define __at_align__ __attribute__((aligned(16))) +#elif defined(_WIN32) +#define __at_align__ __declspec(align(16)) +#else +#define __at_align__ +#endif +#define VECTOR_WIDTH 16 +#else // CPU_CAPABILITY_AVX512 +#if defined(__GNUC__) +#define __at_align__ __attribute__((aligned(32))) +#elif defined(_WIN32) +#define __at_align__ __declspec(align(32)) +#else +#define __at_align__ +#endif +#define VECTOR_WIDTH 32 +#define int_vector __m256i +#endif // CPU_CAPABILITY_AVX512 + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { +// at::Half and at::BFloat16 should be treated as floating point +template +struct is_floating_point + : std::integral_constant< + bool, + std::is_floating_point_v || std::is_same_v || + std::is_same_v> {}; + +template +constexpr bool is_floating_point_v = is_floating_point::value; + +template +struct is_reduced_floating_point + : std::integral_constant< + bool, + std::is_same_v || std::is_same_v> {}; + +template +constexpr bool is_reduced_floating_point_v = + is_reduced_floating_point::value; + +template +struct is_8bit_integer + : std::integral_constant< + bool, + std::is_same_v || std::is_same_v> { +}; + +template +constexpr bool is_8bit_integer_v = is_8bit_integer::value; + +template +struct int_of_size; + +#define DEFINE_INT_OF_SIZE(int_t) \ + template <> \ + struct int_of_size { \ + using type = int_t; \ + } + +DEFINE_INT_OF_SIZE(int64_t); +DEFINE_INT_OF_SIZE(int32_t); +DEFINE_INT_OF_SIZE(int16_t); +DEFINE_INT_OF_SIZE(int8_t); + +#undef DEFINE_INT_OF_SIZE + +template +using int_same_size_t = typename int_of_size::type; + +/** + * Detect at compile time whether Vectorized has an explicit + * specialization for T. (You are required to specialize this type + * whenever you specialize Vectorized). Useful for generic algorithms + * to decide whether to rely on a specialization being fast. For + * example, they might choose to handle reduced-precision floating + * point types directly if they're supported, or convert through float + * if not. + */ +#if defined(__s390x__) +template +#else +template +#endif +struct is_vec_specialized_for : std::bool_constant { +}; + +template +constexpr bool is_vec_specialized_for_v = is_vec_specialized_for::value; + +// NOTE: If you specialize Vectorized on a type, you must define all +// operations! You must also specialize is_vec_specialized_for for +// that type. + +// emulates Vectorized types +#if defined(__s390x__) +template +#else +template +#endif +struct Vectorized { + private: + __at_align__ T values[VECTOR_WIDTH / sizeof(T)]; + + public: + using value_type = T; + using size_type = int; + + static constexpr size_type kSize = VECTOR_WIDTH / sizeof(T); + static constexpr size_type size() { + return kSize; + } + Vectorized() : values{static_cast(0)} {} + Vectorized(T val) { + for (int i = 0; i != size(); i++) { + values[i] = val; + } + } + template < + typename... Args, + typename = std::enable_if_t<(sizeof...(Args) == size())>> + Vectorized(Args... vals) : values{vals...} {} + Vectorized(const T (&arr)[kSize]) { + std::memcpy(values, arr, sizeof(values)); + } + // This also implies const T& operator[](int idx) const + inline operator const T*() const { + return values; + } + // This also implies T& operator[](int idx) + inline operator T*() { + return values; + } + // Return the values as char* for type punning + auto as_bytes() const -> const char* { + return reinterpret_cast(values); + } + template + static Vectorized blend(const Vectorized& a, const Vectorized& b) { + int64_t mask = mask_; + Vectorized vector; + for (const auto i : c10::irange(size())) { + if (mask & 0x01) { + vector[i] = b[i]; + } else { + vector[i] = a[i]; + } + mask = mask >> 1; + } + return vector; + } +// Workaround for https: //gcc.gnu.org/bugzilla/show_bug.cgi?id=117001 +#if __GNUC__ <= 12 && !defined(__clang__) && defined(__ARM_FEATURE_SVE) + static Vectorized __attribute__((optimize("-fno-tree-loop-vectorize"))) + blendv( + const Vectorized& a, +#else + static Vectorized blendv( + const Vectorized& a, +#endif + const Vectorized& b, + const Vectorized& mask) { + Vectorized vector; + int_same_size_t buffer[size()]; + mask.store(buffer); +#if defined(__clang__) && __ARM_FEATURE_SVE +#pragma clang loop vectorize(disable) +#endif + for (const auto i : c10::irange(size())) { + if (buffer[i] & 0x01) { + vector[i] = b[i]; + } else { + vector[i] = a[i]; + } + } + return vector; + } + template // step sometimes requires a higher precision type + // (e.g., T=int, step_t=double) + static Vectorized arange( + T base = static_cast(0), + step_t step = static_cast(1)) { + Vectorized vector; + for (const auto i : c10::irange(size())) { + vector.values[i] = base + i * step; + } + return vector; + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + Vectorized vector; + for (const auto i : c10::irange(size())) { + if (i < count) { + vector[i] = b[i]; + } else { + vector[i] = a[i]; + } + } + return vector; + } + static Vectorized loadu(const void* ptr) { + Vectorized vector; + std::memcpy(vector.values, ptr, VECTOR_WIDTH); + return vector; + } + static Vectorized loadu(const void* ptr, int64_t count) { + Vectorized vector; + std::memcpy(vector.values, ptr, count * sizeof(T)); + return vector; + } + static Vectorized loadu_one_fourth(const void* ptr) { + static_assert( + std::is_same_v || std::is_same_v, + "For byte types only"); + return Vectorized::loadu(ptr, 8); + } + + void store(void* ptr, int count = size()) const { + std::memcpy(ptr, values, count * sizeof(T)); + } + int zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit + // and others are translated to 0-bit + int mask = 0; + for (int i = 0; i < size(); ++i) { + if (values[i] == static_cast(0)) { + mask |= (1 << i); + } + } + return mask; + } + Vectorized isnan() const { + Vectorized vector; + for (int64_t i = 0; i != size(); i++) { + if (_isnan(values[i])) { + std::memset(static_cast(vector.values + i), 0xFF, sizeof(T)); + } else { + std::memset(static_cast(vector.values + i), 0, sizeof(T)); + } + } + return vector; + } + bool has_inf_nan() const { + for (int64_t i = 0; i != size(); i++) { + if (_isnan(values[i]) || _isinf(values[i])) { + return true; + } + } + return false; + } +// MSVC versions between 14.36 and 14.42 has a loop unrolling bug on Windows +// Arm64 +// See +// https://developercommunity.visualstudio.com/t/MSVC-loop-unrolling-problem-194033813-/10720692 +#if defined(_WIN32) && defined(__aarch64__) && \ + ((_MSVC_VER >= 1936) && (_MSVC_VER <= 1942)) + Vectorized map(T (*const f)(T)) const { + Vectorized ret; + for (int64_t i = 0; i < size(); i++) { + ret[i] = f(values[i]); + if (++i < size()) + ret[i] = f(values[i]); + } + return ret; + } + T reduce(T (*const f)(T)) const { + T ret = 0; + for (int64_t i = 0; i < size(); i++) { + ret = f(ret, values[i]); + if (++i < size()) + ret = f(ret, values[i]); + } + return ret; + } +#else + Vectorized map(T (*const f)(T)) const { + Vectorized ret; + for (int64_t i = 0; i != size(); i++) { + ret[i] = f(values[i]); + } + return ret; + } + T reduce(T (*const f)(T)) const { + T ret = 0; + for (int64_t i = 0; i != size(); i++) { + ret = f(ret, values[i]); + } + return ret; + } +#endif + Vectorized map(T (*const f)(const T&)) const { + Vectorized ret; + for (int64_t i = 0; i != size(); i++) { + ret[i] = f(values[i]); + } + return ret; + } + T reduce(T (*const f)(const T&)) const { + T ret = 0; + for (int64_t i = 0; i != size(); i++) { + ret = f(ret, values[i]); + } + return ret; + } + template < + typename other_t_abs = T, + typename std::enable_if_t< + !is_floating_point_v && + !c10::is_complex::value, + int> = 0> + Vectorized abs() const { + // other_t_abs is for SFINAE and clarity. Make sure it is not changed. + static_assert(std::is_same_v, "other_t_abs must be T"); + return map([](T x) -> T { return x < static_cast(0) ? -x : x; }); + } + template < + typename float_t_abs = T, + typename std::enable_if_t, int> = 0> + Vectorized abs() const { + // float_t_abs is for SFINAE and clarity. Make sure it is not changed. + static_assert(std::is_same_v, "float_t_abs must be T"); + // Specifically deal with floating-point because the generic code above + // won't handle -0.0 (which should result in 0.0) properly. + return map([](T x) -> T { return std::abs(x); }); + } + template < + typename complex_t_abs = T, + typename std::enable_if_t::value, int> = 0> + Vectorized abs() const { + // complex_t_abs is for SFINAE and clarity. Make sure it is not changed. + static_assert(std::is_same_v, "complex_t_abs must be T"); + // Specifically map() does not perform the type conversion needed by abs. + return map([](T x) { return static_cast(std::abs(x)); }); + } + + template < + typename other_t_sgn = T, + typename std::enable_if_t::value, int> = 0> + Vectorized sgn() const { + return map(at::native::sgn_impl); + } + + template < + typename other_t_angle = T, + typename std::enable_if_t::value, int> = + 0> + Vectorized angle() const { + // other_t_angle is for SFINAE and clarity. Make sure it is not changed. + static_assert(std::is_same_v, "other_t_angle must be T"); + return map(at::native::angle_impl); // compiler is unable to resolve the + // overload without + } + template < + typename complex_t_angle = T, + typename std::enable_if_t::value, int> = + 0> + Vectorized angle() const { + // complex_t_angle is for SFINAE and clarity. Make sure it is not changed. + static_assert( + std::is_same_v, "complex_t_angle must be T"); + return map([](T x) { return static_cast(std::arg(x)); }); + } + template < + typename other_t_real = T, + typename std::enable_if_t::value, int> = 0> + Vectorized real() const { + // other_t_real is for SFINAE and clarity. Make sure it is not changed. + static_assert(std::is_same_v, "other_t_real must be T"); + return *this; + } + template < + typename complex_t_real = T, + typename std::enable_if_t::value, int> = + 0> + Vectorized real() const { + // complex_t_real is for SFINAE and clarity. Make sure it is not changed. + static_assert( + std::is_same_v, "complex_t_real must be T"); + return map([](T x) { return static_cast(x.real()); }); + } + template < + typename other_t_imag = T, + typename std::enable_if_t::value, int> = 0> + Vectorized imag() const { + // other_t_imag is for SFINAE and clarity. Make sure it is not changed. + static_assert(std::is_same_v, "other_t_imag must be T"); + return Vectorized(0); + } + template < + typename complex_t_imag = T, + typename std::enable_if_t::value, int> = + 0> + Vectorized imag() const { + // complex_t_imag is for SFINAE and clarity. Make sure it is not changed. + static_assert( + std::is_same_v, "complex_t_imag must be T"); + return map([](T x) { return static_cast(x.imag()); }); + } + template < + typename other_t_conj = T, + typename std::enable_if_t::value, int> = 0> + Vectorized conj() const { + // other_t_conj is for SFINAE and clarity. Make sure it is not changed. + static_assert(std::is_same_v, "other_t_conj must be T"); + return *this; + } + template < + typename complex_t_conj = T, + typename std::enable_if_t::value, int> = + 0> + Vectorized conj() const { + // complex_t_conj is for SFINAE and clarity. Make sure it is not changed. + static_assert( + std::is_same_v, "complex_t_conj must be T"); + return map([](T x) { return static_cast(std::conj(x)); }); + } + Vectorized acos() const { + return map(std::acos); + } + Vectorized acosh() const { + return map(std::acosh); + } + Vectorized asin() const { + return map(std::asin); + } + Vectorized asinh() const { + return map(std::asinh); + } + Vectorized atan() const { + return map(std::atan); + } + Vectorized atanh() const { + return map(std::atanh); + } + Vectorized atan2(const Vectorized& exp) const { + Vectorized ret; + for (const auto i : c10::irange(size())) { + ret[i] = std::atan2(values[i], exp[i]); + } + return ret; + } + template < + typename U = T, + typename std::enable_if_t, int> = 0> + Vectorized copysign(const Vectorized& sign) const { + Vectorized ret; + for (size_type i = 0; i < size(); i++) { + ret[i] = c10::copysign(values[i], sign[i]); + } + return ret; + } + Vectorized erf() const { + return map(std::erf); + } + Vectorized erfc() const { + return map(std::erfc); + } + Vectorized erfinv() const { + return map(calc_erfinv); + } + Vectorized exp() const { + return map(std::exp); + } + Vectorized exp2() const { + return map(exp2_impl); + } + Vectorized expm1() const { + return map(std::expm1); + } + Vectorized exp_u20() const { + return map(std::exp); + } + Vectorized frac() const { + return *this - this->trunc(); + } + template < + typename U = T, + typename std::enable_if_t, int> = 0> + Vectorized fmod(const Vectorized& q) const { + // U is for SFINAE purposes only. Make sure it is not changed. + static_assert(std::is_same_v, "U must be T"); + Vectorized ret; + for (const auto i : c10::irange(size())) { + ret[i] = std::fmod(values[i], q[i]); + } + return ret; + } + Vectorized log() const { + return map(std::log); + } + Vectorized log10() const { + return map(std::log10); + } + Vectorized log1p() const { + return map(std::log1p); + } + template < + typename other_t_log2 = T, + typename std::enable_if_t::value, int> = 0> + Vectorized log2() const { + // other_t_log2 is for SFINAE and clarity. Make sure it is not changed. + static_assert(std::is_same_v, "other_t_log2 must be T"); + return map(std::log2); + } + template < + typename complex_t_log2 = T, + typename std::enable_if_t::value, int> = + 0> + Vectorized log2() const { + // complex_t_log2 is for SFINAE and clarity. Make sure it is not changed. + static_assert( + std::is_same_v, "complex_t_log2 must be T"); + const T log_2 = T(std::log(2.0)); + return Vectorized(map(std::log)) / Vectorized(log_2); + } + Vectorized ceil() const { + return map(at::native::ceil_impl); + } + Vectorized cos() const { + return map(std::cos); + } + Vectorized cosh() const { + return map(std::cosh); + } + Vectorized floor() const { + return map(at::native::floor_impl); + } + Vectorized hypot(const Vectorized& b) const { + Vectorized ret; + for (const auto i : c10::irange(size())) { + ret[i] = std::hypot(values[i], b[i]); + } + return ret; + } + Vectorized i0() const { + return map(calc_i0); + } + Vectorized i0e() const { + return map(calc_i0e); + } + Vectorized digamma() const { + return map(calc_digamma); + } + Vectorized igamma(const Vectorized& x) const { + Vectorized ret; + for (const auto i : c10::irange(size())) { + ret[i] = calc_igamma(values[i], x[i]); + } + return ret; + } + Vectorized igammac(const Vectorized& x) const { + Vectorized ret; + for (const auto i : c10::irange(size())) { + ret[i] = calc_igammac(values[i], x[i]); + } + return ret; + } + Vectorized neg() const { + // NB: the trailing return type is needed because we need to coerce the + // return value back to T in the case of unary operator- incuring a + // promotion + return map([](T x) -> T { return -x; }); + } + Vectorized nextafter(const Vectorized& b) const { + Vectorized ret; + for (const auto i : c10::irange(size())) { + ret[i] = std::nextafter(values[i], b[i]); + } + return ret; + } + Vectorized round() const { + // We do not use std::round because we would like to round midway numbers to + // the nearest even integer. + return map(at::native::round_impl); + } + Vectorized sin() const { + return map(std::sin); + } + Vectorized sinh() const { + return map(std::sinh); + } + Vectorized tan() const { + return map(std::tan); + } + Vectorized tanh() const { + return map(std::tanh); + } + Vectorized trunc() const { + return map(at::native::trunc_impl); + } + Vectorized lgamma() const { + return map(std::lgamma); + } + Vectorized sqrt() const { + return map(std::sqrt); + } + Vectorized reciprocal() const { + return map([](T x) { return (T)(1) / x; }); + } + Vectorized rsqrt() const { + return map([](T x) { return (T)1 / std::sqrt(x); }); + } + Vectorized pow(const Vectorized& exp) const { + Vectorized ret; + for (const auto i : c10::irange(size())) { + ret[i] = std::pow(values[i], exp[i]); + } + return ret; + } + T reduce_add() const { + return reduce([](T x, T y) -> T { return x + y; }); + } + T reduce_max() const { + return reduce(std::max); + } + + private: + template + inline Vectorized binary_pred(const Vectorized& other, Op op) const { + // All bits are set to 1 if the pred is true, otherwise 0. + Vectorized vector; + for (int64_t i = 0; i != size(); i++) { + if (op(values[i], other.values[i])) { + std::memset(static_cast(vector.values + i), 0xFF, sizeof(T)); + } else { + std::memset(static_cast(vector.values + i), 0, sizeof(T)); + } + } + return vector; + } + + public: + Vectorized operator==(const Vectorized& other) const { + return binary_pred(other, std::equal_to()); + } + Vectorized operator!=(const Vectorized& other) const { + return binary_pred(other, std::not_equal_to()); + } + Vectorized operator>=(const Vectorized& other) const { + return binary_pred(other, std::greater_equal()); + } + Vectorized operator<=(const Vectorized& other) const { + return binary_pred(other, std::less_equal()); + } + Vectorized operator>(const Vectorized& other) const { + return binary_pred(other, std::greater()); + } + Vectorized operator<(const Vectorized& other) const { + return binary_pred(other, std::less()); + } + + private: + template + inline Vectorized binary_pred_bool(const Vectorized& other, Op op) + const { + // 1 if the pred is true, otherwise 0. + Vectorized vector; + for (int i = 0; i != size(); ++i) { + vector[i] = static_cast(op(values[i], other.values[i])); + } + return vector; + } + + public: + Vectorized eq(const Vectorized& other) const { + return binary_pred_bool(other, std::equal_to()); + } + Vectorized ne(const Vectorized& other) const { + return binary_pred_bool(other, std::not_equal_to()); + } + Vectorized gt(const Vectorized& other) const { + return binary_pred_bool(other, std::greater()); + } + Vectorized ge(const Vectorized& other) const { + return binary_pred_bool(other, std::greater_equal()); + } + Vectorized lt(const Vectorized& other) const { + return binary_pred_bool(other, std::less()); + } + Vectorized le(const Vectorized& other) const { + return binary_pred_bool(other, std::less_equal()); + } +}; + +template +Vectorized inline operator-(const Vectorized& a) { + return a.neg(); +} + +// There is an implicit conversion that would make this work if +// these operators weren't template functions, but they are template +// functions (and can't be moved to be non-member friends defined in +// the class body as suggested in +// https://stackoverflow.com/questions/9787593/implicit-type-conversion-with-template/9788255#9788255 +// because we have a lot of disparate specializations of +// Vectorized). So, just explicitly make scalars work. +#define VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(name) \ + template \ + Vectorized inline name(const Vectorized& a, T b) { \ + return name(a, Vectorized(b)); \ + } \ + template \ + Vectorized inline name(T a, const Vectorized& b) { \ + return name(Vectorized(a), b); \ + } +#define VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(op) \ + VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(operator op) + +template +Vectorized inline operator+(const Vectorized& a, const Vectorized& b) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = a[i] + b[i]; + } + return c; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(+) + +template +Vectorized inline operator-(const Vectorized& a, const Vectorized& b) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = a[i] - b[i]; + } + return c; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(-) + +template +Vectorized inline operator*(const Vectorized& a, const Vectorized& b) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = a[i] * b[i]; + } + return c; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(*) + +template +Vectorized inline operator/(const Vectorized& a, const Vectorized& b) + __ubsan_ignore_float_divide_by_zero__ { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = a[i] / b[i]; + } + return c; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(/) + +template , int> = 0> +Vectorized inline operator%(const Vectorized& a, const Vectorized& b) + __ubsan_ignore_float_divide_by_zero__ { + return a - a / b * b; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(%) + +template +Vectorized inline operator||( + const Vectorized& a, + const Vectorized& b) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = a[i] || b[i]; + } + return c; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(||) + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = (a[i] > b[i]) ? a[i] : b[i]; + if (_isnan(a[i])) { + // If either input is NaN, propagate a NaN. + // NOTE: The case where b[i] was NaN is handled correctly by the naive + // ternary operator above. + c[i] = a[i]; + } + } + return c; +} + +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = (std::abs(a[i]) > std::abs(b[i])) ? a[i] : b[i]; + if (_isnan(a[i])) { + // If either input is NaN, propagate a NaN. + // NOTE: The case where b[i] was NaN is handled correctly by the naive + // ternary operator above. + c[i] = a[i]; + } + } + return c; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(maximum) + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = (a[i] < b[i]) ? a[i] : b[i]; + if (_isnan(a[i])) { + // If either input is NaN, propagate a NaN. + // NOTE: The case where b[i] was NaN is handled correctly by the naive + // ternary operator above. + c[i] = a[i]; + } + } + return c; +} + +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = (std::abs(a[i]) < std::abs(b[i])) ? a[i] : b[i]; + if (_isnan(a[i])) { + // If either input is NaN, propagate a NaN. + // NOTE: The case where b[i] was NaN is handled correctly by the naive + // ternary operator above. + c[i] = a[i]; + } + } + return c; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(minimum) + +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min_vec, + const Vectorized& max_vec) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = std::min(std::max(a[i], min_vec[i]), max_vec[i]); + } + return c; +} + +#define VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(name) \ + template \ + Vectorized inline name( \ + const Vectorized& a, const Vectorized& b, T c) { \ + return name(a, b, Vectorized(c)); \ + } \ + \ + template \ + Vectorized inline name( \ + const Vectorized& a, T b, const Vectorized& c) { \ + return name(a, Vectorized(b), c); \ + } \ + \ + template \ + Vectorized inline name(const Vectorized& a, T b, T c) { \ + return name(a, Vectorized(b), Vectorized(c)); \ + } \ + \ + template \ + Vectorized inline name( \ + T a, const Vectorized& b, const Vectorized& c) { \ + return name(Vectorized(a), b, c); \ + } \ + \ + template \ + Vectorized inline name(T a, const Vectorized& b, T c) { \ + return name(Vectorized(a), b, Vectorized(c)); \ + } \ + \ + template \ + Vectorized inline name(T a, T b, const Vectorized& c) { \ + return name(Vectorized(a), Vectorized(b), c); \ + } + +VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(clamp) + +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max_vec) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = a[i] > max_vec[i] ? max_vec[i] : a[i]; + } + return c; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(clamp_max) + +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min_vec) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = a[i] < min_vec[i] ? min_vec[i] : a[i]; + } + return c; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(clamp_min) + +struct Vectorizedi; + +#if defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512) +template +static inline Vectorized bitwise_binary_op( + const Vectorized& a, + const Vectorized& b, + Op op) { + int_vector buffer; +#if defined(CPU_CAPABILITY_AVX2) + int_vector a_buffer = + _mm256_load_si256(reinterpret_cast((const T*)a)); + int_vector b_buffer = + _mm256_load_si256(reinterpret_cast((const T*)b)); +#elif defined(CPU_CAPABILITY_AVX512) + int_vector a_buffer = + _mm512_load_si512(reinterpret_cast((const T*)a)); + int_vector b_buffer = + _mm512_load_si512(reinterpret_cast((const T*)b)); +#endif + buffer = op(a_buffer, b_buffer); + __at_align__ T results[Vectorized::size()]; + +#if defined(CPU_CAPABILITY_AVX2) + _mm256_store_si256(reinterpret_cast(results), buffer); +#elif defined(CPU_CAPABILITY_AVX512) + _mm512_store_si512(reinterpret_cast(results), buffer); +#endif + return Vectorized::loadu(results); +} + +template < + class T, + typename std::enable_if_t< + !std::is_base_of>::value, + int> = 0> +inline Vectorized operator&(const Vectorized& a, const Vectorized& b) { + // We enclose _mm512_and_si512 or _mm256_and_si256 with lambda because it is + // always_inline +#if defined(CPU_CAPABILITY_AVX2) + return bitwise_binary_op( + a, b, [](int_vector a, int_vector b) { return _mm256_and_si256(a, b); }); +#elif defined(CPU_CAPABILITY_AVX512) + return bitwise_binary_op( + a, b, [](int_vector a, int_vector b) { return _mm512_and_si512(a, b); }); +#endif +} +template < + class T, + typename std::enable_if_t< + !std::is_base_of>::value, + int> = 0> +inline Vectorized operator|(const Vectorized& a, const Vectorized& b) { + // We enclose _mm512_or_si512 or _mm256_or_si256 with lambda because it is + // always_inline +#if defined(CPU_CAPABILITY_AVX2) + return bitwise_binary_op( + a, b, [](int_vector a, int_vector b) { return _mm256_or_si256(a, b); }); +#elif defined(CPU_CAPABILITY_AVX512) + return bitwise_binary_op( + a, b, [](int_vector a, int_vector b) { return _mm512_or_si512(a, b); }); +#endif +} +template < + class T, + typename std::enable_if_t< + !std::is_base_of>::value, + int> = 0> +inline Vectorized operator^(const Vectorized& a, const Vectorized& b) { + // We enclose _mm512_xor_si512 or _mm256_xor_si256 with lambda because it is + // always_inline +#if defined(CPU_CAPABILITY_AVX2) + return bitwise_binary_op( + a, b, [](int_vector a, int_vector b) { return _mm256_xor_si256(a, b); }); +#elif defined(CPU_CAPABILITY_AVX512) + return bitwise_binary_op( + a, b, [](int_vector a, int_vector b) { return _mm512_xor_si512(a, b); }); +#endif +} + +#else + +template +auto load(char const* data) -> T { + T ret; + std::memcpy(&ret, data, sizeof(ret)); + return ret; +} + +template +static inline Vectorized bitwise_binary_op( + const Vectorized& a, + const Vectorized& b, + Op op) { + static constexpr uint32_t element_no = VECTOR_WIDTH / sizeof(intmax_t); + __at_align__ intmax_t buffer[element_no]; + static_assert( + VECTOR_WIDTH % sizeof(intmax_t) == 0, + "VECTOR_WIDTH not a multiple of sizeof(intmax_t)"); + static_assert( + sizeof(buffer) == sizeof(Vectorized), + "sizeof(buffer) must match sizeof(Vectorized)"); + // We should be using memcpy in order to respect the strict aliasing rule + // see: https://github.com/pytorch/pytorch/issues/66119 + // Using char* is defined in the C11 standard 6.5 Expression paragraph 7 + // (http://www.open-std.org/jtc1/sc22/wg14/www/docs/n1570.pdf) + const auto* a_data = a.as_bytes(); + const auto* b_data = b.as_bytes(); + // load each intmax_t chunk and process; increase pointers by sizeof(intmax_t) + for (auto& out : buffer) { + out = op(load(a_data), load(b_data)); + a_data += sizeof(intmax_t); + b_data += sizeof(intmax_t); + } + assert(a_data == a.as_bytes() + sizeof(a)); + assert(b_data == b.as_bytes() + sizeof(b)); + return Vectorized::loadu(buffer); +} + +template < + class T, + typename std:: + enable_if_t>, int> = 0> +inline Vectorized operator&(const Vectorized& a, const Vectorized& b) { + return bitwise_binary_op(a, b, std::bit_and()); +} +template < + class T, + typename std:: + enable_if_t>, int> = 0> +inline Vectorized operator|(const Vectorized& a, const Vectorized& b) { + return bitwise_binary_op(a, b, std::bit_or()); +} +template < + class T, + typename std:: + enable_if_t>, int> = 0> +inline Vectorized operator^(const Vectorized& a, const Vectorized& b) { + return bitwise_binary_op(a, b, std::bit_xor()); +} + +#endif // defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512) + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(&) +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(|) +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(^) + +template < + class T, + typename std:: + enable_if_t>, int> = 0> +inline Vectorized operator~(const Vectorized& a) { + using int_t = int_same_size_t; + Vectorized ones(c10::bit_cast((int_t)(~(int_t)0))); // All bits are 1 + return a ^ ones; +} + +template +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + constexpr T max_shift = sizeof(T) * CHAR_BIT; + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + T shift = b[i]; + if ((static_cast>(shift) < 0) || + (shift >= max_shift)) { + c[i] = 0; + } else { + c[i] = static_cast>(a[i]) << shift; + } + } + return c; +} + +template +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + // right shift value to retain sign bit for signed and no bits for unsigned + constexpr T max_shift = sizeof(T) * CHAR_BIT - std::is_signed_v; + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + T shift = b[i]; + if ((static_cast>(shift) < 0) || + (shift >= max_shift)) { + c[i] = a[i] >> max_shift; + } else { + c[i] = a[i] >> shift; + } + } + return c; +} + +template +inline Vectorized& operator+=(Vectorized& a, const Vectorized& b) { + a = a + b; + return a; +} +template +inline Vectorized& operator-=(Vectorized& a, const Vectorized& b) { + a = a - b; + return a; +} +template +inline Vectorized& operator/=(Vectorized& a, const Vectorized& b) { + a = a / b; + return a; +} +template +inline Vectorized& operator%=(Vectorized& a, const Vectorized& b) { + a = a % b; + return a; +} +template +inline Vectorized& operator*=(Vectorized& a, const Vectorized& b) { + a = a * b; + return a; +} + +template +inline Vectorized& operator<<=(Vectorized& a, const Vectorized& b) { + a = a << b; + return a; +} + +template +inline Vectorized& operator>>=(Vectorized& a, const Vectorized& b) { + a = a >> b; + return a; +} + +template +inline Vectorized fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return a * b + c; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(fmadd) + +template +inline Vectorized fmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return a * b - c; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(fmsub) + +template +Vectorized inline operator&&( + const Vectorized& a, + const Vectorized& b) { + Vectorized ret; + for (int i = 0; i != Vectorized::size(); i++) { + ret[i] = a[i] && b[i]; + } + return ret; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(&&) + +template +std::enable_if_t< + scale == 1 || scale == 2 || scale == 4 || scale == 8, + Vectorized< + T>> inline gather(T const* base_addr, const Vectorized>& vindex) { + static constexpr int size = Vectorized::size(); + int_same_size_t index_arr[size]; + vindex.store(static_cast(index_arr)); + T buffer[size]; + for (const auto i : c10::irange(size)) { + buffer[i] = base_addr[index_arr[i] * scale / sizeof(T)]; + } + return Vectorized::loadu(static_cast(buffer)); +} + +template +std:: + enable_if_t> inline mask_gather( + const Vectorized& src, + T const* base_addr, + const Vectorized>& vindex, + Vectorized& mask) { + static constexpr int size = Vectorized::size(); + T src_arr[size]; + int_same_size_t mask_arr[size]; // use int type so we can logical and + int_same_size_t index_arr[size]; + src.store(static_cast(src_arr)); + mask.store(static_cast(mask_arr)); + vindex.store(static_cast(index_arr)); + T buffer[size]; + for (const auto i : c10::irange(size)) { + if (mask_arr[i] & 0x01) { // check highest bit + buffer[i] = base_addr[index_arr[i] * scale / sizeof(T)]; + } else { + buffer[i] = src_arr[i]; + } + } + mask = Vectorized(static_cast(0)); // "zero out" mask + return Vectorized::loadu(static_cast(buffer)); +} + +// Cast a given vector to another type without changing the bits representation. +// So a Vectorized of 512 bits containing all ones can be cast to a +// Vectorized of 512 bits containing all ones (i.e., eight negative +// 1s). A Vec of 256 bits containing all ones can be cast to a +// Vec of 256 bits containing all ones (i.e., four negative 1s). +// There is a struct here because we don't have static_if and I can't +// partially specialize a templated function. +template +struct CastImpl { + static inline Vectorized apply(const Vectorized& src) { + src_t src_arr[Vectorized::size()]; + src.store(static_cast(src_arr)); + return Vectorized::loadu(static_cast(src_arr)); + } +}; + +template +struct CastImpl { + static inline Vectorized apply(const Vectorized& src) { + return src; + } +}; + +template +inline Vectorized cast(const Vectorized& src) { + return CastImpl::apply(src); +} + +template > +inline Vectorized convert_to_int_of_same_size( + const Vectorized& src) { + static_assert(sizeof(T) == sizeof(IntType)); + static constexpr int size = Vectorized::size(); + + std::array src_arr = {}; + src.store(static_cast(src_arr.data())); + std::array buffer; + std::transform( + src_arr.cbegin(), src_arr.cend(), buffer.begin(), [](const T& x) { + return static_cast(x); + }); + return Vectorized::loadu(static_cast(buffer.data())); +} + +template > +inline Vectorized convert_to_fp_of_same_size( + const Vectorized& src) { + static_assert(sizeof(T) == sizeof(IntType)); + static constexpr int size = Vectorized::size(); + + std::array src_arr; + src.store(static_cast(src_arr.data())); + std::array buffer; + std::transform( + src_arr.cbegin(), src_arr.cend(), buffer.begin(), [](const IntType& x) { + return static_cast(x); + }); + return Vectorized::loadu(static_cast(buffer.data())); +} + +// clang-format off +// Example inputs for AVX512: +// a Vectorized = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7} +// b Vectorized = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15} +// returns: +// Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15} +// Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15} +// Example inputs for AVX2: a Vectorized = {a0, b0, a1, b1, a2, b2, a3, b3} +// b Vectorized = {a4, b4, a5, b5, a6, b6, a7, b7} +// returns: Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7} +// Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7} +// clang-format on +template +inline std::enable_if_t< + Vectorized::size() % 2 == 0, + std::pair, Vectorized>> +deinterleave2(const Vectorized& a, const Vectorized& b) { + static constexpr int size = Vectorized::size(); + static constexpr int half_size = size / 2; + T a_arr[size]; + T b_arr[size]; + T buffer1[size]; + T buffer2[size]; + a.store(static_cast(a_arr)); + b.store(static_cast(b_arr)); + for (const auto i : c10::irange(half_size)) { + buffer1[i] = a_arr[i * 2]; + buffer1[half_size + i] = b_arr[i * 2]; + buffer2[i] = a_arr[i * 2 + 1]; + buffer2[half_size + i] = b_arr[i * 2 + 1]; + } + return std::make_pair( + Vectorized::loadu(static_cast(buffer1)), + Vectorized::loadu(static_cast(buffer2))); +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(deinterleave2) + +// clang-format off +// inverse operation of deinterleave2 +// Example inputs for AVX512: +// a Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15} +// b Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15} +// returns, for AVX512: +// Vectorized = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7} +// Vectorized = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15} +// Example inputs for AVX2 : a Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7} +// b Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7} +// returns: Vectorized = {a0, b0, a1, b1, a2, b2, a3, b3} +// Vectorized = {a4, b4, a5, b5, a6, b6, a7, b7} +// clang-format on +template +inline std::enable_if_t< + Vectorized::size() % 2 == 0, + std::pair, Vectorized>> +interleave2(const Vectorized& a, const Vectorized& b) { + static constexpr int size = Vectorized::size(); + static constexpr int half_size = size / 2; + T a_arr[size]; + T b_arr[size]; + T buffer1[size]; + T buffer2[size]; + a.store(static_cast(a_arr)); + b.store(static_cast(b_arr)); + for (const auto i : c10::irange(half_size)) { + buffer1[i * 2] = a_arr[i]; + buffer1[i * 2 + 1] = b_arr[i]; + buffer2[i * 2] = a_arr[half_size + i]; + buffer2[i * 2 + 1] = b_arr[half_size + i]; + } + return std::make_pair( + Vectorized::loadu(static_cast(buffer1)), + Vectorized::loadu(static_cast(buffer2))); +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(interleave2) + +#undef VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC +#undef VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP +#undef VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC + +template +inline void convert(const src_T* src, dst_T* dst, int64_t n) { +#ifndef _MSC_VER +#pragma unroll +#endif + for ([[maybe_unused]] const auto i : c10::irange(n)) { + *dst = c10::convert(c10::load(src)); + src++; + dst++; + } +} + +template +inline Vectorized flip(const Vectorized& data) { + static constexpr int size = Vectorized::size(); + T output[size]; + T buffer[size]; + data.store(static_cast(buffer)); + for (const auto i : c10::irange(size)) { + output[i] = buffer[size - i - 1]; + } + return Vectorized::loadu(static_cast(output)); +} + +// Transpose the `src` buffer of type `T` and size (M,N) into the `dst` buffer. +// `ld_src` is the leading dimension of `src` and `ld_dst` is the leading +// dimension of `dst`. +template +inline void transpose_mxn( + const T* src, + int64_t ld_src, + T* dst, + int64_t ld_dst, + int M, + int N) { + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + dst[j * ld_dst + i] = src[i * ld_src + j]; + } + } +} + +template +inline void transpose_mxn( + const T* src, + int64_t ld_src, + T* dst, + int64_t ld_dst) { + transpose_mxn(src, ld_src, dst, ld_dst, M, N); +} + +} // namespace CPU_CAPABILITY +} // namespace at::vec + +// additional headers for more operations that depend on vec_base +#include +#include +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec_convert.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec_convert.h new file mode 100644 index 0000000000000000000000000000000000000000..60264c1ef0119529cd18dd3301d22a8caa58affd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec_convert.h @@ -0,0 +1,79 @@ +#pragma once + +#include +#include + +namespace at::vec { +inline namespace CPU_CAPABILITY { + +template < + typename dst_t, + int dst_n, + typename src_t, + int src_n, + typename Enabled = void> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + constexpr int count = std::min( + VectorizedN::size(), VectorizedN::size()); + __at_align__ src_t src_buf[VectorizedN::size()]; + src.store(src_buf); + __at_align__ dst_t dst_buf[VectorizedN::size()]; + for (int i = 0; i < count; i++) { + dst_buf[i] = static_cast(src_buf[i]); + } + return VectorizedN::loadu(dst_buf, count); + } +}; + +template +inline std::enable_if_t, Vectorized> convert( + const Vectorized& src) { + return src; +} + +template +inline std::enable_if_t, Vectorized> +convert(const Vectorized& src) { + return VecConvert::apply(src); +} + +template < + typename dst_t, + int dst_n, + typename src_t, + int src_n, + std::enable_if_t = 0> +inline VectorizedN convert(const VectorizedN& src) { + return VecConvert::apply(src); +} + +template < + typename dst_t, + int dst_n, + typename src_t, + int src_n, + bool keep = false, + std::enable_if_t = 0> +inline std::conditional_t, Vectorized> +convert(const VectorizedN& src) { + return VecConvert::apply(src); +} + +} // namespace CPU_CAPABILITY + +template < + typename scalar_t, + typename std::enable_if_t, int> = 0> +inline std::tuple, Vectorized> convert_to_float( + const Vectorized&); + +template < + typename scalar_t, + typename std::enable_if_t, int> = 0> +inline Vectorized convert_from_float( + const Vectorized&, + const Vectorized&); + +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec_half.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec_half.h new file mode 100644 index 0000000000000000000000000000000000000000..556dd3826e12dbaba4509295e1e627b0a9d5f550 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec_half.h @@ -0,0 +1,156 @@ +#pragma once + +#include +#include + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \ + !defined(__APPLE__) +static inline uint16_t float2half_scalar(float val) { +#if defined(CPU_CAPABILITY_AVX2) +#if defined(_MSC_VER) + __m256 v = _mm256_set1_ps(val); + __m128i o = + _mm256_cvtps_ph(v, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + return static_cast(_mm_cvtsi128_si32(o)); +#else + return _cvtss_sh(val, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); +#endif +#elif defined(CPU_CAPABILITY_AVX512) + __m512 v = _mm512_set1_ps(val); + __m256i o = + _mm512_cvtps_ph(v, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + return static_cast( + _mm_cvtsi128_si32(_mm256_castsi256_si128(o))); +#endif +} + +static inline float half2float_scalar(uint16_t val) { +#if defined(CPU_CAPABILITY_AVX2) +#if defined(_MSC_VER) + __m128i v = _mm_cvtsi32_si128(val); + __m256 o = _mm256_cvtph_ps(v); + return _mm256_cvtss_f32(o); +#else + return _cvtsh_ss(val); +#endif +#elif defined(CPU_CAPABILITY_AVX512) + __m256i v = + _mm256_setr_epi16(val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0); + __m512 o = _mm512_cvtph_ps(v); + return _mm512_cvtss_f32(o); +#endif +} + +#endif + +// Transpose a [2, 32] matrix to [32, 2] +// Note: the output leading dimension should be 2, +// that is, the output must be contiguous +template > +static inline void transpose_pad_2x32_block( + const scalar_t* src, + scalar_t* dst, + int64_t ld_src, + int krem = 2, + int nrem = 32) { +#if defined(CPU_CAPABILITY_AVX512) + __m512i r0, r1; + __m512i d0, d1; + // load + if (nrem < 32) { + __mmask32 mask_krem_v = (1LL << nrem) - 1; + r0 = _mm512_maskz_loadu_epi16(mask_krem_v, src); + // if krem is not 2, pad with zeros + if (krem == 2) { + r1 = _mm512_maskz_loadu_epi16(mask_krem_v, src + ld_src); + } else { + r1 = _mm512_setzero_si512(); + } + } else { + r0 = _mm512_loadu_si512(reinterpret_cast(src)); + if (krem == 2) { + r1 = _mm512_loadu_si512(reinterpret_cast(src + ld_src)); + } else { + r1 = _mm512_setzero_si512(); + } + } + // transpose + d0 = _mm512_unpacklo_epi16(r0, r1); + d1 = _mm512_unpackhi_epi16(r0, r1); + r0 = _mm512_shuffle_i32x4(d0, d1, 0x88); + r1 = _mm512_shuffle_i32x4(d0, d1, 0xdd); + d0 = _mm512_shuffle_i32x4(r0, r1, 0x88); + d1 = _mm512_shuffle_i32x4(r0, r1, 0xdd); + + // store + if (nrem < 16) { + __mmask32 mask_rem_v = (1LL << (nrem * 2)) - 1; + _mm512_mask_storeu_epi16(dst, mask_rem_v, d0); + } else if (nrem == 16) { + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), d0); + } else if (nrem < 32) { + __mmask32 mask_rem_v = (1LL << (nrem * 2 - 32)) - 1; + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), d0); + _mm512_mask_storeu_epi16( + reinterpret_cast<__m512i*>(dst + 32), mask_rem_v, d1); + } else { + // normal store + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), d0); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 32), d1); + } +#else + TORCH_CHECK( + false, + "transpose_pad_2x32_block is only supported when avx512 is supported") +#endif +} + +// To use AMX to accelerate GEMM, +// reorder the memory format [K, N] -> [K/2, N, 2] +// Note: If K % 2 != 0, pad K implicitly +template > +static inline void pack_vnni2( + const scalar_t* src, + scalar_t* dst, + int64_t ld_src, + int64_t K, + int64_t N) { +#if defined(CPU_CAPABILITY_AVX512) + int64_t bk = 0; + int64_t _K = K / 2 * 2; + int64_t _N = N / 32 * 32; + for (; bk < _K; bk += 2) { + int64_t bn = 0; + for (; bn < _N; bn += 32) { + transpose_pad_2x32_block( + src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src); + } + int64_t nrem = N - bn; + if (nrem > 0) { + transpose_pad_2x32_block( + src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 2, nrem); + } + } + if (K % 2 == 1) { + int64_t bn = 0; + for (; bn < _N; bn += 32) { + transpose_pad_2x32_block( + src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 1); + } + int64_t nrem = N - bn; + if (nrem > 0) { + transpose_pad_2x32_block( + src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 1, nrem); + } + } +#else + TORCH_CHECK(false, "pack_vnni2 is only supported when avx512 is supported") +#endif +} + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec_mask.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec_mask.h new file mode 100644 index 0000000000000000000000000000000000000000..a945ea2e0a4e59f37516821fb1ae9d6d2d381ed0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec_mask.h @@ -0,0 +1,300 @@ +#pragma once + +#include +#include +namespace at::vec { +inline namespace CPU_CAPABILITY { + +/** + * The `VecMask` class provides a convenient interface for working with + * vectorized masks in SIMD operations. It encapsulates a `Vectorized` + * mask that can be directly usable in masked vectorized operations. It provides + * various methods for manipulating and accessing the mask elements: + * 1. `from` and `to`: Conversion between a vector of boolean values and a + * vectorized mask. + * 2. `cast`: Casts the mask to a different base type. + * 3. `all_zero`: Checks if all mask elements are zero. + * 4. `is_masked`: Checks if a specific element is masked. + * 5. `loadu`: Loads data from memory using the mask. + * 6. `all_masked`: Checks if all mask elements are masked. + * + * Some helper template classes are provided to simplify the specialization of + * the `VecMask` for the specific CPU arch: + * 1. `VecMaskLoad`: Loads data from memory using the mask. + * 2. `VecMaskTo`: Converts the mask to boolean. + * 3. `VecMaskCast`: Casts the mask to a different base type. + * + */ +template +class VecMask; + +template < + typename data_t, + int data_n, + typename mask_t, + int mask_n, + typename Enabled = void> +struct VecMaskLoad { + static inline VectorizedN apply( + const data_t* ptr, + const VecMask& vec_mask) { + constexpr typename VecMask::size_type size = + VecMask::size(); + static_assert(VectorizedN::size() >= size); + __at_align__ data_t data[size]; + __at_align__ mask_t mask[size]; + auto mask_ = VectorizedN(vec_mask); + mask_.store(mask); + for (int i = 0; i < size; i++) { + data[i] = mask[i] ? ptr[i] : static_cast(0); + } + return VectorizedN::loadu(data, size); + } +}; + +template < + typename dst_t, + int dst_n, + typename src_t, + int src_n, + typename Enabled = void> +struct VecMaskTo { + static inline VecMask apply( + const VecMask& vec_mask) { + auto zeros = VectorizedN(static_cast(0)); + auto ones = VectorizedN(static_cast(1)); + return VectorizedN::blendv( + zeros, ones, vec_mask.template cast()); + } +}; + +template < + typename dst_t, + int dst_n, + typename src_t, + int src_n, + typename Enabled = void> +struct VecMaskCast { + static inline VecMask apply( + const VecMask& vec_mask) { + return VecMask::from(VectorizedN(vec_mask)); + } +}; + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + return vec_mask; + } +}; + +template +struct VecMaskCheck { + static inline bool all_zero(const VectorizedN& vec_mask) { + __at_align__ T mask[VectorizedN::size()]; + vec_mask.store(mask); + return std::all_of(mask, mask + VectorizedN::size(), [](T m) { + return m == static_cast(0); + }); + } + + static inline bool all_masked(const VectorizedN& vec_mask) { + __at_align__ T mask[VectorizedN::size()]; + vec_mask.store(mask); + return std::all_of(mask, mask + VectorizedN::size(), [](T m) { + return m != static_cast(0); + }); + } + + static inline bool is_masked(const VectorizedN& vec_mask, int i) { + __at_align__ T mask[VectorizedN::size()]; + vec_mask.store(mask); + return mask[i] != static_cast(0); + } +}; + +template +class VecMask { + public: + using size_type = int; + static constexpr size_type size() { + return VectorizedN::size(); + } + + private: + VectorizedN mask_; + + public: + VecMask() : mask_(static_cast(0)) {} + VecMask(const VectorizedN& mask) : mask_(mask) {} + + template = 0> + VecMask(const Vectorized& mask) : mask_(mask) {} + + template + static VecMask from(const VectorizedN& b_vec) { + __at_align__ U b_buf[size()]; + if constexpr (size() >= VectorizedN::size()) { + b_vec.store(b_buf); + for (int i = VectorizedN::size(); i < size(); i++) { + b_buf[i] = static_cast(0); + } + } else { + b_vec.store(b_buf, size()); + } + return from(b_buf); + } + + template + static VecMask from(U b) { + using int_t = int_same_size_t; + T mask = b ? c10::bit_cast((int_t)(~(int_t)0)) : (T)0; + return VectorizedN(mask); + } + + template + static VecMask from(U* b) { + using int_t = int_same_size_t; + __at_align__ T mask[size()]; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (int i = 0; i < size(); i++) { + *(int_t*)(mask + i) = b[i] ? ~(int_t)0 : (int_t)0; + } + return VectorizedN(VectorizedN::loadu(mask)); + } + + static VecMask blendv( + const VecMask& c, + const VecMask& b, + const VecMask& a) { + VectorizedN result = VectorizedN::blendv( + VectorizedN(c), VectorizedN(b), VectorizedN(a)); + return result; + } + + static VecMask set( + const VecMask& a, + const VecMask& b, + int64_t count = size()) { + VectorizedN result = VectorizedN::set( + VectorizedN(a), VectorizedN(b), count); + return result; + } + + void store(bool* b, int count = size()) { + constexpr int L = + (VectorizedN::size() + Vectorized::size() - 1) / + Vectorized::size(); + auto res = this->to(); + res.store(b, count); + return; + } + + template = 2, int> = 0> + inline VectorizedN to() const { + return VecMaskTo::apply(*this); + } + + template = 0> + inline Vectorized to() const { + return VecMaskTo::apply(*this); + } + + template + inline VecMask cast() const { + return VecMaskCast::apply(*this); + } + + inline bool all_zero() const { + return VecMaskCheck::all_zero(mask_); + } + + inline bool all_masked() const { + return VecMaskCheck::all_masked(mask_); + } + + inline bool is_masked(int i) const { + return VecMaskCheck::is_masked(mask_, i); + } + + inline operator VectorizedN() const { + return mask_; + } + + template = 0> + inline operator Vectorized() const { + return mask_[0]; + } + + inline Vectorized operator[](int i) const { + return mask_[i]; + } + + template < + typename U, + int L, + std::enable_if_t= 2 && VectorizedN::size() >= size(), int> = 0> + VectorizedN loadu(const U* ptr) const { + return VecMaskLoad::apply(ptr, *this); + } + + template < + typename U, + int L, + std::enable_if_t::size() >= size(), int> = 0> + Vectorized loadu(const U* ptr) const { + return VecMaskLoad::apply(ptr, *this); + } +}; + +#define VEC_MASK_DEFINE_UNARY_OP_GLOBAL(op) \ + template \ + inline VecMask op(const VecMask& a) { \ + return op(VectorizedN(a)); \ + } + +#define VEC_MASK_DEFINE_BINARY_OP_GLOBAL(op) \ + template < \ + typename T, \ + int N, \ + typename V, \ + int M, \ + std::enable_if_t::size() == VecMask::size(), int> = \ + 0> \ + inline VecMask op(const VecMask& a, const VecMask& b) { \ + return op( \ + VectorizedN(a), VectorizedN(b.template cast())); \ + } + +#define VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(op, EXPR) \ + template < \ + typename T, \ + int N, \ + typename V, \ + int M, \ + std::enable_if_t::size() == VecMask::size(), int> = \ + 0> \ + inline VecMask op(const VecMask& a, const VecMask& b) { \ + return EXPR; \ + } + +VEC_MASK_DEFINE_UNARY_OP_GLOBAL(operator~) +VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator&) +VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator|) +VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator^) +VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator*) +VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator>, a & ~b) +VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator<, ~a& b) +VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator==, ~(a ^ b)) +VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator>=, (a == b) | (a > b)) +VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator<=, (a == b) | (a < b)) +VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator!=, (a ^ b)) + +#undef VEC_MASK_DEFINE_UNARY_OP_GLOBAL +#undef VEC_MASK_DEFINE_BINARY_OP_GLOBAL +#undef VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec_n.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec_n.h new file mode 100644 index 0000000000000000000000000000000000000000..84624ac7c985ff4454ace81a90bf254c0c5c54e8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vec/vec_n.h @@ -0,0 +1,405 @@ +#pragma once + +#include +#include + +namespace at::vec { +inline namespace CPU_CAPABILITY { + +/** + * @brief A class template representing a vectorized type with + * `N * Vectorized::size()` elements, aiming to support vectors of + * arbitrary size. A specific use case of it is to represent vectors + * converted from data types with different sizes but with the same + * number of vector elements, e.g., `VectorizedN` can be + * a vector converted from two `Vectorized`, `VectorizedN` + * can be a vector converted from two `Vectorized` etc. + * + * It supports most of the operations of `Vectorized` + * and the implementation delegates to `Vectorized` with loops over `N`. + * + * @tparam T The underlying type of the vectorized elements. + * @tparam N The number of underlying `Vectorized`. + */ +template +class VectorizedN { + public: + using value_type = T; + using size_type = int; + + static constexpr size_type size_T = sizeof(T); + static constexpr size_type size() { + return Vectorized::size() * N; + } + + private: + std::array, N> values; + + public: + // methods not implemented yet: + // variadic constructor, operator T*, as_bytes, zero_mask + +#define VECTORIZEDN_DEFINE_UNARY_OP(op) \ + VectorizedN op() const { \ + return unary_op([](const Vectorized& a) { return a.op(); }); \ + } + +#define VECTORIZEDN_DEFINE_BINARY_OP(op) \ + VectorizedN op(const VectorizedN& other) const { \ + return binary_op( \ + other, [](const Vectorized& a, const Vectorized& b) { \ + return a.op(b); \ + }); \ + } + + template + inline VectorizedN unary_op(Op op) const { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result.values[i] = op(values[i]); + } + return result; + } + + template + inline VectorizedN binary_op(const VectorizedN& other, Op op) + const { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result.values[i] = op(values[i], other.values[i]); + } + return result; + } + + template + inline VectorizedN ternary_op( + const VectorizedN& other, + const VectorizedN& other2, + Op op) const { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result.values[i] = op(values[i], other.values[i], other2.values[i]); + } + return result; + } + + VectorizedN() = default; + + explicit VectorizedN(T val) { + for (int i = 0; i < N; ++i) { + values[i] = Vectorized(val); + } + } + + template = 0> + VectorizedN(const Vectorized& val) : values({val}) {} + + template = 0> + VectorizedN(const Vectorized& val_0, const Vectorized& val_1) + : values({val_0, val_1}) {} + + template = 0> + inline operator Vectorized() const { + return values[0]; + } + + inline const Vectorized& operator[](int i) const { + return values[i]; + } + + inline Vectorized& operator[](int i) { + return values[i]; + } + + template + static VectorizedN blend( + const VectorizedN& a, + const VectorizedN& b) { + VectorizedN result; + for (int i = 0; i < N; ++i) { + result.values[i] = + Vectorized::template blend(a.values[i], b.values[i]); + } + return result; + } + + static VectorizedN blendv( + const VectorizedN& a, + const VectorizedN& b, + const VectorizedN& mask) { + VectorizedN result; + for (int i = 0; i < N; ++i) { + result.values[i] = + Vectorized::blendv(a.values[i], b.values[i], mask.values[i]); + } + return result; + } + + template + static VectorizedN arange( + T base = static_cast(0), + step_t step = static_cast(1)) { + VectorizedN result; + for (int i = 0; i < N; ++i) { + result.values[i] = Vectorized::arange(base, step); + base += step * Vectorized::size(); + } + return result; + } + + static VectorizedN set( + const VectorizedN& a, + const VectorizedN& b, + int64_t count = size()) { + VectorizedN result; + for (int i = 0; i < N; ++i) { + if (count > 0) { + result.values[i] = Vectorized::set( + a.values[i], + b.values[i], + std::min(count, (int64_t)Vectorized::size())); + count -= Vectorized::size(); + } else { + result.values[i] = a.values[i]; + } + } + return result; + } + + static VectorizedN loadu(const void* ptr) { + VectorizedN result; + for (int i = 0; i < N; ++i) { + result.values[i] = Vectorized::loadu(ptr); + ptr = static_cast(ptr) + Vectorized::size(); + } + return result; + } + + static VectorizedN loadu(const void* ptr, int64_t count) { + VectorizedN result; + for (int i = 0; i < N; ++i) { + result.values[i] = Vectorized::loadu( + ptr, std::min(count, (int64_t)Vectorized::size())); + ptr = static_cast(ptr) + Vectorized::size(); + count -= Vectorized::size(); + if (count <= 0) { + break; + } + } + return result; + } + + void store(void* ptr) const { + for (int i = 0; i < N; ++i) { + values[i].store(ptr); + ptr = static_cast(ptr) + Vectorized::size(); + } + } + + void store(void* ptr, int count) const { + for (int i = 0; i < N; ++i) { + values[i].store(ptr, std::min(count, (int)Vectorized::size())); + ptr = static_cast(ptr) + Vectorized::size(); + count -= Vectorized::size(); + if (count <= 0) { + break; + } + } + } + + bool has_inf_nan() const { + for (int i = 0; i < N; ++i) { + if (values[i].has_inf_nan()) { + return true; + } + } + return false; + } + + VectorizedN map(T (*const f)(T)) const { + VectorizedN result; + for (int i = 0; i < N; ++i) { + result.values[i] = values[i].map(f); + } + return result; + } + + VectorizedN map(T (*const f)(const T&)) const { + VectorizedN result; + for (int i = 0; i < N; ++i) { + result.values[i] = values[i].map(f); + } + return result; + } + + VECTORIZEDN_DEFINE_UNARY_OP(isnan) + VECTORIZEDN_DEFINE_UNARY_OP(abs) + VECTORIZEDN_DEFINE_UNARY_OP(sgn) + VECTORIZEDN_DEFINE_UNARY_OP(angle) + VECTORIZEDN_DEFINE_UNARY_OP(real) + VECTORIZEDN_DEFINE_UNARY_OP(imag) + VECTORIZEDN_DEFINE_UNARY_OP(conj) + VECTORIZEDN_DEFINE_UNARY_OP(acos) + VECTORIZEDN_DEFINE_UNARY_OP(acosh) + VECTORIZEDN_DEFINE_UNARY_OP(asin) + VECTORIZEDN_DEFINE_UNARY_OP(asinh) + VECTORIZEDN_DEFINE_UNARY_OP(atan) + VECTORIZEDN_DEFINE_UNARY_OP(atanh) + VECTORIZEDN_DEFINE_BINARY_OP(atan2) + VECTORIZEDN_DEFINE_BINARY_OP(copysign) + VECTORIZEDN_DEFINE_UNARY_OP(erf) + VECTORIZEDN_DEFINE_UNARY_OP(erfc) + VECTORIZEDN_DEFINE_UNARY_OP(erfinv) + VECTORIZEDN_DEFINE_UNARY_OP(exp) + VECTORIZEDN_DEFINE_UNARY_OP(exp2) + VECTORIZEDN_DEFINE_UNARY_OP(expm1) + VECTORIZEDN_DEFINE_UNARY_OP(exp_u20) + VECTORIZEDN_DEFINE_UNARY_OP(frac) + VECTORIZEDN_DEFINE_BINARY_OP(fmod) + VECTORIZEDN_DEFINE_UNARY_OP(log) + VECTORIZEDN_DEFINE_UNARY_OP(log10) + VECTORIZEDN_DEFINE_UNARY_OP(log1p) + VECTORIZEDN_DEFINE_UNARY_OP(log2) + VECTORIZEDN_DEFINE_UNARY_OP(ceil) + VECTORIZEDN_DEFINE_UNARY_OP(cos) + VECTORIZEDN_DEFINE_UNARY_OP(cosh) + VECTORIZEDN_DEFINE_UNARY_OP(floor) + VECTORIZEDN_DEFINE_BINARY_OP(hypot) + VECTORIZEDN_DEFINE_UNARY_OP(i0) + VECTORIZEDN_DEFINE_UNARY_OP(i0e) + VECTORIZEDN_DEFINE_UNARY_OP(digamma) + VECTORIZEDN_DEFINE_BINARY_OP(igamma) + VECTORIZEDN_DEFINE_BINARY_OP(igammac) + VECTORIZEDN_DEFINE_UNARY_OP(neg) + VECTORIZEDN_DEFINE_BINARY_OP(nextafter) + VECTORIZEDN_DEFINE_UNARY_OP(round) + VECTORIZEDN_DEFINE_UNARY_OP(sin) + VECTORIZEDN_DEFINE_UNARY_OP(sinh) + VECTORIZEDN_DEFINE_UNARY_OP(tan) + VECTORIZEDN_DEFINE_UNARY_OP(tanh) + VECTORIZEDN_DEFINE_UNARY_OP(trunc) + VECTORIZEDN_DEFINE_UNARY_OP(lgamma) + VECTORIZEDN_DEFINE_UNARY_OP(sqrt) + VECTORIZEDN_DEFINE_UNARY_OP(reciprocal) + VECTORIZEDN_DEFINE_UNARY_OP(rsqrt) + VECTORIZEDN_DEFINE_BINARY_OP(pow) + VECTORIZEDN_DEFINE_BINARY_OP(operator==) + VECTORIZEDN_DEFINE_BINARY_OP(operator!=) + VECTORIZEDN_DEFINE_BINARY_OP(operator>=) + VECTORIZEDN_DEFINE_BINARY_OP(operator<=) + VECTORIZEDN_DEFINE_BINARY_OP(operator>) + VECTORIZEDN_DEFINE_BINARY_OP(operator<) + VECTORIZEDN_DEFINE_BINARY_OP(eq) + VECTORIZEDN_DEFINE_BINARY_OP(ne) + VECTORIZEDN_DEFINE_BINARY_OP(gt) + VECTORIZEDN_DEFINE_BINARY_OP(ge) + VECTORIZEDN_DEFINE_BINARY_OP(lt) + VECTORIZEDN_DEFINE_BINARY_OP(le) + +#undef VECTORIZEDN_DEFINE_UNARY_OP +#undef VECTORIZEDN_DEFINE_BINARY_OP +}; + +#define VECTORIZEDN_DEFINE_UNARY_OP_GLOBAL(op) \ + template \ + inline VectorizedN op(const VectorizedN& a) { \ + return a.unary_op([](const Vectorized& a) { return op(a); }); \ + } + +#define VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(op) \ + template \ + inline VectorizedN op( \ + const VectorizedN& a, const VectorizedN& b) { \ + return a.binary_op(b, [](const Vectorized& a, const Vectorized& b) { \ + return op(a, b); \ + }); \ + } + +#define VECTORIZEDN_DEFINE_TERNARY_OP_GLOBAL(op) \ + template \ + inline VectorizedN op( \ + const VectorizedN& a, \ + const VectorizedN& b, \ + const VectorizedN& c) { \ + return a.ternary_op( \ + b, \ + c, \ + [](const Vectorized& a, \ + const Vectorized& b, \ + const Vectorized& c) { return op(a, b, c); }); \ + } + +#define VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(op) \ + template \ + inline VectorizedN& op( \ + VectorizedN& a, const VectorizedN& b) { \ + a = a.binary_op(b, [](const Vectorized& a, const Vectorized& b) { \ + return op(a, b); \ + }); \ + return a; \ + } + +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator+) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator-) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator*) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator/) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator%) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator||) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator<<) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator>>) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(maximum) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(minimum) +VECTORIZEDN_DEFINE_TERNARY_OP_GLOBAL(fmadd) +VECTORIZEDN_DEFINE_TERNARY_OP_GLOBAL(fmsub) +VECTORIZEDN_DEFINE_TERNARY_OP_GLOBAL(clamp) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(clamp_max) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(clamp_min) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator&) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator|) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator^) +VECTORIZEDN_DEFINE_UNARY_OP_GLOBAL(operator~) + +VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator+=) +VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator-=) +VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator*=) +VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator/=) +VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator%=) +VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator<<=) +VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator>>=) + +#undef VECTORIZEDN_DEFINE_UNARY_OP_GLOBAL +#undef VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL +#undef VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL + +template +inline T vec_reduce_all(const OpVec& vec_fun, VectorizedN acc_vec) { + Vectorized vec_result = acc_vec[0]; + for (int i = 1; i < N; i++) { + vec_result = vec_fun(vec_result, acc_vec[i]); + } + return vec_reduce_all(vec_fun, vec_result); +} + +template +std::ostream& operator<<(std::ostream& stream, const VectorizedN& vec_n) { + stream << "vec_n["; + for (int i = 0; i < N; ++i) { + if (i != 0) { + stream << ", "; + } + stream << vec_n[i]; + } + stream << ']'; + return stream; +} +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpu/vml.h b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vml.h new file mode 100644 index 0000000000000000000000000000000000000000..109a3d3e9c354dc6f8a49dafd9dbdf1a51a33001 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpu/vml.h @@ -0,0 +1,170 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +// This header implements various unary operations using a MKL VML style +// interface. + +// It implements various functions with a simple interface +// For example it enables the user to call vsin(float* out, const float* in, +// size) This functions takes a pointer to a continuous output array of floats and +// a constant input array. It will then apply sin to each value in the input +// array and write the result into the output array. out and in may point to the +// same memory, i.e. this fully supports in-place operations. These functions +// also implement their own parallelization, so take precautions when calling +// these from threaded functions. + +// When MKL is available it will call into MKL's VML library similar to NumPy +// If MKL is not available it will use SLEEF. + +// This file might be compiled under AVX or AVX2 when called from e.g. +// UnaryOpsKernel.cpp + +#include +#include +#include +#include +#include + +#if AT_MKL_ENABLED() && !defined(__APPLE__) +#include +#endif + + +namespace at::vml { +inline namespace CPU_CAPABILITY { + +using namespace vec; + +template +inline void vrsqrt(scalar_t* out, scalar_t* in, int64_t size) { + parallel_for(0, size, 2048, [out, in](int64_t begin, int64_t end) { + map( + [](const Vectorized& x) { + return Vectorized((scalar_t)(1)) / x.sqrt(); + }, + out + begin, + in + begin, + end - begin); + }); +} + +// NB: We ignore numerical errors by convention and leave them to the user + +#define IMPLEMENT_VML(op) \ + template \ + inline void v##op(scalar_t* out, const scalar_t* in, int64_t size) { \ + using vec_t = Vectorized>; \ + vec::map([](vec_t x) { return x.op(); }, out, in, size); \ + } \ + +IMPLEMENT_VML(abs) +IMPLEMENT_VML(acos) +IMPLEMENT_VML(asin) +IMPLEMENT_VML(atan) +IMPLEMENT_VML(atanh) +IMPLEMENT_VML(ceil) +IMPLEMENT_VML(cos) +// IMPLEMENT_VML(cosh) +IMPLEMENT_VML(erf) +IMPLEMENT_VML(erfc) +IMPLEMENT_VML(erfinv) +IMPLEMENT_VML(exp) +IMPLEMENT_VML(expm1) +IMPLEMENT_VML(floor) +IMPLEMENT_VML(i0) +IMPLEMENT_VML(i0e) +IMPLEMENT_VML(digamma) +IMPLEMENT_VML(reciprocal) +IMPLEMENT_VML(log) +IMPLEMENT_VML(log10) +IMPLEMENT_VML(log1p) +IMPLEMENT_VML(log2) +IMPLEMENT_VML(neg) +IMPLEMENT_VML(sin) +// IMPLEMENT_VML(sinh) +IMPLEMENT_VML(sqrt) +IMPLEMENT_VML(round) +IMPLEMENT_VML(rsqrt) +IMPLEMENT_VML(tan) +IMPLEMENT_VML(tanh) +IMPLEMENT_VML(trunc) +IMPLEMENT_VML(lgamma) + + +#if AT_MKL_ENABLED() && !defined(__APPLE__) + +// NB: LP64 MKL is the most commonly used and thus we assume it here. That means +// we need to expect MKL_INT to be of type int, which implies int32_t or int64_t in most +// cases. +static_assert( + std::is_same_v || std::is_same_v, + "MKL_INT is assumed to be int32_t or int64_t"); +#define IMPLEMENT_VML_MKL_STUB(op, mklop, type, mkltype) \ + template <> \ + inline void v##op(type * out, const type * in, int64_t size) { \ + auto constexpr max_mkl_ind = std::numeric_limits::max(); \ + if (size <= static_cast(max_mkl_ind)) { \ + vm##mkltype##mklop( \ + size, in, out, VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \ + } else { \ + int64_t ind = 0; \ + int64_t chunks = size / max_mkl_ind; \ + int64_t rest = size % max_mkl_ind; \ + for (; ind < chunks; ind++) { \ + vm##mkltype##mklop( \ + max_mkl_ind, \ + in + ind * max_mkl_ind, \ + out + ind * max_mkl_ind, \ + VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \ + } \ + vm##mkltype##mklop( \ + rest, \ + in + ind * max_mkl_ind, \ + out + ind * max_mkl_ind, \ + VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \ + } \ + } + +#define IMPLEMENT_VML_MKL(op, mklop) \ + IMPLEMENT_VML_MKL_STUB(op, mklop, float, s) \ + IMPLEMENT_VML_MKL_STUB(op, mklop, double, d) + +// NB: abs, cosh and sinh were temporarily disabled due to issues with Apple +// NB: expm1 is disabled because on some configs it produces expm1(nan)=-1 +IMPLEMENT_VML_MKL(acos, Acos) +IMPLEMENT_VML_MKL(asin, Asin) +IMPLEMENT_VML_MKL(atan, Atan) +IMPLEMENT_VML_MKL(cos, Cos) +// IMPLEMENT_VML_MKL(cosh, Cosh) +IMPLEMENT_VML_MKL(erf, Erf) +IMPLEMENT_VML_MKL(erfc, Erfc) +IMPLEMENT_VML_MKL(erfinv, ErfInv) +IMPLEMENT_VML_MKL(exp, Exp) +// IMPLEMENT_VML_MKL(expm1, Expm1) +IMPLEMENT_VML_MKL(log, Ln) +IMPLEMENT_VML_MKL(log10, Log10) +IMPLEMENT_VML_MKL(sin, Sin) +// IMPLEMENT_VML_MKL(sinh, Sinh) +IMPLEMENT_VML_MKL(sqrt, Sqrt) +IMPLEMENT_VML_MKL(tan, Tan) +IMPLEMENT_VML_MKL(tanh, Tanh) +IMPLEMENT_VML_MKL(trunc, Trunc) + +// Not vectorized in MKL version tested +// IMPLEMENT_VML_MKL(abs, Abs) +// IMPLEMENT_VML_MKL(log1p, Log1p) + +#if INTEL_MKL_VERSION >= 20180406 +IMPLEMENT_VML_MKL(log2, Log2) +#endif + +#endif + +} // namespace +} // namespace at::vml diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/ATenCUDAGeneral.h b/phivenv/Lib/site-packages/torch/include/ATen/cuda/ATenCUDAGeneral.h new file mode 100644 index 0000000000000000000000000000000000000000..3938aa341bb3943a9e42a3178d3233868b755101 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/ATenCUDAGeneral.h @@ -0,0 +1,9 @@ +#pragma once + +#include +#include +#include + +#include + +// Use TORCH_CUDA_CPP_API or TORCH_CUDA_CU_API for exports from this folder diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/ApplyGridUtils.cuh b/phivenv/Lib/site-packages/torch/include/ATen/cuda/ApplyGridUtils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..b67b0905a09fd2a1bb17f7cc69863fd849ded1ff --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/ApplyGridUtils.cuh @@ -0,0 +1,47 @@ +#include + +#include + +namespace at::cuda { + +/** + Computes ceil(a / b) +*/ +template +__host__ __device__ __forceinline__ T ATenCeilDiv(T a, T b) { + return (a + b - 1) / b; +} + +namespace { + +// Threads per block for our apply kernel +// FIXME: use occupancy calculator instead +constexpr uint32_t AT_APPLY_THREADS_PER_BLOCK = 512; +constexpr uint32_t AT_APPLY_BLOCKS_PER_SM = 4; + +template +inline bool getApplyGrid(uint64_t totalElements, dim3& grid, c10::DeviceIndex curDevice, int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK) { + if (curDevice == -1) return false; + uint64_t numel_per_thread = static_cast(max_threads_per_block) * static_cast(step); + uint64_t numBlocks = ATenCeilDiv(totalElements, numel_per_thread); + uint64_t maxGridX = at::cuda::getDeviceProperties(curDevice)->maxGridSize[0]; + if (numBlocks > maxGridX) + numBlocks = maxGridX; + grid = dim3(numBlocks); + return true; +} + +constexpr int getApplyBlocksPerSM() { + return AT_APPLY_BLOCKS_PER_SM; +} + +constexpr int getApplyBlockSize() { + return AT_APPLY_THREADS_PER_BLOCK; +} + +inline dim3 getApplyBlock(int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK) { + return dim3(max_threads_per_block); +} + +} // anonymous namespace +} // namespace at::cuda diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/AsmUtils.cuh b/phivenv/Lib/site-packages/torch/include/ATen/cuda/AsmUtils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..1daf0349042c77bf0627c61ecfa294a5b5c73a3c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/AsmUtils.cuh @@ -0,0 +1,149 @@ +#pragma once +#include + +// Collection of direct PTX functions + +namespace at::cuda { + +template +struct Bitfield {}; + +template <> +struct Bitfield { + static __device__ __host__ __forceinline__ + unsigned int getBitfield(unsigned int val, int pos, int len) { +#if !defined(__CUDA_ARCH__) + pos &= 0xff; + len &= 0xff; + + unsigned int m = (1u << len) - 1u; + return (val >> pos) & m; +#else + unsigned int ret; + asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len)); + return ret; +#endif + } + + static __device__ __host__ __forceinline__ + unsigned int setBitfield(unsigned int val, unsigned int toInsert, int pos, int len) { +#if !defined(__CUDA_ARCH__) + pos &= 0xff; + len &= 0xff; + + unsigned int m = (1u << len) - 1u; + toInsert &= m; + toInsert <<= pos; + m <<= pos; + + return (val & ~m) | toInsert; +#else + unsigned int ret; + asm("bfi.b32 %0, %1, %2, %3, %4;" : + "=r"(ret) : "r"(toInsert), "r"(val), "r"(pos), "r"(len)); + return ret; +#endif + } +}; + +template <> +struct Bitfield { + static __device__ __host__ __forceinline__ + uint64_t getBitfield(uint64_t val, int pos, int len) { +#if !defined(__CUDA_ARCH__) + pos &= 0xff; + len &= 0xff; + + uint64_t m = (1u << len) - 1u; + return (val >> pos) & m; +#else + uint64_t ret; + asm("bfe.u64 %0, %1, %2, %3;" : "=l"(ret) : "l"(val), "r"(pos), "r"(len)); + return ret; +#endif + } + + static __device__ __host__ __forceinline__ + uint64_t setBitfield(uint64_t val, uint64_t toInsert, int pos, int len) { +#if !defined(__CUDA_ARCH__) + pos &= 0xff; + len &= 0xff; + + uint64_t m = (1u << len) - 1u; + toInsert &= m; + toInsert <<= pos; + m <<= pos; + + return (val & ~m) | toInsert; +#else + uint64_t ret; + asm("bfi.b64 %0, %1, %2, %3, %4;" : + "=l"(ret) : "l"(toInsert), "l"(val), "r"(pos), "r"(len)); + return ret; +#endif + } +}; + +__device__ __forceinline__ int getLaneId() { +#if defined(USE_ROCM) + return __lane_id(); +#else + int laneId; + asm("mov.s32 %0, %%laneid;" : "=r"(laneId) ); + return laneId; +#endif +} + +#if defined(USE_ROCM) +__device__ __forceinline__ unsigned long long int getLaneMaskLt() { + const std::uint64_t m = (1ull << getLaneId()) - 1ull; + return m; +} +#else +__device__ __forceinline__ unsigned getLaneMaskLt() { + unsigned mask; + asm("mov.u32 %0, %%lanemask_lt;" : "=r"(mask)); + return mask; +} +#endif + +#if defined (USE_ROCM) +__device__ __forceinline__ unsigned long long int getLaneMaskLe() { + std::uint64_t m = UINT64_MAX >> (sizeof(std::uint64_t) * CHAR_BIT - (getLaneId() + 1)); + return m; +} +#else +__device__ __forceinline__ unsigned getLaneMaskLe() { + unsigned mask; + asm("mov.u32 %0, %%lanemask_le;" : "=r"(mask)); + return mask; +} +#endif + +#if defined(USE_ROCM) +__device__ __forceinline__ unsigned long long int getLaneMaskGt() { + const std::uint64_t m = getLaneMaskLe(); + return m ? ~m : m; +} +#else +__device__ __forceinline__ unsigned getLaneMaskGt() { + unsigned mask; + asm("mov.u32 %0, %%lanemask_gt;" : "=r"(mask)); + return mask; +} +#endif + +#if defined(USE_ROCM) +__device__ __forceinline__ unsigned long long int getLaneMaskGe() { + const std::uint64_t m = getLaneMaskLt(); + return ~m; +} +#else +__device__ __forceinline__ unsigned getLaneMaskGe() { + unsigned mask; + asm("mov.u32 %0, %%lanemask_ge;" : "=r"(mask)); + return mask; +} +#endif + +} // namespace at::cuda diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/Atomic.cuh b/phivenv/Lib/site-packages/torch/include/ATen/cuda/Atomic.cuh new file mode 100644 index 0000000000000000000000000000000000000000..8b2436a3b246c405ec51a260f3d030d422490152 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/Atomic.cuh @@ -0,0 +1,525 @@ +#pragma once + +#include +#include +#include + +#include + +#if !(defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) +#include +#endif + +template +struct AtomicFPOp; + +template <> +struct AtomicFPOp { + template + inline __device__ at::Half operator() (at::Half *address, at::Half val, const func_t& func) { + unsigned int * address_as_ui = + (unsigned int *) ((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + at::Half hsum; + do { + assumed = old; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + hsum = func(hsum, val); + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } while (assumed != old); + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + return hsum; + } +}; + +template <> +struct AtomicFPOp { + template + inline __device__ at::BFloat16 operator() (at::BFloat16 *address, at::BFloat16 val, const func_t& func) { + unsigned int * address_as_ui = + (unsigned int *) ((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + at::BFloat16 bsum; + do { + assumed = old; + bsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + bsum = func(bsum, val); + old = (size_t)address & 2 ? (old & 0xffff) | (bsum.x << 16) : (old & 0xffff0000) | bsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } while (assumed != old); + bsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + return bsum.x; + } +}; + +template <> +struct AtomicFPOp { + template + inline __device__ double operator() (double * address, double val, const func_t& func) { + unsigned long long int* address_as_ull = (unsigned long long int*)address; + unsigned long long int old = *address_as_ull; + unsigned long long int assumed; + + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, func(val, assumed)); + // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) + } while (assumed != old); + + return __longlong_as_double(old); + } +}; + +#define ATOMIC_INTEGER_IMPL(NAME) \ +template \ +struct Atomic##NAME##IntegerImpl; \ + \ +template \ +struct Atomic##NAME##IntegerImpl { \ + template \ + inline __device__ void operator()(T *address, T val, const func_t& func) { \ + size_t offset = (size_t)address & 3; \ + uint32_t * address_as_ui = (uint32_t *)((char *)address - offset); \ + uint32_t old = *address_as_ui; \ + uint32_t shift = offset * 8; \ + uint32_t old_byte; \ + uint32_t newval; \ + uint32_t assumed; \ + \ + do { \ + assumed = old; \ + old_byte = (old >> shift) & 0xff; \ + newval = static_cast(func(val, static_cast(old_byte))); \ + newval = (old & ~(0x000000ff << shift)) | (newval << shift); \ + old = atomicCAS(address_as_ui, assumed, newval); \ + } while (assumed != old); \ + } \ +}; \ + \ +template \ +struct Atomic##NAME##IntegerImpl { \ + template \ + inline __device__ void operator()(T *address, T val, const func_t& func) { \ + size_t offset = (size_t)address & 2; \ + uint32_t * address_as_ui = (uint32_t *)((char *)address - offset); \ + bool is_32_align = offset; \ + uint32_t old = *address_as_ui; \ + uint32_t old_bytes; \ + uint32_t newval; \ + uint32_t assumed; \ + \ + do { \ + assumed = old; \ + old_bytes = is_32_align ? old >> 16 : old & 0xffff; \ + newval = static_cast(func(val, static_cast(old_bytes))); \ + newval = is_32_align ? (old & 0xffff) | (newval << 16) : (old & 0xffff0000) | newval; \ + old = atomicCAS(address_as_ui, assumed, newval); \ + } while (assumed != old); \ + } \ +}; \ + \ +template \ +struct Atomic##NAME##IntegerImpl { \ + template \ + inline __device__ void operator()(T *address, T val, const func_t& func) { \ + uint32_t * address_as_ui = (uint32_t *) (address); \ + uint32_t old = *address_as_ui; \ + uint32_t newval; \ + uint32_t assumed; \ + \ + do { \ + assumed = old; \ + newval = static_cast(func(val, static_cast(old))); \ + old = atomicCAS(address_as_ui, assumed, newval); \ + } while (assumed != old); \ + } \ +}; \ + \ +template \ +struct Atomic##NAME##IntegerImpl { \ + template \ + inline __device__ void operator()(T *address, T val, const func_t& func) { \ + unsigned long long * address_as_ui = (unsigned long long *) (address); \ + unsigned long long old = *address_as_ui; \ + unsigned long long newval; \ + unsigned long long assumed; \ + \ + do { \ + assumed = old; \ + newval = static_cast(func(val, static_cast(old))); \ + old = atomicCAS(address_as_ui, assumed, newval); \ + } while (assumed != old); \ + } \ +}; + + +# define GPU_ATOMIC_INTEGER(NAME, OP, DTYPE) \ +inline __device__ void gpuAtomic##NAME(DTYPE *address, DTYPE val) { \ +Atomic##NAME##IntegerImpl()(address, \ + val, \ + [](DTYPE a, DTYPE b) { \ + return OP; \ + }); \ +} \ + +ATOMIC_INTEGER_IMPL(Add) +GPU_ATOMIC_INTEGER(Add, a || b, bool) + +// Don't instantiate gpuAtomicAdd with the macro as it seems non-standard (see int32, int64) +inline __device__ void gpuAtomicAdd(uint8_t *address, uint8_t val) { + AtomicAddIntegerImpl()(address, + val, + [](uint8_t a, uint8_t b) { + return a + b; + }); +} + +inline __device__ void gpuAtomicAdd(int8_t *address, int8_t val) { + AtomicAddIntegerImpl()(address, + val, + [](int8_t a, int8_t b) { + return a + b; + }); +} + +inline __device__ void gpuAtomicAdd(int16_t *address, int16_t val) { + AtomicAddIntegerImpl()(address, + val, + [](int16_t a, int16_t b) { + return a + b; + }); +} + +inline __device__ int32_t gpuAtomicAdd(int32_t *address, int32_t val) { + return atomicAdd(address, val); +} + +inline __device__ void gpuAtomicAdd(int64_t *address, int64_t val) { +#if defined(USE_ROCM) + __atomic_fetch_add(address, val, __ATOMIC_RELAXED); +#else + static_assert(sizeof(unsigned long long int) == sizeof(int64_t), "bitwidth change is not allowed"); + atomicAdd(reinterpret_cast(address), static_cast(val)); +#endif +} + +inline __device__ at::Half gpuAtomicAdd(at::Half *address, at::Half val) { +#if defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700))) + return AtomicFPOp()(address, val, + [](at::Half hsum, at::Half val) { + return hsum + val; + }); +#else + return atomicAdd(reinterpret_cast<__half*>(address), val); +#endif +} + +inline __device__ at::BFloat16 gpuAtomicAdd(at::BFloat16 *address, at::BFloat16 val) { +#if defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))) +return AtomicFPOp()(address, val, + [](at::BFloat16 bsum, at::BFloat16 val) { + return bsum + val; + }); +#else + __nv_bfloat16 r = atomicAdd(reinterpret_cast<__nv_bfloat16*>(address), *reinterpret_cast<__nv_bfloat16*>(&val)); + return *reinterpret_cast(&r); +#endif +} + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600) +// from CUDA C Programmic Guide +inline __device__ double atomicAdd(double* address, double val) +#if defined(__clang__) && defined(__CUDA__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wgcc-compat" + __attribute__((enable_if(true, ""))) +#pragma GCC diagnostic pop +#endif +{ + + return AtomicFPOp()(address, val, + [](double val, unsigned long long int assumed) { + return __double_as_longlong(val + __longlong_as_double(assumed)); + }); +} +#elif defined(USE_ROCM) || !(defined(__CUDA_ARCH__)) + +/* Note [hip-clang differences to hcc] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * The upcoming hip-clang compiler for ROCm differs from hcc in a few details. + * It exports the __HIP__ macro, we can hence differentiate between hcc and + * hip-clang. In the below, hcc only received support for atomicAdd with double + * typing after work week 18312. hip-clang had support from the first version. + * In general, the code-visible differences between hip-clang and hcc will be + * minimal. + */ + +#if defined(USE_ROCM) && __hcc_workweek__ < 18312 && !__HIP__ + // This needs to be defined for the host side pass + inline __device__ double atomicAdd(double *address, double val) { } +#endif +#endif + +inline __device__ double gpuAtomicAdd(double *address, double val) { + return atomicAdd(address, val); +} + +inline __device__ float gpuAtomicAdd(float *address, float val) { + return atomicAdd(address, val); +} + +template +inline __device__ void gpuAtomicAdd(c10::complex *address, c10::complex val) { + gpuAtomicAdd(&address->real_, val.real_); + gpuAtomicAdd(&address->imag_, val.imag_); +} + +/* Note [gpuAtomicAdd vs atomicAdd] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * Some extensions such as torchvision call atomicAdd() + * directly and require non-library provided data type support. Only for these, we + * continue to provide atomicAdd overloads. + */ +inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) { + return gpuAtomicAdd(address, val); +} + +inline __device__ at::BFloat16 atomicAdd(at::BFloat16 *address, at::BFloat16 val) { + return gpuAtomicAdd(address, val); +} + +inline __device__ void atomicAdd(uint8_t *address, uint8_t val) { + gpuAtomicAdd(address, val); +} + +inline __device__ void atomicAdd(int8_t *address, int8_t val) { + gpuAtomicAdd(address, val); +} + +inline __device__ void atomicAdd(int16_t *address, int16_t val) { + gpuAtomicAdd(address, val); +} + +inline __device__ void atomicAdd(int64_t *address, int64_t val) { + gpuAtomicAdd(address, val); +} + +inline __device__ void atomicAdd(bool *address, bool val) { + gpuAtomicAdd(address, val); +} + +/* Note [explicitly non-returning atomics] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * AMD's MI100 (gfx908) provides an optimized fp32 atomicAdd, exposed via atomicAddNoRet(). + * Due to compiler limitations, callers must opt-in to guarantee the optimized instruction. + * This non-returning atomicAddNoRet cannot be used to implement the returning atomicAdd, + * therefore we need a new API 'gpuAtomicAddNoReturn'. + */ +template +inline __device__ void gpuAtomicAddNoReturn(c10::complex *address, c10::complex val) { gpuAtomicAdd(address, val); } +inline __device__ void gpuAtomicAddNoReturn(uint8_t *address, uint8_t val) { gpuAtomicAdd(address, val); } +inline __device__ void gpuAtomicAddNoReturn(int8_t *address, int8_t val) { gpuAtomicAdd(address, val); } +inline __device__ void gpuAtomicAddNoReturn(int16_t *address, int16_t val) { gpuAtomicAdd(address, val); } +inline __device__ void gpuAtomicAddNoReturn(int32_t *address, int32_t val) { gpuAtomicAdd(address, val); } +inline __device__ void gpuAtomicAddNoReturn(int64_t *address, int64_t val) { gpuAtomicAdd(address, val); } +inline __device__ void gpuAtomicAddNoReturn(bool *address, bool val) { gpuAtomicAdd(address, val); } +inline __device__ void gpuAtomicAddNoReturn(at::Half *address, at::Half val) { gpuAtomicAdd(address, val); } +inline __device__ void gpuAtomicAddNoReturn(at::BFloat16 *address, at::BFloat16 val) { gpuAtomicAdd(address, val); } + +/* Note [HIP unsafeAtomicAdd] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~ + * Use unsafeAtomicAdd instead of atomicAdd for fp32 and fp64. + * On HIP, atomicAdd is always correct but is a slow CAS loop. + * unsafeAtomicAdd will use HW instructions and is much faster, + * but the caller must guarantee the pointer is GPU memory. + * If the pointer is system memory, the result is a silent no-op. + * This guarantee is upheld by all PyTorch uses of unsafeAtomicAdd. + * AMD HIP atomic header file is named amd_hip_atomic.h and is + * under the LLVM compiler directory. + */ +#if defined(USE_ROCM) +inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { +#if defined(__gfx908__) + atomicAddNoRet(address, val); +#else + (void)unsafeAtomicAdd(address, val); +#endif +} +inline __device__ void gpuAtomicAddNoReturn(double *address, double val) { (void)unsafeAtomicAdd(address, val); } +#else +inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { gpuAtomicAdd(address, val); } +inline __device__ void gpuAtomicAddNoReturn(double *address, double val) { gpuAtomicAdd(address, val); } +#endif + +// Atomic multiplication implementation. + +ATOMIC_INTEGER_IMPL(Mul) +GPU_ATOMIC_INTEGER(Mul, a * b, uint8_t) +GPU_ATOMIC_INTEGER(Mul, a * b, int8_t) +GPU_ATOMIC_INTEGER(Mul, a * b, int16_t) +GPU_ATOMIC_INTEGER(Mul, a * b, int32_t) +GPU_ATOMIC_INTEGER(Mul, a * b, int64_t) + +inline __device__ at::Half gpuAtomicMul(at::Half * address, at::Half val) { + return AtomicFPOp()(address, val, + [](at::Half bsum, at::Half val) { + return bsum * val; + }); +} + +inline __device__ at::BFloat16 gpuAtomicMul(at::BFloat16 * address, at::BFloat16 val) { + return AtomicFPOp()(address, val, + [](at::BFloat16 bsum, at::BFloat16 val) { + return bsum * val; + }); +} + +inline __device__ double gpuAtomicMul(double * address, double val) { + return AtomicFPOp()(address, val, + [](double val, unsigned long long int assumed) { + return __double_as_longlong(val * __longlong_as_double(assumed)); + }); +} + +// Dont use a templated function for this since the addition function defaults to the CUDA built-in. +inline __device__ float gpuAtomicMul (float * address, float val) { + unsigned int* address_as_ull = (unsigned int*)address; + unsigned int old = *address_as_ull; + unsigned int assumed; + + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __float_as_int(val * + __int_as_float(assumed))); + + // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) + } while (assumed != old); + + return __int_as_float(old); +} + +// Atomic maximum implementation. + +template +__host__ __device__ T safe_max(T a, T b) { + #if defined(__HIPCC__) + // TODO: remove this special case for HIP when issue is fixed: + // https://github.com/ROCm/hip/issues/2209 + T max = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::max(a, b)); + #else + T max = at::_isnan(b) ? b : std::max(a, b); + #endif + + return max; +} + +ATOMIC_INTEGER_IMPL(Max) +GPU_ATOMIC_INTEGER(Max, safe_max(a, b), uint8_t) +GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int8_t) +GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int16_t) +GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int32_t) +GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int64_t) + +inline __device__ at::Half gpuAtomicMax(at::Half * address, at::Half val) { + return AtomicFPOp()(address, val, + [](at::Half bsum, at::Half val) { + return safe_max(bsum, val); + }); +} + +inline __device__ at::BFloat16 gpuAtomicMax(at::BFloat16 * address, at::BFloat16 val) { + return AtomicFPOp()(address, val, + [](at::BFloat16 bsum, at::BFloat16 val) { + return safe_max(bsum, val); + }); +} + +inline __device__ double gpuAtomicMax(double * address, double val) { + return AtomicFPOp()(address, val, + [](double val, unsigned long long int assumed) { + return __double_as_longlong(safe_max(val, __longlong_as_double(assumed))); + }); +} + +// Dont use a templated function for this since the addition function defaults to the CUDA built-in. +inline __device__ float gpuAtomicMax(float * address, float val) { + unsigned int* address_as_ull = (unsigned int*)address; + unsigned int old = *address_as_ull; + unsigned int assumed; + + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __float_as_int(safe_max(val, __int_as_float(assumed)))); + + // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) + } while (assumed != old); + + return __int_as_float(old); +} + +// Atomic minimum implementation. + +template +__host__ __device__ T safe_min(T a, T b) { + #if defined(__HIPCC__) + // TODO: remove this special case for HIP when issue is fixed: + // https://github.com/ROCm/hip/issues/2209 + T min = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::min(a, b)); + #else + T min = at::_isnan(b) ? b : std::min(a, b); + #endif + + return min; +} + +ATOMIC_INTEGER_IMPL(Min) +GPU_ATOMIC_INTEGER(Min, safe_min(a, b), uint8_t) +GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int8_t) +GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int16_t) +GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int32_t) +GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int64_t) + +inline __device__ at::Half gpuAtomicMin(at::Half * address, at::Half val) { + return AtomicFPOp()(address, val, + [](at::Half bsum, at::Half val) { + return safe_min(bsum, val); + }); +} + +inline __device__ at::BFloat16 gpuAtomicMin(at::BFloat16 * address, at::BFloat16 val) { + return AtomicFPOp()(address, val, + [](at::BFloat16 bsum, at::BFloat16 val) { + return safe_min(bsum, val); + }); +} + +inline __device__ double gpuAtomicMin(double * address, double val) { + return AtomicFPOp()(address, val, + [](double val, unsigned long long int assumed) { + return __double_as_longlong(safe_min(val, __longlong_as_double(assumed))); + }); +} + +// Dont use a templated function for this since the addition function defaults to the CUDA built-in. +inline __device__ float gpuAtomicMin(float * address, float val) { + unsigned int* address_as_ull = (unsigned int*)address; + unsigned int old = *address_as_ull; + unsigned int assumed; + + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __float_as_int(safe_min(val, __int_as_float(assumed)))); + + // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) + } while (assumed != old); + + return __int_as_float(old); +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDAApplyUtils.cuh b/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDAApplyUtils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..eb26308c52dfc4b1c62b67b22a76d6a6a37c241c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDAApplyUtils.cuh @@ -0,0 +1,537 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +// +// This file contains pointwise operation functions and kernels that +// work on both contiguous and non-contiguous tensor arguments of +// arbitrary (up to MAX_CUTORCH_DIMS) dimensioned arguments without +// copying or temporary storage. +// + +/* + NOTE [ CUDA_tensor_applyN helpers ] + + The following CUDA_tensor_applyN (where N currently can be 1, 2, 3, or 4) + functions apply a pointwise operator to N tensor(s). + + The calling convention is + + 1. The template arguments should be, sequentially, + - First N typename args specify the scalar types of each of the N tensors. + - (Optional) `int step` arg specifies the number of elements processed + together at the same time. + Default is 1. + - A usually omitted (i.e., inferred) typename arg specifies the type of the + function/functor applied on `N * step` values in each iteration of each + CUDA thread. + 2. The arguments should be, sequentially, + - N tensors + - op: a function/functor that processes `N * step` values at the same time. + - If `step == 1`, it must have signature + `void(*)(scalar1_t&, scalar2_t&, ..., scalarN_t&)`, where + `scalar*_t`s are the first N typename template args, and the inputs + are the `N` values from the `N` tensors retrieved at a common index. + - Otherwise, it must must have signature + void(*)(int n, scalar1_t&, scalar1_t&, ..., scalar1_t&, // repeat `step` times + scalar2_t&, scalar2_t&, ..., scalar2_t&, // repeat `step` times + ..., + scalarN_t&, scalarN_t&, ..., scalarN_t&) // repeat `step` times + Different from `step == 1` case, it processes `N * step` values taken + from `step` common indices. Moreover, the first input `n` represents the + number of valid indices (it will always have `0 < n <= step`). It will + almost always be `step`, but at the boundary we may not have full `step` + elements and `n` can be a lesser value. + + E.g., if `step == 4` and `N == 2`, `op` could be + + [](int n, scalar1_t &u1, scalar1_t &u2, scalar1_t &u3, scalar1_t &u4, + scalar2_t &v1, scalar2_t &v2, scalar2_t &v3, scalar2_t &v4) { + // Only process u1, ..., un and v1, ..., vn. + // So if `n == 3`, `u4` and `v4` need not to be considered. + } + + In both cases, the references can actually be const, but at least one of + them should be non-const in order to write the output. + - (Optional, but recommended) N TensorArgType args that specify for each + tensor whether `op` reads AND writes ] (i.e., TensorArgType::ReadWrite), + or only reads (i.e., TensorArgType::ReadOnly). + Default is TensorArgType::ReadWrite for first Tensor, and + TensorArgType::ReadOnly for the rest. + + E.g., + + to compute a = b^2 for a and b of same dtype, we can call + + CUDA_tensor_apply2( + a, b, + [] __device__ (scalar &a_val, const scalar &b_val) { a_val = b_val * b_val; } + ); + + to work on 2 values at the same time, we can call + + CUDA_tensor_apply2( + a, b, + [] __device__ (int n, scalar1 &a_val1, scalar1 &a_val2, + const scalar2 &b_val1, const scalar2 &b_val2) { + // call special vectorized op here, or just do elementwise and enjoy unrolling... + // if n == 1, only process a_val1 and b_val1 + } + ); +*/ + +namespace at::cuda { + +// TODO: combine with TensorArg? So far that's been for debugging, and this is functional... +enum class TensorArgType { ReadWrite, ReadOnly }; + +namespace { + +// Rearrange dimensions for pointwise operations so that strides are in +// decreasing order as much as possible, so that kernels have better memory +// access patterns. +// +// For example, consider a binary operation on two "transposed" 2-dim tensors: +// sizes: 256 512 +// aInfo->strides: 1 256 +// bInfo->strides: 1 256 +// +// Given this, each concurrent memory access inside kernelPointwiseApply2() is +// exactly 256 elements apart, resulting in poor performance. +// +// This function exchanges dimensions so that memory access is contiguous: +// sizes: 512 256 +// aInfo->strides: 256 1 +// bInfo->strides: 256 1 +// +// (Actually, it becomes even better because now collapseDims() can turn each +// input into one contiguous array.) +// +// In general, given M (<=4) TensorInfo's with N dimensions, we can view each +// strides[i] (0 <= i < N) as an M-tuple. Given each pair i < j, we exchange +// strides[i] and [j] if +// (1) strides[i][k] < strides[j][k] for some k (0 <= k < M) +// (exchanging them will benefit input #k), and +// (2) strides[i][k] <= strieds[j][k] for all k +// (exchanging them will not make any input worse). +template +inline void rearrangeDims(detail::TensorInfo* aInfo, + detail::TensorInfo* bInfo = nullptr, + detail::TensorInfo* cInfo = nullptr, + detail::TensorInfo* dInfo = nullptr) { + int numInfos = 1; + int dims = aInfo->dims; + IndexType *sizes[4] = { aInfo->sizes, }; + IndexType *strides[4] = { aInfo->strides, }; + + if (bInfo != nullptr) { + ++numInfos; + if (bInfo->dims != dims) return; + sizes[1] = bInfo->sizes; + strides[1] = bInfo->strides; + } + + if (cInfo != nullptr) { + ++numInfos; + if (cInfo->dims != dims) return; + sizes[2] = cInfo->sizes; + strides[2] = cInfo->strides; + } + + if (dInfo != nullptr) { + ++numInfos; + if (dInfo->dims != dims) return; + sizes[3] = dInfo->sizes; + strides[3] = dInfo->strides; + } + + // Bail out if sizes do not match: we are using "deprecated pointwise + // behavior" among tensors of different shapes but same number of elements. + for (int i = 1; i < numInfos; ++i) { + for (int j = 0; j < dims; ++j) { + if (sizes[i][j] != sizes[0][j]) return; + } + } + + for (int i = 0; i < dims - 1; ++i) { + // No need to consider dimensions of size 1. + if (sizes[0][i] == 1) continue; + + for (int j = i + 1; j < dims; ++j) { + if (sizes[0][j] == 1) continue; + + // Compare the relative sizes of strides between dim #i and dim #j. + bool hasIncreasingStrides = false; + bool hasDecreasingStrides = false; + + for (int k = 0; k < numInfos; k++) { + IndexType stride_i = strides[k][i]; + IndexType stride_j = strides[k][j]; + if (stride_i < stride_j) { + hasIncreasingStrides = true; + } else if (stride_i > stride_j) { + hasDecreasingStrides = true; + } + } + + if (hasIncreasingStrides && !hasDecreasingStrides) { + for (int k = 0; k < numInfos; k++) { + IndexType size = sizes[k][i]; + sizes[k][i] = sizes[k][j]; + sizes[k][j] = size; + + IndexType stride = strides[k][i]; + strides[k][i] = strides[k][j]; + strides[k][j] = stride; + } + } + } + } +} + +// The `remaining_steps` argument is used to support Op that operates on +// multiple elements at the same time. Generally, the strategy of ApplyOpN is to +// 1. Initialize `remaining_steps = step`, where `step` is the template arg of +// CUDA_tensor_applyN helpers. The input arg `n` to `apply()` represents the +// number of elements in bound for this call. It will almost always equal to +// `step` except at boundaries. +// 2. If `remaining_steps > 0` convert the current linearIndex to offset (if in +// bound), and recursively call `ApplyOpN` with `remaining_steps - 1`. +// 3. At `remaining_steps = 0`, +// if `step = 1`, call `op(tensor1_val, tensor2_val, ...)`; +// if `step > 1`, call `op(n, tensor1_val1, tensor1_val2, ..., tesor1_valstep, +// tensor2_val1, tensor2_val2, ..., tesor2_valstep, +// ... +// tensorN_val1, tensorN_val2, ..., tesorN_valstep);` +// +// See NOTE [ CUDA_tensor_applyN helpers ] above for how Op may look like. + +template +struct ApplyOp1 { +__device__ __forceinline__ +static void apply(detail::TensorInfo &a, const Op &op, int n, + IndexType linearIndex, Offsets... aOffsets) { + // Convert `linearIndex` into an offset of `a` + const IndexType aOffset = sizeof...(Offsets) < n ? + detail::IndexToOffset::get(linearIndex, a) : 0; + + ApplyOp1::apply( + a, op, n, linearIndex + 1, aOffsets..., aOffset + ); +} +}; + +// Specialize `step=1` case (i.e., `remaining_steps=0` and `len(Offsets)=1`). +// We don't need to pass in how many elements need to processed in this case. +template +struct ApplyOp1 { +__device__ __forceinline__ +static void apply(detail::TensorInfo &a, const Op &op, + int n, IndexType linearIndex, Offset offset) { + op(a.data[offset]); +} +}; + +template +struct ApplyOp1 { +__device__ __forceinline__ +static void apply(detail::TensorInfo &a, const Op &op, int n, + IndexType linearIndex, Offsets... offsets) { + op(n, a.data[offsets]...); +} +}; + +template +#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) +C10_LAUNCH_BOUNDS_2(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM) +#endif +__global__ void kernelPointwiseApply1(detail::TensorInfo a, + IndexType totalElements, const Op op) { + for (IndexType linearIndex = (blockIdx.x * blockDim.x + threadIdx.x) * step; + linearIndex < totalElements; + linearIndex += gridDim.x * blockDim.x * step) { + ApplyOp1::apply( + a, op, ::min(step, static_cast(totalElements - linearIndex)), linearIndex); + } +} + + +template +struct ApplyOp2 { +__device__ __forceinline__ +static void apply(detail::TensorInfo &a, + detail::TensorInfo &b, + const Op &op, int64_t n, IndexType linearIndex, + Offsets... aOffsets, Offsets... bOffsets) { + // Convert `linearIndex` into an offset of `a` + const IndexType aOffset = static_cast(sizeof...(Offsets)) < n ? + detail::IndexToOffset::get(linearIndex, a) : 0; + + // Convert `linearIndex` into an offset of `b` + const IndexType bOffset = static_cast(sizeof...(Offsets)) < n ? + detail::IndexToOffset::get(linearIndex, b) : 0; + + ApplyOp2::apply( + a, b, op, n, linearIndex + 1, aOffsets..., aOffset, bOffsets..., bOffset + ); +} +}; + +// Specialize `step=1` case (i.e., `remaining_steps=0` and `len(Offsets)=1`). +// We don't need to pass in how many elements need to processed in this case. +template +struct ApplyOp2 { +__device__ __forceinline__ +static void apply(detail::TensorInfo &a, + detail::TensorInfo &b, + const Op &op, int /*n*/, IndexType /*linearIndex*/, + Offset aOffset, Offset bOffset) { + op(a.data[aOffset], b.data[bOffset]); +} +}; + +template +struct ApplyOp2 { +__device__ __forceinline__ +static void apply(detail::TensorInfo &a, + detail::TensorInfo &b, + const Op &op, int n, IndexType linearIndex, + Offsets... aOffsets, Offsets... bOffsets) { + op(n, a.data[aOffsets]..., b.data[bOffsets]...); +} +}; + +template +#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) +C10_LAUNCH_BOUNDS_2(max_threads_per_block, min_blocks_per_sm) +#endif +__global__ void +kernelPointwiseApply2(detail::TensorInfo a, + detail::TensorInfo b, + IndexType totalElements, + const Op op) { + for (IndexType linearIndex = (blockIdx.x * blockDim.x + threadIdx.x) * step; + linearIndex < totalElements; + linearIndex += gridDim.x * blockDim.x * step) { + ApplyOp2::apply( + a, b, op, ::min(step, static_cast(totalElements - linearIndex)), + linearIndex); + } +} + +} // anonymous namespace + +template +inline bool CUDA_tensor_apply2(at::TensorBase a, + at::TensorBase b, + const Op op, + TensorArgType aType = TensorArgType::ReadWrite, + TensorArgType bType = TensorArgType::ReadOnly) { + TORCH_CHECK(a.device().is_cuda() && b.device().is_cuda(), + "CUDA_tensor_apply2: Expected tensors to have CUDA DeviceType, but got " + "tensors with type ", a.device().type(), " and ", b.device().type()); + int64_t totalElements = a.numel(); + + if (totalElements != b.numel()) { + return false; + } + + if (a.dim() > MAX_TENSORINFO_DIMS || + b.dim() > MAX_TENSORINFO_DIMS) { + return false; + } + + if (a.numel() == 0) { + // Empty tensor; do nothing + return true; + } + const dim3 block = getApplyBlock(max_threads_per_block); + + dim3 grid; + auto curDevice = current_device(); + if (curDevice == -1) return false; + if (!getApplyGrid(totalElements, grid, curDevice, max_threads_per_block)) { + return false; + } + + /* + Expands readable/writable tensors whose indices may be "overlapped." + This ensures that each element of the tensor is operated on once and only + once. + */ + TensorBase oldA; + TensorBase oldB; + + if (aType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(a)) { + // Must perform in contiguous space + oldA = std::exchange(a, a.contiguous()); + } + if (bType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(b)) { + // Must perform in contiguous space + oldB = std::exchange(b, b.contiguous()); + } + + // It is possible that the tensor dimensions are able to be collapsed, + // and thus we can reduce the actual code complexity of the copy by + // exploiting this knowledge statically, since the div/mod is the + // most expensive part of the operation, more so than memory accesses. + // For instance, when copying a non-contiguous to a contiguous tensor + // (or vice versa), the contiguous tensor can be collapsed to one + // dimension, and the loop to translate the linear index to the array + // index can be similarly collapsed. That is what this unrolling is for. + +#define HANDLE_CASE(TYPE, A, B) \ + kernelPointwiseApply2 \ + <<>>( \ + aInfo, bInfo, static_cast(totalElements), op); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); + +#define HANDLE_B_CASE(TYPE, A, B) { \ + switch (B) { \ + case 1: \ + HANDLE_CASE(TYPE, A, 1); \ + break; \ + case 2: \ + HANDLE_CASE(TYPE, A, 2); \ + break; \ + default: \ + HANDLE_CASE(TYPE, A, -1); \ + break; \ + } \ +} + +#define HANDLE_A_CASE(TYPE, A, B) { \ + switch (A) { \ + case 1: \ + HANDLE_B_CASE(TYPE, 1, B); \ + break; \ + case 2: \ + HANDLE_B_CASE(TYPE, 2, B); \ + break; \ + default: \ + HANDLE_B_CASE(TYPE, -1, B); \ + break; \ + } \ +} + + if (detail::canUse32BitIndexMath(a) && + detail::canUse32BitIndexMath(b)) { + detail::TensorInfo aInfo = + detail::getTensorInfo(a); + + detail::TensorInfo bInfo = + detail::getTensorInfo(b); + rearrangeDims(&aInfo, &bInfo); + aInfo.collapseDims(); + bInfo.collapseDims(); + + HANDLE_A_CASE(unsigned int, aInfo.dims, bInfo.dims); + } else { + detail::TensorInfo aInfo = + detail::getTensorInfo(a); + + detail::TensorInfo bInfo = + detail::getTensorInfo(b); + rearrangeDims(&aInfo, &bInfo); + aInfo.collapseDims(); + bInfo.collapseDims(); + + /* + Only instantiates the all 1D special case and the fallback all nD case for + large (64-bit indexed) tensors to reduce compilation time. + */ + if (aInfo.dims == 1 && bInfo.dims == 1) { + HANDLE_CASE(uint64_t, 1, 1); + } else { + HANDLE_CASE(uint64_t, -1, -1); + } + } +#undef HANDLE_CASE +#undef HANDLE_B_CASE +#undef HANDLE_A_CASE + + if (oldA.defined()) { + at::native::copy_ignoring_overlaps(oldA, a); + } + + if (oldB.defined()) { + at::native::copy_ignoring_overlaps(oldB, b); + } + + return true; +} + +/* Provides default step = 1 to CUDA_tensor_apply2. */ +template +inline bool CUDA_tensor_apply2(const at::TensorBase &a, + const at::TensorBase &b, + const Op op, + TensorArgType aType = TensorArgType::ReadWrite, + TensorArgType bType = TensorArgType::ReadOnly) { + return CUDA_tensor_apply2(a, b, op, aType, bType); +} + +} // namespace at::cuda diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDABlas.h b/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDABlas.h new file mode 100644 index 0000000000000000000000000000000000000000..f99c49dd965c84a202986610c8112fac1f18518f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDABlas.h @@ -0,0 +1,417 @@ +#pragma once +/* + Provides a subset of CUDA BLAS functions as templates: + + gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, + ldc) + + gemv(transa, m, n, alpha, a, lda, x, incx, beta, y, incy) + + dot(n, x, incx, y, incy, result) + + where Dtype is double, float, at::Half or at::BFloat16 (ROCm, NOT for dot). + The functions are available in at::cuda::blas namespace. + */ + +#include +#include + +namespace at::cuda::blas { + +// RAII guard that sets the CuBLAS pointer mode and restores it to +// its previous value when the guard is destroyed +class PointerModeGuard { +public: + PointerModeGuard(cublasHandle_t handle, cublasPointerMode_t mode) : + handle(handle) { + TORCH_CUDABLAS_CHECK(cublasGetPointerMode(handle, &previous_mode)); + TORCH_CUDABLAS_CHECK(cublasSetPointerMode(handle, mode)); + } + + ~PointerModeGuard() { + cublasSetPointerMode(handle, previous_mode); + } + +private: + cublasHandle_t handle; + cublasPointerMode_t previous_mode{}; +}; + +/* LEVEL 3 BLAS FUNCTIONS */ + +#define CUDABLAS_GEMM_ARGTYPES(Dtype) CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(Dtype, Dtype) + +#define CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dtype) \ + char transa, char transb, int64_t m, int64_t n, int64_t k, at::opmath_type alpha, \ + const Dtype *a, int64_t lda, const Dtype *b, int64_t ldb, at::opmath_type beta,\ + C_Dtype *c, int64_t ldc + +#define CUDABLAS_GEMM_ARGS(Dtype) transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc + +#define CUDABLAS_GEMM_DTYPE_IS_FLOAT_TYPE_AND_C_DTYPE_IS_FLOAT \ + ((std::is_same::value || std::is_same::value) && std::is_same::value) + +template ::type* = nullptr> +inline void gemm(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dtype)) { + static_assert(false&&sizeof(Dtype),"at::cuda::blas::gemm: not implemented"); +} + +template ::type* = nullptr> +void gemm(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dtype)); + +template <> +void gemm(CUDABLAS_GEMM_ARGTYPES(double)); +template <> +void gemm(CUDABLAS_GEMM_ARGTYPES(float)); +template <> +void gemm>(CUDABLAS_GEMM_ARGTYPES(c10::complex)); +template <> +void gemm>(CUDABLAS_GEMM_ARGTYPES(c10::complex)); +template <> +void gemm(CUDABLAS_GEMM_ARGTYPES(at::Half)); +template <> +void gemm(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)); +template<> +void gemm(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(at::Half, float)); +template<> +void gemm(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(at::BFloat16, float)); + +template +inline void gemm_internal(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dtype)) { + static_assert(false&&sizeof(Dtype),"at::cuda::blas::gemm_internal: not implemented"); +} + +template <> +void gemm_internal(CUDABLAS_GEMM_ARGTYPES(double)); +template <> +void gemm_internal(CUDABLAS_GEMM_ARGTYPES(float)); +template <> +void gemm_internal>(CUDABLAS_GEMM_ARGTYPES(c10::complex)); +template <> +void gemm_internal>(CUDABLAS_GEMM_ARGTYPES(c10::complex)); +template <> +void gemm_internal(CUDABLAS_GEMM_ARGTYPES(at::Half)); +template <> +void gemm_internal(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)); +template<> +void gemm_internal(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(at::Half, float)); +template<> +void gemm_internal(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(at::BFloat16, float)); + +enum GEMMAndBiasActivationEpilogue { + None, + RELU, + GELU, +}; + +// NOTE: GELU activation is not supported prior to CUDA 11.4 and will +// do nothing if passed in that case. +template +bool gemm_and_bias( + bool transpose_mat1, + bool transpose_mat2, + int64_t m, + int64_t n, + int64_t k, + at::opmath_type alpha_val, + const Dtype* mat1_ptr, + int64_t mat1_ld, + const Dtype* mat2_ptr, + int64_t mat2_ld, + const Dtype* bias, + C_Dtype* result_ptr, + int64_t result_ld, + GEMMAndBiasActivationEpilogue activation = GEMMAndBiasActivationEpilogue::None); + +void int8_gemm( + bool transpose_mat1, + bool transpose_mat2, + int64_t m, + int64_t n, + int64_t k, + const int8_t* mat1_ptr, + int64_t mat1_ld, + const int8_t* mat2_ptr, + int64_t mat2_ld, + int32_t* result_ptr, + int64_t result_ld); + +void scaled_gemm( + char transa, + char transb, + int64_t m, + int64_t n, + int64_t k, + const void* mat1_ptr, + const void* mat1_scale_ptr, + int64_t mat1_ld, + ScalarType mat1_dtype, + ScalarType mat1_scale_dtype, + const void* mat2_ptr, + const void* mat2_scale_ptr, + int64_t mat2_ld, + ScalarType mat2_dtype, + ScalarType mat2_scale_dtype, + const void* bias_ptr, + ScalarType bias_dtype, + void* result_ptr, + const void* result_scale_ptr, + int64_t result_ld, + ScalarType result_dtype, + bool use_fast_accum, + bool use_rowwise); + +#define CUDABLAS_BGEMM_ARGTYPES(Dtype) CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, Dtype) + +#define CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dtype) \ + char transa, char transb, int64_t m, int64_t n, int64_t k, at::opmath_type alpha, \ + const Dtype *a, int64_t lda, int64_t stridea, \ + const Dtype *b, int64_t ldb, int64_t strideb, \ + at::opmath_type beta, C_Dtype *c, int64_t ldc, int64_t stridec, int64_t num_batches + +#define CUDABLAS_BGEMM_ARGS(Dtype) \ + transa, transb, m, n, k, alpha, a, lda, stridea, b, ldb, strideb, beta, c, ldc, stridec, num_batches + +template ::type* = nullptr> +inline void bgemm(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dtype)) { + static_assert(false&&sizeof(Dtype),"at::cuda::blas::bgemm: not implemented"); +} + +template ::type* = nullptr> +void bgemm(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dtype)); + +template <> +void bgemm(CUDABLAS_BGEMM_ARGTYPES(double)); +template <> +void bgemm(CUDABLAS_BGEMM_ARGTYPES(float)); +template <> +void bgemm>(CUDABLAS_BGEMM_ARGTYPES(c10::complex)); +template <> +void bgemm>(CUDABLAS_BGEMM_ARGTYPES(c10::complex)); +template <> +void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::Half)); +template <> +void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +template<> +void bgemm(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(at::Half, float)); +template<> +void bgemm(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(at::BFloat16, float)); + +template +inline void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dtype)) { + static_assert(false&&sizeof(Dtype),"at::cuda::blas::bgemm_internal: not implemented"); +} + +template <> +void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(double)); +template <> +void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(float)); +template <> +void bgemm_internal>(CUDABLAS_BGEMM_ARGTYPES(c10::complex)); +template <> +void bgemm_internal>(CUDABLAS_BGEMM_ARGTYPES(c10::complex)); +template <> +void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(at::Half)); +template <> +void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +template<> +void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(at::Half, float)); +template<> +void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(at::BFloat16, float)); + +#define CUDABLAS_TRSM_ARGTYPES(Dtype) \ + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, \ + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, \ + const Dtype *alpha, const Dtype *A, int lda, Dtype *B, int ldb + +template +inline void trsm(CUDABLAS_TRSM_ARGTYPES(Dtype)) { + static_assert(false&&sizeof(Dtype), "at::cuda::blas::trsm: not implemented"); +} + +template <> +TORCH_CUDA_CU_API void trsm(CUDABLAS_TRSM_ARGTYPES(float)); +template <> +TORCH_CUDA_CU_API void trsm(CUDABLAS_TRSM_ARGTYPES(double)); +template <> +TORCH_CUDA_CU_API void trsm>(CUDABLAS_TRSM_ARGTYPES(c10::complex)); +template <> +TORCH_CUDA_CU_API void trsm>(CUDABLAS_TRSM_ARGTYPES(c10::complex)); + +#define CUDABLAS_TRSM_BATCHED_ARGTYPES(Dtype) \ + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, \ + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, \ + const Dtype *alpha, Dtype *A[], int lda, Dtype *B[], int ldb, \ + int batchCount + +template +inline void trsmBatched(CUDABLAS_TRSM_BATCHED_ARGTYPES(Dtype)) { + static_assert(false&&sizeof(Dtype), "at::cuda::blas::trsmBatched: not implemented"); +} + +template <> +TORCH_CUDA_CU_API void trsmBatched(CUDABLAS_TRSM_BATCHED_ARGTYPES(float)); +template <> +TORCH_CUDA_CU_API void trsmBatched(CUDABLAS_TRSM_BATCHED_ARGTYPES(double)); +template <> +TORCH_CUDA_CU_API void trsmBatched>(CUDABLAS_TRSM_BATCHED_ARGTYPES(c10::complex)); +template <> +TORCH_CUDA_CU_API void trsmBatched>(CUDABLAS_TRSM_BATCHED_ARGTYPES(c10::complex)); + +/* LEVEL 2 BLAS FUNCTIONS */ + +#define CUDABLAS_GEMV_ARGTYPES(Dtype) \ + char trans, int64_t m, int64_t n, Dtype alpha, const Dtype *a, int64_t lda, \ + const Dtype *x, int64_t incx, Dtype beta, Dtype *y, int64_t incy + +template +inline void gemv(CUDABLAS_GEMV_ARGTYPES(Dtype)) { + static_assert(false&&sizeof(Dtype), "at::cuda::blas::gemv: not implemented"); +} + +template <> +void gemv(CUDABLAS_GEMV_ARGTYPES(double)); +template <> +void gemv(CUDABLAS_GEMV_ARGTYPES(float)); +template <> +void gemv>(CUDABLAS_GEMV_ARGTYPES(c10::complex)); +template <> +void gemv>(CUDABLAS_GEMV_ARGTYPES(c10::complex)); +template <> +void gemv(CUDABLAS_GEMV_ARGTYPES(at::Half)); +template <> +void gemv(CUDABLAS_GEMV_ARGTYPES(at::BFloat16)); + +/* LEVEL 1 BLAS FUNCTIONS */ + +#define CUDABLAS_DOT_ARGTYPES(Dtype) \ + cublasHandle_t handle, int n, const Dtype *x, int incx, const Dtype *y, \ + int incy, Dtype *result + +template +inline void dot(CUDABLAS_DOT_ARGTYPES(Dtype)) { + static_assert(false&&sizeof(Dtype),"at::cuda::blas::dot: not implemented"); +} + +template <> +void dot(CUDABLAS_DOT_ARGTYPES(double)); +template <> +void dot(CUDABLAS_DOT_ARGTYPES(float)); +template <> +void dot(CUDABLAS_DOT_ARGTYPES(at::Half)); +template <> +void dot(CUDABLAS_DOT_ARGTYPES(at::BFloat16)); +template <> +void dot>(CUDABLAS_DOT_ARGTYPES(c10::complex)); +template <> +void dot>(CUDABLAS_DOT_ARGTYPES(c10::complex)); + +template +inline void vdot(CUDABLAS_DOT_ARGTYPES(Dtype)) { + static_assert(false&&sizeof(Dtype),"at::cuda::blas::vdot: not implemented"); +} + +template <> +void vdot>(CUDABLAS_DOT_ARGTYPES(c10::complex)); +template <> +void vdot>(CUDABLAS_DOT_ARGTYPES(c10::complex)); + +#define CUDABLAS_GETRS_ARGTYPES(Dtype) \ + cublasHandle_t handle, cublasOperation_t trans, \ + int n, int nrhs, Dtype** dA_array, int lda, int* ipiv_array, \ + Dtype** dB_array, int ldb, int* info_array, int batchsize + +#define CUDABLAS_GEQRF_BATCHED_ARGTYPES(Dtype) \ + cublasHandle_t handle, int m, int n, Dtype **A_array, int lda, \ + Dtype **tau_array, int *info, int batchsize + +#define CUDABLAS_GETRF_ARGTYPES(Dtype) \ + int n, Dtype** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize + +#define CUDABLAS_GELS_BATCHED_ARGTYPES(Dtype) \ + cublasHandle_t handle, cublasOperation_t trans, \ + int m, int n, int nrhs, Dtype** dA_array, int ldda, \ + Dtype** dC_array, int lddc, int* info, int *devInfoArray, int batchSize + +// HIP on Windows does not support getrs, geqrf, getrf, gels +#if !(defined(USE_ROCM) && defined(_MSC_VER)) + +template +void getrsBatched(CUDABLAS_GETRS_ARGTYPES(Dtype)) { + static_assert(false&&sizeof(Dtype),"at::cuda::blas::getrsBatched: not implemented"); +} +template<> +TORCH_CUDA_CU_API void getrsBatched(CUDABLAS_GETRS_ARGTYPES(float)); +template<> +TORCH_CUDA_CU_API void getrsBatched(CUDABLAS_GETRS_ARGTYPES(double)); +template<> +TORCH_CUDA_CU_API void getrsBatched>(CUDABLAS_GETRS_ARGTYPES(c10::complex)); +template<> +TORCH_CUDA_CU_API void getrsBatched>(CUDABLAS_GETRS_ARGTYPES(c10::complex)); + +template +void geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES(Dtype)) { + static_assert(false&&sizeof(Dtype), "at::cuda::blas::geqrfBatched: not implemented"); +} +template <> +TORCH_CUDA_CU_API void geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES(float)); +template <> +TORCH_CUDA_CU_API void geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES(double)); +template <> +TORCH_CUDA_CU_API void geqrfBatched>( + CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex)); +template <> +TORCH_CUDA_CU_API void geqrfBatched>( + CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex)); + +template +void getrfBatched(CUDABLAS_GETRF_ARGTYPES(Dtype)) { + static_assert(false&&sizeof(Dtype), "at::cuda::blas::getrfBatched: not implemented"); +} +template<> +TORCH_CUDA_CU_API void getrfBatched(CUDABLAS_GETRF_ARGTYPES(float)); +template<> +TORCH_CUDA_CU_API void getrfBatched(CUDABLAS_GETRF_ARGTYPES(double)); +template<> +TORCH_CUDA_CU_API void getrfBatched>(CUDABLAS_GETRF_ARGTYPES(c10::complex)); +template<> +TORCH_CUDA_CU_API void getrfBatched>(CUDABLAS_GETRF_ARGTYPES(c10::complex)); + +template +void gelsBatched(CUDABLAS_GELS_BATCHED_ARGTYPES(Dtype)) { + static_assert(false&&sizeof(Dtype), "at::cuda::blas::gelsBatched: not implemented"); +} +template<> +TORCH_CUDA_CU_API void gelsBatched(CUDABLAS_GELS_BATCHED_ARGTYPES(double)); +template<> +TORCH_CUDA_CU_API void gelsBatched(CUDABLAS_GELS_BATCHED_ARGTYPES(float)); +template<> +TORCH_CUDA_CU_API void gelsBatched>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex)); +template<> +TORCH_CUDA_CU_API void gelsBatched>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex)); + +#else // !(defined(USE_ROCM) && defined(_MSC_VER)) + +template +void getrsBatched(CUDABLAS_GETRS_ARGTYPES(Dtype)) { + TORCH_CHECK(false, "at::cuda::blas::getrsBatched: not supported for HIP on Windows"); +} + +template +void geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES(Dtype)) { + TORCH_CHECK(false, "at::cuda::blas::geqrfBatched: not supported for HIP on Windows"); +} + +template +void getrfBatched(CUDABLAS_GETRF_ARGTYPES(Dtype)) { + TORCH_CHECK(false, "at::cuda::blas::getrfBatched: not supported for HIP on Windows"); +} + +template +void gelsBatched(CUDABLAS_GELS_BATCHED_ARGTYPES(Dtype)) { + TORCH_CHECK(false, "at::cuda::blas::gelsBatched: not supported for HIP on Windows"); +} + +#endif // !(defined(USE_ROCM) && defined(_MSC_VER)) + +} // namespace at::cuda::blas diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDAContext.h b/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDAContext.h new file mode 100644 index 0000000000000000000000000000000000000000..b257e3f16b4adb5efde62dff92ed6f8fb9bc1a64 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDAContext.h @@ -0,0 +1,9 @@ +#pragma once + +#include + +// Preserved for BC, as many files depend on these includes +#include +#include +#include +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDAContextLight.h b/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDAContextLight.h new file mode 100644 index 0000000000000000000000000000000000000000..ba489763e4ea6b6accd487e49576561552b57e26 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDAContextLight.h @@ -0,0 +1,102 @@ +#pragma once +// Light-weight version of CUDAContext.h with fewer transitive includes + +#include +#include + +#include +#include +#include + +// cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also +// added bf16 support +#include + +#ifdef CUDART_VERSION +#include +#endif + +#if defined(USE_CUDSS) +#include +#endif + +#if defined(USE_ROCM) +#include +#endif + +#include +#include + +namespace c10 { +struct Allocator; +} + +namespace at::cuda { + +/* +A common CUDA interface for ATen. + +This interface is distinct from CUDAHooks, which defines an interface that links +to both CPU-only and CUDA builds. That interface is intended for runtime +dispatch and should be used from files that are included in both CPU-only and +CUDA builds. + +CUDAContext, on the other hand, should be preferred by files only included in +CUDA builds. It is intended to expose CUDA functionality in a consistent +manner. + +This means there is some overlap between the CUDAContext and CUDAHooks, but +the choice of which to use is simple: use CUDAContext when in a CUDA-only file, +use CUDAHooks otherwise. + +Note that CUDAContext simply defines an interface with no associated class. +It is expected that the modules whose functions compose this interface will +manage their own state. There is only a single CUDA context/state. +*/ + +/** + * DEPRECATED: use device_count() instead + */ +inline int64_t getNumGPUs() { + return c10::cuda::device_count(); +} + +/** + * CUDA is available if we compiled with CUDA, and there are one or more + * devices. If we compiled with CUDA but there is a driver problem, etc., + * this function will report CUDA is not available (rather than raise an error.) + */ +inline bool is_available() { + return c10::cuda::device_count() > 0; +} + +TORCH_CUDA_CPP_API cudaDeviceProp* getCurrentDeviceProperties(); + +TORCH_CUDA_CPP_API int warp_size(); + +TORCH_CUDA_CPP_API cudaDeviceProp* getDeviceProperties(c10::DeviceIndex device); + +TORCH_CUDA_CPP_API bool canDeviceAccessPeer( + c10::DeviceIndex device, + c10::DeviceIndex peer_device); + +TORCH_CUDA_CPP_API c10::Allocator* getCUDADeviceAllocator(); + +/* Handles */ +TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle(); +TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle(); +TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle(); + +TORCH_CUDA_CPP_API void clearCublasWorkspaces(); +TORCH_CUDA_CPP_API std::map, at::DataPtr>& cublas_handle_stream_to_workspace(); +TORCH_CUDA_CPP_API size_t getChosenWorkspaceSize(); + +#if defined(CUDART_VERSION) || defined(USE_ROCM) +TORCH_CUDA_CPP_API cusolverDnHandle_t getCurrentCUDASolverDnHandle(); +#endif + +#if defined(USE_CUDSS) +TORCH_CUDA_CPP_API cudssHandle_t getCurrentCudssHandle(); +#endif + +} // namespace at::cuda diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDADataType.h b/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDADataType.h new file mode 100644 index 0000000000000000000000000000000000000000..2c35b980ad0bfac737ed46940647bb866f923a6c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDADataType.h @@ -0,0 +1,102 @@ +#pragma once + +#include + +#include +#include + +namespace at::cuda { + +template +cudaDataType getCudaDataType() { + static_assert(false && sizeof(scalar_t), "Cannot convert type to cudaDataType."); + return {}; +} + +template<> inline cudaDataType getCudaDataType() { + return CUDA_R_16F; +} +template<> inline cudaDataType getCudaDataType() { + return CUDA_R_32F; +} +template<> inline cudaDataType getCudaDataType() { + return CUDA_R_64F; +} +template<> inline cudaDataType getCudaDataType>() { + return CUDA_C_16F; +} +template<> inline cudaDataType getCudaDataType>() { + return CUDA_C_32F; +} +template<> inline cudaDataType getCudaDataType>() { + return CUDA_C_64F; +} + +template<> inline cudaDataType getCudaDataType() { + return CUDA_R_8U; +} +template<> inline cudaDataType getCudaDataType() { + return CUDA_R_8I; +} +template<> inline cudaDataType getCudaDataType() { + return CUDA_R_32I; +} + +template<> inline cudaDataType getCudaDataType() { + return CUDA_R_16I; +} +template<> inline cudaDataType getCudaDataType() { + return CUDA_R_64I; +} +template<> inline cudaDataType getCudaDataType() { + return CUDA_R_16BF; +} + +inline cudaDataType ScalarTypeToCudaDataType(const c10::ScalarType& scalar_type) { + switch (scalar_type) { + case c10::ScalarType::Byte: + return CUDA_R_8U; + case c10::ScalarType::Char: + return CUDA_R_8I; + case c10::ScalarType::Int: + return CUDA_R_32I; + case c10::ScalarType::Half: + return CUDA_R_16F; + case c10::ScalarType::Float: + return CUDA_R_32F; + case c10::ScalarType::Double: + return CUDA_R_64F; + case c10::ScalarType::ComplexHalf: + return CUDA_C_16F; + case c10::ScalarType::ComplexFloat: + return CUDA_C_32F; + case c10::ScalarType::ComplexDouble: + return CUDA_C_64F; + case c10::ScalarType::Short: + return CUDA_R_16I; + case c10::ScalarType::Long: + return CUDA_R_64I; + case c10::ScalarType::BFloat16: + return CUDA_R_16BF; +#if !defined(USE_ROCM) || ROCM_VERSION >= 60300 + case c10::ScalarType::Float8_e4m3fn: + return CUDA_R_8F_E4M3; + case c10::ScalarType::Float8_e5m2: + return CUDA_R_8F_E5M2; +#endif +#if defined(USE_ROCM) + case c10::ScalarType::Float8_e4m3fnuz: + return HIP_R_8F_E4M3_FNUZ; + case c10::ScalarType::Float8_e5m2fnuz: + return HIP_R_8F_E5M2_FNUZ; +#endif +#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12080) + case c10::ScalarType::Float4_e2m1fn_x2: + return CUDA_R_4F_E2M1; +#endif + default: + TORCH_INTERNAL_ASSERT(false, "Cannot convert ScalarType ", scalar_type, " to cudaDataType.") + } +} + +} // namespace at::cuda diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDADevice.h b/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDADevice.h new file mode 100644 index 0000000000000000000000000000000000000000..5353a06ca6b11f607151a0b7c64762234b617c79 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDADevice.h @@ -0,0 +1,23 @@ +#pragma once + +#include + +#include +#include + +namespace at::cuda { + +inline Device getDeviceFromPtr(void* ptr) { + cudaPointerAttributes attr{}; + + AT_CUDA_CHECK(cudaPointerGetAttributes(&attr, ptr)); + +#if !defined(USE_ROCM) + TORCH_CHECK(attr.type != cudaMemoryTypeUnregistered, + "The specified pointer resides on host memory and is not registered with any CUDA device."); +#endif + + return {c10::DeviceType::CUDA, static_cast(attr.device)}; +} + +} // namespace at::cuda diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDAEvent.h b/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDAEvent.h new file mode 100644 index 0000000000000000000000000000000000000000..17828ca499b69c8f18a7c640c996a44759d9919a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDAEvent.h @@ -0,0 +1,249 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +/* +* `cudaEventExternal` is a torch-specific flag that is used to +* indicate that the CUDAEvent will be used only for synchronization +* with work outside of the cuda graph, rather than creation of +* cross-stream dependencies within a cuda graph. Resources: +* https://docs.nvidia.com/cuda/archive/12.9.0/cuda-c-programming-guide/index.html#cross-stream-dependencies-and-events +* https://docs.nvidia.com/cuda/archive/12.9.0/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g3457b81d1d32c6a00f6132fbc2693d47 +* https://docs.nvidia.com/cuda/archive/12.9.0/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g0c23426b7252eaa9cef695859991304e +*/ +#define cudaEventExternal 0x08 + +namespace at::cuda { + +/* +* CUDAEvents are movable not copyable wrappers around CUDA's events. +* +* CUDAEvents are constructed lazily when first recorded unless it is +* reconstructed from a cudaIpcEventHandle_t. The event has a device, and this +* device is acquired from the first recording stream. However, if reconstructed +* from a handle, the device should be explicitly specified; or if ipc_handle() is +* called before the event is ever recorded, it will use the current device. +* Later streams that record the event must match this device. +*/ +struct TORCH_CUDA_CPP_API CUDAEvent { + // Constructors + // Default value for `flags` is specified below - it's cudaEventDisableTiming + CUDAEvent() noexcept = default; + CUDAEvent(unsigned int flags) noexcept : flags_{flags} {} + + CUDAEvent( + DeviceIndex device_index, const cudaIpcEventHandle_t* handle) : device_index_(device_index) { + CUDAGuard guard(device_index_); + + AT_CUDA_CHECK(cudaIpcOpenEventHandle(&event_, *handle)); + is_created_ = true; + } + + // Note: event destruction done on creating device to avoid creating a + // CUDA context on other devices. + ~CUDAEvent() { + try { + if (is_created_) { + CUDAGuard guard(device_index_); + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_deletion(at::kCUDA, reinterpret_cast(event_)); + } + AT_CUDA_CHECK(cudaEventDestroy(event_)); + } + } catch (...) { /* No throw */ } + } + + CUDAEvent(const CUDAEvent&) = delete; + CUDAEvent& operator=(const CUDAEvent&) = delete; + + CUDAEvent(CUDAEvent&& other) noexcept { moveHelper(std::move(other)); } + CUDAEvent& operator=(CUDAEvent&& other) noexcept { + if (this != &other) { + moveHelper(std::move(other)); + } + return *this; + } + + operator cudaEvent_t() const { return event(); } + + // Less than operator (to allow use in sets) + friend bool operator<(const CUDAEvent& left, const CUDAEvent& right) { + return left.event_ < right.event_; + } + + std::optional device() const { + if (is_created_) { + return at::Device(at::kCUDA, device_index_); + } else { + return {}; + } + } + + bool isCreated() const { return is_created_; } + DeviceIndex device_index() const {return device_index_;} + cudaEvent_t event() const { return event_; } + + // Note: cudaEventQuery can be safely called from any device + bool query() const { + if (!is_created_) { + return true; + } + + cudaError_t err = cudaEventQuery(event_); + if (err == cudaSuccess) { + return true; + } else if (err != cudaErrorNotReady) { + C10_CUDA_CHECK(err); + } else { + // ignore and clear the error if not ready + (void)cudaGetLastError(); + } + + return false; + } + + void record() { record(getCurrentCUDAStream()); } + + void recordOnce(const CUDAStream& stream) { + if (!was_recorded_) record(stream); + } + + // Note: cudaEventRecord must be called on the same device as the event. + void record(const CUDAStream& stream) { + if (!is_created_) { + createEvent(stream.device_index()); + } + + TORCH_CHECK(device_index_ == stream.device_index(), "Event device ", device_index_, + " does not match recording stream's device ", stream.device_index(), "."); + CUDAGuard guard(device_index_); + +#ifndef USE_ROCM + // it is an error to use cudaEventRecordExternal when not doing stream capture + unsigned int flags = (c10::cuda::currentStreamCaptureStatusMayInitCtx() != c10::cuda::CaptureStatus::None && external_) ? cudaEventRecordExternal : cudaEventRecordDefault; + AT_CUDA_CHECK(cudaEventRecordWithFlags(event_, stream, flags)); +#else + AT_CUDA_CHECK(cudaEventRecord(event_, stream)); +#endif + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_record(at::kCUDA, + reinterpret_cast(event_), + reinterpret_cast(stream.stream()) + ); + } + was_recorded_ = true; + } + + // Note: cudaStreamWaitEvent must be called on the same device as the stream. + // The event has no actual GPU resources associated with it. + void block(const CUDAStream& stream) { + if (is_created_) { + CUDAGuard guard(stream.device_index()); +#ifndef USE_ROCM + // it is an error to use cudaEventWaitExternal when not doing stream capture + unsigned int flags = (c10::cuda::currentStreamCaptureStatusMayInitCtx() != c10::cuda::CaptureStatus::None && external_) ? cudaEventWaitExternal : cudaEventWaitDefault; + AT_CUDA_CHECK(cudaStreamWaitEvent(stream, event_, flags)); +#else + AT_CUDA_CHECK(cudaStreamWaitEvent(stream, event_)); +#endif + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_wait(at::kCUDA, + reinterpret_cast(event_), + reinterpret_cast(stream.stream()) + ); + } + } + } + + // Note: cudaEventElapsedTime can be safely called from any device + float elapsed_time(const CUDAEvent& other) const { + TORCH_CHECK_VALUE( + !(flags_ & cudaEventDisableTiming) && !(other.flags_ & cudaEventDisableTiming), + "Both events must be created with argument 'enable_timing=True'."); + TORCH_CHECK_VALUE( + is_created_ && other.isCreated(), + "Both events must be recorded before calculating elapsed time."); + TORCH_CHECK( + query() && other.query(), + "Both events must be completed before calculating elapsed time."); + + float time_ms = 0; + // We do not strictly have to set the device index to the same as our event, + // but if we don't and the current device is not initialized, it will + // create a new cuda context, which will consume a lot of memory. + CUDAGuard guard(device_index_); + // raise cudaErrorNotReady if either event is recorded but not yet completed + AT_CUDA_CHECK(cudaEventElapsedTime(&time_ms, event_, other.event_)); + return time_ms; + } + + // Note: cudaEventSynchronize can be safely called from any device + void synchronize() const { + if (is_created_) { + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_synchronization(at::kCUDA, reinterpret_cast(event_)); + } + AT_CUDA_CHECK(cudaEventSynchronize(event_)); + } + } + + // Note: cudaIpcGetEventHandle must be called on the same device as the event + void ipc_handle(cudaIpcEventHandle_t * handle) { + if (!is_created_) { + // this CUDAEvent object was initially constructed from flags but event_ + // is not created yet. + createEvent(getCurrentCUDAStream().device_index()); + } + CUDAGuard guard(device_index_); + AT_CUDA_CHECK(cudaIpcGetEventHandle(handle, event_)); + } + +private: + unsigned int flags_ = cudaEventDisableTiming; + bool is_created_ = false; + bool was_recorded_ = false; + bool external_ = false; + DeviceIndex device_index_ = -1; + cudaEvent_t event_{}; + + void createEvent(DeviceIndex device_index) { + external_ = (flags_ & cudaEventExternal) != 0; +#ifdef USE_ROCM + TORCH_CHECK(!external_, "External events are disallowed in rocm"); +#endif + flags_ &= ~cudaEventExternal; + device_index_ = device_index; + CUDAGuard guard(device_index_); + AT_CUDA_CHECK(cudaEventCreateWithFlags(&event_, flags_)); + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_creation(at::kCUDA, reinterpret_cast(event_)); + } + is_created_ = true; + } + + void moveHelper(CUDAEvent&& other) { + std::swap(flags_, other.flags_); + std::swap(is_created_, other.is_created_); + std::swap(was_recorded_, other.was_recorded_); + std::swap(device_index_, other.device_index_); + std::swap(event_, other.event_); + } +}; + +} // namespace at::cuda diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDAGeneratorImpl.h b/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDAGeneratorImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..3c0faefbe2d407ce931d98bd783ed77eabad1329 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDAGeneratorImpl.h @@ -0,0 +1,180 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +namespace at { + +namespace cuda { +struct CUDAGraph; +} + +/** + * Note [CUDA Graph-safe RNG states] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * + * Strategy: + * ~~~~~~~~~ + * (It helps to look at + * cuda/detail/PhiloxCudaStateRaw.cuh and + * cuda/detail/UnpackRaw.cuh + * while you read this.) + * + * A CUDA graph containing multiple RNG ops behaves like a + * single giant kernel from the perspective of ops external + * to the graph. During graph capture, logic in CUDAGeneratorImpl + * records the total of all offset increments that occur in the + * graphed region, and records the final total as the offset for + * the entire graph. + * + * When the graph reruns, the logic that reruns it + * increments this device's CUDA generator's offset + * by that total. + * + * Meanwhile, within the graph, at capture time, instead of + * populating PhiloxCudaStates with the uint64_t offset pulled + * directly from the global state, PhiloxCudaState uses a pointer + * to a one-element stream-local int64_t device tensor + * holding an initial offset value, and a uint64_t holding an + * intra-graph offset. (The intra-graph offset starts from zero + * when capture begins.) In each consumer kernel, + * at::cuda::philox::unpack computes the offset to use for this kernel + * as intra-graph offset + *initial offset. + * + * When the graph reruns, the logic that reruns it first + * fill_s the initial offset tensor with this device's + * CUDA generator's current offset. + * + * The control flow above ensures graphed execution is bitwise + * identical to eager execution as long as RNG ops are enqueued + * from a single thread, even if RNG ops and graphs containing + * RNG ops are enqueued and run simultaneously on multiple streams. + * + * Usage: + * ~~~~~~ + * PhiloxCudaState in this file, and unpack() in + * cuda/CUDAGraphsUtils.cuh allow non-divergent use of + * CUDAGeneratorImpl whether graph capture is underway or not. + * + * Each PhiloxCudaState instance should be used for one and only one + * consumer kernel. + * + * Example (see e.g. native/cuda/Dropout.cu): + * + * #include + * #include + * + * __global__ void kernel(..., PhiloxCudaState philox_args) { + * auto seeds = at::cuda::philox::unpack(philox_args); + * IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; + * curandStatePhilox4_32_10_t state; + * curand_init(std::get<0>(seeds), // seed + * idx, // per-thread subsequence + * std::get<1>(seeds), // offset in subsequence + * &state); + * ... + * } + * + * host_caller(...) { + * PhiloxCudaState rng_engine_inputs; + * { + * // See Note [Acquire lock when using random generators] + * std::lock_guard lock(gen->mutex_); + * + * // gen could be HostState or DevState here! No divergent code needed! + * rng_engine_inputs = gen->philox_cuda_state(offset_increment); + * } + * kernel<<<...>>>(..., rng_engine_inputs); + * } + * + */ + +struct CUDAGeneratorState : public c10::intrusive_ptr_target { + uint64_t seed_; + uint64_t philox_offset_per_thread_; + uint32_t offset_intragraph_; + bool capturing_{}; + std::unordered_set registered_graphs_; + at::TensorBase seed_extragraph_{}; + at::TensorBase offset_extragraph_{}; + + CUDAGeneratorState( + uint64_t seed = default_rng_seed_val, + uint64_t philox_offset_per_thread = 0, + uint32_t offset_intragraph = 0) + : seed_(seed), + philox_offset_per_thread_(philox_offset_per_thread), + offset_intragraph_(offset_intragraph) {} + + void increase(uint64_t increment); + + void register_graph(cuda::CUDAGraph* graph); + void unregister_graph(cuda::CUDAGraph* graph); + + void capture_prologue(); + // capture_epilogue returns the wholegraph_increment + uint64_t capture_epilogue(); + void replay_prologue(uint64_t wholegraph_increment); + c10::intrusive_ptr clone(); +}; + +struct TORCH_CUDA_CPP_API CUDAGeneratorImpl : public c10::GeneratorImpl { + // Constructors + CUDAGeneratorImpl(DeviceIndex device_index = -1); + CUDAGeneratorImpl( + DeviceIndex device_index, + c10::intrusive_ptr state_); + ~CUDAGeneratorImpl() override = default; + + // CUDAGeneratorImpl methods + std::shared_ptr clone() const; + void set_current_seed(uint64_t seed) override; + void set_offset(uint64_t offset) override; + uint64_t get_offset() const override; + uint64_t current_seed() const override; + uint64_t seed() override; + void set_state(const c10::TensorImpl& new_state) override; + c10::intrusive_ptr get_state() const override; + void graphsafe_set_state( + const c10::intrusive_ptr& state) override; + c10::intrusive_ptr graphsafe_get_state() const override; + + void set_philox_offset_per_thread(uint64_t offset); + uint64_t philox_offset_per_thread() const; + + void register_graph(cuda::CUDAGraph* graph); + void unregister_graph(cuda::CUDAGraph* graph); + + // Generates a PhiloxCudaState with a specified increment, and increment + // current state + PhiloxCudaState philox_cuda_state(uint64_t increment); + + bool reset_rnn_state() { + return !no_reset_rnn_state_.test_and_set(); + } + + // Temporarily accommodates call sites that use philox_engine_inputs. + // Allows incremental refactor of call sites to use philox_cuda_state. + std::pair philox_engine_inputs(uint64_t increment); + + static c10::DeviceType device_type(); + + private: + CUDAGeneratorImpl* clone_impl() const override; + + c10::intrusive_ptr state_; + std::atomic_flag no_reset_rnn_state_{}; +}; + +namespace cuda::detail { + +TORCH_CUDA_CPP_API const Generator& getDefaultCUDAGenerator( + DeviceIndex device_index = -1); +TORCH_CUDA_CPP_API Generator createCUDAGenerator(DeviceIndex device_index = -1); + +} // namespace cuda::detail +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDAGraph.h b/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDAGraph.h new file mode 100644 index 0000000000000000000000000000000000000000..90e76cc934634672a8a48bf507923e5ad0f63a46 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDAGraph.h @@ -0,0 +1,93 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace at { + +struct Generator; +struct CUDAGeneratorImpl; +struct CUDAGeneratorState; + +namespace cuda { + +// Standalone way to get a unique mempool id usable as a pool=... argument +// to CUDAGraph::capture_begin +TORCH_CUDA_CPP_API MempoolId_t graph_pool_handle(); + +struct TORCH_CUDA_CPP_API CUDAGraph { + CUDAGraph(bool keep_graph=false); + ~CUDAGraph(); + + // See Note [Explicit Registration of Generators to the CUDA Graph] + void register_generator_state(c10::intrusive_ptr state); + void register_generator_state(const at::Generator& generator); + void capture_begin( + MempoolId_t pool = {0, 0}, + cudaStreamCaptureMode capture_mode = cudaStreamCaptureModeGlobal); + void capture_end(); + void instantiate(); + void replay(); + void reset(); + MempoolId_t pool(); + void enable_debug_mode(); + void debug_dump(const std::string& debug_path); + cudaGraph_t raw_cuda_graph(); + + protected: + cudaGraph_t graph_ = nullptr; + cudaGraphExec_t graph_exec_ = nullptr; + + // internal states so reset() can do its best cleaning up + + // Set to true in capture_end if cudaStreamEndCapture succeeded + // Set back to false after instantiate() unless keep_graph=True or + // enable_debug_mode() was called on any CUDAGraph instance. + bool has_graph_ = false; + // Set to true in capture_end if cudaStreamEndCapture succeeded + bool capture_ended_ = false; + // Set to true in capture_end if cudaGraphInstantiate succeeded + bool has_graph_exec_ = false; + + // the ID assigned by cuda during graph capture, + // used to identify when a stream is participating in capture + CaptureId_t capture_id_ = -1; + + // uuid used to request a particular private mempool from CUDACachingAllocator. + // By default, this will be set to {id_, 0}. + // + // If capture_begin is called with "pool=other_graph.pool()", this graph's mempool_id_ + // will be set to the other graph's mempool_id_, and therefore share a mempool with the + // other graph. + // + // If capture_begin is called with "pool=handle" where "handle" came from graph_pool_handle(), + // it will share a mempool with any other captures that used "pool=handle". + // + // Sharing a mempool across graphs saves memory, and it's safe if you + // know you'll replay those graphs in the same order you captured them. + MempoolId_t mempool_id_; + + // Stream on which capture began + at::cuda::CUDAStream capture_stream_; + + // multiple generator states and their wholegraph_increments in this graph + // that are managed by the CUDA Graph + ska::flat_hash_map, uint64_t> + captured_generator_states_; + + // Device where capture occurred. Right now, for simplicity, we require all ops + // in a capture to run on the same device, but this is a limitation of CUDAGraph, + // not CUDA itself. We can straightforwardly modify CUDAGraph to support multi-device + // captures if needed. + // init capture_dev_ as UNDEFINED_DEVICE to check that it stores the real device id in the destructor + static constexpr c10::DeviceIndex UNDEFINED_DEVICE = -1; + c10::DeviceIndex capture_dev_{UNDEFINED_DEVICE}; + + bool keep_graph_; +}; + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDAGraphsUtils.cuh b/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDAGraphsUtils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..fa262a6110799a777dd8febd6d9b812a775534aa --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDAGraphsUtils.cuh @@ -0,0 +1,53 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +// c10/cuda/CUDAGraphsC10Utils.h has utils used by both c10 and aten. +// This file adds utils used by aten only. + +namespace at::cuda { + +using CaptureId_t = c10::cuda::CaptureId_t; +using CaptureStatus = c10::cuda::CaptureStatus; + +// Use this version where you don't want to create a CUDA context if none exists. +inline CaptureStatus currentStreamCaptureStatus() { + // don't create a context if we don't have to + if (c10::cuda::hasPrimaryContext(c10::cuda::current_device())) { + return c10::cuda::currentStreamCaptureStatusMayInitCtx(); + } else { + return CaptureStatus::None; + } +} + +inline void assertNotCapturing(const std::string& attempt) { + auto status = currentStreamCaptureStatus(); + TORCH_CHECK(status == CaptureStatus::None, + attempt, + " during CUDA graph capture. If you need this call to be captured, " + "please file an issue. " + "Current cudaStreamCaptureStatus: ", + status); +} + +inline void errorIfCapturingCudnnBenchmark(const std::string& version_specific) { + auto status = currentStreamCaptureStatus(); + TORCH_CHECK(status == CaptureStatus::None, + "Current cudaStreamCaptureStatus: ", + status, + "\nCapturing ", + version_specific, + "is prohibited. Possible causes of this error:\n" + "1. No warmup iterations occurred before capture.\n" + "2. The convolutions you're trying to capture use dynamic shapes, " + "in which case capturing them is generally prohibited."); +} + +} // namespace at::cuda diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDASparse.h b/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDASparse.h new file mode 100644 index 0000000000000000000000000000000000000000..983fafbf114c47d884604c9270705324130ab4d8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDASparse.h @@ -0,0 +1,75 @@ +#pragma once + +#include +#if defined(USE_ROCM) +#include +#define HIPSPARSE_VERSION ((hipsparseVersionMajor*100000) + (hipsparseVersionMinor*100) + hipsparseVersionPatch) +#endif + +// cuSparse Generic API added in CUDA 10.1 +// Windows support added in CUDA 11.0 +#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && ((CUSPARSE_VERSION >= 10300) || (CUSPARSE_VERSION >= 11000 && defined(_WIN32))) +#define AT_USE_CUSPARSE_GENERIC_API() 1 +#else +#define AT_USE_CUSPARSE_GENERIC_API() 0 +#endif + +// cuSparse Generic API descriptor pointers were changed to const in CUDA 12.0 +#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && \ + (CUSPARSE_VERSION < 12000) +#define AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS() 1 +#else +#define AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS() 0 +#endif + +#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && \ + (CUSPARSE_VERSION >= 12000) +#define AT_USE_CUSPARSE_CONST_DESCRIPTORS() 1 +#else +#define AT_USE_CUSPARSE_CONST_DESCRIPTORS() 0 +#endif + +#if defined(USE_ROCM) +// hipSparse const API added in v2.4.0 +#if HIPSPARSE_VERSION >= 200400 +#define AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 1 +#define AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS() 0 +#define AT_USE_HIPSPARSE_GENERIC_API() 1 +#else +#define AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 0 +#define AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS() 1 +#define AT_USE_HIPSPARSE_GENERIC_API() 1 +#endif +#else // USE_ROCM +#define AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 0 +#define AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS() 0 +#define AT_USE_HIPSPARSE_GENERIC_API() 0 +#endif // USE_ROCM + +// cuSparse Generic API spsv function was added in CUDA 11.3.0 +#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11500) +#define AT_USE_CUSPARSE_GENERIC_SPSV() 1 +#else +#define AT_USE_CUSPARSE_GENERIC_SPSV() 0 +#endif + +// cuSparse Generic API spsm function was added in CUDA 11.3.1 +#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11600) +#define AT_USE_CUSPARSE_GENERIC_SPSM() 1 +#else +#define AT_USE_CUSPARSE_GENERIC_SPSM() 0 +#endif + +// cuSparse Generic API sddmm function was added in CUDA 11.2.1 (cuSparse version 11400) +#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11400) +#define AT_USE_CUSPARSE_GENERIC_SDDMM() 1 +#else +#define AT_USE_CUSPARSE_GENERIC_SDDMM() 0 +#endif + +// BSR triangular solve functions were added in hipSPARSE 1.11.2 (ROCm 4.5.0) +#if defined(CUDART_VERSION) || defined(USE_ROCM) +#define AT_USE_HIPSPARSE_TRIANGULAR_SOLVE() 1 +#else +#define AT_USE_HIPSPARSE_TRIANGULAR_SOLVE() 0 +#endif diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDASparseBlas.h b/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDASparseBlas.h new file mode 100644 index 0000000000000000000000000000000000000000..9b370f2a791271747ac47bc3770d91dd93047bbe --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDASparseBlas.h @@ -0,0 +1,320 @@ +#pragma once + +/* + Provides a subset of cuSPARSE functions as templates: + + csrgeam2(...) + + where scalar_t is double, float, c10::complex or c10::complex. + The functions are available in at::cuda::sparse namespace. +*/ + +#include +#include + +// NOLINTBEGIN(misc-misplaced-const) +namespace at::cuda::sparse { + +#define CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(scalar_t) \ + cusparseHandle_t handle, int m, int n, const scalar_t *alpha, \ + const cusparseMatDescr_t descrA, int nnzA, \ + const scalar_t *csrSortedValA, const int *csrSortedRowPtrA, \ + const int *csrSortedColIndA, const scalar_t *beta, \ + const cusparseMatDescr_t descrB, int nnzB, \ + const scalar_t *csrSortedValB, const int *csrSortedRowPtrB, \ + const int *csrSortedColIndB, const cusparseMatDescr_t descrC, \ + const scalar_t *csrSortedValC, const int *csrSortedRowPtrC, \ + const int *csrSortedColIndC, size_t *pBufferSizeInBytes + +template +inline void csrgeam2_bufferSizeExt( + CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(scalar_t)) { + TORCH_INTERNAL_ASSERT( + false, + "at::cuda::sparse::csrgeam2_bufferSizeExt: not implemented for ", + typeid(scalar_t).name()); +} + +template <> +void csrgeam2_bufferSizeExt( + CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(float)); +template <> +void csrgeam2_bufferSizeExt( + CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(double)); +template <> +void csrgeam2_bufferSizeExt>( + CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(c10::complex)); +template <> +void csrgeam2_bufferSizeExt>( + CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(c10::complex)); + +#define CUSPARSE_CSRGEAM2_NNZ_ARGTYPES() \ + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, \ + int nnzA, const int *csrSortedRowPtrA, const int *csrSortedColIndA, \ + const cusparseMatDescr_t descrB, int nnzB, const int *csrSortedRowPtrB, \ + const int *csrSortedColIndB, const cusparseMatDescr_t descrC, \ + int *csrSortedRowPtrC, int *nnzTotalDevHostPtr, void *workspace + +template +inline void csrgeam2Nnz(CUSPARSE_CSRGEAM2_NNZ_ARGTYPES()) { + TORCH_CUDASPARSE_CHECK(cusparseXcsrgeam2Nnz( + handle, + m, + n, + descrA, + nnzA, + csrSortedRowPtrA, + csrSortedColIndA, + descrB, + nnzB, + csrSortedRowPtrB, + csrSortedColIndB, + descrC, + csrSortedRowPtrC, + nnzTotalDevHostPtr, + workspace)); +} + +#define CUSPARSE_CSRGEAM2_ARGTYPES(scalar_t) \ + cusparseHandle_t handle, int m, int n, const scalar_t *alpha, \ + const cusparseMatDescr_t descrA, int nnzA, \ + const scalar_t *csrSortedValA, const int *csrSortedRowPtrA, \ + const int *csrSortedColIndA, const scalar_t *beta, \ + const cusparseMatDescr_t descrB, int nnzB, \ + const scalar_t *csrSortedValB, const int *csrSortedRowPtrB, \ + const int *csrSortedColIndB, const cusparseMatDescr_t descrC, \ + scalar_t *csrSortedValC, int *csrSortedRowPtrC, int *csrSortedColIndC, \ + void *pBuffer + +template +inline void csrgeam2(CUSPARSE_CSRGEAM2_ARGTYPES(scalar_t)) { + TORCH_INTERNAL_ASSERT( + false, + "at::cuda::sparse::csrgeam2: not implemented for ", + typeid(scalar_t).name()); +} + +template <> +void csrgeam2(CUSPARSE_CSRGEAM2_ARGTYPES(float)); +template <> +void csrgeam2(CUSPARSE_CSRGEAM2_ARGTYPES(double)); +template <> +void csrgeam2>( + CUSPARSE_CSRGEAM2_ARGTYPES(c10::complex)); +template <> +void csrgeam2>( + CUSPARSE_CSRGEAM2_ARGTYPES(c10::complex)); + +#define CUSPARSE_BSRMM_ARGTYPES(scalar_t) \ + cusparseHandle_t handle, cusparseDirection_t dirA, \ + cusparseOperation_t transA, cusparseOperation_t transB, int mb, int n, \ + int kb, int nnzb, const scalar_t *alpha, \ + const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \ + const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \ + const scalar_t *B, int ldb, const scalar_t *beta, scalar_t *C, int ldc + +template +inline void bsrmm(CUSPARSE_BSRMM_ARGTYPES(scalar_t)) { + TORCH_INTERNAL_ASSERT( + false, + "at::cuda::sparse::bsrmm: not implemented for ", + typeid(scalar_t).name()); +} + +template <> +void bsrmm(CUSPARSE_BSRMM_ARGTYPES(float)); +template <> +void bsrmm(CUSPARSE_BSRMM_ARGTYPES(double)); +template <> +void bsrmm>(CUSPARSE_BSRMM_ARGTYPES(c10::complex)); +template <> +void bsrmm>(CUSPARSE_BSRMM_ARGTYPES(c10::complex)); + +#define CUSPARSE_BSRMV_ARGTYPES(scalar_t) \ + cusparseHandle_t handle, cusparseDirection_t dirA, \ + cusparseOperation_t transA, int mb, int nb, int nnzb, \ + const scalar_t *alpha, const cusparseMatDescr_t descrA, \ + const scalar_t *bsrValA, const int *bsrRowPtrA, const int *bsrColIndA, \ + int blockDim, const scalar_t *x, const scalar_t *beta, scalar_t *y + +template +inline void bsrmv(CUSPARSE_BSRMV_ARGTYPES(scalar_t)) { + TORCH_INTERNAL_ASSERT( + false, + "at::cuda::sparse::bsrmv: not implemented for ", + typeid(scalar_t).name()); +} + +template <> +void bsrmv(CUSPARSE_BSRMV_ARGTYPES(float)); +template <> +void bsrmv(CUSPARSE_BSRMV_ARGTYPES(double)); +template <> +void bsrmv>(CUSPARSE_BSRMV_ARGTYPES(c10::complex)); +template <> +void bsrmv>(CUSPARSE_BSRMV_ARGTYPES(c10::complex)); + +#if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE() + +#define CUSPARSE_BSRSV2_BUFFER_ARGTYPES(scalar_t) \ + cusparseHandle_t handle, cusparseDirection_t dirA, \ + cusparseOperation_t transA, int mb, int nnzb, \ + const cusparseMatDescr_t descrA, scalar_t *bsrValA, \ + const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \ + bsrsv2Info_t info, int *pBufferSizeInBytes + +template +inline void bsrsv2_bufferSize(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(scalar_t)) { + TORCH_INTERNAL_ASSERT( + false, + "at::cuda::sparse::bsrsv2_bufferSize: not implemented for ", + typeid(scalar_t).name()); +} + +template <> +void bsrsv2_bufferSize(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(float)); +template <> +void bsrsv2_bufferSize(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(double)); +template <> +void bsrsv2_bufferSize>( + CUSPARSE_BSRSV2_BUFFER_ARGTYPES(c10::complex)); +template <> +void bsrsv2_bufferSize>( + CUSPARSE_BSRSV2_BUFFER_ARGTYPES(c10::complex)); + +#define CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(scalar_t) \ + cusparseHandle_t handle, cusparseDirection_t dirA, \ + cusparseOperation_t transA, int mb, int nnzb, \ + const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \ + const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \ + bsrsv2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer + +template +inline void bsrsv2_analysis(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(scalar_t)) { + TORCH_INTERNAL_ASSERT( + false, + "at::cuda::sparse::bsrsv2_analysis: not implemented for ", + typeid(scalar_t).name()); +} + +template <> +void bsrsv2_analysis(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(float)); +template <> +void bsrsv2_analysis(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(double)); +template <> +void bsrsv2_analysis>( + CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(c10::complex)); +template <> +void bsrsv2_analysis>( + CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(c10::complex)); + +#define CUSPARSE_BSRSV2_SOLVE_ARGTYPES(scalar_t) \ + cusparseHandle_t handle, cusparseDirection_t dirA, \ + cusparseOperation_t transA, int mb, int nnzb, const scalar_t *alpha, \ + const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \ + const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \ + bsrsv2Info_t info, const scalar_t *x, scalar_t *y, \ + cusparseSolvePolicy_t policy, void *pBuffer + +template +inline void bsrsv2_solve(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(scalar_t)) { + TORCH_INTERNAL_ASSERT( + false, + "at::cuda::sparse::bsrsv2_solve: not implemented for ", + typeid(scalar_t).name()); +} + +template <> +void bsrsv2_solve(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(float)); +template <> +void bsrsv2_solve(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(double)); +template <> +void bsrsv2_solve>( + CUSPARSE_BSRSV2_SOLVE_ARGTYPES(c10::complex)); +template <> +void bsrsv2_solve>( + CUSPARSE_BSRSV2_SOLVE_ARGTYPES(c10::complex)); + +#define CUSPARSE_BSRSM2_BUFFER_ARGTYPES(scalar_t) \ + cusparseHandle_t handle, cusparseDirection_t dirA, \ + cusparseOperation_t transA, cusparseOperation_t transX, int mb, int n, \ + int nnzb, const cusparseMatDescr_t descrA, scalar_t *bsrValA, \ + const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \ + bsrsm2Info_t info, int *pBufferSizeInBytes + +template +inline void bsrsm2_bufferSize(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(scalar_t)) { + TORCH_INTERNAL_ASSERT( + false, + "at::cuda::sparse::bsrsm2_bufferSize: not implemented for ", + typeid(scalar_t).name()); +} + +template <> +void bsrsm2_bufferSize(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(float)); +template <> +void bsrsm2_bufferSize(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(double)); +template <> +void bsrsm2_bufferSize>( + CUSPARSE_BSRSM2_BUFFER_ARGTYPES(c10::complex)); +template <> +void bsrsm2_bufferSize>( + CUSPARSE_BSRSM2_BUFFER_ARGTYPES(c10::complex)); + +#define CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(scalar_t) \ + cusparseHandle_t handle, cusparseDirection_t dirA, \ + cusparseOperation_t transA, cusparseOperation_t transX, int mb, int n, \ + int nnzb, const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \ + const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \ + bsrsm2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer + +template +inline void bsrsm2_analysis(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(scalar_t)) { + TORCH_INTERNAL_ASSERT( + false, + "at::cuda::sparse::bsrsm2_analysis: not implemented for ", + typeid(scalar_t).name()); +} + +template <> +void bsrsm2_analysis(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(float)); +template <> +void bsrsm2_analysis(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(double)); +template <> +void bsrsm2_analysis>( + CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(c10::complex)); +template <> +void bsrsm2_analysis>( + CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(c10::complex)); + +#define CUSPARSE_BSRSM2_SOLVE_ARGTYPES(scalar_t) \ + cusparseHandle_t handle, cusparseDirection_t dirA, \ + cusparseOperation_t transA, cusparseOperation_t transX, int mb, int n, \ + int nnzb, const scalar_t *alpha, const cusparseMatDescr_t descrA, \ + const scalar_t *bsrValA, const int *bsrRowPtrA, const int *bsrColIndA, \ + int blockDim, bsrsm2Info_t info, const scalar_t *B, int ldb, \ + scalar_t *X, int ldx, cusparseSolvePolicy_t policy, void *pBuffer + +template +inline void bsrsm2_solve(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(scalar_t)) { + TORCH_INTERNAL_ASSERT( + false, + "at::cuda::sparse::bsrsm2_solve: not implemented for ", + typeid(scalar_t).name()); +} + +template <> +void bsrsm2_solve(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(float)); +template <> +void bsrsm2_solve(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(double)); +template <> +void bsrsm2_solve>( + CUSPARSE_BSRSM2_SOLVE_ARGTYPES(c10::complex)); +template <> +void bsrsm2_solve>( + CUSPARSE_BSRSM2_SOLVE_ARGTYPES(c10::complex)); + +#endif // AT_USE_HIPSPARSE_TRIANGULAR_SOLVE + +} // namespace at::cuda::sparse +// NOLINTEND(misc-misplaced-const) diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDASparseDescriptors.h b/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDASparseDescriptors.h new file mode 100644 index 0000000000000000000000000000000000000000..c35fe91e6135cf0d60520da9df5b536229156423 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDASparseDescriptors.h @@ -0,0 +1,288 @@ +#pragma once + +#include +#include +#include + +#include + +#if defined(USE_ROCM) +#include +#endif + +namespace at::cuda::sparse { + +template +struct CuSparseDescriptorDeleter { + void operator()(T* x) { + if (x != nullptr) { + TORCH_CUDASPARSE_CHECK(destructor(x)); + } + } +}; + +template +class CuSparseDescriptor { + public: + T* descriptor() const { + return descriptor_.get(); + } + T* descriptor() { + return descriptor_.get(); + } + + protected: + std::unique_ptr> descriptor_; +}; + +#if AT_USE_CUSPARSE_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_CONST_DESCRIPTORS() +template +struct ConstCuSparseDescriptorDeleter { + void operator()(T* x) { + if (x != nullptr) { + TORCH_CUDASPARSE_CHECK(destructor(x)); + } + } +}; + +template +class ConstCuSparseDescriptor { + public: + T* descriptor() const { + return descriptor_.get(); + } + T* descriptor() { + return descriptor_.get(); + } + + protected: + std::unique_ptr> descriptor_; +}; +#endif // AT_USE_CUSPARSE_CONST_DESCRIPTORS || AT_USE_HIPSPARSE_CONST_DESCRIPTORS + +#if defined(USE_ROCM) +using cusparseMatDescr = std::remove_pointer_t; +using cusparseDnMatDescr = std::remove_pointer_t; +using cusparseDnVecDescr = std::remove_pointer_t; +using cusparseSpMatDescr = std::remove_pointer_t; +using cusparseSpMatDescr = std::remove_pointer_t; +using cusparseSpGEMMDescr = std::remove_pointer_t; +#if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE() +using bsrsv2Info = std::remove_pointer_t; +using bsrsm2Info = std::remove_pointer_t; +#endif +#endif + +// NOTE: This is only needed for CUDA 11 and earlier, since CUDA 12 introduced +// API for const descriptors +cusparseStatus_t destroyConstDnMat(const cusparseDnMatDescr* dnMatDescr); + +class TORCH_CUDA_CPP_API CuSparseMatDescriptor + : public CuSparseDescriptor { + public: + CuSparseMatDescriptor() { + cusparseMatDescr_t raw_descriptor = nullptr; + TORCH_CUDASPARSE_CHECK(cusparseCreateMatDescr(&raw_descriptor)); + descriptor_.reset(raw_descriptor); + } + + CuSparseMatDescriptor(bool upper, bool unit) { + cusparseFillMode_t fill_mode = + upper ? CUSPARSE_FILL_MODE_UPPER : CUSPARSE_FILL_MODE_LOWER; + cusparseDiagType_t diag_type = + unit ? CUSPARSE_DIAG_TYPE_UNIT : CUSPARSE_DIAG_TYPE_NON_UNIT; + cusparseMatDescr_t raw_descriptor = nullptr; + TORCH_CUDASPARSE_CHECK(cusparseCreateMatDescr(&raw_descriptor)); + TORCH_CUDASPARSE_CHECK(cusparseSetMatFillMode(raw_descriptor, fill_mode)); + TORCH_CUDASPARSE_CHECK(cusparseSetMatDiagType(raw_descriptor, diag_type)); + descriptor_.reset(raw_descriptor); + } +}; + +#if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE() + +class TORCH_CUDA_CPP_API CuSparseBsrsv2Info + : public CuSparseDescriptor { + public: + CuSparseBsrsv2Info() { + bsrsv2Info_t raw_descriptor = nullptr; + TORCH_CUDASPARSE_CHECK(cusparseCreateBsrsv2Info(&raw_descriptor)); + descriptor_.reset(raw_descriptor); + } +}; + +class TORCH_CUDA_CPP_API CuSparseBsrsm2Info + : public CuSparseDescriptor { + public: + CuSparseBsrsm2Info() { + bsrsm2Info_t raw_descriptor = nullptr; + TORCH_CUDASPARSE_CHECK(cusparseCreateBsrsm2Info(&raw_descriptor)); + descriptor_.reset(raw_descriptor); + } +}; + +#endif // AT_USE_HIPSPARSE_TRIANGULAR_SOLVE + +#if AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API() + +cusparseIndexType_t getCuSparseIndexType(const c10::ScalarType& scalar_type); + +#if AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS() +class TORCH_CUDA_CPP_API CuSparseDnMatDescriptor + : public CuSparseDescriptor { + public: + explicit CuSparseDnMatDescriptor(const Tensor& input, int64_t batch_offset = -1); +}; + +class TORCH_CUDA_CPP_API CuSparseConstDnMatDescriptor + : public CuSparseDescriptor { + public: + explicit CuSparseConstDnMatDescriptor(const Tensor& input, int64_t batch_offset = -1); + cusparseDnMatDescr* unsafe_mutable_descriptor() const { + return const_cast(descriptor()); + } + cusparseDnMatDescr* unsafe_mutable_descriptor() { + return const_cast(descriptor()); + } +}; + +class TORCH_CUDA_CPP_API CuSparseDnVecDescriptor + : public CuSparseDescriptor { + public: + explicit CuSparseDnVecDescriptor(const Tensor& input); +}; + +class TORCH_CUDA_CPP_API CuSparseSpMatDescriptor + : public CuSparseDescriptor {}; + +#elif AT_USE_CUSPARSE_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_CONST_DESCRIPTORS() + class TORCH_CUDA_CPP_API CuSparseDnMatDescriptor + : public ConstCuSparseDescriptor< + cusparseDnMatDescr, + &cusparseDestroyDnMat> { + public: + explicit CuSparseDnMatDescriptor( + const Tensor& input, + int64_t batch_offset = -1); + }; + + class TORCH_CUDA_CPP_API CuSparseConstDnMatDescriptor + : public ConstCuSparseDescriptor< + const cusparseDnMatDescr, + &destroyConstDnMat> { + public: + explicit CuSparseConstDnMatDescriptor( + const Tensor& input, + int64_t batch_offset = -1); + cusparseDnMatDescr* unsafe_mutable_descriptor() const { + return const_cast(descriptor()); + } + cusparseDnMatDescr* unsafe_mutable_descriptor() { + return const_cast(descriptor()); + } + }; + + class TORCH_CUDA_CPP_API CuSparseDnVecDescriptor + : public ConstCuSparseDescriptor< + cusparseDnVecDescr, + &cusparseDestroyDnVec> { + public: + explicit CuSparseDnVecDescriptor(const Tensor& input); + }; + + class TORCH_CUDA_CPP_API CuSparseSpMatDescriptor + : public ConstCuSparseDescriptor< + cusparseSpMatDescr, + &cusparseDestroySpMat> {}; +#endif // AT_USE_CUSPARSE_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_CONST_DESCRIPTORS() + +class TORCH_CUDA_CPP_API CuSparseSpMatCsrDescriptor + : public CuSparseSpMatDescriptor { + public: + explicit CuSparseSpMatCsrDescriptor(const Tensor& input, int64_t batch_offset = -1); + + std::tuple get_size() { + int64_t rows = 0, cols = 0, nnz = 0; + TORCH_CUDASPARSE_CHECK(cusparseSpMatGetSize( + this->descriptor(), + &rows, + &cols, + &nnz)); + return std::make_tuple(rows, cols, nnz); + } + + void set_tensor(const Tensor& input) { + auto crow_indices = input.crow_indices(); + auto col_indices = input.col_indices(); + auto values = input.values(); + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(crow_indices.is_contiguous()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(col_indices.is_contiguous()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.is_contiguous()); + TORCH_CUDASPARSE_CHECK(cusparseCsrSetPointers( + this->descriptor(), + crow_indices.data_ptr(), + col_indices.data_ptr(), + values.data_ptr())); + } + +#if AT_USE_CUSPARSE_GENERIC_SPSV() + void set_mat_fill_mode(bool upper) { + cusparseFillMode_t fill_mode = + upper ? CUSPARSE_FILL_MODE_UPPER : CUSPARSE_FILL_MODE_LOWER; + TORCH_CUDASPARSE_CHECK(cusparseSpMatSetAttribute( + this->descriptor(), + CUSPARSE_SPMAT_FILL_MODE, + &fill_mode, + sizeof(fill_mode))); + } + + void set_mat_diag_type(bool unit) { + cusparseDiagType_t diag_type = + unit ? CUSPARSE_DIAG_TYPE_UNIT : CUSPARSE_DIAG_TYPE_NON_UNIT; + TORCH_CUDASPARSE_CHECK(cusparseSpMatSetAttribute( + this->descriptor(), + CUSPARSE_SPMAT_DIAG_TYPE, + &diag_type, + sizeof(diag_type))); + } +#endif +}; + +#if AT_USE_CUSPARSE_GENERIC_SPSV() +class TORCH_CUDA_CPP_API CuSparseSpSVDescriptor + : public CuSparseDescriptor { + public: + CuSparseSpSVDescriptor() { + cusparseSpSVDescr_t raw_descriptor = nullptr; + TORCH_CUDASPARSE_CHECK(cusparseSpSV_createDescr(&raw_descriptor)); + descriptor_.reset(raw_descriptor); + } +}; +#endif + +#if AT_USE_CUSPARSE_GENERIC_SPSM() +class TORCH_CUDA_CPP_API CuSparseSpSMDescriptor + : public CuSparseDescriptor { + public: + CuSparseSpSMDescriptor() { + cusparseSpSMDescr_t raw_descriptor = nullptr; + TORCH_CUDASPARSE_CHECK(cusparseSpSM_createDescr(&raw_descriptor)); + descriptor_.reset(raw_descriptor); + } +}; +#endif + +class TORCH_CUDA_CPP_API CuSparseSpGEMMDescriptor + : public CuSparseDescriptor { + public: + CuSparseSpGEMMDescriptor() { + cusparseSpGEMMDescr_t raw_descriptor = nullptr; + TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_createDescr(&raw_descriptor)); + descriptor_.reset(raw_descriptor); + } +}; + +#endif // AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API() + +} // namespace at::cuda::sparse diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDATensorMethods.cuh b/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDATensorMethods.cuh new file mode 100644 index 0000000000000000000000000000000000000000..e3c8526a0004cde8198965f3aea34af25ac5c452 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDATensorMethods.cuh @@ -0,0 +1,15 @@ +#pragma once + +#include +#include + +#include +#include +#include + +namespace at { +template <> +inline __half* Tensor::data() const { + return reinterpret_cast<__half*>(data()); +} +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDAUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDAUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..f41fae69ea89d078d61ebb3f698d0e24904761a0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/CUDAUtils.h @@ -0,0 +1,20 @@ +#pragma once + +#include + +namespace at::cuda { + +// Check if every tensor in a list of tensors matches the current +// device. +inline bool check_device(ArrayRef ts) { + if (ts.empty()) { + return true; + } + Device curDevice = Device(kCUDA, current_device()); + for (const Tensor& t : ts) { + if (t.device() != curDevice) return false; + } + return true; +} + +} // namespace at::cuda diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/CachingHostAllocator.h b/phivenv/Lib/site-packages/torch/include/ATen/cuda/CachingHostAllocator.h new file mode 100644 index 0000000000000000000000000000000000000000..3cccaf3a064b263a0c420120ea806bdcbb5fe690 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/CachingHostAllocator.h @@ -0,0 +1,70 @@ +#pragma once + +#include +#include +#include +#include + +namespace at::cuda { + +// +// A caching allocator for CUDA host allocations (pinned memory). +// +// This provides a drop-in replacement for THCudaHostAllocator, which re-uses +// freed pinned (page-locked) memory allocations. This avoids device +// synchronizations due to cudaFreeHost calls. +// +// To ensure correct behavior, THCCachingHostAllocator_recordEvent must be +// called anytime a pointer from this allocator is used in a cudaMemcpyAsync +// call between host and device, and passed the corresponding context from the +// allocation. This is currently invoked by at::native::copy_kernel_cuda. +// +C10_DEPRECATED_MESSAGE( + "at::cuda::getCachingHostAllocator() is deprecated. Please use at::getHostAllocator(at::kCUDA) instead.") +inline TORCH_CUDA_CPP_API at::HostAllocator* getCachingHostAllocator() { + return at::getHostAllocator(at::kCUDA); +} + +// Records an event in the specified stream. The allocation corresponding to the +// input `ptr`/`ctx` will not be re-used until the event has occurred. +C10_DEPRECATED_MESSAGE( + "at::cuda::CachingHostAllocator_recordEvent(...) is deprecated. Please use at::getHostAllocator(at::kCUDA)->record_event(...) instead.") +inline TORCH_CUDA_CPP_API bool CachingHostAllocator_recordEvent( + void* ptr, + void* ctx, + c10::cuda::CUDAStream stream) { + return getHostAllocator(at::kCUDA)->record_event(ptr, ctx, stream.unwrap()); +} + +// Releases cached pinned memory allocations via cudaHostFree +C10_DEPRECATED_MESSAGE( + "at::cuda::CachingHostAllocator_emptyCache() is deprecated. Please use at::getHostAllocator(at::kCUDA)->empty_cache() instead.") +inline TORCH_CUDA_CPP_API void CachingHostAllocator_emptyCache() { + getHostAllocator(at::kCUDA)->empty_cache(); +} + +C10_DEPRECATED_MESSAGE( + "at::cuda::HostAlloc(...) is deprecated. Please use at::getHostAllocator(at::kCUDA)->allocate(...) instead.") +inline TORCH_CUDA_CPP_API at::DataPtr HostAlloc(size_t size) { + return getHostAllocator(at::kCUDA)->allocate(size); +} + +C10_DEPRECATED_MESSAGE( + "at::cuda::CachingHostAllocator_getStats() is deprecated. Please use at::getHostAllocator(at::kCUDA)->get_stats() instead.") +inline TORCH_CUDA_CPP_API at::HostStats CachingHostAllocator_getStats() { + return getHostAllocator(at::kCUDA)->get_stats(); +} + +C10_DEPRECATED_MESSAGE( + "at::cuda::CachingHostAllocator_resetAccumulatedStats() is deprecated. Please use at::getHostAllocator(at::kCUDA)->reset_accumulated_stats() instead.") +inline TORCH_CUDA_CPP_API void CachingHostAllocator_resetAccumulatedStats() { + getHostAllocator(at::kCUDA)->reset_accumulated_stats(); +} + +C10_DEPRECATED_MESSAGE( + "at::cuda::CachingHostAllocator_resetPeakStats() is deprecated. Please use at::getHostAllocator(at::kCUDA)->reset_peak_stats() instead.") +inline TORCH_CUDA_CPP_API void CachingHostAllocator_resetPeakStats() { + getHostAllocator(at::kCUDA)->reset_peak_stats(); +} + +} // namespace at::cuda diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/DeviceUtils.cuh b/phivenv/Lib/site-packages/torch/include/ATen/cuda/DeviceUtils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..3a258954db6306d16caf24906499faa7bc54aa77 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/DeviceUtils.cuh @@ -0,0 +1,121 @@ +#pragma once + +#include +#include +#include + +__device__ __forceinline__ unsigned int ACTIVE_MASK() +{ +#if !defined(USE_ROCM) + return __activemask(); +#else +// will be ignored anyway + return 0xffffffff; +#endif +} + +__device__ __forceinline__ void WARP_SYNC(unsigned mask = 0xffffffff) { +#if !defined(USE_ROCM) + return __syncwarp(mask); +#endif +} + +#if defined(USE_ROCM) +__device__ __forceinline__ unsigned long long int WARP_BALLOT(int predicate) +{ +return __ballot(predicate); +} +#else +__device__ __forceinline__ unsigned int WARP_BALLOT(int predicate, unsigned int mask = 0xffffffff) +{ +#if !defined(USE_ROCM) + return __ballot_sync(mask, predicate); +#else + return __ballot(predicate); +#endif +} +#endif + +template +__device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) +{ +#if !defined(USE_ROCM) + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template +__device__ __forceinline__ T WARP_SHFL(T value, int srcLane, int width = warpSize, unsigned int mask = 0xffffffff) +{ +#if !defined(USE_ROCM) + return __shfl_sync(mask, value, srcLane, width); +#else + return __shfl(value, srcLane, width); +#endif +} + +template +__device__ __forceinline__ T WARP_SHFL_UP(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff) +{ +#if !defined(USE_ROCM) + return __shfl_up_sync(mask, value, delta, width); +#else + return __shfl_up(value, delta, width); +#endif +} + +template +__device__ __forceinline__ T WARP_SHFL_DOWN(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff) +{ +#if !defined(USE_ROCM) + return __shfl_down_sync(mask, value, delta, width); +#else + return __shfl_down(value, delta, width); +#endif +} + +#if defined(USE_ROCM) +template<> +__device__ __forceinline__ int64_t WARP_SHFL_DOWN(int64_t value, unsigned int delta, int width , unsigned int mask) +{ + //(HIP doesn't support int64_t). Trick from https://devblogs.nvidia.com/faster-parallel-reductions-kepler/ + int2 a = *reinterpret_cast(&value); + a.x = __shfl_down(a.x, delta); + a.y = __shfl_down(a.y, delta); + return *reinterpret_cast(&a); +} +#endif + +template<> +__device__ __forceinline__ c10::Half WARP_SHFL_DOWN(c10::Half value, unsigned int delta, int width, unsigned int mask) +{ + return c10::Half(WARP_SHFL_DOWN(value.x, delta, width, mask), c10::Half::from_bits_t{}); +} + +template +__device__ __forceinline__ c10::complex WARP_SHFL_DOWN(c10::complex value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff) +{ +#if !defined(USE_ROCM) + return c10::complex( + __shfl_down_sync(mask, value.real_, delta, width), + __shfl_down_sync(mask, value.imag_, delta, width)); +#else + return c10::complex( + __shfl_down(value.real_, delta, width), + __shfl_down(value.imag_, delta, width)); +#endif +} + +/** + * For CC 3.5+, perform a load using __ldg + */ +template +__device__ __forceinline__ T doLdg(const T* p) { +#if __CUDA_ARCH__ >= 350 && !defined(USE_ROCM) + return __ldg(p); +#else + return *p; +#endif +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/EmptyTensor.h b/phivenv/Lib/site-packages/torch/include/ATen/cuda/EmptyTensor.h new file mode 100644 index 0000000000000000000000000000000000000000..ba70c1f8f6a30ea40f4f71d93bb99b7b8e4d401e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/EmptyTensor.h @@ -0,0 +1,44 @@ +#pragma once +#include + +namespace at::detail { + +TORCH_CUDA_CPP_API TensorBase empty_cuda( + IntArrayRef size, + ScalarType dtype, + std::optional device_opt, + std::optional memory_format_opt); + +TORCH_CUDA_CPP_API TensorBase empty_cuda( + IntArrayRef size, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt, + std::optional memory_format_opt); + +TORCH_CUDA_CPP_API TensorBase empty_cuda( + IntArrayRef size, + const TensorOptions &options); + +TORCH_CUDA_CPP_API TensorBase empty_strided_cuda( + IntArrayRef size, + IntArrayRef stride, + ScalarType dtype, + std::optional device_opt); + +TORCH_CUDA_CPP_API TensorBase empty_strided_cuda( + IntArrayRef size, + IntArrayRef stride, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt); + +TORCH_CUDA_CPP_API TensorBase empty_strided_cuda( + IntArrayRef size, + IntArrayRef stride, + const TensorOptions &options); + + +} // namespace at::detail diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/Exceptions.h b/phivenv/Lib/site-packages/torch/include/ATen/cuda/Exceptions.h new file mode 100644 index 0000000000000000000000000000000000000000..a92337000965258b0ab670c2a37576bd2378d9bf --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/Exceptions.h @@ -0,0 +1,230 @@ +#pragma once + +#include +#include +#include + +#if !defined(USE_ROCM) +#include +#else +#include +#endif + +#if defined(USE_CUDSS) +#include +#endif + +#include +#include +#include + + +namespace c10 { + +class CuDNNError : public c10::Error { + using Error::Error; +}; + +} // namespace c10 + +#define AT_CUDNN_FRONTEND_CHECK(EXPR, ...) \ + do { \ + auto error_object = EXPR; \ + if (!error_object.is_good()) { \ + TORCH_CHECK_WITH(CuDNNError, false, \ + "cuDNN Frontend error: ", error_object.get_message()); \ + } \ + } while (0) \ + +#define AT_CUDNN_CHECK_WITH_SHAPES(EXPR, ...) AT_CUDNN_CHECK(EXPR, "\n", ##__VA_ARGS__) + +// See Note [CHECK macro] +#define AT_CUDNN_CHECK(EXPR, ...) \ + do { \ + cudnnStatus_t status = EXPR; \ + if (status != CUDNN_STATUS_SUCCESS) { \ + if (status == CUDNN_STATUS_NOT_SUPPORTED) { \ + TORCH_CHECK_WITH(CuDNNError, false, \ + "cuDNN error: ", \ + cudnnGetErrorString(status), \ + ". This error may appear if you passed in a non-contiguous input.", ##__VA_ARGS__); \ + } else { \ + TORCH_CHECK_WITH(CuDNNError, false, \ + "cuDNN error: ", cudnnGetErrorString(status), ##__VA_ARGS__); \ + } \ + } \ + } while (0) + +namespace at::cuda::blas { +C10_EXPORT const char* _cublasGetErrorEnum(cublasStatus_t error); +} // namespace at::cuda::blas + +#define TORCH_CUDABLAS_CHECK(EXPR) \ + do { \ + cublasStatus_t __err = EXPR; \ + TORCH_CHECK(__err == CUBLAS_STATUS_SUCCESS, \ + "CUDA error: ", \ + at::cuda::blas::_cublasGetErrorEnum(__err), \ + " when calling `" #EXPR "`"); \ + } while (0) + +const char *cusparseGetErrorString(cusparseStatus_t status); + +#define TORCH_CUDASPARSE_CHECK(EXPR) \ + do { \ + cusparseStatus_t __err = EXPR; \ + TORCH_CHECK(__err == CUSPARSE_STATUS_SUCCESS, \ + "CUDA error: ", \ + cusparseGetErrorString(__err), \ + " when calling `" #EXPR "`"); \ + } while (0) + +#if defined(USE_CUDSS) +namespace at::cuda::cudss { +C10_EXPORT const char* cudssGetErrorMessage(cudssStatus_t error); +} // namespace at::cuda::solver + +#define TORCH_CUDSS_CHECK(EXPR) \ + do { \ + cudssStatus_t __err = EXPR; \ + if (__err == CUDSS_STATUS_EXECUTION_FAILED) { \ + TORCH_CHECK_LINALG( \ + false, \ + "cudss error: ", \ + at::cuda::cudss::cudssGetErrorMessage(__err), \ + ", when calling `" #EXPR "`", \ + ". This error may appear if the input matrix contains NaN. ");\ + } else { \ + TORCH_CHECK( \ + __err == CUDSS_STATUS_SUCCESS, \ + "cudss error: ", \ + at::cuda::cudss::cudssGetErrorMessage(__err), \ + ", when calling `" #EXPR "`. "); \ + } \ + } while (0) +#else +#define TORCH_CUDSS_CHECK(EXPR) EXPR +#endif + +namespace at::cuda::solver { +#if !defined(USE_ROCM) + +C10_EXPORT const char* cusolverGetErrorMessage(cusolverStatus_t status); + +constexpr const char* _cusolver_backend_suggestion = \ + "If you keep seeing this error, you may use " \ + "`torch.backends.cuda.preferred_linalg_library()` to try " \ + "linear algebra operators with other supported backends. " \ + "See https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.preferred_linalg_library"; + +// When cuda >= 11.5, cusolver normally finishes execution and sets info array indicating convergence issue. +#define TORCH_CUSOLVER_CHECK(EXPR) \ + do { \ + cusolverStatus_t __err = EXPR; \ + if (__err == CUSOLVER_STATUS_INVALID_VALUE) { \ + TORCH_CHECK_LINALG( \ + false, \ + "cusolver error: ", \ + at::cuda::solver::cusolverGetErrorMessage(__err), \ + ", when calling `" #EXPR "`", \ + ". This error may appear if the input matrix contains NaN. ", \ + at::cuda::solver::_cusolver_backend_suggestion); \ + } else { \ + TORCH_CHECK( \ + __err == CUSOLVER_STATUS_SUCCESS, \ + "cusolver error: ", \ + at::cuda::solver::cusolverGetErrorMessage(__err), \ + ", when calling `" #EXPR "`. ", \ + at::cuda::solver::_cusolver_backend_suggestion); \ + } \ + } while (0) + +#else // defined(USE_ROCM) + +C10_EXPORT const char* hipsolverGetErrorMessage(hipsolverStatus_t status); + +constexpr const char* _hipsolver_backend_suggestion = \ + "If you keep seeing this error, you may use " \ + "`torch.backends.cuda.preferred_linalg_library()` to try " \ + "linear algebra operators with other supported backends. " \ + "See https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.preferred_linalg_library"; + +#define TORCH_CUSOLVER_CHECK(EXPR) \ + do { \ + hipsolverStatus_t __err = EXPR; \ + if (__err == HIPSOLVER_STATUS_INVALID_VALUE) { \ + TORCH_CHECK_LINALG( \ + false, \ + "hipsolver error: ", \ + at::cuda::solver::hipsolverGetErrorMessage(__err), \ + ", when calling `" #EXPR "`", \ + ". This error may appear if the input matrix contains NaN. ", \ + at::cuda::solver::_hipsolver_backend_suggestion); \ + } else { \ + TORCH_CHECK( \ + __err == HIPSOLVER_STATUS_SUCCESS, \ + "hipsolver error: ", \ + at::cuda::solver::hipsolverGetErrorMessage(__err), \ + ", when calling `" #EXPR "`. ", \ + at::cuda::solver::_hipsolver_backend_suggestion); \ + } \ + } while (0) +#endif +} // namespace at::cuda::solver + +#define AT_CUDA_CHECK(EXPR) C10_CUDA_CHECK(EXPR) + +// For CUDA Driver API +// +// This is here instead of in c10 because NVRTC is loaded dynamically via a stub +// in ATen, and we need to use its nvrtcGetErrorString. +// See NOTE [ USE OF NVRTC AND DRIVER API ]. +#if !defined(USE_ROCM) + +#define AT_CUDA_DRIVER_CHECK(EXPR) \ + do { \ + CUresult __err = EXPR; \ + if (__err != CUDA_SUCCESS) { \ + const char* err_str; \ + [[maybe_unused]] CUresult get_error_str_err = \ + at::globalContext().getNVRTC().cuGetErrorString(__err, &err_str); \ + if (get_error_str_err != CUDA_SUCCESS) { \ + TORCH_CHECK(false, "CUDA driver error: unknown error"); \ + } else { \ + TORCH_CHECK(false, "CUDA driver error: ", err_str); \ + } \ + } \ + } while (0) + +#else + +#define AT_CUDA_DRIVER_CHECK(EXPR) \ + do { \ + CUresult __err = EXPR; \ + if (__err != CUDA_SUCCESS) { \ + TORCH_CHECK(false, "CUDA driver error: ", static_cast(__err)); \ + } \ + } while (0) + +#endif + +// For CUDA NVRTC +// +// Note: As of CUDA 10, nvrtc error code 7, NVRTC_ERROR_BUILTIN_OPERATION_FAILURE, +// incorrectly produces the error string "NVRTC unknown error." +// The following maps it correctly. +// +// This is here instead of in c10 because NVRTC is loaded dynamically via a stub +// in ATen, and we need to use its nvrtcGetErrorString. +// See NOTE [ USE OF NVRTC AND DRIVER API ]. +#define AT_CUDA_NVRTC_CHECK(EXPR) \ + do { \ + nvrtcResult __err = EXPR; \ + if (__err != NVRTC_SUCCESS) { \ + if (static_cast(__err) != 7) { \ + TORCH_CHECK(false, "CUDA NVRTC error: ", at::globalContext().getNVRTC().nvrtcGetErrorString(__err)); \ + } else { \ + TORCH_CHECK(false, "CUDA NVRTC error: NVRTC_ERROR_BUILTIN_OPERATION_FAILURE"); \ + } \ + } \ + } while (0) diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/NumericLimits.cuh b/phivenv/Lib/site-packages/torch/include/ATen/cuda/NumericLimits.cuh new file mode 100644 index 0000000000000000000000000000000000000000..d02b41a8157f30aeb4e91fc865ed654598318351 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/NumericLimits.cuh @@ -0,0 +1,121 @@ +#pragma once + +#include +#include +#include +#include + +// NumericLimits.cuh is a holder for numeric limits definitions of commonly used +// types. This header is very specific to ROCm HIP and may be removed in the future. +// This header is derived from the legacy THCNumerics.cuh. + +// The lower_bound and upper_bound constants are same as lowest and max for +// integral types, but are -inf and +inf for floating point types. They are +// useful in implementing min, max, etc. + +namespace at { + +template +struct numeric_limits { +}; + +// WARNING: the following at::numeric_limits definitions are there only to support +// HIP compilation for the moment. Use std::numeric_limits if you are not +// compiling for ROCm. +// from @colesbury: "The functions on numeric_limits aren't marked with +// __device__ which is why they don't work with ROCm. CUDA allows them +// because they're constexpr." + +namespace { + // ROCm doesn't like INFINITY too. + constexpr double inf = INFINITY; +} + +template <> +struct numeric_limits { + static inline __host__ __device__ bool lowest() { return false; } + static inline __host__ __device__ bool max() { return true; } + static inline __host__ __device__ bool lower_bound() { return false; } + static inline __host__ __device__ bool upper_bound() { return true; } +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ uint8_t lowest() { return 0; } + static inline __host__ __device__ uint8_t max() { return UINT8_MAX; } + static inline __host__ __device__ uint8_t lower_bound() { return 0; } + static inline __host__ __device__ uint8_t upper_bound() { return UINT8_MAX; } +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ int8_t lowest() { return INT8_MIN; } + static inline __host__ __device__ int8_t max() { return INT8_MAX; } + static inline __host__ __device__ int8_t lower_bound() { return INT8_MIN; } + static inline __host__ __device__ int8_t upper_bound() { return INT8_MAX; } +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ int16_t lowest() { return INT16_MIN; } + static inline __host__ __device__ int16_t max() { return INT16_MAX; } + static inline __host__ __device__ int16_t lower_bound() { return INT16_MIN; } + static inline __host__ __device__ int16_t upper_bound() { return INT16_MAX; } +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ int32_t lowest() { return INT32_MIN; } + static inline __host__ __device__ int32_t max() { return INT32_MAX; } + static inline __host__ __device__ int32_t lower_bound() { return INT32_MIN; } + static inline __host__ __device__ int32_t upper_bound() { return INT32_MAX; } +}; + +template <> +struct numeric_limits { +#ifdef _MSC_VER + static inline __host__ __device__ int64_t lowest() { return _I64_MIN; } + static inline __host__ __device__ int64_t max() { return _I64_MAX; } + static inline __host__ __device__ int64_t lower_bound() { return _I64_MIN; } + static inline __host__ __device__ int64_t upper_bound() { return _I64_MAX; } +#else + static inline __host__ __device__ int64_t lowest() { return INT64_MIN; } + static inline __host__ __device__ int64_t max() { return INT64_MAX; } + static inline __host__ __device__ int64_t lower_bound() { return INT64_MIN; } + static inline __host__ __device__ int64_t upper_bound() { return INT64_MAX; } +#endif +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ at::Half lowest() { return at::Half(0xFBFF, at::Half::from_bits()); } + static inline __host__ __device__ at::Half max() { return at::Half(0x7BFF, at::Half::from_bits()); } + static inline __host__ __device__ at::Half lower_bound() { return at::Half(0xFC00, at::Half::from_bits()); } + static inline __host__ __device__ at::Half upper_bound() { return at::Half(0x7C00, at::Half::from_bits()); } +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ at::BFloat16 lowest() { return at::BFloat16(0xFF7F, at::BFloat16::from_bits()); } + static inline __host__ __device__ at::BFloat16 max() { return at::BFloat16(0x7F7F, at::BFloat16::from_bits()); } + static inline __host__ __device__ at::BFloat16 lower_bound() { return at::BFloat16(0xFF80, at::BFloat16::from_bits()); } + static inline __host__ __device__ at::BFloat16 upper_bound() { return at::BFloat16(0x7F80, at::BFloat16::from_bits()); } +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ float lowest() { return -FLT_MAX; } + static inline __host__ __device__ float max() { return FLT_MAX; } + static inline __host__ __device__ float lower_bound() { return -static_cast(inf); } + static inline __host__ __device__ float upper_bound() { return static_cast(inf); } +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ double lowest() { return -DBL_MAX; } + static inline __host__ __device__ double max() { return DBL_MAX; } + static inline __host__ __device__ double lower_bound() { return -inf; } + static inline __host__ __device__ double upper_bound() { return inf; } +}; + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/PeerToPeerAccess.h b/phivenv/Lib/site-packages/torch/include/ATen/cuda/PeerToPeerAccess.h new file mode 100644 index 0000000000000000000000000000000000000000..281b02053d919f90b22f648d7ba339d9c4babe9e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/PeerToPeerAccess.h @@ -0,0 +1,12 @@ +#include +#include +#include + +namespace at::cuda { +namespace detail { +void init_p2p_access_cache(int64_t num_devices); +} + +TORCH_CUDA_CPP_API bool get_p2p_access(c10::DeviceIndex source_dev, c10::DeviceIndex dest_dev); + +} // namespace at::cuda diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/PhiloxCudaState.h b/phivenv/Lib/site-packages/torch/include/ATen/cuda/PhiloxCudaState.h new file mode 100644 index 0000000000000000000000000000000000000000..257ac6bbb896ab2883e7e85011ddee1426f53d15 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/PhiloxCudaState.h @@ -0,0 +1,5 @@ +#pragma once + +#include + +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/PhiloxUtils.cuh b/phivenv/Lib/site-packages/torch/include/ATen/cuda/PhiloxUtils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..091dd5e4402b9987edebed96d6d06c3baffa8272 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/PhiloxUtils.cuh @@ -0,0 +1,4 @@ +#pragma once + +#include +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/PinnedMemoryAllocator.h b/phivenv/Lib/site-packages/torch/include/ATen/cuda/PinnedMemoryAllocator.h new file mode 100644 index 0000000000000000000000000000000000000000..4f539c6dab9960832cdbcc736270e0f81f27573f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/PinnedMemoryAllocator.h @@ -0,0 +1,10 @@ +#pragma once + +#include + +namespace at::cuda { + +inline TORCH_CUDA_CPP_API at::HostAllocator* getPinnedMemoryAllocator() { + return at::getHostAllocator(at::kCUDA); +} +} // namespace at::cuda diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/ScanUtils.cuh b/phivenv/Lib/site-packages/torch/include/ATen/cuda/ScanUtils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..f81f560b4b523f8bc81423183cbbccca9c9d45e2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/ScanUtils.cuh @@ -0,0 +1,78 @@ +#pragma once + +#include +#include +#include +#include + +// Collection of in-kernel scan / prefix sum utilities + +namespace at::cuda { + +// Inclusive prefix sum for binary vars using intra-warp voting + +// shared memory +template +__device__ void inclusiveBinaryPrefixScan(T* smem, bool in, T* out, BinaryFunction binop) { + // Within-warp, we use warp voting. +#if defined (USE_ROCM) + unsigned long long int vote = WARP_BALLOT(in); + T index = __popcll(getLaneMaskLe() & vote); + T carry = __popcll(vote); +#else + T vote = WARP_BALLOT(in); + T index = __popc(getLaneMaskLe() & vote); + T carry = __popc(vote); +#endif + + int warp = threadIdx.x / C10_WARP_SIZE; + + // Per each warp, write out a value + if (getLaneId() == 0) { + smem[warp] = carry; + } + + __syncthreads(); + + // Sum across warps in one thread. This appears to be faster than a + // warp shuffle scan for CC 3.0+ + if (threadIdx.x == 0) { + int current = 0; + for (int i = 0; i < blockDim.x / C10_WARP_SIZE; ++i) { + T v = smem[i]; + smem[i] = binop(smem[i], current); + current = binop(current, v); + } + } + + __syncthreads(); + + // load the carry from the preceding warp + if (warp >= 1) { + index = binop(index, smem[warp - 1]); + } + + *out = index; + + if (KillWARDependency) { + __syncthreads(); + } +} + +// Exclusive prefix sum for binary vars using intra-warp voting + +// shared memory +template +__device__ void exclusiveBinaryPrefixScan(T* smem, bool in, T* out, T* carry, BinaryFunction binop) { + inclusiveBinaryPrefixScan(smem, in, out, binop); + + // Inclusive to exclusive + *out -= (T) in; + + // The outgoing carry for all threads is the last warp's sum + *carry = smem[at::ceil_div(blockDim.x, C10_WARP_SIZE) - 1]; + + if (KillWARDependency) { + __syncthreads(); + } +} + +} // namespace at::cuda diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/Sleep.h b/phivenv/Lib/site-packages/torch/include/ATen/cuda/Sleep.h new file mode 100644 index 0000000000000000000000000000000000000000..2ce5941af8390e17d15d060b5e1bfe94936969a1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/Sleep.h @@ -0,0 +1,13 @@ +#pragma once +#include +#include + +namespace at::cuda { + +// enqueues a kernel that spins for the specified number of cycles +TORCH_CUDA_CU_API void sleep(int64_t cycles); + +// flushes instruction cache for ROCm; no-op for CUDA +TORCH_CUDA_CU_API void flush_icache(); + +} // namespace at::cuda diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/ThrustAllocator.h b/phivenv/Lib/site-packages/torch/include/ATen/cuda/ThrustAllocator.h new file mode 100644 index 0000000000000000000000000000000000000000..e7f56bd455e5a71bef001908cd55f0e40a45f6ad --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/ThrustAllocator.h @@ -0,0 +1,23 @@ +#pragma once + +#include +#include + +namespace at::cuda { + +/// Allocator for Thrust to re-route its internal device allocations +/// to the THC allocator +class ThrustAllocator { +public: + typedef char value_type; + + char* allocate(std::ptrdiff_t size) { + return static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(size)); + } + + void deallocate(char* p, size_t size) { + c10::cuda::CUDACachingAllocator::raw_delete(p); + } +}; + +} // namespace at::cuda diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/cub-RadixSortPairs.cuh b/phivenv/Lib/site-packages/torch/include/ATen/cuda/cub-RadixSortPairs.cuh new file mode 100644 index 0000000000000000000000000000000000000000..bfe9c4a708ba10ad3d6c52813ea480914e2b77d1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/cub-RadixSortPairs.cuh @@ -0,0 +1,74 @@ +#pragma once + +#define TORCH_ASSERT_NO_OPERATORS +#include +#include + +namespace at::cuda::cub::detail { + +template +void radix_sort_pairs_impl( + const key_t* keys_in, + key_t* keys_out, + const OpaqueType* values_in, + OpaqueType* values_out, + int64_t n, + bool descending, + int64_t begin_bit, + int64_t end_bit) { + TORCH_CHECK( + n <= std::numeric_limits::max(), + "cub sort does not support sorting more than INT_MAX elements"); + using key_t_ = typename detail::cuda_type::type; + + auto allocator = c10::cuda::CUDACachingAllocator::get(); + c10::DataPtr keys_out_owner; + + if (keys_out == nullptr) { + keys_out_owner = allocator->allocate(n * sizeof(key_t)); + keys_out = reinterpret_cast(keys_out_owner.get()); + } + + const key_t_* keys_in_ = reinterpret_cast(keys_in); + key_t_* keys_out_ = reinterpret_cast(keys_out); + + if (descending) { + CUB_WRAPPER( + NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortPairsDescending, + keys_in_, + keys_out_, + values_in, + values_out, + n, + begin_bit, + end_bit, + c10::cuda::getCurrentCUDAStream()); + } else { + CUB_WRAPPER( + NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortPairs, + keys_in_, + keys_out_, + values_in, + values_out, + n, + begin_bit, + end_bit, + c10::cuda::getCurrentCUDAStream()); + } +} + +#define AT_INSTANTIATE_SORT_PAIRS(key_t, value_size) \ + template void radix_sort_pairs_impl( \ + const key_t* keys_in, \ + key_t* keys_out, \ + const OpaqueType* values_in, \ + OpaqueType* values_out, \ + int64_t n, \ + bool descending, \ + int64_t begin_bit, \ + int64_t end_bit); + +#define AT_INSTANTIATE_SORT_PAIRS_8(scalar_t, ScalarType) \ + AT_INSTANTIATE_SORT_PAIRS(scalar_t, 8) + +} // namespace at::cuda::cub::detail diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/cub.cuh b/phivenv/Lib/site-packages/torch/include/ATen/cuda/cub.cuh new file mode 100644 index 0000000000000000000000000000000000000000..04c9c1c060d5611cb2173906ac40459b158c668d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/cub.cuh @@ -0,0 +1,625 @@ +#pragma once +#include + +#include +#include +#include +#include + +#include +#include + +#if USE_GLOBAL_CUB_WRAPPED_NAMESPACE() + +#include + +#else + +// include cub in a safe manner, see: +// https://github.com/pytorch/pytorch/pull/55292 +#undef CUB_NS_POSTFIX //undef to avoid redefinition warnings +#undef CUB_NS_PREFIX +#undef CUB_NS_QUALIFIER +#define CUB_NS_PREFIX namespace at_cuda_detail { +#define CUB_NS_POSTFIX } +#define CUB_NS_QUALIFIER ::at_cuda_detail::cub +#include +#undef CUB_NS_POSTFIX +#undef CUB_NS_PREFIX +#undef CUB_NS_QUALIFIER + +#endif + +#include +#include +#include + +// handle the temporary storage and 'twice' calls for cub API +#define CUB_WRAPPER(func, ...) do { \ + size_t temp_storage_bytes = 0; \ + AT_CUDA_CHECK(func(nullptr, temp_storage_bytes, __VA_ARGS__)); \ + auto& caching_allocator = *::c10::cuda::CUDACachingAllocator::get(); \ + auto temp_storage = caching_allocator.allocate(temp_storage_bytes); \ + AT_CUDA_CHECK(func(temp_storage.get(), temp_storage_bytes, __VA_ARGS__));\ +} while (false) + +#ifdef USE_ROCM +#define NO_ROCM(x) +#define ROCM_HIPCUB(x) ::hipcub +#else +#define NO_ROCM(x) x +#define ROCM_HIPCUB(x) x +#endif + +#if (!defined(USE_ROCM) && !CUB_SUPPORTS_NV_BFLOAT16()) || defined(USE_ROCM) + +#if !defined(USE_ROCM) +namespace at_cuda_detail { +#endif + +// backport https://github.com/NVIDIA/cub/pull/306 for c10::BFloat16 + +template <> +struct ROCM_HIPCUB(cub)::FpLimits +{ + static __host__ __device__ __forceinline__ c10::BFloat16 Max() { + unsigned short max_word = 0x7F7F; + return reinterpret_cast(max_word); + } + + static __host__ __device__ __forceinline__ c10::BFloat16 Lowest() { + unsigned short lowest_word = 0xFF7F; + return reinterpret_cast(lowest_word); + } +}; + +template <> +struct ROCM_HIPCUB(cub)::NumericTraits: + ROCM_HIPCUB(cub)::BaseTraits {}; + +#if !defined(USE_ROCM) +} // namespace at_cuda_detail +#endif + +#endif + +#if !defined(USE_ROCM) +namespace at::native { +namespace cub = ::at_cuda_detail::cub; +} // namespace at::native +#endif + +namespace at::cuda::cub { + +namespace detail { + +template +struct cuda_type { + using type = T; +}; +template<> +struct cuda_type { + using type = __half; +}; + +#if !defined(USE_ROCM) && CUB_SUPPORTS_NV_BFLOAT16() + +template<> +struct cuda_type { + using type = __nv_bfloat16; +}; + +#elif defined(USE_ROCM) + +template<> +struct cuda_type { + using type = hip_bfloat16; +}; + +#endif + +} // namespace detail + +template +inline void segmented_sort_pairs( + const key_t *keys_in, key_t *keys_out, + const value_t *values_in, value_t *values_out, + int64_t num_elements, int64_t num_segments, + OffsetIteratorT begin_offsets, OffsetIteratorT end_offsets, + bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8 +) { + TORCH_CHECK(num_elements <= std::numeric_limits::max(), + "cub sort does not support sorting more than INT_MAX elements"); + TORCH_CHECK(num_segments <= std::numeric_limits::max(), + "cub sort does not support sorting more than INT_MAX elements"); + using key_t_ = typename detail::cuda_type::type; + + auto allocator = c10::cuda::CUDACachingAllocator::get(); + c10::DataPtr keys_out_owner; + + if (keys_out == nullptr) { + keys_out_owner = allocator->allocate(num_elements * sizeof(key_t)); + keys_out = reinterpret_cast(keys_out_owner.get()); + } + + const key_t_ *keys_in_ = reinterpret_cast(keys_in); + key_t_ *keys_out_ = reinterpret_cast(keys_out); + + if (descending) { + CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSegmentedRadixSort::SortPairsDescending, + keys_in_, keys_out_, values_in, values_out, + num_elements, num_segments, begin_offsets, end_offsets, + begin_bit, end_bit, c10::cuda::getCurrentCUDAStream()); + } else { + CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSegmentedRadixSort::SortPairs, + keys_in_, keys_out_, values_in, values_out, + num_elements, num_segments, begin_offsets, end_offsets, + begin_bit, end_bit, c10::cuda::getCurrentCUDAStream()); + } +} + +#if CUB_SUPPORTS_UNIQUE_BY_KEY() +template +inline void unique_by_key( + KeysInputIteratorT keys_in, ValuesInputIteratorT values_in, + ValuesOutputIteratorT values_out, + NumSelectedIteratorT num_selected, int64_t num_input_items) +{ + // TODO: use thrust::discard_iterator to handle null keys_out when https://github.com/NVIDIA/cub/issues/406 is fixed. + using KeyT = typename std::iterator_traits::value_type; + auto allocator = c10::cuda::CUDACachingAllocator::get(); + c10::DataPtr keys_out_owner; + keys_out_owner = allocator->allocate(num_input_items * sizeof(KeyT)); + auto keys_out_ = static_cast(keys_out_owner.get()); + CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::UniqueByKey, + keys_in, values_in, keys_out_, values_out, num_selected, num_input_items, c10::cuda::getCurrentCUDAStream()); +} +#endif + +namespace impl { + +template +C10_LAUNCH_BOUNDS_1(1) +__global__ void transform_vals(InputIteratorT1 a, InputIteratorT2 b, OutputIteratorT out, ScanOpT scan_op){ + // NOTE: out here not the final scan output, but an intermediate of the accumulation type. + using acc_t = typename std::iterator_traits::value_type; + *out = scan_op(static_cast(*a), static_cast(*b)); +} + +#if !CUB_SUPPORTS_FUTURE_VALUE() +template +struct chained_iterator { + using iterator_category = std::random_access_iterator_tag; + using difference_type = std::ptrdiff_t; + using value_type = ValueT; + using pointer = ValueT*; + using reference = ValueT&; + + InputIteratorT iter; + ValueT *first; + difference_type offset = 0; + + __device__ ValueT operator[](difference_type i) { + i += offset; + if (i == 0) { + return *first; + } else { + return ValueT(iter[i - 1]); + } + } + __device__ chained_iterator operator+(difference_type i) { + return chained_iterator{iter, first, i}; + } + __device__ ValueT operator*() { + return (*this)[0]; + } +}; +#endif + +// even though cub is supposed to support tensors with int_max elements, in reality it doesn't, +// so split at int_max/2 +constexpr int max_cub_size = std::numeric_limits::max() / 2 + 1; // 2**30 +} + +// non synchronizing cub call +// even though cub is supposed to support tensors with int_max elements, in reality it doesn't, +// so split at int_max/2 +template +inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT scan_op, int64_t num_items) { +#if defined(USE_ROCM) + //For ROCm, use hipCUB chained iterators + CUB_WRAPPER(NO_ROCM(detail)::hipcub::DeviceScan::InclusiveScan, + input, + output, + scan_op, + num_items, + at::cuda::getCurrentCUDAStream()); + C10_HIP_KERNEL_LAUNCH_CHECK(); +#else + // non synchronizing cub call + // even though cub is supposed to support tensors with int_max elements, in reality it doesn't, + // so split at int_max/2 + int size_cub = std::min(num_items, max_cub_size); + CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan, + input, + output, + scan_op, + size_cub, + at::cuda::getCurrentCUDAStream()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + using input_t = typename std::iterator_traits::value_type; + for (int64_t i = max_cub_size; i < num_items; i += max_cub_size) { + auto allocator = c10::cuda::CUDACachingAllocator::get(); + c10::DataPtr first_elem = allocator->allocate(sizeof(input_t)); + auto first_elem_ptr = reinterpret_cast(first_elem.get()); + + size_cub = std::min(num_items - i, max_cub_size); + impl::transform_vals<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( + output + i - 1, + input + i, + first_elem_ptr, + scan_op); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +#if !CUB_SUPPORTS_FUTURE_VALUE() + using ArgIndexInputIterator = NO_ROCM(at_cuda_detail)::cub::ArgIndexInputIterator; + using tuple = typename ArgIndexInputIterator::value_type; + auto input_iter_transform = [=] __device__ (const tuple &x)->input_t { + if (x.key == 0) { + return *first_elem_ptr; + } else { + return x.value; + } + }; + auto input_ = NO_ROCM(at_cuda_detail)::cub::TransformInputIterator( + ArgIndexInputIterator(input + i), input_iter_transform); + CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan, + input_, + output + i, + scan_op, + size_cub, + at::cuda::getCurrentCUDAStream()); +#else + CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan, + input + i + 1, + output + i, + scan_op, + ::at_cuda_detail::cub::FutureValue(first_elem_ptr), + size_cub, + at::cuda::getCurrentCUDAStream()); +#endif + } +#endif +} + +# if defined(CUDA_VERSION) || defined(USE_ROCM) + +template +struct BlockPrefixCallbackOp +{ + public: + T running_total; + + __host__ __device__ BlockPrefixCallbackOp(T running_total) : running_total(running_total) {} + + // Callback operator to be entered by the first warp of threads in the block. + // Thread-0 is responsible for returning a value for seeding the block-wide scan. + __host__ __device__ T operator()(T block_aggregate) + { + T old_prefix = running_total; + running_total += block_aggregate; + return old_prefix; + } +}; + +template +__global__ void final_scan_kernel(const T* d_in, T* d_out, T* agg, int64_t nelem, int iters_per_cta) { + int64_t offset = BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * (int64_t)blockIdx.x; + int64_t remaining = nelem - offset; + if (remaining <= 0) { + return; + } + + d_in += offset; + d_out += offset; + + using BlockLoadT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockLoad; + + // Specialize BlockStore type for our thread block (uses warp-striped loads for coalescing, then transposes in shared + // memory to a blocked arrangement) + using BlockStoreT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockStore; + + // Specialize BlockScan type for our thread block + using BlockScanT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockScan; + using BlockReduceT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockReduce; + + + // Shared memory + __shared__ union TempStorage + { + typename BlockLoadT::TempStorage load; + typename BlockStoreT::TempStorage store; + typename BlockScanT::TempStorage scan; + typename BlockReduceT::TempStorage reduce; + } temp_storage; + + // load agg and reduce my starting value + T agg_data; + agg_data = threadIdx.x >= blockIdx.x ? T(0) : agg[threadIdx.x]; + // if there are fewer threads than previous values to be read, + // read another value + if (threadIdx.x + blockDim.x < blockIdx.x) { + agg_data += agg[threadIdx.x + blockDim.x]; + } + T aggregate = BlockReduceT(temp_storage.reduce).Sum(agg_data); + __syncthreads(); + BlockPrefixCallbackOp prefix_op(aggregate); + + + // Per-thread tile data + T data[ITEMS_PER_THREAD]; + + for (int i=0; i= BLOCK_THREADS * ITEMS_PER_THREAD) { + BlockLoadT(temp_storage.load).Load(d_in, data); + } else { + #pragma unroll + for (int j=0; j= BLOCK_THREADS * ITEMS_PER_THREAD) { + BlockStoreT(temp_storage.store).Store(d_out, data); + } else { + BlockStoreT(temp_storage.store).Store(d_out, data, remaining); + } + d_in += BLOCK_THREADS * ITEMS_PER_THREAD; + d_out += BLOCK_THREADS * ITEMS_PER_THREAD; + remaining -= BLOCK_THREADS * ITEMS_PER_THREAD; + if (remaining <= 0) return; + __syncthreads(); + } + +} + +template +struct TransformFunctor { + __device__ aggT operator()(T value) const { + if constexpr (!nonzero) { + return value; + } else { + return (value != T(0)) ? 1 : 0; + } + } +}; + +template +__global__ void calc_block_sums(const T * d_in, aggT * agg, int64_t nelem, int iters_per_cta){ + int64_t offset = BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * (int64_t)blockIdx.x; + int64_t remaining = nelem - offset; + if (remaining <= 0) { + return; + } + d_in += offset; + + using BlockLoadT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockLoad; + using BlockReduceT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockReduce; + // Shared memory + __shared__ union TempStorage + { + typename BlockLoadT::TempStorage load; + typename BlockReduceT::TempStorage reduce; + } temp_storage; + aggT data[ITEMS_PER_THREAD]; + aggT agg_val = 0; + TransformFunctor transform_functor; + auto iter_in = ROCM_HIPCUB(at_cuda_detail::cub)::TransformInputIterator, const T*>(d_in, transform_functor); + for (int i=0; i= BLOCK_THREADS * ITEMS_PER_THREAD) { + BlockLoadT(temp_storage.load).Load(iter_in, data); + __syncthreads(); + agg_val += BlockReduceT(temp_storage.reduce).Sum(data); + + } else { + BlockLoadT(temp_storage.load).Load(iter_in, data, remaining, aggT(0)); + __syncthreads(); + agg_val += BlockReduceT(temp_storage.reduce).Sum(data); + } + iter_in += BLOCK_THREADS * ITEMS_PER_THREAD; + remaining -= BLOCK_THREADS * ITEMS_PER_THREAD; + if (remaining <= 0) { + // for nonzeros we need to write out last blocks + // accumulated value to be able to compute + // total number of nonzeros + if (nonzero && threadIdx.x == 0) { + agg[blockIdx.x] = agg_val; + } + return; + } + __syncthreads(); + + } + if (threadIdx.x == 0) { + agg[blockIdx.x] = agg_val; + } + +} + +template +struct NonZeroOp { + __host__ __device__ __forceinline__ int operator()(const T& a) const { + return (a != T(0)); + } +}; + +template +constexpr int block_threads(){ + if constexpr (size >=16) { + return 128; + } else if constexpr (size >=8) { + return 256; + } else { + return 512; + } +} + +template +inline void inclusive_deterministic_scan(const scalar_t * input, scalar_t * output, ScanOpT scan_op, int64_t num_items) { + static_assert(std::is_same_v>, ""); + constexpr int BLOCK_THREADS = block_threads(); + constexpr int ITEMS_PER_THREAD = 16; + auto grid_size = (num_items + BLOCK_THREADS * ITEMS_PER_THREAD - 1) / (BLOCK_THREADS * ITEMS_PER_THREAD); + const int64_t num_sms = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; + + const int iters_per_cta = (grid_size + num_sms - 1)/num_sms; + grid_size = std::min(num_sms, grid_size); + // simple reduction in scan kernel handles at most 2 items per thread + TORCH_INTERNAL_ASSERT(2 * BLOCK_THREADS >= grid_size); + auto& allocator = *c10::cuda::CUDACachingAllocator::get(); + auto agg = allocator.allocate(grid_size * sizeof(scalar_t)); + calc_block_sums + <<>>( + input, (scalar_t*)agg.get(), num_items, iters_per_cta); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + final_scan_kernel + <<>>( + input, output, (scalar_t*)agg.get(), num_items, iters_per_cta); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +#endif + +template +inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT scan_op, InitValueT init_value, int64_t num_items) { +#if defined(USE_ROCM) + //For ROCm, use hipCUB chained iterators + CUB_WRAPPER(NO_ROCM(detail)::hipcub::DeviceScan::ExclusiveScan, + input, + output, + scan_op, + init_value, + num_items, + at::cuda::getCurrentCUDAStream()); + C10_HIP_KERNEL_LAUNCH_CHECK(); +#else + // non synchronizing cub call + // even though cub is supposed to support tensors with int_max elements, in reality it doesn't, + // so split at int_max/2 + int size_cub = std::min(num_items, max_cub_size); + CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan, + input, + output, + scan_op, + init_value, + size_cub, + at::cuda::getCurrentCUDAStream()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + for (int64_t i = max_cub_size; i < num_items; i += max_cub_size) { + auto allocator = c10::cuda::CUDACachingAllocator::get(); + c10::DataPtr first_elem = allocator->allocate(sizeof(InitValueT)); + auto first_elem_ptr = reinterpret_cast(first_elem.get()); + + size_cub = std::min(num_items - i, max_cub_size); + impl::transform_vals<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( + output + i - 1, + input + i - 1, + first_elem_ptr, + scan_op); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +#if !CUB_SUPPORTS_FUTURE_VALUE() + auto input_ = impl::chained_iterator{ + input + i, first_elem_ptr}; + CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan, + input_, + output + i, + scan_op, + size_cub, + at::cuda::getCurrentCUDAStream()); +#else + CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan, + input + i, + output + i, + scan_op, + ::at_cuda_detail::cub::FutureValue(first_elem_ptr), + size_cub, + at::cuda::getCurrentCUDAStream()); +#endif + } +#endif +} + +#if CUB_SUPPORTS_SCAN_BY_KEY() + +template +inline void inclusive_sum_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items) { + TORCH_CHECK(num_items <= std::numeric_limits::max(), + "cub InclusiveSumByKey does not support more than INT_MAX elements"); +#if !defined(USE_ROCM) + CUB_WRAPPER(at_cuda_detail::cub::DeviceScan::InclusiveSumByKey, + keys, input, output, num_items, at_cuda_detail::cub::Equality(), at::cuda::getCurrentCUDAStream()); +#else + CUB_WRAPPER(cub::DeviceScan::InclusiveSumByKey, + keys, input, output, num_items, hipcub::Equality(), at::cuda::getCurrentCUDAStream()); +#endif +} + +template +inline void inclusive_scan_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, ScanOpT scan_op, int64_t num_items) { + TORCH_CHECK(num_items <= std::numeric_limits::max(), + "cub InclusiveSumByKey does not support more than INT_MAX elements"); +#if !defined(USE_ROCM) + CUB_WRAPPER(at_cuda_detail::cub::DeviceScan::InclusiveScanByKey, + keys, input, output, scan_op, num_items, at_cuda_detail::cub::Equality(), at::cuda::getCurrentCUDAStream()); +#else + CUB_WRAPPER(cub::DeviceScan::InclusiveScanByKey, + keys, input, output, scan_op, num_items, hipcub::Equality(), at::cuda::getCurrentCUDAStream()); +#endif +} + +#endif + +template +void unique(InputIteratorT input, OutputIteratorT output, + NumSelectedIteratorT num_selected_out, int64_t num_items) { + TORCH_CHECK(num_items <= std::numeric_limits::max(), + "cub unique does not support more than INT_MAX elements"); + CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::Unique, + input, output, num_selected_out, num_items, at::cuda::getCurrentCUDAStream()); +} + +template +void run_length_encode(InputIteratorT input, OutputIteratorT output, CountsOutputIteratorT counts_out, + LengthOutputIteratorT length_out, int64_t num_items) { + TORCH_CHECK(num_items <= std::numeric_limits::max(), + "cub run_length_encode does not support more than INT_MAX elements"); + CUB_WRAPPER( + NO_ROCM(at_cuda_detail)::cub::DeviceRunLengthEncode::Encode, + input, output, counts_out, length_out, num_items, + at::cuda::getCurrentCUDAStream()); +} + +template +void reduce(InputIteratorT input, OutputIteratorT output, int64_t num_items, ReductionOpT op, T init) { + TORCH_CHECK(num_items <= std::numeric_limits::max(), + "cub reduce does not support more than INT_MAX elements"); + CUB_WRAPPER( + NO_ROCM(at_cuda_detail)::cub::DeviceReduce::Reduce, + input, output, num_items, op, init, + at::cuda::getCurrentCUDAStream()); + +} + +} // namespace at::cuda::cub diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/cub.h b/phivenv/Lib/site-packages/torch/include/ATen/cuda/cub.h new file mode 100644 index 0000000000000000000000000000000000000000..077549e64c85f50b87c80e98b0ccbecb895ad010 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/cub.h @@ -0,0 +1,87 @@ +#pragma once +#include +#include +#include + +// NOTE: These templates are intentionally not defined in this header, +// which aviods re-compiling them for each translation unit. If you get +// a link error, you need to add an explicit instantiation for your +// types in cub.cu + +namespace at::cuda::cub { + +inline int get_num_bits(uint64_t max_key) { + int num_bits = 1; + while (max_key > 1) { + max_key >>= 1; + num_bits++; + } + return num_bits; +} + +namespace detail { + +// radix_sort_pairs doesn't interact with value_t other than to copy +// the data, so we can save template instantiations by reinterpreting +// it as an opaque type. +template struct alignas(N) OpaqueType { char data[N]; }; + +template +void radix_sort_pairs_impl( + const key_t *keys_in, key_t *keys_out, + const OpaqueType *values_in, OpaqueType *values_out, + int64_t n, bool descending, int64_t begin_bit, int64_t end_bit); + +} // namespace detail + +template +void radix_sort_pairs( + const key_t *keys_in, key_t *keys_out, + const value_t *values_in, value_t *values_out, + int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8) { + static_assert(std::is_trivially_copyable_v || + AT_ROCM_ENABLED(), // ROCm incorrectly fails this check for vector types + "radix_sort_pairs value type must be trivially copyable"); + // Make value type opaque, so all inputs of a certain size use the same template instantiation + using opaque_t = detail::OpaqueType; + static_assert(sizeof(value_t) <= 8 && (sizeof(value_t) & (sizeof(value_t) - 1)) == 0, + "This size of value_t is not instantiated. Please instantiate it in cub.cu" + " and modify this check."); + static_assert(sizeof(value_t) == alignof(value_t), "Expected value_t to be size-aligned"); + detail::radix_sort_pairs_impl( + keys_in, keys_out, + reinterpret_cast(values_in), + reinterpret_cast(values_out), + n, descending, begin_bit, end_bit); +} + +template +void radix_sort_keys( + const key_t *keys_in, key_t *keys_out, + int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8); + +// NOTE: Intermediate sums will be truncated to input_t precision +template +void inclusive_sum_truncating(const input_t *input, output_t *output, int64_t n); + +template +void inclusive_sum(const scalar_t *input, scalar_t *output, int64_t n) { + return inclusive_sum_truncating(input, output, n); +} + +// NOTE: Sums are done is common_type +template +void exclusive_sum_in_common_type(const input_t *input, output_t *output, int64_t n); + +template +void exclusive_sum(const scalar_t *input, scalar_t *output, int64_t n) { + return exclusive_sum_in_common_type(input, output, n); +} + +void mask_exclusive_sum(const uint8_t *mask, int64_t *output_idx, int64_t n); +inline void mask_exclusive_sum(const bool *mask, int64_t *output_idx, int64_t n) { + return mask_exclusive_sum( + reinterpret_cast(mask), output_idx, n); +} + +} // namespace at::cuda::cub diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/cub_definitions.cuh b/phivenv/Lib/site-packages/torch/include/ATen/cuda/cub_definitions.cuh new file mode 100644 index 0000000000000000000000000000000000000000..fd0a157f208b9fc79759106e58d392ff3b840cee --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/cub_definitions.cuh @@ -0,0 +1,53 @@ +#pragma once + +#if !defined(USE_ROCM) +#include // for CUDA_VERSION +#endif + +#if !defined(USE_ROCM) +#include +#else +#define CUB_VERSION 200001 +#endif + +// cub sort support for __nv_bfloat16 is added to cub 1.13 in: +// https://github.com/NVIDIA/cub/pull/306 +#if CUB_VERSION >= 101300 +#define CUB_SUPPORTS_NV_BFLOAT16() true +#else +#define CUB_SUPPORTS_NV_BFLOAT16() false +#endif + +// cub support for CUB_WRAPPED_NAMESPACE is added to cub 1.13.1 in: +// https://github.com/NVIDIA/cub/pull/326 +// CUB_WRAPPED_NAMESPACE is defined globally in cmake/Dependencies.cmake +// starting from CUDA 11.5 +#if defined(CUB_WRAPPED_NAMESPACE) || defined(THRUST_CUB_WRAPPED_NAMESPACE) +#define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() true +#else +#define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false +#endif + +// cub support for UniqueByKey is added to cub 1.16 in: +// https://github.com/NVIDIA/cub/pull/405 +#if CUB_VERSION >= 101600 +#define CUB_SUPPORTS_UNIQUE_BY_KEY() true +#else +#define CUB_SUPPORTS_UNIQUE_BY_KEY() false +#endif + +// cub support for scan by key is added to cub 1.15 +// in https://github.com/NVIDIA/cub/pull/376 +#if CUB_VERSION >= 101500 +#define CUB_SUPPORTS_SCAN_BY_KEY() 1 +#else +#define CUB_SUPPORTS_SCAN_BY_KEY() 0 +#endif + +// cub support for cub::FutureValue is added to cub 1.15 in: +// https://github.com/NVIDIA/cub/pull/305 +#if CUB_VERSION >= 101500 +#define CUB_SUPPORTS_FUTURE_VALUE() true +#else +#define CUB_SUPPORTS_FUTURE_VALUE() false +#endif diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/detail/CUDAHooks.h b/phivenv/Lib/site-packages/torch/include/ATen/cuda/detail/CUDAHooks.h new file mode 100644 index 0000000000000000000000000000000000000000..189ca8b5236cec25ec61f4144b1a8d90ccfa0516 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/detail/CUDAHooks.h @@ -0,0 +1,66 @@ +#pragma once + +#include + +#include + +// TODO: No need to have this whole header, we can just put it all in +// the cpp file + +namespace at::cuda::detail { + +// Set the callback to initialize Magma, which is set by +// torch_cuda_cu. This indirection is required so magma_init is called +// in the same library where Magma will be used. +TORCH_CUDA_CPP_API void set_magma_init_fn(void (*magma_init_fn)()); + + +// The real implementation of CUDAHooksInterface +struct CUDAHooks : public at::CUDAHooksInterface { + CUDAHooks(at::CUDAHooksArgs) {} + void init() const override; + Device getDeviceFromPtr(void* data) const override; + bool isPinnedPtr(const void* data) const override; + const Generator& getDefaultGenerator( + DeviceIndex device_index = -1) const override; + Generator getNewGenerator( + DeviceIndex device_index = -1) const override; + bool hasCUDA() const override; + bool hasMAGMA() const override; + bool hasCuDNN() const override; + bool hasCuSOLVER() const override; + bool hasCuBLASLt() const override; + bool hasROCM() const override; + const at::cuda::NVRTC& nvrtc() const override; + DeviceIndex current_device() const override; + bool isBuilt() const override {return true;} + bool isAvailable() const override {return hasCUDA();} + bool hasPrimaryContext(DeviceIndex device_index) const override; + Allocator* getCUDADeviceAllocator() const override; + Allocator* getPinnedMemoryAllocator() const override; + bool compiledWithCuDNN() const override; + bool compiledWithMIOpen() const override; + bool supportsDilatedConvolutionWithCuDNN() const override; + bool supportsDepthwiseConvolutionWithCuDNN() const override; + bool supportsBFloat16ConvolutionWithCuDNNv8() const override; + bool hasCUDART() const override; + long versionCUDART() const override; + long versionCuDNN() const override; + long versionMIOpen() const override; + std::string showConfig() const override; + double batchnormMinEpsilonCuDNN() const override; + int64_t cuFFTGetPlanCacheMaxSize(DeviceIndex device_index) const override; + void cuFFTSetPlanCacheMaxSize(DeviceIndex device_index, int64_t max_size) const override; + int64_t cuFFTGetPlanCacheSize(DeviceIndex device_index) const override; + void cuFFTClearPlanCache(DeviceIndex device_index) const override; + int getNumGPUs() const override; + DeviceIndex deviceCount() const override; + DeviceIndex getCurrentDevice() const override; + +#ifdef USE_ROCM + bool isGPUArch(const std::vector& archs, DeviceIndex device_index = -1) const override; +#endif + void deviceSynchronize(DeviceIndex device_index) const override; +}; + +} // at::cuda::detail diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/detail/DeviceThreadHandles.h b/phivenv/Lib/site-packages/torch/include/ATen/cuda/detail/DeviceThreadHandles.h new file mode 100644 index 0000000000000000000000000000000000000000..5ff1919e9391a8166a37765a1846c88c389bb580 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/detail/DeviceThreadHandles.h @@ -0,0 +1,151 @@ +// Some stateful GPU libraries, such as cuDNN, cuBLAS, use handles to store states. +// These handles are tied to device, and these libraries requires/recommends not to +// share handles across host threads. +// +// These libraries recommend using one handle per host thread. We may not want to do +// this because threads are relatively light-weight, but creating and destroying +// handles is expensive (destroying the handle causes synchronizations). DataParallel, +// for example, creates new threads for each forward pass. +// +// This file implements a handle pool mechanism. The handle pool returns handles on +// demand as threads request them. If all existing handles in the pool are in use, +// it creates a new one. As threads terminate, they release handles back into the pool. +// In this way, the handle pool never creates more handles than the high-water mark of +// active threads, so it's efficient with DataParallel. + +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace at::cuda { namespace { + +template +struct DeviceThreadHandlePool : public std::enable_shared_from_this> { + + struct Handle { + Handle_t handle; + Handle(bool create = false) : handle(nullptr) + { + if(create) Create(&handle); + } + // std::vector.emplace() and push_back() may route through temporaries and call + // copy/move constructors along the way. If this is the case, we don't want + // the destructors of temporaries to call cudnnDestroy on the handle. + // We can achieve safety (for the narrow case of stashing within std::vectors) + // by making Handle moveable but not copyable, and transferring handle ownership + // to the latest constructed object. This is not a substitute for full-blown + // reference counting, but reference counting may be overkill here. + // Another alternative is to wrap the saved Handles in unique_ptrs, i.e., + // unordered_map>> created_handles; + Handle(const Handle& rhs) = delete; + // Following https://stackoverflow.com/questions/3279543/what-is-the-copy-and-swap-idiom + Handle(Handle&& rhs) noexcept : Handle() { std::swap(handle, rhs.handle); } + // operator= takes argument by value + Handle& operator=(Handle rhs) { std::swap(handle, rhs.handle); return *this; } + ~Handle() { + if(handle) Destroy(handle); + } + }; + + std::mutex mutex; + + // Handles are lazily created as different threads request them, + // but are never destroyed until the end of the process. + // The maximum number of handles this process will create for each device is equal + // to the high-water mark of the number of concurrently active threads that request + // handles for that device. + // When threads terminate, they release their handles back into the pool for reuse. + // Otherwise, new handles would be created every time new threads were spawned, + // resulting in poor performance for Python modules that repeatedly or frequently + // spawned new sets of threads (like DataParallel, which creates a new set of threads + // for each forward pass). + // + // To prevent potential deadlocks, we explicitly choose not to cap the number + // of handles that are created per device. + // Example of danger: If we cap the max handles at 4, and 5 threads are sharing a device, + // only 4 can make forward progress at any time. The other 4 will not release their + // handles until they exit, so the fifth cannot make progress until then. This is + // not a problem...UNLESS all 5 threads attempt some sort of synchronization at an + // intermediate point (ie, before any of them have exited). We have no way to anticipate + // or enforce that user threads will not attempt such intermediate synchronization. + // The only way to ensure safety is to avoid imposing a cap on the number of handles. + std::unordered_map> created_handles; + std::unordered_map> available_handles; + + // PoolWindow lazily creates and caches the handles that a particular thread is using, + // so in the common case handle access doesn't incur either handle creation or a mutex lock. + class PoolWindow + { + public: + PoolWindow(std::shared_ptr parent): weak_parent(std::move(parent)) {} + ~PoolWindow(){ release(); } + + Handle_t reserve(int device) + { + // If this thread already has a handle for this device, return it + if(my_handles.find(device) != my_handles.end()) + return my_handles[device]; + + // otherwise, either grab a handle from the pool if one is available, + // or if not, create a new one. + auto parent = weak_parent.lock(); + TORCH_CHECK(parent, "Cannot create handle during program termination"); + std::lock_guard guard(parent->mutex); + + if(parent->available_handles[device].size() > 0) + { + my_handles[device] = parent->available_handles[device].back(); + parent->available_handles[device].pop_back(); + } + else + { + // In local testing, I do observe that emplace_back sometimes routes through temporaries + // that incur move-constructor and destructor calls. See comments in Handle above. + parent->created_handles[device].emplace_back(true /*create*/); + my_handles[device] = parent->created_handles[device].back().handle; + } + + return my_handles[device]; + } + + private: + // Stores the per-device handles currently owned by this thread + std::unordered_map my_handles; + + std::weak_ptr weak_parent; + + // Called by the destructor. Releases this thread's handles back into the pool. + void release() { + if(my_handles.size() > 0) { + auto parent = weak_parent.lock(); + if (!parent) { + // If this thread exits after atexit handlers have completed, the + // cuda context itself may be invalid, so we must leak the handles. + return; + } + + std::lock_guard guard(parent->mutex); + for(auto d_h : my_handles) + parent->available_handles[d_h.first].push_back(d_h.second); + } + } + }; + + // Warning: + // If you want to change this function, be aware that this function will be called + // by multiple threads and there is no mutex guarding the call of this function, so + // make sure your implementation is thread-safe. + PoolWindow *newPoolWindow() { + // The returned pointer will be owned by a thread local variable + // so that different threads does not share the same PoolWindow. + return new PoolWindow(this->shared_from_this()); + } +}; + +}} // namespace at::cuda::detail:: diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/detail/IndexUtils.cuh b/phivenv/Lib/site-packages/torch/include/ATen/cuda/detail/IndexUtils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..4462d035cebc8de4997b7d4d198039680960bb38 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/detail/IndexUtils.cuh @@ -0,0 +1,36 @@ +#pragma once + +#include +#include +#include + +namespace at::cuda::detail { + +TORCH_CUDA_CU_API bool maybeOverlappingIndices(const at::TensorBase &t); +using at::native::canUse32BitIndexMath; + +template +TensorInfo +getTensorInfo(const at::TensorBase &t) { + IndexType sz[MAX_TENSORINFO_DIMS]; + IndexType st[MAX_TENSORINFO_DIMS]; + + int dims = t.dim(); + for (int i = 0; i < dims; ++i) { + sz[i] = t.size(i); + st[i] = t.stride(i); + } + + scalar* data_ptr = nullptr; + + if constexpr (std::is_const_v) { + data_ptr = t.const_data_ptr(); + } else { + data_ptr = t.mutable_data_ptr(); + } + + return TensorInfo( + data_ptr, dims, sz, st); +} + +} // namespace at::cuda::detail diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/detail/IntegerDivider.cuh b/phivenv/Lib/site-packages/torch/include/ATen/cuda/detail/IntegerDivider.cuh new file mode 100644 index 0000000000000000000000000000000000000000..7532aed5fee08a22c88135169634d206ab3c8982 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/detail/IntegerDivider.cuh @@ -0,0 +1,124 @@ +#pragma once + +#include +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) +#include +#endif + +namespace at::cuda::detail { + +// A utility class to implement integer division by multiplication, given a fixed +// divisor. +// +// WARNING: The fast divider algorithm is only implemented for unsigned int; +// otherwise we default to plain integer division. For unsigned int, +// we further assume that the dividend is at most INT32_MAX. Thus, +// IntDivider must NOT be used for general integer division. +// +// This reduced range is enough for our purpose, and it allows us to +// slightly simplify the computation. +// +// (NOTE: Below, "2^k" denotes exponentiation, i.e., 1< 0), we can find a "magic number" m (2^N +// <= m < 2^(N+1)) and shift s such that: +// +// \floor(n / d) = \floor((m * n) / 2^(N+s)). +// +// Given such m and s, the integer division can be then implemented as: +// +// let m' = m - 2^N // 0 <= m' < 2^N +// +// fast_integer_division(n): +// // Multiply two N-bit unsigned integers: the result is a 2N-bit unsigned +// // integer. Then take the higher N bits. +// t = (m' * n) >> N +// +// // Here we use the fact that n is less than 2^(N-1): otherwise the value +// // of (t + n) may not fit in an N-bit integer. +// return (t + n) >> s +// +// Finding such a magic number is surprisingly easy: +// +// s = \ceil(\log_2 d) +// m' = \floor(2^N * (2^s - d) / d) + 1 // Need 2N-bit integer arithmetic. +// +// See also: +// - Division by Invariant Integers Using Multiplication, +// Torbjörn Granlund and Peter L. Montgomery, 1994. +// +// - http://www.hackersdelight.org/magic.htm +// +// - http://ridiculousfish.com/blog/posts/labor-of-division-episode-i.html + +// Result of div/mod operation stored together. +template +struct DivMod { + Value div, mod; + + C10_HOST_DEVICE DivMod(Value div, Value mod) : div(div), mod(mod) { } +}; + +// Base case: we only have an implementation for uint32_t for now. For +// everything else, we use plain division. +template +struct IntDivider { + IntDivider() = default; + IntDivider(Value d) : divisor(d) { } + + C10_HOST_DEVICE inline Value div(Value n) const { return n / divisor; } + C10_HOST_DEVICE inline Value mod(Value n) const { return n % divisor; } + C10_HOST_DEVICE inline DivMod divmod(Value n) const { + return DivMod(n / divisor, n % divisor); + } + + Value divisor; +}; + +// Implement fast integer division. +template <> +struct IntDivider { + static_assert(sizeof(unsigned int) == 4, "Assumes 32-bit unsigned int."); + + IntDivider() = default; + + IntDivider(unsigned int d) : divisor(d) { + assert(divisor >= 1 && divisor <= INT32_MAX); + + // TODO: gcc/clang has __builtin_clz() but it's not portable. + for (shift = 0; shift < 32; shift++) if ((1U << shift) >= divisor) break; + + uint64_t one = 1; + uint64_t magic = ((one << 32) * ((one << shift) - divisor)) / divisor + 1; + m1 = magic; + assert(m1 > 0 && m1 == magic); // m1 must fit in 32 bits. + } + + C10_HOST_DEVICE inline unsigned int div(unsigned int n) const { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + // 't' is the higher 32-bits of unsigned 32-bit multiplication of 'n' and + // 'm1'. + unsigned int t = __umulhi(n, m1); + return (t + n) >> shift; +#else + // Using uint64_t so that the addition does not overflow. + uint64_t t = ((uint64_t) n * m1) >> 32; + return (t + n) >> shift; +#endif + } + + C10_HOST_DEVICE inline unsigned int mod(unsigned int n) const { + return n - div(n) * divisor; + } + + C10_HOST_DEVICE inline DivMod divmod(unsigned int n) const { + unsigned int q = div(n); + return DivMod(q, n - q * divisor); + } + + unsigned int divisor; // d above. + unsigned int m1; // Magic number: m' above. + unsigned int shift; // Shift amounts. +}; + +} // namespace at::cuda::detail diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/detail/KernelUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/cuda/detail/KernelUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..02af7d1302b86c5e4bdf52f98eec5c800060291a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/detail/KernelUtils.h @@ -0,0 +1,37 @@ +#pragma once + +#include +#include + +namespace at::cuda::detail { + +// CUDA: grid stride looping +// +// int64_t _i_n_d_e_x specifically prevents overflow in the loop increment. +// If input.numel() < INT_MAX, _i_n_d_e_x < INT_MAX, except after the final +// iteration of the loop where _i_n_d_e_x += blockDim.x * gridDim.x can be +// greater than INT_MAX. But in that case _i_n_d_e_x >= n, so there are no +// further iterations and the overflowed value in i=_i_n_d_e_x is not used. +#define CUDA_KERNEL_LOOP_TYPE(i, n, index_type) \ + int64_t _i_n_d_e_x = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x; \ + for (index_type i=_i_n_d_e_x; _i_n_d_e_x < (n); _i_n_d_e_x+=blockDim.x * gridDim.x, i=_i_n_d_e_x) + +#define CUDA_KERNEL_LOOP(i, n) CUDA_KERNEL_LOOP_TYPE(i, n, int) + + +// Use 1024 threads per block, which requires cuda sm_2x or above +constexpr int CUDA_NUM_THREADS = 1024; + +// CUDA: number of blocks for threads. +inline int GET_BLOCKS(const int64_t N, const int64_t max_threads_per_block=CUDA_NUM_THREADS) { + TORCH_INTERNAL_ASSERT(N > 0, "CUDA kernel launch blocks must be positive, but got N=", N); + constexpr int64_t max_int = std::numeric_limits::max(); + + // Round up division for positive number that cannot cause integer overflow + auto block_num = (N - 1) / max_threads_per_block + 1; + TORCH_INTERNAL_ASSERT(block_num <= max_int, "Can't schedule too many blocks on CUDA device"); + + return static_cast(block_num); +} + +} // namespace at::cuda::detail diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/detail/LazyNVRTC.h b/phivenv/Lib/site-packages/torch/include/ATen/cuda/detail/LazyNVRTC.h new file mode 100644 index 0000000000000000000000000000000000000000..23821c88e964ea499df1479a0c369228ba854738 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/detail/LazyNVRTC.h @@ -0,0 +1,11 @@ +#pragma once +#include +namespace at::cuda { +// Forward-declares at::cuda::NVRTC +struct NVRTC; + +namespace detail { +extern NVRTC lazyNVRTC; +} // namespace detail + +} // namespace at::cuda diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/detail/OffsetCalculator.cuh b/phivenv/Lib/site-packages/torch/include/ATen/cuda/detail/OffsetCalculator.cuh new file mode 100644 index 0000000000000000000000000000000000000000..cbaa35aa8b967ce465b0121ea35bb513ed183cce --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/detail/OffsetCalculator.cuh @@ -0,0 +1,118 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +// If element_sizes is nullptr, then the strides will be in bytes, otherwise +// the strides will be in # of elements. +// Operands that share the same shape, but may have different strides. +// OffsetCalculator iterates the tensor in a column-major order + +#if defined(USE_ROCM) +constexpr int MAX_DIMS = 16; +#else +constexpr int MAX_DIMS = 25; +#endif + +template +struct OffsetCalculator { + // We allow having negative strides to implement some operations like torch.flip + using stride_t = std::conditional_t, + index_t>; + // The offset for each argument. Wrapper around fixed-size array. + // On CUDA, zero sized array is not allowed, so when we are handling nullary + // operators, we need to create a size 1 offset to avoid compiler failure. + // This size 1 offset is just a placeholder, and we will not use it. + using offset_type = std::array(NARGS, 1)>; + + // if element_sizes is nullptr, then the strides will be in bytes, otherwise + // the strides will be in # of elements. + OffsetCalculator(int dims, const int64_t* sizes, const int64_t* const* strides, const int64_t* element_sizes=nullptr) : dims(dims) { + TORCH_CHECK(dims <= MAX_DIMS, "tensor has too many (>", MAX_DIMS, ") dims"); + for (int i=0; i < dims; i++){ + sizes_[i] = at::cuda::detail::IntDivider(sizes[i]); + for (int arg = 0; arg < NARGS; arg++) { + int64_t element_size = (element_sizes == nullptr ? 1LL : element_sizes[arg]); + strides_[i][arg] = strides[arg][i] / element_size; + } + } + } + + C10_HOST_DEVICE offset_type get(index_t linear_idx) const { + offset_type offsets; + #pragma unroll + for (int arg = 0; arg < NARGS; arg++) { + offsets[arg] = 0; + } + + #pragma unroll + for (int dim = 0; dim < MAX_DIMS; ++dim) { + if (dim == dims) { + break; + } + auto divmod = sizes_[dim].divmod(linear_idx); + linear_idx = divmod.div; + + #pragma unroll + for (int arg = 0; arg < NARGS; arg++) { + offsets[arg] += divmod.mod * strides_[dim][arg]; + } + + } + return offsets; + } + + int dims; + at::cuda::detail::IntDivider sizes_[MAX_DIMS]; + stride_t strides_[MAX_DIMS][std::max(NARGS, 1)]; +}; + +template +struct TrivialOffsetCalculator { + // The offset for each argument. Wrapper around fixed-size array. + // The offsets are in # of elements, not in bytes. + // On CUDA, zero sized array is not allowed, so when we are handling nullary + // operators, we need to create a size 1 offset to avoid compiler failure. + // This size 1 offset is just a placeholder, and we will not use it. + using offset_type = std::array(NARGS, 1)>; + + C10_HOST_DEVICE offset_type get(index_t linear_idx) const { + offset_type offsets; + #pragma unroll + for (int arg = 0; arg < NARGS; arg++) { + offsets[arg] = linear_idx; + } + return offsets; + } +}; + +// Make an OffsetCalculator with byte offsets +template +static OffsetCalculator make_offset_calculator(const at::TensorIteratorBase& iter) { + TORCH_INTERNAL_ASSERT(N <= iter.ntensors()); + std::array strides; + for (int i = 0; i < N; i++) { + strides[i] = iter.strides(i).data(); + } + return OffsetCalculator(iter.ndim(), iter.shape().data(), strides.data()); +} + +// Make an OffsetCalculator with element offsets +template +static OffsetCalculator make_element_offset_calculator( + const at::TensorIteratorBase& iter) { + TORCH_INTERNAL_ASSERT(N <= iter.ntensors()); + std::array strides; + std::array element_sizes; + for (int i = 0; i < N; i++) { + strides[i] = iter.strides(i).data(); + element_sizes[i] = iter.element_size(i); + } + return OffsetCalculator( + iter.ndim(), iter.shape().data(), strides.data(), element_sizes.data()); +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/detail/PhiloxCudaStateRaw.cuh b/phivenv/Lib/site-packages/torch/include/ATen/cuda/detail/PhiloxCudaStateRaw.cuh new file mode 100644 index 0000000000000000000000000000000000000000..7336fbccf663c0d2b514507b717dddb7584fefda --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/detail/PhiloxCudaStateRaw.cuh @@ -0,0 +1,43 @@ +// No "#pragma once" because this is a raw definition that can be copied by jit codegen. +// Eager mode clients should not include this file directly, instead, +// they should #include , which has a #pragma once. + +// Stores RNG state values. Passed as a kernel argument. +// See Note [CUDA Graph-safe RNG states]. +// +// The raw definition lives in its own file so jit codegen can easily copy it. +namespace at { + +struct PhiloxCudaState { + PhiloxCudaState() = default; + // Called if graph capture is not underway + PhiloxCudaState(uint64_t seed, + uint64_t offset) { + seed_.val = seed; + offset_.val = offset; + } + // Called if graph capture is underway + PhiloxCudaState(int64_t* seed, + int64_t* offset_extragraph, + uint32_t offset_intragraph) { + seed_.ptr = seed; + offset_.ptr = offset_extragraph; + offset_intragraph_ = offset_intragraph; + captured_ = true; + } + + // Public members, directly accessible by at::cuda::philox::unpack. + // If we made them private with getters/setters, the getters/setters + // would have to be __device__, and we can't declare __device__ in ATen. + union Payload { + uint64_t val; + int64_t* ptr; + }; + + Payload seed_{}; + Payload offset_{}; + uint32_t offset_intragraph_ = 0; + bool captured_ = false; +}; + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/detail/TensorInfo.cuh b/phivenv/Lib/site-packages/torch/include/ATen/cuda/detail/TensorInfo.cuh new file mode 100644 index 0000000000000000000000000000000000000000..dec8f789c7358c3f487c1104007a0e7318829a0c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/detail/TensorInfo.cuh @@ -0,0 +1,116 @@ +#pragma once + +#include + +namespace at::cuda::detail { + +#define MAX_TENSORINFO_DIMS 25 + +// CUDA kernel argument that defines tensor layout +template +struct TensorInfo { + TensorInfo(); + TensorInfo(T* p, + int dim, + IndexType sz[MAX_TENSORINFO_DIMS], + IndexType st[MAX_TENSORINFO_DIMS]); + + // Set the size of the given dimension to 1, as if it were a + // reduction dim (allows you to calculate offsets of the reduction + // slice) + void reduceDim(int dim); + + // See note on [collapse dims]. + int collapseDims(const int excludeDim = -1); + + // Contiguous tensors of more than one dimension are collapsed down + // to one tensor + __host__ __device__ inline bool isContiguous() const { + return (dims == 1 && strides[0] == 1); + } + + T* data; + IndexType sizes[MAX_TENSORINFO_DIMS]; + IndexType strides[MAX_TENSORINFO_DIMS]; + int dims; +}; + +template +TensorInfo::TensorInfo() { + data = nullptr; + dims = 0; +} + +template +TensorInfo::TensorInfo(T* p, + int dim, + IndexType sz[MAX_TENSORINFO_DIMS], + IndexType st[MAX_TENSORINFO_DIMS]) { + data = p; + dims = dim; + TORCH_CHECK(dims < MAX_TENSORINFO_DIMS, "CUDA Tensors cannot have more than 25 dimensions"); + + for (int i = 0; i < dim; ++i) { + sizes[i] = sz[i]; + strides[i] = st[i]; + } +} + +template +void +TensorInfo::reduceDim(int dim) { + TORCH_CHECK(dim < dims && dim >= 0, "expected dim between 0 and dims - 1"); + sizes[dim] = 1; +} + +template +int +TensorInfo::collapseDims(const int excludeDim) { + auto result = at::collapse_dims(sizes, strides, dims, excludeDim); + dims = std::get<1>(result); + return std::get<0>(result); +} + +// Translate a linear index for the apply to a T* offset; +// specialized on `Dims` to reduce nvcc compilation time +template +struct IndexToOffset { + static __host__ __device__ IndexType get( + IndexType linearId, + const TensorInfo& info) { + + IndexType offset = 0; + + // Uses static dims + for (int i = Dims - 1; i > 0; --i) { + IndexType curDimIndex = linearId % info.sizes[i]; + IndexType curDimOffset = curDimIndex * info.strides[i]; + offset += curDimOffset; + linearId /= info.sizes[i]; + } + + return offset + linearId * info.strides[0]; + } +}; + +// Uses dynamic (runtime) instead of static (compiletime) dims +template +struct IndexToOffset { + static inline __host__ __device__ IndexType get( + IndexType linearId, + const TensorInfo& info) { + + IndexType offset = 0; + + for (int i = info.dims - 1; i > 0; --i) { + IndexType curDimIndex = linearId % info.sizes[i]; + IndexType curDimOffset = curDimIndex * info.strides[i]; + offset += curDimOffset; + linearId /= info.sizes[i]; + } + + return offset + linearId * info.strides[0]; + } +}; + +} // namespace at::cuda::detail diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/detail/UnpackRaw.cuh b/phivenv/Lib/site-packages/torch/include/ATen/cuda/detail/UnpackRaw.cuh new file mode 100644 index 0000000000000000000000000000000000000000..2548733f01042a5bebcd02690a77b6349e237ac5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/detail/UnpackRaw.cuh @@ -0,0 +1,34 @@ +// No "#pragma once" because this is a raw definition that can be copied by jit codegen. +// Eager mode clients should not include this file directly, instead, +// they should #include , which has a #pragma once. + +namespace at::cuda::philox { + +// In-kernel call to retrieve philox seed and offset from a PhiloxCudaState instance whether +// that instance was created with graph capture underway or not. +// See Note [CUDA Graph-safe RNG states]. +// +// We can't write a __device__ function in CUDAGeneratorImpl.h, because it's in ATen. +// Also, whatever call unpacks PhiloxCudaState in consumer kernels must be inlineable. +// Easiest thing that comes to mind is, define a __device__ unpack helper here, in ATen/cuda. +// +// The raw definition lives in its own file so jit codegen can easily copy it. +__host__ __device__ __forceinline__ std::tuple +unpack(at::PhiloxCudaState arg) { + if (arg.captured_) { + // static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long". + // *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel. + // For most threads' reads it will hit in cache, so it shouldn't hurt performance. + return std::make_tuple(static_cast(*arg.seed_.ptr), static_cast(*(arg.offset_.ptr) + arg.offset_intragraph_)); + } else { + return std::make_tuple(arg.seed_.val, arg.offset_.val); + } +} + +// Adapted from TE +// extract seed and offset from PhiloxCudaState +__global__ void unpack_cudnn(at::PhiloxCudaState arg, int64_t* seed_ptr, int64_t* offset_ptr); + +void unpack_cudnn_wrapper(at::PhiloxCudaState arg, int64_t* seed_ptr, int64_t* offset_ptr, cudaStream_t stream); + +} // namespace at::cuda::philox diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/jiterator.h b/phivenv/Lib/site-packages/torch/include/ATen/cuda/jiterator.h new file mode 100644 index 0000000000000000000000000000000000000000..5e67b0f83c5d8a52cb1534bdbc7879138b53bdf9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/jiterator.h @@ -0,0 +1,40 @@ +#pragma once +#include + +#if AT_USE_JITERATOR() + +#include +#include +#include + +#include +#include + +namespace at::cuda { + +TORCH_CUDA_CPP_API c10::SmallVector CompileAndLaunchKernel( + const std::string& code_string, + const std::string& kernel_name, + const int num_outputs, + const c10::SmallVector& tensors, + const c10::SmallVector& extra_args, + bool return_by_ref); + +} // namespace at::cuda + +#else + +namespace at::cuda { + +TORCH_CUDA_CPP_API c10::SmallVector CompileAndLaunchKernel( + const std::string& code_string, + const std::string& kernel_name, + const int num_outputs, + const c10::SmallVector& tensors, + const c10::SmallVector& extra_args, + bool return_by_ref) { + TORCH_CHECK(false, "Jiterator is not supported"); + } +} // namespace at::cuda + +#endif // AT_USE_JITERATOR() diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/jiterator_impl.h b/phivenv/Lib/site-packages/torch/include/ATen/cuda/jiterator_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..786218b2c5971e8b0cf7acac583ed4ee73607ab0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/jiterator_impl.h @@ -0,0 +1,250 @@ +#pragma once +#include + +#if AT_USE_JITERATOR() + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace at::native { + + +#define AT_FOR_8_CASES(_) \ + _(1) \ + _(2) \ + _(3) \ + _(4) \ + _(5) \ + _(6) \ + _(7) \ + _(8) + +#define AT_FOR_8_CASES_WITH_COMMA(_) \ + _(1) , \ + _(2) , \ + _(3) , \ + _(4) , \ + _(5) , \ + _(6) , \ + _(7) , \ + _(8) + +c10::SmallVector get_extra_args_typenames(const c10::SmallVector& extra_args) { + c10::SmallVector args_typenames(extra_args.size()); + for (const auto i : c10::irange(extra_args.size())) { + args_typenames[i] = at::cuda::jit::typeName(extra_args[i].type()); + } + return args_typenames; +} + +int can_vectorize_up_to(at::ScalarType type, char* pointer) { + switch(type) { +#define DEFINE_CASE(ctype, scalartype) \ + case ScalarType::scalartype : return memory::can_vectorize_up_to(pointer); + + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE) +#undef DEFINE_CASE + + default: TORCH_INTERNAL_ASSERT(false, "Unrecognized ScalarType: ", type); + } +} + +// jitted version of the above +// See Note [Jiterator], this relies on the assumptions enumerated there +int jitted_can_vectorize_up_to(const TensorIteratorBase& iter) { + const at::ScalarType common_dtype = iter.common_dtype(); + const at::ScalarType result_dtype = common_dtype; + + // Deals with output + int result = can_vectorize_up_to(result_dtype, static_cast(iter.data_ptr(0))); + + // Incorporates input(s) + for (auto i = 1; i < iter.ntensors(); ++i) { + result = std::min(result, can_vectorize_up_to(common_dtype, static_cast(iter.data_ptr(i)))); + } + + return result; +} + +template +static std::unique_ptr> make_unique_offset_calculator( + const TensorIteratorBase& iter) { + // array size can not be 0, this happens when N == 0 + constexpr int array_size = std::max(N, 1); + TORCH_INTERNAL_ASSERT(N == (IS_INPUT ? iter.ninputs() : iter.noutputs())); + + std::array strides; + int64_t element_sizes[array_size]; + for (int i = 0; i < N; i++) { + int index = IS_INPUT ? i + iter.noutputs() : i; + strides[i] = iter.strides(index).data(); + element_sizes[i] = iter.element_size(index); + } + return std::make_unique>(iter.ndim(), iter.shape().data(), strides.data(), element_sizes); +} + +template +struct OffsetCalculatorVariant { +#define DEFINE_CASE(index) std::unique_ptr> + using OffsetCalculatorTypes = std::variant< + AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE) + >; +#undef DEFINE_CASE + + OffsetCalculatorVariant(const TensorIteratorBase& iter) { + int num = IS_INPUT ? iter.ninputs() : iter.noutputs(); + + switch(num) { +#define DEFINE_CASE(index) \ + case index : v = make_unique_offset_calculator(iter); break; + + AT_FOR_8_CASES(DEFINE_CASE) +#undef DEFINE_CASE + default: + TORCH_CHECK(false, "OffsetCalculatorVariant is not implemented for num_tensor = ", num); + } + } + + void* data_ptr() { + return std::visit([](auto & v){ return static_cast(v.get()); }, v); + } + + private: + OffsetCalculatorTypes v{}; +}; + +struct ArrayVariant { +// works for up to 8 input + 8 outputs +#define DEFINE_CASE(index) std::array, std::array + using ArrayTypes = std::variant< + AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE) + >; +#undef DEFINE_CASE + + ArrayVariant(const TensorIteratorBase& iter) { + int ntensors = iter.ntensors(); + switch(ntensors) { +#define DEFINE_CASE(index) \ + case index: array = std::array{}; break; \ + case index+8: array = std::array{}; break; + + AT_FOR_8_CASES(DEFINE_CASE) +#undef DEFINE_CASE + + default: + TORCH_CHECK(false, "ArrayVariant is not implemented for ntensors = ", ntensors); + } + + std::visit([&](auto& a) { + for (auto i = 0; i < ntensors; ++i) { + a[i] = (char*)iter.data_ptr(i); + } + }, array); + } + + void* data_ptr() { + return std::visit([](auto & a){ return static_cast(&a); }, array); + } + +private: + ArrayTypes array; +}; + +struct TrivialOffsetCalculatorVariant { +#define DEFINE_CASE(index) TrivialOffsetCalculator + using TrivialOffsetCalculatorTypes = std::variant< + AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE) + >; +#undef DEFINE_CASE + + TrivialOffsetCalculatorVariant(int num) { + switch(num) { +#define DEFINE_CASE(index) \ + case index: v = TrivialOffsetCalculator(); break; + + AT_FOR_8_CASES(DEFINE_CASE) +#undef DEFINE_CASE + + default: + TORCH_CHECK(false, "TrivialOffsetCalculatorVariant is not implemented for num_tensors = ", num); + } + } + + void* data_ptr() { + return std::visit([](auto & v){ return static_cast(&v); }, v); + } + +private: + TrivialOffsetCalculatorTypes v{}; +}; + +struct LoadWithCastVariant { +#define DEFINE_CASE(index) std::unique_ptr> + using LoadWithCastPtr = std::variant< + AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE) + >; +#undef DEFINE_CASE + + LoadWithCastVariant(const TensorIteratorBase& iter) { + int arity = iter.ninputs(); + switch(arity) { +#define DEFINE_CASE(index) \ + case index: v = std::make_unique>(iter); break; + + AT_FOR_8_CASES(DEFINE_CASE) +#undef DEFINE_CASE + + default: + TORCH_CHECK(false, "LoadWithCastVariant is not implemented for ninputs = ", arity); + } + } + + void* data_ptr() { + return std::visit([](auto & v){ return static_cast(v.get()); }, v); + } + +private: + LoadWithCastPtr v{}; +}; + +struct StoreWithCastVariant { +#define DEFINE_CASE(index) std::unique_ptr> + using StoreWithCastPtr = std::variant< + AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE) + >; +#undef DEFINE_CASE + + StoreWithCastVariant(const TensorIteratorBase& iter) { + int num = iter.noutputs(); + switch(num) { +#define DEFINE_CASE(index) \ + case index: v = std::make_unique>(iter); break; + + AT_FOR_8_CASES(DEFINE_CASE) +#undef DEFINE_CASE + + default: + TORCH_CHECK(false, "StoreWithCastVariant is not implemented for noutputs = ", num); + } + } + + void* data_ptr() { + return std::visit([](auto & v){ return static_cast(v.get()); }, v); + } + +private: + StoreWithCastPtr v{}; +}; + +} // namespace at::native + + +#endif // AT_USE_JITERATOR() diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/llvm_jit_strings.h b/phivenv/Lib/site-packages/torch/include/ATen/cuda/llvm_jit_strings.h new file mode 100644 index 0000000000000000000000000000000000000000..ec2caa7b34b80eec75210988b7d6081e368f65bf --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/llvm_jit_strings.h @@ -0,0 +1,14 @@ +#pragma once + +#include +#include + +namespace at::cuda { + +TORCH_CUDA_CPP_API const std::string &get_traits_string(); +TORCH_CUDA_CPP_API const std::string &get_cmath_string(); +TORCH_CUDA_CPP_API const std::string &get_complex_body_string(); +TORCH_CUDA_CPP_API const std::string &get_complex_half_body_string(); +TORCH_CUDA_CPP_API const std::string &get_complex_math_string(); + +} // namespace at::cuda diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/tunable/GemmCommon.h b/phivenv/Lib/site-packages/torch/include/ATen/cuda/tunable/GemmCommon.h new file mode 100644 index 0000000000000000000000000000000000000000..e7aada989881300ccc23daba8980fcf68d19df01 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/tunable/GemmCommon.h @@ -0,0 +1,694 @@ +// Original TunableOp is from onnxruntime. +// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h +// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +// Adapting TunableOp into PyTorch +// Copyright (c) Advanced Micro Devices, Inc. +// +#pragma once + +#include +#include + +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#endif +#include +#include + +namespace at::cuda::tunable { + +enum class BlasOp { + N = 0, + T = 1 +}; + +inline char BlasOpToString(BlasOp op) { + switch (op) { + case BlasOp::N: + return 'N'; + case BlasOp::T: + return 'T'; + } + TORCH_CHECK(false, "unrecognized BlasOp"); + return 'N'; +} + +template +inline const char* BLASTypeName(T v) { + return "unknown"; +} + +template <> +inline const char* BLASTypeName(float v) { + return "f32_r"; +} + +template <> +inline const char* BLASTypeName(double v) { + return "f64_r"; +} + +template <> +inline const char* BLASTypeName(BFloat16 v) { + return "bf16_r"; +} + +template <> +inline const char* BLASTypeName(Half v) { + return "f16_r"; +} + +//https://github.com/ROCm/hipBLASLt/blob/develop/library/src/include/auxiliary.hpp#L175 +template <> +inline const char* BLASTypeName(Float8_e4m3fn v) { + return "f8_r"; +} + +template <> +inline const char* BLASTypeName(Float8_e5m2 v) { + return "bf8_r"; +} + +template <> +inline const char* BLASTypeName(Float8_e4m3fnuz v) { + return "f8_fnuz_r"; +} + +template <> +inline const char* BLASTypeName(Float8_e5m2fnuz v) { + return "bf8_fnuz_r"; +} + +template <> +inline const char* BLASTypeName(c10::complex v) { + return "f64_r"; +} + +template <> +inline const char* BLASTypeName(c10::complex v) { + return "f32_r"; +} + +inline std::string ScalarTypeToBLASType(c10::ScalarType scalar_type) { + std::string BLASType; + switch (scalar_type) { + case c10::ScalarType::Float:{ + BLASType = "f32_r"; + break; + } + case c10::ScalarType::Double:{ + BLASType = "f64_r"; + break; + } + case c10::ScalarType::BFloat16:{ + BLASType = "bf16_r"; + break; + } + case c10::ScalarType::Half: { + BLASType = "f16_r"; + break; + } + case c10::ScalarType::Float8_e4m3fn: { + BLASType = "f8_r"; + break; + } + case c10::ScalarType::Float8_e5m2: { + BLASType = "bf8_r"; + break; + } + case c10::ScalarType::Float8_e4m3fnuz: { + BLASType = "f8_fnuz_r"; + break; + } + case c10::ScalarType::Float8_e5m2fnuz: { + BLASType = "bf8_fnuz_r"; + break; + } + case c10::ScalarType::ComplexFloat:{ + BLASType = "f32_c"; + break; + } + case c10::ScalarType::ComplexDouble:{ + BLASType = "f64_c"; + break; + } + default: + BLASType = "unknown"; + } + return BLASType; +} + +// Similar to Compute Type in GemmRocblas.h +template +inline std::string ComputeTypeFor() { + return "Unknown ComputeType"; +} + +// This is a union of the compute types for +// ROCBLAS and hipBLASLt. +template <> +inline std::string ComputeTypeFor() { + if (!at::globalContext().allowTF32CuBLAS()) { + return "f32_r"; + } else { + return "xf32_r"; + } +} + +template <> +inline std::string ComputeTypeFor() { + return "f64_r"; +} + +template <> +inline std::string ComputeTypeFor() { + return "f32_r"; +} + +template <> +inline std::string ComputeTypeFor() { + return "f32_r"; +} + +template <> +inline std::string ComputeTypeFor>() { + return "f32_c"; +} + +template <> +inline std::string ComputeTypeFor>() { + return "f64_c"; +} + +template <> +inline std::string ComputeTypeFor() { + return "f32_r"; +} + +template <> +inline std::string ComputeTypeFor() { + return "f32_r"; +} + +template <> +inline std::string ComputeTypeFor() { + return "f32_r"; +} + +template <> +inline std::string ComputeTypeFor() { + return "f32_r"; +} + +// Convert opmath_type to string +template +inline std::string to_string_opmath(const at::opmath_type& value) { + if constexpr (std::is_same_v, c10::complex> || + std::is_same_v, c10::complex>) { + return fmt::format("({:.4f}, {:.4f})", value.real(), value.imag()); + } else { + return fmt::format("{:.4f}", value); + } +} + +// convert activation epilogue to string +inline std::string to_string_epilogue(const at::cuda::blas::GEMMAndBiasActivationEpilogue& value) { + switch (value) { + case at::cuda::blas::GEMMAndBiasActivationEpilogue::None: + return std::string("None"); + break; + case at::cuda::blas::GEMMAndBiasActivationEpilogue::RELU: + return std::string("RELU"); + break; + case cuda::blas::GEMMAndBiasActivationEpilogue::GELU: + return std::string("GELU"); + break; + default: + return std::string("unknown"); + } +} + +namespace detail { + +static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size) { + auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA); + // comparison done as 1D tensor + at::Tensor ref = at::from_blob(c, {size}, options); + at::Tensor oth = at::from_blob(other_c, {size}, options); + at::Tensor ref_float = ref.to(at::kFloat); + at::Tensor oth_float = oth.to(at::kFloat); + std::vector atols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5}; + std::vector rtols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5}; + double last_succeed_atol = 1; + double last_succeed_rtol = 1; + for (auto& atol : atols) { + for (auto& rtol : rtols) { + if (at::allclose(ref_float, oth_float, rtol, atol)) { + last_succeed_atol = atol; + last_succeed_rtol = rtol; + } + } + } + if (last_succeed_atol == 1) { + return false; + } + else { + TUNABLE_LOG3("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol); + } + + return true; +} + +} + +// Note on GetSizeA et al. +// Tensors can be dense or arbitrarily strided. We only need our copies to be large enough. +// Our copies must be at least as large as the m n k shapes dictate, but could be larger +// depending on the lda ldb ldc values. Similarly for the batched case. + +template +struct GemmParams : OpParams { + GemmParams() = default; + + std::string BLASSignature() const override { + std::string alpha_str = to_string_opmath(alpha); + std::string beta_str = to_string_opmath(beta); + return fmt::sprintf("- { function: matmul, M: %ld, N: %ld, K: %ld, lda: %ld, ldb: %ld, ldc: %ld, ldd: %ld, stride_a: 0, stride_b: 0, stride_c: 0, stride_d: 0, " + "alpha: %s, beta: %s, transA: %c, transB: %c, batch_count: 1, a_type: %s, b_type: %s, c_type: %s, d_type: %s, scale_type: %s, bias_type: %s, compute_type: %s }", + m, n, k, lda, ldb, ldc, ldc, alpha_str, beta_str, transa, transb, + BLASTypeName(T{}), BLASTypeName(T{}), BLASTypeName(T{}), BLASTypeName(T{}), ComputeTypeFor(), ComputeTypeFor(), ComputeTypeFor()); + } + + std::string Signature() const override { + return fmt::sprintf("%c%c_%ld_%ld_%ld_ld_%ld_%ld_%ld", transa, transb, m, n, k, lda, ldb, ldc); + } + + size_t GetSizeA() const { + size_t size_stride = lda * ((transa == 'n' || transa == 'N') ? k : m); + size_t size_dense = m * k; + return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense); + } + + size_t GetSizeB() const { + size_t size_stride = ldb * ((transb == 'n' || transb == 'N') ? n : k); + size_t size_dense = k * n; + return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense); + } + + size_t GetSizeC() const { + size_t size_stride = ldc * n; + size_t size_dense = m * n; + return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense); + } + + size_t GetSize(bool duplicate_inputs) const { + size_t size = GetSizeC(); + if (duplicate_inputs) { + size += GetSizeA(); + size += GetSizeB(); + } + return size; + } + + GemmParams* DeepCopy(bool duplicate_inputs) const { + GemmParams* copy = new GemmParams; + *copy = *this; + c10::DeviceIndex device = 0; + AT_CUDA_CHECK(c10::cuda::GetDevice(&device)); + size_t c_size = GetSizeC(); + copy->c = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(c_size)); + AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync( + copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true)); + if (duplicate_inputs) { + size_t a_size = GetSizeA(); + size_t b_size = GetSizeB(); + copy->a = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(a_size)); + copy->b = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(b_size)); + copy->duplicate_inputs_ = true; + } + return copy; + } + + // only call on object returned by DeepCopy + void Delete() { + c10::cuda::CUDACachingAllocator::raw_delete(c); + if (duplicate_inputs_) { + // NOLINTNEXTLINE(*const-cast*) + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(a)); + // NOLINTNEXTLINE(*const-cast*) + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(b)); + } + } + + TuningStatus NumericalCheck(GemmParams *other) { + auto c_dtype = c10::CppTypeToScalarType::value; + return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL; + } + + char transa{}; + char transb{}; + int64_t m{}; + int64_t n{}; + int64_t k{}; + at::opmath_type alpha; + const T* a{}; + int64_t lda{}; + const T* b{}; + int64_t ldb{}; + at::opmath_type beta; + T* c{}; + int64_t ldc{}; +private: + bool duplicate_inputs_{false}; +}; + +template +struct GemmAndBiasParams : OpParams { + std::string BLASSignature() const override { + std::string alpha_str = to_string_opmath(alpha); + std::string activation_str = to_string_epilogue(activation); + return fmt::sprintf("- { function: matmul, M: %ld, N: %ld, K: %ld, lda: %ld, ldb: %ld, ldc: %ld, ldd: %ld, stride_a: 0, stride_b: 0, stride_c: 0, stride_d: 0, " + "alpha: %s, transA: %c, transB: %c, batch_count: 1, a_type: %s, b_type: %s, c_type: %s, d_type: %s, activation: %s, bias_type: %s, scale_type: %s, compute_type: %s }", + m, n, k, lda, ldb, ldc, ldc, alpha_str, transa, transb, + BLASTypeName(T{}), BLASTypeName(T{}), BLASTypeName(T{}), BLASTypeName(T{}), activation_str, BLASTypeName(T{}), ComputeTypeFor(), ComputeTypeFor(), ComputeTypeFor()); + } + + std::string Signature() const override { + return fmt::sprintf("%c%c_%ld_%ld_%ld_ld_%ld_%ld_%ld", transa, transb, m, n, k, lda, ldb, ldc); + } + + size_t GetSizeA() const { + size_t size_stride = lda * ((transa == 'n' || transa == 'N') ? k : m); + size_t size_dense = m * k; + return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense); + } + + size_t GetSizeB() const { + size_t size_stride = ldb * ((transb == 'n' || transb == 'N') ? n : k); + size_t size_dense = k * n; + return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense); + } + + size_t GetSizeC() const { + size_t size_stride = ldc * n; + size_t size_dense = m * n; + return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense); + } + + size_t GetSize(bool duplicate_inputs) const { + size_t size = GetSizeC(); + if (duplicate_inputs) { + size += GetSizeA(); + size += GetSizeB(); + } + return size; + } + + GemmAndBiasParams* DeepCopy(bool duplicate_inputs) const { + GemmAndBiasParams* copy = new GemmAndBiasParams; + *copy = *this; + c10::DeviceIndex device = 0; + AT_CUDA_CHECK(c10::cuda::GetDevice(&device)); + size_t c_size = GetSizeC(); + copy->c = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(c_size)); + AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync( + copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true)); + if (duplicate_inputs) { + size_t a_size = GetSizeA(); + size_t b_size = GetSizeB(); + copy->a = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(a_size)); + copy->b = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(b_size)); + copy->duplicate_inputs_ = true; + } + return copy; + } + + // only call on object returned by DeepCopy + void Delete() { + c10::cuda::CUDACachingAllocator::raw_delete(c); + if (duplicate_inputs_) { + // NOLINTNEXTLINE(*const-cast) + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(a)); + // NOLINTNEXTLINE(*const-cast) + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(b)); + } + } + + TuningStatus NumericalCheck(GemmAndBiasParams *other) { + auto c_dtype = c10::CppTypeToScalarType::value; + return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL; + } + + char transa{}; + char transb{}; + int64_t m{}; + int64_t n{}; + int64_t k{}; + at::opmath_type alpha{}; + const T* a{}; + int64_t lda{}; + const T* b{}; + int64_t ldb{}; + T* c{}; + int64_t ldc{}; + const T* bias{}; + at::cuda::blas::GEMMAndBiasActivationEpilogue activation{}; +private: + bool duplicate_inputs_{false}; +}; + +template +struct GemmStridedBatchedParams : OpParams { + std::string BLASSignature() const override { + std::string alpha_str = to_string_opmath(alpha); + std::string beta_str = to_string_opmath(beta); + return fmt::sprintf("- { function: matmul, M: %ld, N: %ld, K: %ld, lda: %ld, ldb: %ld, ldc: %ld, ldd: %ld, stride_a: %ld, stride_b: %ld, stride_c: %ld, stride_d: %ld, " + "alpha: %s, beta: %s, transA: %c, transB: %c, batch_count: %ld, a_type: %s, b_type: %s, c_type: %s, d_type: %s, scale_type: %s, compute_type: %s }", + m, n, k, lda, ldb, ldc, ldc, stride_a, stride_b, stride_c, stride_c, alpha_str, beta_str, transa, transb, batch, + BLASTypeName(T{}), BLASTypeName(T{}), BLASTypeName(C_Dtype{}), BLASTypeName(T{}), ComputeTypeFor(), ComputeTypeFor()); + } + + std::string Signature() const override { + return fmt::sprintf("%c%c_%ld_%ld_%ld_B_%ld_ld_%ld_%ld_%ld", transa, transb, m, n, k, batch, lda, ldb, ldc); + } + + size_t GetSizeA() const { + size_t size_stride = stride_a * batch; + size_t size_dense = m * k * batch; + return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense); + } + + size_t GetSizeB() const { + size_t size_stride = stride_b * batch; + size_t size_dense = k * n * batch; + return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense); + } + + size_t GetSizeC() const { + size_t size_stride = stride_c * batch; + size_t size_dense = m * n * batch; + return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense); + } + + size_t GetSize(bool duplicate_inputs) const { + size_t size = GetSizeC(); + if (duplicate_inputs) { + size += GetSizeA(); + size += GetSizeB(); + } + return size; + } + + GemmStridedBatchedParams* DeepCopy(bool duplicate_inputs) const { + GemmStridedBatchedParams* copy = new GemmStridedBatchedParams; + *copy = *this; + c10::DeviceIndex device = 0; + AT_CUDA_CHECK(c10::cuda::GetDevice(&device)); + size_t c_size = GetSizeC(); + copy->c = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(c_size)); + AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync( + copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true)); + if (duplicate_inputs) { + size_t a_size = GetSizeA(); + size_t b_size = GetSizeB(); + // NOLINTNEXTLINE(*const-cast*) + copy->a = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(a_size)); + // NOLINTNEXTLINE(*const-cast*) + copy->b = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(b_size)); + copy->duplicate_inputs_ = true; + } + return copy; + } + + // only call on object returned by DeepCopy + void Delete() { + c10::cuda::CUDACachingAllocator::raw_delete(c); + if (duplicate_inputs_) { + // NOLINTNEXTLINE(*const-cast*) + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(a)); + // NOLINTNEXTLINE(*const-cast*) + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(b)); + } + } + + TuningStatus NumericalCheck(GemmStridedBatchedParams *other) { + auto c_dtype = c10::CppTypeToScalarType::value; + return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL; + } + + char transa{}; + char transb{}; + int64_t m{}; + int64_t n{}; + int64_t k{}; + at::opmath_type alpha{}; + const T* a{}; + int64_t lda{}; + int64_t stride_a{}; + const T* b{}; + int64_t ldb{}; + int64_t stride_b{}; + at::opmath_type beta; + C_Dtype* c{}; + int64_t ldc{}; + int64_t stride_c{}; + int64_t batch{}; +private: + bool duplicate_inputs_{false}; +}; + +template +struct ScaledGemmParams : OpParams { + ScaledGemmParams() = default; + + std::string BLASSignature() const override { + // Excluding use_fast_accum and use_rowise booleans for now + if (bias_ptr == nullptr) { + return fmt::sprintf("- { function: matmul, M: %ld, N: %ld, K: %ld, lda: %ld, ldb: %ld, ldc: %ld, ldd: %ld, stride_a: 0, stride_b: 0, stride_c: 0, stride_d: 0, " + "transA: %c, transB: %c, batch_count: 1, scaleA: f32_r, scaleB: f32_r, a_type: %s, b_type: %s, c_type: %s, d_type: %s, scale_type: %s, compute_type: %s }", + m, n, k, lda, ldb, ldc, ldc, transa, transb, + ScalarTypeToBLASType(a_dtype), ScalarTypeToBLASType(b_dtype), ScalarTypeToBLASType(c_dtype), ScalarTypeToBLASType(c_dtype), + ComputeTypeFor(), ComputeTypeFor()); + } + else { + return fmt::sprintf("- { function: matmul, M: %ld, N: %ld, K: %ld, lda: %ld, ldb: %ld, ldc: %ld, ldd: %ld, stride_a: 0, stride_b: 0, stride_c: 0, stride_d: 0, " + "transA: %c, transB: %c, batch_count: 1, scaleA: f32_r, scaleB: f32_r, a_type: %s, b_type: %s, c_type: %s, d_type: %s, bias_type: %s, scale_type: %s, compute_type: %s }", + m, n, k, lda, ldb, ldc, ldc, transa, transb, + ScalarTypeToBLASType(a_dtype), ScalarTypeToBLASType(b_dtype), ScalarTypeToBLASType(c_dtype), ScalarTypeToBLASType(c_dtype), ScalarTypeToBLASType(bias_dtype), + ComputeTypeFor(), ComputeTypeFor()); + } + } + + std::string Signature() const override { + // In Blas.cpp, code defaults to a bias_dtype of Half even when there is no bias vector. + // Search for this line:: + // params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_; + // + // In TunableOp, we must distinguish in param signature these two cases: with and without a bias vector. + return fmt::sprintf("%c%c_%ld_%ld_%ld_ld_%ld_%ld_%ld_rw_%d_bias_%s", + transa, transb, m, n, k, lda, ldb, ldc, use_rowwise, + bias_ptr == nullptr ? "None" : at::toString(bias_dtype)); + } + + size_t GetSizeA() const { + size_t size_stride = lda * ((transa == 'n' || transa == 'N') ? k : m); + size_t size_dense = m * k; + return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense); + } + + size_t GetSizeB() const { + size_t size_stride = ldb * ((transb == 'n' || transb == 'N') ? n : k); + size_t size_dense = k * n; + return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense); + } + + size_t GetSizeC() const { + size_t size_stride = ldc * n; + size_t size_dense = m * n; + return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense); + } + + size_t GetSize(bool duplicate_inputs) const { + size_t size = GetSizeC(); + if (duplicate_inputs) { + size += GetSizeA(); + size += GetSizeB(); + } + return size; + } + + ScaledGemmParams* DeepCopy(bool duplicate_inputs) const { + ScaledGemmParams* copy = new ScaledGemmParams; + *copy = *this; + c10::DeviceIndex device = 0; + AT_CUDA_CHECK(c10::cuda::GetDevice(&device)); + size_t c_size = GetSizeC(); + copy->c = c10::cuda::CUDACachingAllocator::raw_alloc(c_size); + AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync( + copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true)); + if (duplicate_inputs) { + size_t a_size = GetSizeA(); + size_t b_size = GetSizeB(); + copy->a = c10::cuda::CUDACachingAllocator::raw_alloc(a_size); + copy->b = c10::cuda::CUDACachingAllocator::raw_alloc(b_size); + copy->duplicate_inputs_ = true; + } + return copy; + } + + // only call on object returned by DeepCopy + void Delete() { + c10::cuda::CUDACachingAllocator::raw_delete(c); + if (duplicate_inputs_) { + // NOLINTNEXTLINE(*const-cast*) + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(a)); + // NOLINTNEXTLINE(*const-cast*) + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(b)); + } + } + + TuningStatus NumericalCheck(ScaledGemmParams *other) { + return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL; + } + + char transa{}; + char transb{}; + int64_t m{}; + int64_t n{}; + int64_t k{}; + const void* a{}; + const void* a_scale_ptr{}; + int64_t lda{}; + ScalarType a_dtype{}; + ScalarType a_scale_dtype{}; + const void* b{}; + const void* b_scale_ptr{}; + int64_t ldb{}; + ScalarType b_dtype{}; + ScalarType b_scale_dtype{}; + const void* bias_ptr{}; + ScalarType bias_dtype{}; + void* c{}; + const void* c_scale_ptr{}; + int64_t ldc{}; + ScalarType c_dtype{}; + void* amax_ptr{}; + bool use_fast_accum{}; + bool use_rowwise{}; +private: + bool duplicate_inputs_{false}; +}; + +} // namespace at::cuda::tunable diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/tunable/GemmHipblaslt.h b/phivenv/Lib/site-packages/torch/include/ATen/cuda/tunable/GemmHipblaslt.h new file mode 100644 index 0000000000000000000000000000000000000000..ece61c08cfefa518c0245508c46c50aa64027bbb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/tunable/GemmHipblaslt.h @@ -0,0 +1,685 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#define TORCH_HIPBLASLT_CHECK(EXPR) \ + do { \ + hipblasStatus_t __err = EXPR; \ + TORCH_CHECK(__err == HIPBLAS_STATUS_SUCCESS, \ + "hipblaslt error: ", \ + hipblasStatusToString(__err), \ + " when calling `" #EXPR "`"); \ + } while (0) + +namespace at::cuda::tunable { + +template +constexpr hipDataType HipDataTypeFor(); + +template <> +constexpr hipDataType HipDataTypeFor() { + return HIP_R_32F; +} + +template <> +constexpr hipDataType HipDataTypeFor() { + return HIP_R_16F; +} + +template <> +constexpr hipDataType HipDataTypeFor() { + return HIP_R_16BF; +} + +template <> +constexpr hipDataType HipDataTypeFor() { + return HIP_R_64F; +} + +template <> +constexpr hipDataType HipDataTypeFor() { + return HIP_R_8F_E4M3_FNUZ; +} + +template <> +constexpr hipDataType HipDataTypeFor() { + return HIP_R_8F_E5M2_FNUZ; +} + +// This code is instantiated regardless of ROCm version. +// Prior to ROCm 6.3, we hard-code the known enum values. +template <> +constexpr hipDataType HipDataTypeFor() { +#if ROCM_VERSION >= 60300 + return HIP_R_8F_E4M3; +#else + return static_cast(28); +#endif +} + +template <> +constexpr hipDataType HipDataTypeFor() { +#if ROCM_VERSION >= 60300 + return HIP_R_8F_E5M2; +#else + return static_cast(29); +#endif +} + +// This type is not intended for matrix types but rather a scale factor. +// Return a dummy value to satisfy linker. +template <> +constexpr hipDataType HipDataTypeFor() { + return static_cast(500); +} + +template +int GetBatchFromParams(const GemmParams* params) { + return 1; +} + +template +int GetBatchFromParams(const GemmAndBiasParams* params) { + return 1; +} + +template +int GetBatchFromParams(const GemmStridedBatchedParams* params) { + return params->batch; +} + +template +int GetBatchFromParams(const ScaledGemmParams* params) { + return 1; +} + +template +int GetStrideAFromParams(const GemmParams* params) { + return 1; +} + +template +int GetStrideAFromParams(const GemmAndBiasParams* params) { + return 1; +} + +template +int GetStrideAFromParams(const GemmStridedBatchedParams* params) { + return params->stride_a; +} + +template +int GetStrideAFromParams(const ScaledGemmParams* params) { + return 1; +} + +template +int GetStrideBFromParams(const GemmParams* params) { + return 1; +} + +template +int GetStrideBFromParams(const GemmAndBiasParams* params) { + return 1; +} + +template +int GetStrideBFromParams(const GemmStridedBatchedParams* params) { + return params->stride_b; +} + +template +int GetStrideBFromParams(const ScaledGemmParams* params) { + return 1; +} + +template +int GetStrideCFromParams(const GemmParams* params) { + return 1; +} + +template +int GetStrideCFromParams(const GemmAndBiasParams* params) { + return 1; +} + +template +int GetStrideCFromParams(const GemmStridedBatchedParams* params) { + return params->stride_c; +} + +template +int GetStrideCFromParams(const ScaledGemmParams* params) { + return 1; +} + +template +float GetAlphaFromParams(const GemmParams* params) { + return params->alpha; +} + +template +float GetAlphaFromParams(const GemmAndBiasParams* params) { + return params->alpha; +} + +template +float GetAlphaFromParams(const GemmStridedBatchedParams* params) { + return params->alpha; +} + +template +float GetAlphaFromParams(const ScaledGemmParams* params) { + return 1.0; +} + +template +float GetBetaFromParams(const GemmParams* params) { + return params->beta; +} + +template +float GetBetaFromParams(const GemmAndBiasParams* params) { + return 0.0; +} + +template +float GetBetaFromParams(const GemmStridedBatchedParams* params) { + return params->beta; +} + +template +float GetBetaFromParams(const ScaledGemmParams* params) { + return 0.0; +} + +template +bool GetUseRowwiseFromParams(const GemmParams* params) { + return false; +} + +template +bool GetUseRowwiseFromParams(const GemmAndBiasParams* params) { + return false; +} + +template +bool GetUseRowwiseFromParams(const GemmStridedBatchedParams* params) { + return false; +} + +template +bool GetUseRowwiseFromParams(const ScaledGemmParams* params) { + return params->use_rowwise; +} + +template +const void* GetAScalePointerFromParams(const GemmParams* params) { + return nullptr; +} + +template +const void* GetAScalePointerFromParams(const GemmAndBiasParams* params) { + return nullptr; +} + +template +const void* GetAScalePointerFromParams(const GemmStridedBatchedParams* params) { + return nullptr; +} + +template +const void* GetAScalePointerFromParams(const ScaledGemmParams* params) { + return params->a_scale_ptr; +} + +template +const void* GetBScalePointerFromParams(const GemmParams* params) { + return nullptr; +} + +template +const void* GetBScalePointerFromParams(const GemmAndBiasParams* params) { + return nullptr; +} + +template +const void* GetBScalePointerFromParams(const GemmStridedBatchedParams* params) { + return nullptr; +} + +template +const void* GetBScalePointerFromParams(const ScaledGemmParams* params) { + return params->b_scale_ptr; +} + +template +const void* GetDScalePointerFromParams(const GemmParams* params) { + return nullptr; +} + +template +const void* GetDScalePointerFromParams(const GemmAndBiasParams* params) { + return nullptr; +} + +template +const void* GetDScalePointerFromParams(const GemmStridedBatchedParams* params) { + return nullptr; +} + +template +const void* GetDScalePointerFromParams(const ScaledGemmParams* params) { + return params->c_scale_ptr; +} + +template +const void* GetBiasPointerFromParams(const GemmParams* params) { + return nullptr; +} + +template +const void* GetBiasPointerFromParams(const GemmAndBiasParams* params) { + return params->bias; +} + +template +const void* GetBiasPointerFromParams(const GemmStridedBatchedParams* params) { + return nullptr; +} + +template +const void* GetBiasPointerFromParams(const ScaledGemmParams* params) { + return params->bias_ptr; +} + +template +hipDataType GetBiasTypeFromParams(const GemmParams* params) { + return HIP_R_32F; +} + +template +hipDataType GetBiasTypeFromParams(const GemmAndBiasParams* params) { + return HipDataTypeFor(); +} + +template +hipDataType GetBiasTypeFromParams(const GemmStridedBatchedParams* params) { + return HIP_R_32F; +} + +template +hipDataType GetBiasTypeFromParams(const ScaledGemmParams* params) { + return at::cuda::ScalarTypeToCudaDataType(params->bias_dtype); +} + +template +at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmParams* params) { + return at::cuda::blas::GEMMAndBiasActivationEpilogue::None; +} + +template +at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmAndBiasParams* params) { + return params->activation; +} + +template +at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmStridedBatchedParams* params) { + return at::cuda::blas::GEMMAndBiasActivationEpilogue::None; +} + +template +at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const ScaledGemmParams* params) { + return at::cuda::blas::GEMMAndBiasActivationEpilogue::None; +} + +static hipblasOperation_t _hipblasOpFromChar(char op) { + switch (op) { + case 'n': + case 'N': + return HIPBLAS_OP_N; + case 't': + case 'T': + return HIPBLAS_OP_T; + case 'c': + case 'C': + return HIPBLAS_OP_C; + } + TORCH_CHECK(false, + "_hipblasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`"); +} + +static char _charFromhipblasOp(hipblasOperation_t op) { + switch (op) { + case HIPBLAS_OP_N: + return 'N'; + case HIPBLAS_OP_T: + return 'T'; + case HIPBLAS_OP_C: + return 'C'; + } + TORCH_CHECK(false, + "_charFromhipblasOp input should be HIPBLAS_OP_N/T/C but got `", op, "`"); +} + +static hipblasOperation_t MapLayoutToHipBlasLt(BlasOp layout) { + if (layout == BlasOp::N) { + return HIPBLAS_OP_N; + } + return HIPBLAS_OP_T; +} + +static size_t GetHipblasltWorkspaceSize() { + static const auto env = c10::utils::get_env("HIPBLASLT_WORKSPACE_SIZE"); + // 256MB is max workspace size allowed for hipblaslt + // hipblaslt-bench uses 32MB + // recommendation from hipblaslt author was 76MB + // TunableOp hipBLASLt workspace size is aligned with + // PyTorch's default in CUDABlas.cpp (_parseChosenWorkspaceSize) + size_t workspace_size = 76*1024; + if (env) { + try { + workspace_size = std::stoi(env.value()); + } catch(std::invalid_argument const& e) { + TORCH_WARN("invalid HIPBLASLT_WORKSPACE_SIZE,", + " using default workspace size of ", workspace_size, " KiB."); + } catch(std::out_of_range const& e) { + TORCH_WARN("HIPBLASLT_WORKSPACE_SIZE out of range,", + " using default workspace size of ", workspace_size, " KiB."); + } + } + return workspace_size * 1024; +} + +template +struct HipBlasLtDeleter { + void operator()(T* x) { + if (x != nullptr) { + TORCH_CUDABLAS_CHECK(destructor(x)); + } + } +}; + +template +class HipBlasLtDescriptor { + public: + T* descriptor() const { + return descriptor_.get(); + } + T* descriptor() { + return descriptor_.get(); + } + + protected: + std::unique_ptr> descriptor_; +}; + +class HipBlasLtMatmulDescriptor : public HipBlasLtDescriptor< + hipblasLtMatmulDescOpaque_t, + &hipblasLtMatmulDescDestroy> { + public: + HipBlasLtMatmulDescriptor( + hipblasComputeType_t compute_type, + hipDataType scale_type) { + hipblasLtMatmulDesc_t raw_descriptor = nullptr; + TORCH_HIPBLASLT_CHECK( + hipblasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type)); + descriptor_.reset(raw_descriptor); + } + template + inline void setAttribute(hipblasLtMatmulDescAttributes_t attr, const T value) { + TORCH_HIPBLASLT_CHECK(::hipblasLtMatmulDescSetAttribute(descriptor(), attr, &value, sizeof(T))); + } +}; + +template +class HipblasltGemmOp : public Callable { + public: + HipblasltGemmOp(hipblasLtMatmulAlgo_t algo) : algo_{algo} {} + + TuningStatus Call(const ParamsT* params) override { + hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout); + hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout); + auto a_datatype = HipDataTypeFor(); + auto b_datatype = HipDataTypeFor(); + auto in_out_datatype = HipDataTypeFor(); + auto opa = _hipblasOpFromChar(params->transa); + auto opb = _hipblasOpFromChar(params->transb); + + TORCH_CHECK(transa_outer == opa && transb_outer == opb, "trans mismatch, shouldn't happen"); + + float alpha = GetAlphaFromParams(params); + float beta = GetBetaFromParams(params); + + hipblasLtMatrixLayout_t mat_a, mat_b, mat_c; + if (opa == HIPBLAS_OP_N) { + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, a_datatype, params->m, params->k, params->lda)); + } + else { + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, a_datatype, params->k, params->m, params->lda)); + } + if (opb == HIPBLAS_OP_N) { + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, b_datatype, params->k, params->n, params->ldb)); + } + else { + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, b_datatype, params->n, params->k, params->ldb)); + } + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_c, in_out_datatype, params->m, params->n, params->ldc)); + + // specific to batched gemmm + int batch = GetBatchFromParams(params); + if (batch > 1) { + int64_t stride_a = GetStrideAFromParams(params); + int64_t stride_b = GetStrideBFromParams(params); + int64_t stride_c = GetStrideCFromParams(params); + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute( + mat_a, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute( + mat_a, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_a, sizeof(stride_a))); + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute( + mat_b, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute( + mat_b, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_b, sizeof(stride_b))); + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute( + mat_c, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute( + mat_c, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_c, sizeof(stride_c))); + } + + hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F; + if (at::globalContext().allowTF32CuBLAS()) { + computeType = HIPBLAS_COMPUTE_32F_FAST_TF32; + } + HipBlasLtMatmulDescriptor matmul(computeType, HIP_R_32F); + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSA, opa); + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSB, opb); + + // specific to scaled gemm + const void* mat1_scale_ptr = GetAScalePointerFromParams(params); + const void* mat2_scale_ptr = GetBScalePointerFromParams(params); + const void* result_scale_ptr = GetDScalePointerFromParams(params); + if (mat1_scale_ptr && mat2_scale_ptr) { +#ifdef HIPBLASLT_VEC_EXT + if (GetUseRowwiseFromParams(params)) { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT, mat1_scale_ptr); + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT, mat2_scale_ptr); + } + else +#endif + { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr); + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr); + } +#ifdef HIPBLASLT_OUTER_VEC + if (GetUseRowwiseFromParams(params)) { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); + } +#endif + } + if (result_scale_ptr) { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr); + } + + const void* bias_ptr = GetBiasPointerFromParams(params); + auto bias_datatype = GetBiasTypeFromParams(params); + if (bias_ptr) { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_POINTER, bias_ptr); + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, bias_datatype); + auto activation = GetActivationFromParams(params); + if (activation == at::cuda::blas::GEMMAndBiasActivationEpilogue::RELU) { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_RELU_BIAS); + } + else if (activation == at::cuda::blas::GEMMAndBiasActivationEpilogue::GELU) { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_GELU_BIAS); + } + else { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_BIAS); + } + } + + size_t workspace_size = GetHipblasltWorkspaceSize(); + + auto op_handle = at::cuda::getCurrentCUDABlasLtHandle(); + + size_t ret_workspace_size = 0; + auto status = hipblaslt_ext::matmulIsAlgoSupported(op_handle, + matmul.descriptor(), + &alpha, + mat_a, + mat_b, + &beta, + mat_c, + mat_c, + algo_, + ret_workspace_size); + + if (status == HIPBLAS_STATUS_SUCCESS) { + if (ret_workspace_size >= workspace_size) { + return FAIL; + } + } + else { + return FAIL; + } + + void* workspace_buffer = nullptr; + if (workspace_size > 0) { + workspace_buffer = c10::cuda::CUDACachingAllocator::raw_alloc(workspace_size); + } + + TORCH_HIPBLASLT_CHECK(hipblasLtMatmul(op_handle, + matmul.descriptor(), + &alpha, + params->a, + mat_a, + params->b, + mat_b, + &beta, + params->c, + mat_c, + params->c, + mat_c, + &algo_, + workspace_buffer, + workspace_size, + at::cuda::getCurrentCUDAStream())); + + //TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescDestroy(matmul)); + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_a)); + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_b)); + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_c)); + if (workspace_size > 0) { + c10::cuda::CUDACachingAllocator::raw_delete(workspace_buffer); + } + return OK; + } + + private: + hipblasLtMatmulAlgo_t algo_; +}; + +template +auto GetHipBlasLtTypeStringAndOps() { + hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout); + hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout); + auto a_datatype = HipDataTypeFor(); + auto b_datatype = HipDataTypeFor(); + auto in_out_datatype = HipDataTypeFor(); + std::vector heuristic_result; +#if ROCM_VERSION == 60400 + // hipblaslt TT fp32 regression on ROCm 6.4, cannot use + if ((a_datatype == HIP_R_32F || b_datatype == HIP_R_32F || in_out_datatype == HIP_R_32F) + && (transa_outer == HIPBLAS_OP_T && transb_outer == HIPBLAS_OP_T)) { + std::vector>>> ignore; + return ignore; + } +#endif + + hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F; + if (at::globalContext().allowTF32CuBLAS()) { + computeType = HIPBLAS_COMPUTE_32F_FAST_TF32; + } + + hipblasLtHandle_t handle; + TORCH_HIPBLASLT_CHECK(hipblasLtCreate(&handle)); + TORCH_HIPBLASLT_CHECK(hipblaslt_ext::getAllAlgos(handle, + hipblaslt_ext::GemmType::HIPBLASLT_GEMM, + transa_outer, + transb_outer, + a_datatype, + b_datatype, + in_out_datatype, + in_out_datatype, + computeType, + heuristic_result)); + TORCH_HIPBLASLT_CHECK(hipblasLtDestroy(handle)); + + int returned_algo_count = heuristic_result.size(); + std::vector>>> ret; + for (int i = 0; i < returned_algo_count; i++) { + auto algo = heuristic_result[i].algo; + int algo_index = hipblaslt_ext::getIndexFromAlgo(algo); + auto callable = std::make_unique>(algo); + std::string type_string = fmt::sprintf("Gemm_Hipblaslt_%d", algo_index); + ret.emplace_back(type_string, std::move(callable)); + } + + return ret; +} + +template +auto GetHipBlasLtGemmTypeStringAndOps() { + return GetHipBlasLtTypeStringAndOps>(); +} + +template +auto GetHipBlasLtGemmAndBiasTypeStringAndOps() { + return GetHipBlasLtTypeStringAndOps>(); +} + +template +auto GetHipBlasLtGemmStridedBatchedTypeStringAndOps() { + return GetHipBlasLtTypeStringAndOps>(); +} + +template +auto GetHipBlasLtScaledGemmTypeStringAndOps() { + return GetHipBlasLtTypeStringAndOps>(); +} + +#undef TORCH_HIPBLASLT_CHECK + +} // namespace at::cuda::tunable diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/tunable/GemmRocblas.h b/phivenv/Lib/site-packages/torch/include/ATen/cuda/tunable/GemmRocblas.h new file mode 100644 index 0000000000000000000000000000000000000000..094f97d61e185eb858e71f72fa344eafcb7cc737 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/tunable/GemmRocblas.h @@ -0,0 +1,277 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include + +#define ROCBLAS_BETA_FEATURES_API +#include + +#define TORCH_ROCBLAS_CHECK(EXPR) \ + do { \ + rocblas_status __err = EXPR; \ + TORCH_CHECK(__err == rocblas_status_success, \ + "rocblas error: ", \ + rocblas_status_to_string(__err), \ + " when calling `" #EXPR "`"); \ + } while (0) + +namespace at::cuda::tunable { + +template +constexpr rocblas_datatype RocBlasDataTypeFor(); + +template <> +constexpr rocblas_datatype RocBlasDataTypeFor() { + return rocblas_datatype_f32_r; +} + +template <> +constexpr rocblas_datatype RocBlasDataTypeFor() { + return rocblas_datatype_f64_r; +} + +template <> +constexpr rocblas_datatype RocBlasDataTypeFor() { + return rocblas_datatype_f16_r; +} + +template <> +constexpr rocblas_datatype RocBlasDataTypeFor() { + return rocblas_datatype_bf16_r; +} + +template <> +constexpr rocblas_datatype RocBlasDataTypeFor>() { + return rocblas_datatype_f32_c; +} + +template <> +constexpr rocblas_datatype RocBlasDataTypeFor>() { + return rocblas_datatype_f64_c; +} + +template +constexpr rocblas_datatype RocBlasComputeTypeFor(); + +template <> +constexpr rocblas_datatype RocBlasComputeTypeFor() { + return rocblas_datatype_f32_r; +} + +template <> +constexpr rocblas_datatype RocBlasComputeTypeFor() { + return rocblas_datatype_f64_r; +} + +template <> +constexpr rocblas_datatype RocBlasComputeTypeFor() { + // Note that we're returning the _compute_ type for a given datatype. + // As of 12/2022, using compute type FP16 for 16-bit floats was much + // slower than using compute type FP32. So we use FP32 compute even for + // FP16 datatypes. This is how GEMM is implemented even in the function + // rocblasGemmHelper (see fpgeneric.h) + return rocblas_datatype_f32_r; +} + +template <> +constexpr rocblas_datatype RocBlasComputeTypeFor() { + // Note that we're returning the _compute_ type for a given datatype. + // As of 12/2022, using compute type FP16 for 16-bit floats was much + // slower than using compute type FP32. So we use FP32 compute even for + // BF16 datatypes. This is how GEMM is implemented even in the function + // rocblasGemmHelper (see fpgeneric.h) + return rocblas_datatype_f32_r; +} + +template <> +constexpr rocblas_datatype RocBlasComputeTypeFor>() { + return rocblas_datatype_f32_c; +} + +template <> +constexpr rocblas_datatype RocBlasComputeTypeFor>() { + return rocblas_datatype_f64_c; +} + +template +auto DoCastForHalfOrBfloat16(const T fp) { + return fp; +} + +template <> +inline auto DoCastForHalfOrBfloat16(const Half fp) { + // alpha and beta should be the same as compute_type, in Half case it is float. + float h = fp; + return h; +} + +template <> +inline auto DoCastForHalfOrBfloat16(const BFloat16 fp) { + // alpha and beta should be the same as compute_type, in bfloat16 case it is float. + float h = fp; + return h; +} + +static rocblas_operation _rocblasOpFromChar(char op) { + switch (op) { + case 'n': + case 'N': + return rocblas_operation_none; + case 't': + case 'T': + return rocblas_operation_transpose; + case 'c': + case 'C': + return rocblas_operation_conjugate_transpose; + } + TORCH_CHECK(false, + "_rocblasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`"); +} + +template +class RocblasGemmOp : public Callable> { + public: + RocblasGemmOp(int solution) : solution_{solution} {} + + TuningStatus Call(const GemmParams* params) override { + auto input_output_type = RocBlasDataTypeFor(); + if (at::globalContext().allowTF32CuBLAS() && input_output_type == rocblas_datatype_f32_r) + return FAIL; // no support for TF32 in rocBLAS + auto compute_type = RocBlasComputeTypeFor(); + auto h_a = DoCastForHalfOrBfloat16(params->alpha); + auto h_b = DoCastForHalfOrBfloat16(params->beta); + auto status = rocblas_gemm_ex( + (rocblas_handle)at::cuda::getCurrentCUDABlasHandle(), + _rocblasOpFromChar(params->transa), + _rocblasOpFromChar(params->transb), + params->m, params->n, params->k, + &h_a, + params->a, input_output_type, params->lda, + params->b, input_output_type, params->ldb, + &h_b, + params->c, input_output_type, params->ldc, + params->c, input_output_type, params->ldc, + compute_type, + rocblas_gemm_algo_solution_index, + solution_, + rocblas_gemm_flags_none); + if (status != rocblas_status_success) { + return FAIL; + } + return OK; + } + + private: + int solution_; +}; + +template +auto GetRocBlasGemmTypeStringAndOps() { + rocblas_handle handle = (rocblas_handle)at::cuda::getCurrentCUDABlasHandle(); + int solution_size; + auto input_output_type = RocBlasDataTypeFor(); + auto compute_type = RocBlasComputeTypeFor(); + // Get the number of available solutions + TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle, + input_output_type, + input_output_type, + compute_type, + rocblas_gemm_flags_none, + nullptr, + &solution_size)); + std::vector solutions(solution_size); + // Get the list of available solutions + TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle, + input_output_type, + input_output_type, + compute_type, + rocblas_gemm_flags_none, + solutions.data(), + &solution_size)); + std::vector>>>> ret; + for (size_t i = 0; i < solutions.size(); ++i) { + auto callable = std::make_unique>(solutions[i]); + ret.emplace_back(std::make_pair(fmt::sprintf("Gemm_Rocblas_%d", solutions[i]), std::move(callable))); + } + return ret; +} + +template +class RocblasGemmStridedBatchedOp : public Callable> { + public: + RocblasGemmStridedBatchedOp(int solution) : solution_{solution} {} + + TuningStatus Call(const GemmStridedBatchedParams* params) override { + auto input_output_type = RocBlasDataTypeFor(); + if (at::globalContext().allowTF32CuBLAS() && input_output_type == rocblas_datatype_f32_r) + return FAIL; // no support for TF32 in rocBLAS + auto compute_type = RocBlasComputeTypeFor(); + auto h_a = DoCastForHalfOrBfloat16(params->alpha); + auto h_b = DoCastForHalfOrBfloat16(params->beta); + auto status = rocblas_gemm_strided_batched_ex( + (rocblas_handle)at::cuda::getCurrentCUDABlasHandle(), + _rocblasOpFromChar(params->transa), + _rocblasOpFromChar(params->transb), + params->m, params->n, params->k, + &h_a, + params->a, input_output_type, params->lda, params->stride_a, + params->b, input_output_type, params->ldb, params->stride_b, + &h_b, + params->c, input_output_type, params->ldc, params->stride_c, + params->c, input_output_type, params->ldc, params->stride_c, + params->batch, + compute_type, + rocblas_gemm_algo_solution_index, + solution_, + rocblas_gemm_flags_none); + if (status != rocblas_status_success) { + return FAIL; + } + return OK; + } + + private: + int solution_; +}; + +template +auto GetRocBlasGemmStridedBatchedTypeStringAndOps() { + rocblas_handle handle = (rocblas_handle)at::cuda::getCurrentCUDABlasHandle(); + int solution_size; + auto input_output_type = RocBlasDataTypeFor(); + auto compute_type = RocBlasComputeTypeFor(); + // Get the number of available solutions + TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle, + input_output_type, + input_output_type, + compute_type, + rocblas_gemm_flags_none, + nullptr, + &solution_size)); + std::vector solutions(solution_size); + // Get the list of available solutions + TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle, + input_output_type, + input_output_type, + compute_type, + rocblas_gemm_flags_none, + solutions.data(), + &solution_size)); + // Sort the solutions in ascending order to make the solution vector deterministic across runs + std::sort(solutions.begin(), solutions.end()); + + std::vector>>>> ret; + for (size_t i = 0; i < solutions.size(); ++i) { + auto callable = std::make_unique>(solutions[i]); + ret.emplace_back(std::make_pair(c10::str("Gemm_Rocblas_", solutions[i]), std::move(callable))); + } + return ret; +} + +} // namespace at::cuda::tunable diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/tunable/StreamTimer.h b/phivenv/Lib/site-packages/torch/include/ATen/cuda/tunable/StreamTimer.h new file mode 100644 index 0000000000000000000000000000000000000000..7431c68879522f42fbe4ace3c5c3cd7aebd3bd5a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/tunable/StreamTimer.h @@ -0,0 +1,50 @@ +// Original TunableOp is from onnxruntime. +// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h +// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +// Adapting TunableOp into PyTorch +// Copyright (c) Advanced Micro Devices, Inc. +// +#pragma once + +#include + +#include + +namespace at::cuda::tunable { + +class StreamTimer : public ITimer { + public: + StreamTimer(); + ~StreamTimer() override; + + void Start() override; + + void End() override; + + float Duration() override; + + private: + cudaEvent_t start_{}; + cudaEvent_t end_{}; +}; + +class StreamTimerNoSync : public ITimer { + public: + StreamTimerNoSync(); + ~StreamTimerNoSync() override; + + void Start() override; + + void End() override; + + float Duration() override; + + private: + cudaEvent_t start_{}; + cudaEvent_t end_{}; +}; + +} // namespace at::cuda::tunable diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/tunable/Tunable.h b/phivenv/Lib/site-packages/torch/include/ATen/cuda/tunable/Tunable.h new file mode 100644 index 0000000000000000000000000000000000000000..b43a73420589ce547dd2bd0f6d427bc97b471744 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/tunable/Tunable.h @@ -0,0 +1,241 @@ +// Original TunableOp is from onnxruntime. +// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h +// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +// Adapting TunableOp into PyTorch +// Copyright (c) Advanced Micro Devices, Inc. +// +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define TUNABLE_LOGV(LEVEL, ...) getTuningContext()->Log(LEVEL, __VA_ARGS__) +#define TUNABLE_LOG1(...) TUNABLE_LOGV(1, __VA_ARGS__) +#define TUNABLE_LOG2(...) TUNABLE_LOGV(2, __VA_ARGS__) +#define TUNABLE_LOG3(...) TUNABLE_LOGV(3, __VA_ARGS__) + +namespace at::cuda::tunable { + +enum TORCH_CUDA_CPP_API TuningStatus { + OK = 0, + FAIL = 1, + UNSUPPORTED = 2, +}; + +// Mapping from params signature to kernel id +class TORCH_CUDA_CPP_API ResultEntry { + public: + explicit ResultEntry(std::string key, double time) : key_(std::move(key)), time_(time) {} + explicit ResultEntry(std::string key, double time, std::string blas_sig ) : key_(std::move(key)), time_(time), blas_sig_(std::move(blas_sig)) {} + bool operator==(const ResultEntry& other) const { return key_ == other.key_; } + bool operator!=(const ResultEntry& other) const { return key_ != other.key_; } + operator std::string () { return key_; } + std::string GetKey() const { return key_; } + double GetTime() const { return time_; } + friend std::ostream& operator<<(std::ostream& stream, const ResultEntry& entry); + static ResultEntry Null() { return ResultEntry("Null", 0.0); } + static ResultEntry Default() { return ResultEntry("Default", 0.0); } + + private: + std::string key_; + double time_; + std::string blas_sig_; +}; + +typedef std::unordered_map KernelMap; +typedef std::unordered_map ResultsMap; +typedef std::unordered_map> UntunedMap; + +struct TORCH_CUDA_CPP_API TuningResults { + // Validates if these results are compatible with the libraries + std::unordered_map validators; + + // Mapping from Callable signature to Callable's tuning result + ResultsMap results; +}; + +class TORCH_CUDA_CPP_API TuningResultsManager { + public: + TuningResultsManager() = default; + ~TuningResultsManager() = default; + + KernelMap Lookup(const std::string& op_signature); + + ResultEntry Lookup(const std::string& op_signature, const std::string& params_signature); + + void AddImpl(const std::string& op_signature, + const std::string& params_signature, + ResultEntry best, + KernelMap& kernel_map); + + void Add(const std::string& op_signature, + const std::string& params_signature, + ResultEntry best); + + void Delete(const std::string& op_signature, const std::string& params_signature); + + void DisjointMergeImpl( + const std::string& op_signature, + const KernelMap& kernel_map, + /*out*/ ResultsMap& results); + + void Load(const ResultsMap& results_to_load); + + ResultsMap Dump(); + + void DisjointMerge(const std::string& op_signature, const KernelMap& kernel_map); + + size_t GetSize(); + + void RecordUntuned( std::ofstream& untuned_file, const std::string& op_signature, + const std::string& params_signature, const std::string& blas_signature); + private: + std::mutex lock_; + ResultsMap results_; + UntunedMap untuned_results_; + +}; + +class TORCH_CUDA_CPP_API TuningResultsValidator { + public: + using GetFunc = std::function; + using ValidateFunc = std::function; + using GetValidateFuncs = std::unordered_map>; + + TuningResultsValidator(); + ~TuningResultsValidator() = default; + + std::unordered_map GetAllValidators() const; + TuningStatus ValidateAll(const std::unordered_map& to_validate) const; + void RegisterValidator(const std::string& key, const GetFunc& gf, const ValidateFunc& vf); + + protected: + static std::string GetPyTorchVersion() ; + TuningStatus ValidatePyTorchVersion(const std::string& value) const; + + public: + static constexpr const std::array mandatory_keys{"PT_VERSION"}; + + private: + GetValidateFuncs validators_; +}; + +class TORCH_CUDA_CPP_API TuningContext { + public: + TuningContext(); + ~TuningContext(); + TuningContext(TuningContext &) = delete; + TuningContext(TuningContext &&) = delete; + TuningContext &operator=(TuningContext &) = delete; + TuningContext &operator=(TuningContext &&) = delete; + + void EnableTunableOp(bool value); + bool IsTunableOpEnabled() const; + + void EnableTuning(bool value); + bool IsTuningEnabled() const; + + void EnableRecordUntuned(bool value); + bool IsRecordUntunedEnabled() const; + std::ofstream& GetUntunedFile(); + + void EnableNumericsCheck(bool value); + bool IsNumericsCheckEnabled() const; + + void SetMaxTuningDurationMs(int max_duration_ms); + int GetMaxTuningDurationMs() const; + + void SetMaxTuningIterations(int max_iter); + int GetMaxTuningIterations() const; + + void SetMaxWarmupDurationMs(int max_duration_ms); + int GetMaxWarmupDurationMs() const; + + void SetMaxWarmupIterations(int max_iter); + int GetMaxWarmupIterations() const; + + void EnableICacheFlush(bool value); + bool IsICacheFlushEnabled() const; + + void SetRotatingBufferSize(int size); + int GetRotatingBufferSize() const; + + TuningResultsManager& GetTuningResultsManager(); + + TuningResultsValidator& GetTuningResultsValidator(); + + TuningResults GetTuningResults(); + + TuningStatus LoadTuningResults(const TuningResults& tr); + + void SetFilename(const std::string& filename, bool insert_device_ordinal=false); + std::string GetFilename() const; + + void WriteFileOnExit(bool value); + + bool ReadFile(const std::string& filename={}); + bool WriteFile(const std::string& filename={}); + + template + void Log(int level, Types... args) { + if (GetLogOkay() && GetLogLevel() >= level) { + GetLog() << c10::str(args...) << std::endl; + } + } + + private: + std::string GetLogFilename() const; + int GetLogLevel() const; + bool GetLogOkay() const; + std::ostream& GetLog() const; + + bool enable_; + bool tuning_enable_; + bool record_untuned_enable_; + bool manager_initialized_; + bool write_file_on_exit_; + bool numerics_check_enable_; + int max_tuning_duration_ms_; + int max_tuning_iterations_; + int max_warmup_duration_ms_; + int max_warmup_iterations_; + bool icache_flush_; + int rotating_buffer_size_; + mutable TuningResultsManager manager_; + mutable c10::once_flag manager_init_once_; + TuningResultsValidator validator_; + std::string filename_; + std::ofstream untuned_file_; + size_t results_count_from_input_file_; + bool is_shutting_down_; +}; + +TORCH_CUDA_CPP_API TuningContext* getTuningContext(); + +class ITimer { + public: + ITimer() = default; + virtual ~ITimer() = default; + + virtual void Start() = 0; + virtual void End() = 0; + + /// Computes the elapsed time in milliseconds between Start() and End() + virtual float Duration() = 0; +}; + +} // namespace at::cuda::tunable diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/tunable/TunableGemm.h b/phivenv/Lib/site-packages/torch/include/ATen/cuda/tunable/TunableGemm.h new file mode 100644 index 0000000000000000000000000000000000000000..f2c100fb51984580e91a9b7ecfb9a1a4d34bdce0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/tunable/TunableGemm.h @@ -0,0 +1,327 @@ +// Original TunableOp is from onnxruntime. +// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h +// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +// Adapting TunableOp into PyTorch +// Copyright (c) Advanced Micro Devices, Inc. +// +#pragma once + +#include +#ifdef USE_ROCM +#include +#include +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::cuda::tunable { + +template +class DefaultGemmOp : public Callable> { + public: + TuningStatus Call(const GemmParams* params) override { + at::cuda::blas::gemm_internal( + params->transa, params->transb, + params->m, params->n, params->k, + params->alpha, + params->a, params->lda, + params->b, params->ldb, + params->beta, + params->c, params->ldc); + return OK; + } +}; + +static bool _transposeBoolFromChar(char op) { + return op == 't' || op == 'T'; +} + +template +class DefaultGemmAndBiasOp : public Callable> { + public: + TuningStatus Call(const GemmAndBiasParams* params) override { + at::cuda::blas::gemm_and_bias( + _transposeBoolFromChar(params->transa), + _transposeBoolFromChar(params->transb), + params->m, params->n, params->k, + params->alpha, + params->a, params->lda, + params->b, params->ldb, + params->bias, + params->c, params->ldc, + params->activation); + return OK; + } +}; + +template +class DefaultGemmStridedBatchedOp : public Callable> { + public: + TuningStatus Call(const GemmStridedBatchedParams* params) override { + at::cuda::blas::bgemm_internal( + params->transa, params->transb, + params->m, params->n, params->k, + params->alpha, + params->a, params->lda, params->stride_a, + params->b, params->ldb, params->stride_b, + params->beta, + params->c, params->ldc, params->stride_c, + params->batch); + return OK; + } +}; + +template +class DefaultScaledGemmOp : public Callable> { + public: + TuningStatus Call(const ScaledGemmParams* params) override { + at::cuda::blas::scaled_gemm( + params->transa, + params->transb, + params->m, + params->n, + params->k, + params->a, + params->a_scale_ptr, + params->lda, + params->a_dtype, + params->a_scale_dtype, + params->b, + params->b_scale_ptr, + params->ldb, + params->b_dtype, + params->b_scale_dtype, + params->bias_ptr, + params->bias_dtype, + params->c, + params->c_scale_ptr, + params->ldc, + params->c_dtype, + params->use_fast_accum, + params->use_rowwise); + return OK; + } +}; + +template +inline bool IsZero(T v) { + return v == 0.0f; +} + +template <> +inline bool IsZero(BFloat16 v) { + return v.x == 0; +} + +template <> +inline bool IsZero(Half v) { + return float(v) == 0.0f; +} + +template <> +inline bool IsZero(c10::complex v) { + return v == 0.0; +} + +template <> +inline bool IsZero(c10::complex v) { + return v == 0.0f; +} + +template +inline const char* TypeName(T v) { + return "unknown"; +} + +template <> +inline const char* TypeName(float v) { + if (at::globalContext().allowTF32CuBLAS()) { + return "tf32"; + } else { + return "float"; + } +} + +template <> +inline const char* TypeName(double v) { + return "double"; +} + +template <> +inline const char* TypeName(BFloat16 v) { + return "BFloat16"; +} + +template <> +inline const char* TypeName(Half v) { + return "Half"; +} + +template <> +inline const char* TypeName(Float8_e4m3fn v) { + return "Float8_e4m3fn"; +} + +template <> +inline const char* TypeName(Float8_e5m2 v) { + return "Float8_e5m2"; +} + +template <> +inline const char* TypeName(Float8_e4m3fnuz v) { + return "Float8_e4m3fnuz"; +} + +template <> +inline const char* TypeName(Float8_e5m2fnuz v) { + return "Float8_e5m2fnuz"; +} + +template <> +inline const char* TypeName(Float8_e8m0fnu v) { + return "Float8_e8m0fnu"; +} + +template <> +inline const char* TypeName(c10::complex v) { + return "c10::complex"; +} + +template <> +inline const char* TypeName(c10::complex v) { + return "c10::complex"; +} + +template +class GemmTunableOp : public TunableOp> { + public: + GemmTunableOp() { + this->RegisterOp(std::string("Default"), std::make_unique>()); + +#ifdef USE_ROCM + static const auto env_rocblas = c10::utils::check_env("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED"); + if (!env_rocblas.has_value() || env_rocblas.value()) { + for (auto&& [name, op] : GetRocBlasGemmTypeStringAndOps()) { + this->RegisterOp(std::move(name), std::move(op)); + } + } + + static const auto env_hipblaslt = c10::utils::check_env("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); + if (!env_hipblaslt.has_value() || env_hipblaslt.value()) { + // disallow tuning of hipblaslt with c10::complex + if constexpr ( + !std::is_same_v> && + !std::is_same_v>) { + for (auto&& [name, op] : GetHipBlasLtGemmTypeStringAndOps()) { + this->RegisterOp(std::move(name), std::move(op)); + } + } + } +#endif + + this->RegisterOp(std::string("Default"), std::make_unique>()); + } + + std::string Signature() override { + return fmt::sprintf("GemmTunableOp_%s_%c%c", TypeName(T{}), BlasOpToString(ALayout), BlasOpToString(BLayout)); + } +}; + +template +class GemmAndBiasTunableOp : public TunableOp> { + public: + GemmAndBiasTunableOp() { + this->RegisterOp(std::string("Default"), std::make_unique>()); + +#ifdef USE_ROCM + static const auto env_hipblaslt = c10::utils::check_env("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); + if (!env_hipblaslt.has_value() || env_hipblaslt.value()) { + // disallow tuning of hipblaslt with c10::complex + if constexpr ( + !std::is_same_v> && + !std::is_same_v>) { + for (auto&& [name, op] : GetHipBlasLtGemmAndBiasTypeStringAndOps()) { + this->RegisterOp(std::move(name), std::move(op)); + } + } + } +#endif + + this->RegisterOp(std::string("Default"), std::make_unique>()); + } + + std::string Signature() override { + return fmt::sprintf("GemmAndBiasTunableOp_%s_%c%c", TypeName(T{}), BlasOpToString(ALayout), BlasOpToString(BLayout)); + } +}; + +template +class GemmStridedBatchedTunableOp : public TunableOp> { + public: + GemmStridedBatchedTunableOp() { + this->RegisterOp(std::string("Default"), std::make_unique>()); + +#ifdef USE_ROCM + static const auto env_rocblas = c10::utils::check_env("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED"); + if (!env_rocblas.has_value() || env_rocblas.value()) { + for (auto&& [name, op] : GetRocBlasGemmStridedBatchedTypeStringAndOps()) { + this->RegisterOp(std::move(name), std::move(op)); + } + } + + static const auto env_hipblaslt = c10::utils::check_env("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); + if (!env_hipblaslt.has_value() || env_hipblaslt.value()) { + // disallow tuning of hipblaslt with c10::complex + if constexpr ( + !std::is_same_v> && + !std::is_same_v>) { + for (auto&& [name, op] : GetHipBlasLtGemmStridedBatchedTypeStringAndOps()) { + this->RegisterOp(std::move(name), std::move(op)); + } + } + } +#endif + + this->RegisterOp(std::string("Default"), std::make_unique>()); + } + + std::string Signature() override { + return fmt::sprintf("GemmStridedBatchedTunableOp_%s_%c%c", TypeName(T{}), BlasOpToString(ALayout), BlasOpToString(BLayout)); + } +}; + +template +class ScaledGemmTunableOp : public TunableOp> { + public: + ScaledGemmTunableOp() { + this->RegisterOp(std::string("Default"), std::make_unique>()); + +#ifdef USE_ROCM + for (auto&& [name, op] : GetHipBlasLtScaledGemmTypeStringAndOps()) { + this->RegisterOp(std::move(name), std::move(op)); + } +#endif + + this->RegisterOp(std::string("Default"), std::make_unique>()); + } + + std::string Signature() override { + return fmt::sprintf("ScaledGemmTunableOp_%s_%s_%s_%c%c", + TypeName(AT{}), + TypeName(BT{}), + TypeName(CT{}), + BlasOpToString(ALayout), BlasOpToString(BLayout)); + } +}; + +} // namespace at::cuda::tunable diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cuda/tunable/TunableOp.h b/phivenv/Lib/site-packages/torch/include/ATen/cuda/tunable/TunableOp.h new file mode 100644 index 0000000000000000000000000000000000000000..e805057ee448043bda85209beb2cc21af3da6cff --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cuda/tunable/TunableOp.h @@ -0,0 +1,430 @@ +// Original TunableOp is from onnxruntime. +// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h +// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +// Adapting TunableOp into PyTorch +// Copyright (c) Advanced Micro Devices, Inc. +// +#pragma once + +#include +#include +#include +#include + +#ifndef _WIN32 +#include +#endif + +#include +#include +#include +#include + +namespace at::cuda::tunable { + +template +class Callable { + public: + virtual ~Callable() = default; + virtual TuningStatus Call(const ParamsT*) { + return FAIL; + } + virtual TuningStatus IsSupported(const ParamsT* params) { + return Call(params); + } +}; + +namespace { + +/** http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance */ + +class Stats { + public: + Stats() { + _n = 0UL; + _mean = 0.0; + _M2 = 0.0; + _sum = 0.0; + _min = 0.0; + _max = 0.0; + } + + void sample_value(const double x) { + double delta = 0; + _sum = _sum + x; + if (0UL == _n) { + _min = x; + _max = x; + } + else { + _min = _min < x ? _min : x; + _max = _max > x ? _max : x; + } + _n = _n + 1UL; + delta = x - _mean; + _mean = _mean + delta/_n; + _M2 = _M2 + delta * (x - _mean); + } + + double variance() const { + return _M2/(_n-1); + } + + double stddev() const { + return std::sqrt(variance()); + } + + unsigned long _n; + double _mean; + double _M2; + double _sum; + double _min; + double _max; +}; + +class FixedSizeStack { + private: + std::deque stack; + const size_t max_size; + + public: + FixedSizeStack(size_t size) : max_size(size) {} + + void push(const std::string& value) { + if (stack.size() >= max_size) { + stack.pop_front(); // Remove the oldest entry + } + stack.push_back(value); // Add new entry + } + + auto rbegin() { return stack.rbegin(); } + auto rend() { return stack.rend(); } +}; + +} // anonymous namespace + +template +class TunableOp { + public: + virtual ~TunableOp() = default; + + TuningStatus operator()(const ParamsT* params) { + ResultEntry result = ResultEntry::Null(); + TuningContext* ctx = getTuningContext(); + if (ctx->IsTunableOpEnabled()) { + auto& mgr = ctx->GetTuningResultsManager(); + auto op_sig = Signature(); + auto params_sig = params->Signature(); + auto blas_sig = params->BLASSignature(); + result = mgr.Lookup(op_sig, params_sig); + // If there is not previous tuning result been found, we do the tuning iff tuning is enabled + if (result == ResultEntry::Null()) { + if (ctx->IsTuningEnabled()) { + result = FindFastest(params); + mgr.Add(op_sig, params_sig, result); + } + else if (ctx->IsRecordUntunedEnabled()) { + // or record the gemm into file + mgr.RecordUntuned(ctx->GetUntunedFile(), op_sig, params_sig, blas_sig); + } + } + } + else { + result = ResultEntry::Default(); + } + if (result == ResultEntry::Null()) { + TUNABLE_LOG2("no result, using default"); + result = ResultEntry::Default(); + } + auto iter = ops_.find(result); + TORCH_CHECK(iter != ops_.end()); + return iter->second->Call(params); + } + + virtual std::string Signature() { + // According to C++17 standard https://wg21.link/n4659 section 15.7.4 + // > if the operand of typeid refers to the + // > object under construction or destruction, typeid yields the std::type_info object representing the constructor + // > or destructor’s class. + // So delay the op signature generation. + c10::call_once(signature_init_once_, [this]() { signature_ = CreateSignature(); }); + return signature_; + } + + protected: + void RegisterOp(const std::string& name, std::unique_ptr> op) { + this->op_names_.emplace_back(name); + this->ops_.emplace(name, std::move(op)); + } + + private: + static void WarmUp(Callable *op, const std::vector ¶m, size_t num_iter, size_t &offset) { + TuningContext* ctx = getTuningContext(); + bool do_flush = ctx->IsICacheFlushEnabled(); + for (size_t i = 0; i < num_iter; i++) { + if (do_flush) { + at::cuda::flush_icache(); + } + TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK); + } + } + + static double ProfileSimple(Callable *op, const std::vector ¶m, size_t num_iter, size_t &offset) { + TuningContext* ctx = getTuningContext(); + bool do_flush = ctx->IsICacheFlushEnabled(); + StreamTimerNoSync timer{}; + + // Small Mandatory Warmup + // Reduces outliers + for (size_t i = 0; i < 2; i++) { + TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK); + } + + timer.Start(); + for (size_t i = 0; i < num_iter; i++) { + if (do_flush) { + at::cuda::flush_icache(); + } + TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK); + } + timer.End(); + return timer.Duration() / num_iter; + } + + static Stats ProfileStats(Callable *op, const std::vector ¶m, size_t num_iter, size_t &offset) { + TuningContext* ctx = getTuningContext(); + bool do_flush = ctx->IsICacheFlushEnabled(); + std::vector timer(num_iter); + + // Small Mandatory Warmup + // Reduces outliers + for (size_t i = 0; i < 2; i++) { + TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK); + } + + for (size_t i = 0; i < num_iter; i++) { + timer[i].Start(); + TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK); + timer[i].End(); + if (do_flush) { + at::cuda::flush_icache(); + } + } + Stats s; + for (size_t i = 0; i < num_iter; i++) { + s.sample_value(timer[i].Duration()); + } + return s; + } + + protected: + virtual ResultEntry FindFastest(const ParamsT* params) { + TuningContext* ctx = getTuningContext(); + auto op_sig = Signature(); + auto params_sig = params->Signature(); + auto blas_sig = params->BLASSignature(); + TUNABLE_LOG2("finding fastest for ", op_sig, '(', params_sig, ')', " out of ", op_names_.size(), " candidates"); + auto min_duration_ms = std::numeric_limits::infinity(); + std::string id_name = "Default"; + ParamsT* reference_params = nullptr; + auto top_solns = FixedSizeStack(5); + + // numeric check option is controlled by non-static env var, so check it once per tuned operator + bool do_numerics_check = ctx->IsNumericsCheckEnabled(); + + // calcaulte a reference answer for numerical check + if (do_numerics_check) { + reference_params = params->DeepCopy(false); + TORCH_CHECK(ops_[ResultEntry::Default()]->Call(reference_params) == OK); + } + + // need copies of params to reuse + // make as many copies as will fill the requested rotating buffer size, if requested + // rotating_size guaranteed to be >= 0 even though GetRotatingBufferSize() returns int + size_t rotating_size = ctx->GetRotatingBufferSize(); + bool use_buffer_rotation = (rotating_size > 0); + size_t param_size = params->GetSize(use_buffer_rotation); + size_t param_count = (rotating_size / param_size) + 1; + constexpr size_t MB = 1024ull*1024; + if (use_buffer_rotation) { + TUNABLE_LOG2("Rotating buffer ", rotating_size/MB, " MiB. ", + "Needed Size: ", param_size/MB, " MiB. ", + "Needed number of param copies: ", param_count); + } + TORCH_CHECK(param_count > 0); + + std::vector reusable_params(param_count); + for (size_t i = 0; i < param_count; i++) { + reusable_params[i] = params->DeepCopy(use_buffer_rotation); + } + + // for rotating buffer + size_t offset = 0; + + for (size_t i = 0; i < op_names_.size(); i++) { + auto* candidate = ops_[op_names_[i]].get(); // borrow pointer + + if (do_numerics_check) { + ParamsT* numerical_params = params->DeepCopy(false); + auto status = candidate->Call(numerical_params); + if (status != OK) { + numerical_params->Delete(); + TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); + continue; + } + status = reference_params->NumericalCheck(numerical_params); + numerical_params->Delete(); + if (status != OK) { + TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); + continue; + } + } + else { + auto status = candidate->Call(reusable_params[0]); + if (status != OK) { + TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); + continue; + } + } + + // collect a small profile + int approx_num_iter = 3; + auto s = ProfileStats(candidate, reusable_params, approx_num_iter, offset); + double approx_duration = s._mean; + // bail if too slow + if (approx_duration > 1.5 * min_duration_ms) { + TUNABLE_LOG3("├──skip slow instance id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); + continue; + } + + // 2nd phase skip, more aggressive + approx_num_iter = 10; + s = ProfileStats(candidate, reusable_params, approx_num_iter, offset); + approx_duration = s._mean; + // bail if too slow + if (approx_duration > 1.15 * min_duration_ms) { + TUNABLE_LOG3("├──2nd skip slow instance id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); + continue; + } + + // for warmup does user set max duration, max iters, or both? + // warmup is skipped by default, i.e. warmup_iter = 0 + // warmup will be set to the non-zero value of max_warmup_duration + // or max_warmup_iter + // if both are non-zero, we take the smaller of the two. + double max_warmup_duration = ctx->GetMaxWarmupDurationMs(); + int max_warmup_iter = ctx->GetMaxWarmupIterations(); + int warmup_iter = 0; // default + if (max_warmup_duration > 0) { + int duration_iters = max_warmup_duration / approx_duration; + if (max_warmup_iter > 0) { + warmup_iter = std::min(max_warmup_iter, duration_iters); + } + else { + warmup_iter = duration_iters; + } + } + else if (max_warmup_iter > 0) { + warmup_iter = max_warmup_iter; + } + + // for tuning does user set max duration, max iters, or both? + double max_tuning_duration = ctx->GetMaxTuningDurationMs(); + int max_tuning_iter = ctx->GetMaxTuningIterations(); + int tuning_iter = 100; // default + if (max_tuning_duration > 0) { + int duration_iters = max_tuning_duration / approx_duration; + if (max_tuning_iter > 0) { + tuning_iter = std::min(max_tuning_iter, duration_iters); + } + else { + tuning_iter = duration_iters; + } + } + else if (max_tuning_iter > 0) { + tuning_iter = max_tuning_iter; + } + // tuning must run at least 1 iteration + tuning_iter = std::max(1, tuning_iter); + + // do the full warmup followed by tuning + double warmup_ms = warmup_iter * approx_duration; + double tuning_ms = tuning_iter * approx_duration; + TUNABLE_LOG3("├──tuning using " + "warmup iters ", warmup_iter, " [", warmup_ms, " ms] " + "and tuning iters ", tuning_iter, " [", tuning_ms, " ms] ", + "instance id=", i, ", ", op_sig, "(", params_sig, ") ", op_names_[i]); + TUNABLE_LOG3("├──offset at ", offset); + WarmUp(candidate, reusable_params, warmup_iter, offset); + s = ProfileStats(candidate, reusable_params, tuning_iter, offset); + auto s_stddev = s.stddev(); + // Assume normal distribution. + // Solution with smallest mean + 2*sigma will be a better solution? + // if ((s._mean + 2*s_stddev) < (min_duration_ms + 2*min_stddev_ms)) { + if (s._mean < min_duration_ms) { + TUNABLE_LOG3("├──found better instance id=", i, ". " , s._mean, "ms. ", op_names_[i], + " min ", s._min, + " max ", s._max, + " mean ", s._mean, + " std ", s_stddev); + min_duration_ms = s._mean; + id_name = op_names_[i]; + std::string current_soln = std::to_string(s._mean) + " " + op_names_[i]; + top_solns.push(current_soln); + } + else { + TUNABLE_LOG3("├──found slower instance id=", i, ". " , s._mean, "ms. ", op_names_[i], + " min ", s._min, + " max ", s._max, + " mean ", s._mean, + " std ", s_stddev); + } + } + + for (size_t i = 0; i < reusable_params.size(); i++) { + reusable_params[i]->Delete(); + } + if (reference_params) { + reference_params->Delete(); + } + + TUNABLE_LOG2("└──found fastest for ", op_sig, '(', params_sig, ") ", id_name); + TUNABLE_LOG2("└──top five solutions for ", op_sig, '(', params_sig, ") "); + for (auto it = top_solns.rbegin(); it != top_solns.rend(); ++it) { + TUNABLE_LOG2(" ", *it); + } + return ResultEntry(id_name, min_duration_ms, blas_sig); + } + + private: + std::string CreateSignature() { +#ifndef _WIN32 + const auto* name = typeid(*this).name(); + // NOLINTNEXTLINE(*array*) + char buf[256]; + size_t buf_len = 256; + abi::__cxa_demangle(name, buf, &buf_len, nullptr); + buf[255] = '\0'; + return buf; +#else + return typeid(*this).name(); +#endif + } + + mutable c10::once_flag signature_init_once_; + std::string signature_; + + std::unordered_map>> ops_; + std::vector op_names_; +}; + +struct OpParams { + virtual ~OpParams() = default; + virtual std::string Signature() const = 0; + virtual std::string BLASSignature() const = 0; +}; + +} // namespace at::cuda::tunable diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cudnn/Descriptors.h b/phivenv/Lib/site-packages/torch/include/ATen/cudnn/Descriptors.h new file mode 100644 index 0000000000000000000000000000000000000000..aadee5d7a5041b30556e9729b9ab6f6375f66322 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cudnn/Descriptors.h @@ -0,0 +1,409 @@ +#pragma once + +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8907 +#define USE_CUDNN_RNN_V8_API +#endif + +namespace at::native { + +std::string cudnnTypeToString(cudnnDataType_t dtype); + +// TODO: Add constructors for all of the descriptors + +inline int dataSize(cudnnDataType_t dataType) +{ + switch (dataType) { + case CUDNN_DATA_BFLOAT16: + case CUDNN_DATA_HALF: return 2; + case CUDNN_DATA_FLOAT: return 4; + default: return 8; + } +} + +// The stride for a size-1 dimensions is not uniquely determined; in +// fact, it can be anything you want, because the fact that the +// tensor is size 1 at this dimension means that you will never actually +// try advancing your pointer by this stride. +// +// However, CuDNN has a much more stringent requirement on strides: +// if you are passing a contiguous input, it better be the case +// that the stride for dim i is the product of the sizes of dims +// i+1 to the end. This stride is indeed uniquely determined. This +// function modifies 'stride' in place so this invariant holds. +template +static inline void fixSizeOneDimStride(int dim, const T *size, T *stride, bool nhwc) { + int64_t z = 1; + int index = 0; + std::vector permutation(dim); + + if (nhwc) { + permutation[index++] = 1; + } + for (int d = dim-1; d > 1; d--) { + permutation[index++] = d; + } + if (!nhwc) { + permutation[index++] = 1; + } + permutation[index++] = 0; + for (int d : permutation) { + if (size[d] == 1) { + stride[d] = z; + } else { + z *= size[d]; + } + } +} + +template +struct DescriptorDeleter { + void operator()(T* x) { + if (x != nullptr) { + AT_CUDNN_CHECK(dtor(x)); + } + } +}; + +// A generic class for wrapping cuDNN descriptor types. All you need +// is to give the underlying type the Descriptor_t points to (usually, +// if it's cudnnTensorDescriptor_t it points to cudnnTensorStruct), +// the constructor and the destructor. Subclasses are responsible +// for defining a set() function to actually set the descriptor. +// +// Descriptors default construct to a nullptr, and have a descriptor +// initialized the first time you call set() or any other initializing +// function. +template +// NOLINTNEXTLINE(bugprone-exception-escape) +class TORCH_CUDA_CPP_API Descriptor { + public: + // TODO: Figure out why const-correctness doesn't work here + + // Use desc() to access the underlying descriptor pointer in + // a read-only fashion. Most client code should use this. + // If the descriptor was never initialized, this will return + // nullptr. + T* desc() const { return desc_.get(); } + T* desc() { return desc_.get(); } + + // Use mut_desc() to access the underlying descriptor pointer + // if you intend to modify what it points to (e.g., using + // cudnnSetFooDescriptor). This will ensure that the descriptor + // is initialized. Code in this file will use this function. + T* mut_desc() { init(); return desc_.get(); } +protected: + void init() { + if (desc_ == nullptr) { + T* raw_desc = nullptr; + AT_CUDNN_CHECK(ctor(&raw_desc)); + desc_.reset(raw_desc); + } + } +private: + std::unique_ptr> desc_; +}; + +class TORCH_CUDA_CPP_API RNNDataDescriptor : public Descriptor< + cudnnRNNDataStruct, + &cudnnCreateRNNDataDescriptor, + &cudnnDestroyRNNDataDescriptor> { +public: + void set(const at::Tensor &t, cudnnRNNDataLayout_t layout, int maxSeqLength, int batchSize, int vectorSize, const int* seqLengthArray); +private: + void set(cudnnDataType_t dataType, cudnnRNNDataLayout_t layout, int maxSeqLength, int batchSize, int vectorSize, const int* seqLengthArray) { + AT_CUDNN_CHECK(cudnnSetRNNDataDescriptor(mut_desc(), dataType, layout, maxSeqLength, batchSize, vectorSize, seqLengthArray, nullptr)); + } +}; + +class TORCH_CUDA_CPP_API TensorDescriptor : public Descriptor< + cudnnTensorStruct, + &cudnnCreateTensorDescriptor, + &cudnnDestroyTensorDescriptor> { + public: + TensorDescriptor() = default; + explicit TensorDescriptor(const at::Tensor &t, size_t pad = 0) { + set(t, pad); + } + + // Note [CuDNN broadcast padding] + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // pad specifies the minimum dimensionality of the tensor descriptor + // we produce (it doesn't have anything to do with, e.g., convolution + // padding). If 't' is lower-dimensional than 'pad', the remaining + // dimensions (on the right) are padded with ones. This doesn't + // affect the underlying data layout. This is particularly useful for + // dealing with a peculiarity of the CuDNN API, which is that broadcasting in CuDNN is + // done in two steps: first, the client code is expected to pad out + // (the dimensions) input tensors to be the same dimension as the + // target broadcast, and then second, CuDNN takes of actually + // broadcasting size 1 dimensions. + + void set(const at::Tensor &t, size_t pad = 0); + void set(const at::Tensor &t, at::MemoryFormat memory_format, size_t pad = 0); + void set(cudnnDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad = 0); + + void print(); + +private: + void set(cudnnDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad, bool nhwc); + + void set(cudnnDataType_t dataType, int dim, int* size, int* stride, bool nhwc) { + std::vector strides_copy(stride, stride + dim); + fixSizeOneDimStride(dim, size, strides_copy.data(), nhwc); + AT_CUDNN_CHECK(cudnnSetTensorNdDescriptor(mut_desc(), dataType, dim, size, strides_copy.data())); + } +}; + +std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d); + +class TORCH_CUDA_CPP_API FilterDescriptor : public Descriptor< + cudnnFilterStruct, + &cudnnCreateFilterDescriptor, + &cudnnDestroyFilterDescriptor> { + public: + void set(const at::Tensor &t, int64_t pad = 0) { + set(t, at::MemoryFormat::Contiguous, pad); + } + + void set(const at::Tensor &t, const at::MemoryFormat memory_format, int64_t pad = 0); + + void print(); +private: + void set(cudnnDataType_t dataType, int dim, int* size, cudnnTensorFormat_t filter_format) { + AT_CUDNN_CHECK(cudnnSetFilterNdDescriptor(mut_desc(), dataType, filter_format, dim, size)); + } +}; + +std::ostream& operator<<(std::ostream & out, const FilterDescriptor& d); + +struct TORCH_CUDA_CPP_API ConvolutionDescriptor + : public Descriptor< + cudnnConvolutionStruct, + &cudnnCreateConvolutionDescriptor, + &cudnnDestroyConvolutionDescriptor> { + void set(cudnnDataType_t dataType, int dim, int* pad, int* stride, int * upscale /* aka dilation */, int groups, bool allow_tf32) { + cudnnDataType_t mathType = dataType; + if (dataType == CUDNN_DATA_HALF) mathType = CUDNN_DATA_FLOAT; + AT_CUDNN_CHECK(cudnnSetConvolutionNdDescriptor(mut_desc(), dim, pad, stride, upscale, + CUDNN_CROSS_CORRELATION, mathType)); + AT_CUDNN_CHECK(cudnnSetConvolutionGroupCount(mut_desc(), groups)); + // See Note [behavior of cudnnFind and cudnnGet] + AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_DEFAULT_MATH)); + if(dataType == CUDNN_DATA_HALF) { + AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_TENSOR_OP_MATH)); + } else if (dataType == CUDNN_DATA_FLOAT && !allow_tf32) { + AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_FMA_MATH)); + } + } +}; + +struct TORCH_CUDA_CPP_API SpatialTransformerDescriptor + : public Descriptor< + cudnnSpatialTransformerStruct, + &cudnnCreateSpatialTransformerDescriptor, + &cudnnDestroySpatialTransformerDescriptor> { + void set(cudnnDataType_t dataType, int dim, int* size) { + AT_CUDNN_CHECK(cudnnSetSpatialTransformerNdDescriptor(mut_desc(), CUDNN_SAMPLER_BILINEAR, dataType, dim, size)); + } +}; + +// NOLINTNEXTLINE(bugprone-exception-escape) +struct TORCH_CUDA_CPP_API DropoutDescriptor + : public Descriptor< + cudnnDropoutStruct, + &cudnnCreateDropoutDescriptor, + &cudnnDestroyDropoutDescriptor> { + at::Tensor state; + + // Initialize a dropout descriptor's RNG state. + // WARNING: This function is very expensive, avoid calling this function! + void initialize_rng(cudnnHandle_t handle, float dropout, long long int seed, const TensorOptions& options) { + TORCH_INTERNAL_ASSERT(dropout > 0, "dropout must be nonzero; otherwise call set_no_dropout"); + size_t state_size = 0; + AT_CUDNN_CHECK(cudnnDropoutGetStatesSize(handle, &state_size)); + AT_ASSERT(options.device().type() == kCUDA); + AT_ASSERT(options.dtype() == kByte); + state = at::empty({static_cast(state_size)}, options); + AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, dropout, state.data_ptr(), state_size, seed)); + } + + // Restore a dropout descriptor given a dropout probability and existing RNG state. + void set(cudnnHandle_t handle, float dropout, const at::Tensor& state) { + TORCH_INTERNAL_ASSERT(dropout > 0, "dropout must be nonzero; otherwise call set_no_dropout"); + void *state_ptr = state.data_ptr(); + size_t state_size = state.size(0); + // NB: The seed doesn't actually matter, so we give a dummy value + AT_CUDNN_CHECK(cudnnRestoreDropoutDescriptor(mut_desc(), handle, dropout, state_ptr, state_size, 0 /* seed */)); + } + + // Restore a dropout descriptor corresponding to no dropout + void set_no_dropout(cudnnHandle_t handle) { + // NB: seed doesn't matter when dropout = 0, because no random number + // initialization actually takes place when there is no dropout. + // NB: Empirically, cudnnSetDropoutDescriptor is cheap when + // dropout == 0 + AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, 0 /* dropout */, nullptr, 0 /* state_size */, 0 /* seed */)); + } +}; + +struct TORCH_CUDA_CPP_API RNNDescriptor : public Descriptor< + cudnnRNNStruct, + &cudnnCreateRNNDescriptor, + &cudnnDestroyRNNDescriptor> { + DropoutDescriptor dropout_desc_; + void set(cudnnHandle_t handle, +#ifdef USE_CUDNN_RNN_V8_API + int input_size, + bool packed, +#endif + int hidden_size, int proj_size, int num_layers, DropoutDescriptor&& dropout_desc, + cudnnRNNInputMode_t input_mode, cudnnDirectionMode_t bidirectional, + cudnnRNNMode_t mode, cudnnDataType_t datatype, cudnnDataType_t input_type, cudnnRNNAlgo_t algo, bool allow_tf32) { + dropout_desc_ = std::move(dropout_desc); +#ifndef USE_CUDNN_RNN_V8_API + AT_CUDNN_CHECK(cudnnSetRNNDescriptor_v6( + handle, + mut_desc(), + hidden_size, + num_layers, + dropout_desc_.desc(), + input_mode, + bidirectional, + mode, + algo, + datatype)); + if (proj_size != 0) { + AT_CUDNN_CHECK(cudnnSetRNNProjectionLayers( + handle, + /*rnnDesc=*/mut_desc(), + /*recProjSize=*/proj_size, + /*outProjSize=*/0)); + } + cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); + if (prop->major >= 7) { + if (input_type == CUDNN_DATA_HALF) { + cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_TENSOR_OP_MATH); + } + else if (input_type == CUDNN_DATA_FLOAT && !allow_tf32) { + cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_FMA_MATH); + } + else { + // Technically, as the default it's not necessary to explicitly + // set this. + cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_DEFAULT_MATH); + } + } +#else + cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); + auto math_type = CUDNN_DEFAULT_MATH; + if (prop->major >= 7) { + if (input_type == CUDNN_DATA_HALF) { + math_type = CUDNN_TENSOR_OP_MATH; + } else if (!allow_tf32) { + math_type = CUDNN_FMA_MATH; + } + } + AT_CUDNN_CHECK(cudnnSetRNNDescriptor_v8( + mut_desc(), + algo, + mode, + CUDNN_RNN_DOUBLE_BIAS, + bidirectional, + input_mode, + input_type, + datatype, + math_type, + input_size, + hidden_size, + proj_size ? proj_size : hidden_size, + num_layers, + dropout_desc_.desc(), + packed ? CUDNN_RNN_PADDED_IO_DISABLED : CUDNN_RNN_PADDED_IO_ENABLED)); +#endif + } +}; + +struct TORCH_CUDA_CPP_API CTCLossDescriptor + : public Descriptor< + cudnnCTCLossStruct, + &cudnnCreateCTCLossDescriptor, + &cudnnDestroyCTCLossDescriptor> { + void set(cudnnDataType_t datatype) { + AT_CUDNN_CHECK(cudnnSetCTCLossDescriptor(mut_desc(), datatype)); + } + void setEx( + cudnnDataType_t datatype, + cudnnLossNormalizationMode_t normMode, + cudnnNanPropagation_t gradMode) { + AT_CUDNN_CHECK( + cudnnSetCTCLossDescriptorEx(mut_desc(), datatype, normMode, gradMode)); + } + void set_v8_v9( + cudnnDataType_t datatype, + cudnnLossNormalizationMode_t normMode, + cudnnNanPropagation_t gradMode, + int maxLabelLength) { +#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 90000 + auto gradModev9 = CUDNN_CTC_ZERO_OOB_GRADIENTS; + if (gradMode == cudnnNanPropagation_t::CUDNN_PROPAGATE_NAN) { + gradModev9 = CUDNN_CTC_SKIP_OOB_GRADIENTS; + } + AT_CUDNN_CHECK( + cudnnSetCTCLossDescriptor_v9(mut_desc(), datatype, normMode, gradModev9, maxLabelLength)); +#else + AT_CUDNN_CHECK( + cudnnSetCTCLossDescriptor_v8(mut_desc(), datatype, normMode, gradMode, maxLabelLength)); +#endif + } + +}; + +struct TORCH_CUDA_CPP_API ActivationDescriptor + : public Descriptor< + cudnnActivationStruct, + &cudnnCreateActivationDescriptor, + &cudnnDestroyActivationDescriptor> { + void set(cudnnActivationMode_t mode) { + AT_ASSERT( + mode == CUDNN_ACTIVATION_RELU, + "TODO: support more cuDNN activation modes"); + AT_CUDNN_CHECK(cudnnSetActivationDescriptor( + mut_desc(), + mode, + cudnnNanPropagation_t::CUDNN_NOT_PROPAGATE_NAN, + std::numeric_limits::max())); + } +}; + +union Constant +{ + float f; + double d; + Constant(cudnnDataType_t dataType, double value) { + if (dataType == CUDNN_DATA_HALF || dataType == CUDNN_DATA_FLOAT) { + f = static_cast(value); + } else { + d = value; + } + } +}; + +} // namespace diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cudnn/Handle.h b/phivenv/Lib/site-packages/torch/include/ATen/cudnn/Handle.h new file mode 100644 index 0000000000000000000000000000000000000000..a665e4bc5178a8aabc654630ace47112a5320961 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cudnn/Handle.h @@ -0,0 +1,9 @@ +#pragma once + +#include +#include + +namespace at::native { + +TORCH_CUDA_CPP_API cudnnHandle_t getCudnnHandle(); +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cudnn/Handles.h b/phivenv/Lib/site-packages/torch/include/ATen/cudnn/Handles.h new file mode 100644 index 0000000000000000000000000000000000000000..65b5d4454879ad165c8e002fc5df4c400da9303a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cudnn/Handles.h @@ -0,0 +1,2 @@ +#pragma once +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cudnn/Types.h b/phivenv/Lib/site-packages/torch/include/ATen/cudnn/Types.h new file mode 100644 index 0000000000000000000000000000000000000000..9c72a414ecf38465ab0ae97d7bc33dde57eb23ba --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cudnn/Types.h @@ -0,0 +1,14 @@ +#pragma once + +#include +#include + +namespace at::native { + +TORCH_CUDA_CPP_API cudnnDataType_t +getCudnnDataTypeFromScalarType(const at::ScalarType dtype); +cudnnDataType_t getCudnnDataType(const at::Tensor& tensor); + +int64_t cudnn_version(); + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cudnn/Utils.h b/phivenv/Lib/site-packages/torch/include/ATen/cudnn/Utils.h new file mode 100644 index 0000000000000000000000000000000000000000..49631ab2886c9e45fd1d2848bdff025f4f4ecfe7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cudnn/Utils.h @@ -0,0 +1,22 @@ +#pragma once + +#include +#include +#include +#include + +namespace at::native { + +// cuDNN has a buggy check for tensor being contiguous (that is, it does +// not ignore stride for dimension that is equal to 0). This function +// makes tensors which have zero stride contiguous, by setting the +// strides to 1 as cuDNN likes. +inline Tensor contiguousIfZeroInStrides(const Tensor& t) { + for (auto s : t.strides()) { + if (s == 0) + return t.contiguous(); + } + return t; +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cudnn/cudnn-wrapper.h b/phivenv/Lib/site-packages/torch/include/ATen/cudnn/cudnn-wrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..3bfceed839d414e01190bb9276170393b9f00e9c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cudnn/cudnn-wrapper.h @@ -0,0 +1,16 @@ +#pragma once + +#include + +#define STRINGIFY(x) #x +#define STRING(x) STRINGIFY(x) + +#if CUDNN_MAJOR < 8 || (CUDNN_MAJOR == 8 && CUDNN_MINOR < 5) +#pragma message("CuDNN v" STRING( \ + CUDNN_MAJOR) " found, but need at least CuDNN v8. You can get the latest version of CuDNN from https://developer.nvidia.com/cudnn or disable CuDNN with USE_CUDNN=0") +#pragma message "We strongly encourage you to move to 8.5 and above." +#pragma message "This message is intended to annoy you enough to update." +#endif + +#undef STRINGIFY +#undef STRING diff --git a/phivenv/Lib/site-packages/torch/include/ATen/detail/AcceleratorHooksInterface.h b/phivenv/Lib/site-packages/torch/include/ATen/detail/AcceleratorHooksInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..df7401a2b14f960db727b6ecb6403d39361a2dba --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/detail/AcceleratorHooksInterface.h @@ -0,0 +1,96 @@ +#pragma once + +#include + +#include +#include +#include + +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter") + +namespace at { + +// AcceleratorHooksInterface is a shared interface provided by all +// accelerators to allow generic code. +// This inferface is hook-based as it corresponds to all the functions +// that are going to be called in a generic way from the CPU code. + +struct TORCH_API AcceleratorHooksInterface { + // This should never actually be implemented, but it is used to + // squelch -Werror=non-virtual-dtor + virtual ~AcceleratorHooksInterface() = default; + + // Whether this backend was enabled at compilation time. + // This function should NEVER throw. + virtual bool isBuilt() const { + return false; + } + + // Whether this backend can be used at runtime, meaning it was built, + // its runtime dependencies are available (driver) and at least one + // supported device can be used. + // This function should NEVER throw. This function should NOT initialize the context + // on any device (result of hasPrimaryContext below should not change). + // While it is acceptable for this function to poison fork, it is + // recommended to avoid doing so whenever possible. + virtual bool isAvailable() const { + return false; + } + + // Whether the device at device_index is fully initialized or not. + virtual bool hasPrimaryContext(DeviceIndex device_index) const = 0; + + virtual void init() const { + TORCH_CHECK(false, "Backend doesn`t support init()"); + } + + virtual DeviceIndex deviceCount() const { + return 0; + } + + virtual void setCurrentDevice(DeviceIndex device) const { + TORCH_CHECK(false, "Backend doesn't support setCurrentDevice()"); + } + + virtual DeviceIndex getCurrentDevice() const { + TORCH_CHECK(false, "Backend doesn't support getCurrentDevice()"); + return -1; + } + + virtual DeviceIndex exchangeDevice(DeviceIndex device) const { + TORCH_CHECK(false, "Backend doesn't support exchangeDevice()"); + return -1; + } + + virtual DeviceIndex maybeExchangeDevice(DeviceIndex device) const { + TORCH_CHECK(false, "Backend doesn't support maybeExchangeDevice()"); + return -1; + } + + virtual bool isPinnedPtr(const void* data) const { + return false; + } + + virtual Allocator* getPinnedMemoryAllocator() const { + TORCH_CHECK(false, "Backend doesn't support getPinnedMemoryAllocator()"); + return nullptr; + } + + virtual Device getDeviceFromPtr(void* data) const { + TORCH_CHECK(false, "Backend doesn't support getDeviceFromPtr()"); + } + + virtual const Generator& getDefaultGenerator( + [[maybe_unused]] DeviceIndex device_index = -1) const { + TORCH_CHECK(false, "Backend doesn`t support getDefaultGenerator()"); + } + + virtual Generator getNewGenerator( + [[maybe_unused]] DeviceIndex device_index = -1) const { + TORCH_CHECK(false, "Backend doesn`t support getNewGenerator()"); + } +}; + +} // namespace at + +C10_DIAGNOSTIC_POP() diff --git a/phivenv/Lib/site-packages/torch/include/ATen/detail/CUDAHooksInterface.h b/phivenv/Lib/site-packages/torch/include/ATen/detail/CUDAHooksInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..0222a15d9f6c87e236cfc15b476fbde7cd7b6e81 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/detail/CUDAHooksInterface.h @@ -0,0 +1,224 @@ +#pragma once + +#include +#include +#include + +#include + +// NB: Class must live in `at` due to limitations of Registry.h. +namespace at { + +// Forward-declares at::cuda::NVRTC +namespace cuda { +struct NVRTC; +} // namespace cuda + +#ifdef _MSC_VER +constexpr const char* CUDA_HELP = + "PyTorch splits its backend into two shared libraries: a CPU library " + "and a CUDA library; this error has occurred because you are trying " + "to use some CUDA functionality, but the CUDA library has not been " + "loaded by the dynamic linker for some reason. The CUDA library MUST " + "be loaded, EVEN IF you don't directly use any symbols from the CUDA library! " + "One common culprit is a lack of -INCLUDE:?warp_size@cuda@at@@YAHXZ " + "in your link arguments; many dynamic linkers will delete dynamic library " + "dependencies if you don't depend on any of their symbols. You can check " + "if this has occurred by using link on your binary to see if there is a " + "dependency on *_cuda.dll library."; +#else +constexpr const char* CUDA_HELP = + "PyTorch splits its backend into two shared libraries: a CPU library " + "and a CUDA library; this error has occurred because you are trying " + "to use some CUDA functionality, but the CUDA library has not been " + "loaded by the dynamic linker for some reason. The CUDA library MUST " + "be loaded, EVEN IF you don't directly use any symbols from the CUDA library! " + "One common culprit is a lack of -Wl,--no-as-needed in your link arguments; many " + "dynamic linkers will delete dynamic library dependencies if you don't " + "depend on any of their symbols. You can check if this has occurred by " + "using ldd on your binary to see if there is a dependency on *_cuda.so " + "library."; +#endif + +// The CUDAHooksInterface is an omnibus interface for any CUDA functionality +// which we may want to call into from CPU code (and thus must be dynamically +// dispatched, to allow for separate compilation of CUDA code). How do I +// decide if a function should live in this class? There are two tests: +// +// 1. Does the *implementation* of this function require linking against +// CUDA libraries? +// +// 2. Is this function *called* from non-CUDA ATen code? +// +// (2) should filter out many ostensible use-cases, since many times a CUDA +// function provided by ATen is only really ever used by actual CUDA code. +// +// TODO: Consider putting the stub definitions in another class, so that one +// never forgets to implement each virtual function in the real implementation +// in CUDAHooks. This probably doesn't buy us much though. +struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface { + // This should never actually be implemented, but it is used to + // squelch -Werror=non-virtual-dtor + ~CUDAHooksInterface() override = default; + + // Initialize THCState and, transitively, the CUDA state + void init() const override { + TORCH_CHECK(false, "Cannot initialize CUDA without ATen_cuda library. ", CUDA_HELP); + } + + const Generator& getDefaultGenerator( + [[maybe_unused]] DeviceIndex device_index = -1) const override { + TORCH_CHECK( + false, + "Cannot get default CUDA generator without ATen_cuda library. ", + CUDA_HELP); + } + + Generator getNewGenerator( + [[maybe_unused]] DeviceIndex device_index = -1) const override { + TORCH_CHECK( + false, + "Cannot get CUDA generator without ATen_cuda library. ", + CUDA_HELP); + } + + Device getDeviceFromPtr(void* /*data*/) const override { + TORCH_CHECK(false, "Cannot get device of pointer on CUDA without ATen_cuda library. ", CUDA_HELP); + } + + bool isPinnedPtr(const void* /*data*/) const override { + return false; + } + + virtual bool hasCUDA() const { + return false; + } + + virtual bool hasCUDART() const { + return false; + } + + virtual bool hasMAGMA() const { + return false; + } + + virtual bool hasCuDNN() const { + return false; + } + + virtual bool hasCuSOLVER() const { + return false; + } + + virtual bool hasCuBLASLt() const { + return false; + } + + virtual bool hasROCM() const { + return false; + } + + virtual const at::cuda::NVRTC& nvrtc() const { + TORCH_CHECK(false, "NVRTC requires CUDA. ", CUDA_HELP); + } + + bool hasPrimaryContext(DeviceIndex device_index) const override { + TORCH_CHECK(false, "Cannot call hasPrimaryContext(", device_index, ") without ATen_cuda library. ", CUDA_HELP); + } + + virtual DeviceIndex current_device() const { + return -1; + } + + Allocator* getPinnedMemoryAllocator() const override { + TORCH_CHECK(false, "Pinned memory requires CUDA. ", CUDA_HELP); + } + + virtual Allocator* getCUDADeviceAllocator() const { + TORCH_CHECK(false, "CUDADeviceAllocator requires CUDA. ", CUDA_HELP); + } + + virtual bool compiledWithCuDNN() const { + return false; + } + + virtual bool compiledWithMIOpen() const { + return false; + } + + virtual bool supportsDilatedConvolutionWithCuDNN() const { + return false; + } + + virtual bool supportsDepthwiseConvolutionWithCuDNN() const { + return false; + } + + virtual bool supportsBFloat16ConvolutionWithCuDNNv8() const { + return false; + } + + virtual long versionCuDNN() const { + TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP); + } + + virtual long versionMIOpen() const { + TORCH_CHECK(false, "Cannot query MIOpen version without ATen_cuda library. ", CUDA_HELP); + } + + virtual long versionCUDART() const { + TORCH_CHECK(false, "Cannot query CUDART version without ATen_cuda library. ", CUDA_HELP); + } + + virtual std::string showConfig() const { + TORCH_CHECK(false, "Cannot query detailed CUDA version without ATen_cuda library. ", CUDA_HELP); + } + + virtual double batchnormMinEpsilonCuDNN() const { + TORCH_CHECK(false, + "Cannot query batchnormMinEpsilonCuDNN() without ATen_cuda library. ", CUDA_HELP); + } + + virtual int64_t cuFFTGetPlanCacheMaxSize(DeviceIndex /*device_index*/) const { + TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP); + } + + virtual void cuFFTSetPlanCacheMaxSize(DeviceIndex /*device_index*/, int64_t /*max_size*/) const { + TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP); + } + + virtual int64_t cuFFTGetPlanCacheSize(DeviceIndex /*device_index*/) const { + TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP); + } + + virtual void cuFFTClearPlanCache(DeviceIndex /*device_index*/) const { + TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP); + } + + virtual int getNumGPUs() const { + return 0; + } + +#ifdef USE_ROCM + virtual bool isGPUArch(const std::vector& /*archs*/, DeviceIndex = -1 /*device_index*/) const { + TORCH_CHECK(false, "Cannot check GPU arch without ATen_cuda library. ", CUDA_HELP); + } +#endif + + virtual void deviceSynchronize(DeviceIndex /*device_index*/) const { + TORCH_CHECK(false, "Cannot synchronize CUDA device without ATen_cuda library. ", CUDA_HELP); + } +}; + +// NB: dummy argument to suppress "ISO C++11 requires at least one argument +// for the "..." in a variadic macro" +struct TORCH_API CUDAHooksArgs {}; + +TORCH_DECLARE_REGISTRY(CUDAHooksRegistry, CUDAHooksInterface, CUDAHooksArgs); +#define REGISTER_CUDA_HOOKS(clsname) \ + C10_REGISTER_CLASS(CUDAHooksRegistry, clsname, clsname) + +namespace detail { +TORCH_API const CUDAHooksInterface& getCUDAHooks(); +} // namespace detail +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/detail/FunctionTraits.h b/phivenv/Lib/site-packages/torch/include/ATen/detail/FunctionTraits.h new file mode 100644 index 0000000000000000000000000000000000000000..7d7658c210ee8c34da119c7c7b04b85c39f288b9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/detail/FunctionTraits.h @@ -0,0 +1,103 @@ +#pragma once + +#include +#include + +// Modified from https://stackoverflow.com/questions/7943525/is-it-possible-to-figure-out-the-parameter-type-and-return-type-of-a-lambda + +// Fallback, anything with an operator() +template +struct function_traits : public function_traits { +}; + +// Pointers to class members that are themselves functors. +// For example, in the following code: +// template +// struct S { +// func_t f; +// }; +// template +// S make_s(func_t f) { +// return S { .f = f }; +// } +// +// auto s = make_s([] (int, float) -> double { /* ... */ }); +// +// function_traits traits; +template +struct function_traits : public function_traits { +}; + +// Const class member functions +template +struct function_traits : public function_traits { +}; + +// Reference types +template +struct function_traits : public function_traits {}; +template +struct function_traits : public function_traits {}; + +// Free functions +template +struct function_traits { + // arity is the number of arguments. + enum { arity = sizeof...(Args) }; + + using ArgsTuple = std::tuple; + using result_type = ReturnType; + + template + struct arg + { + using type = typename std::tuple_element>::type; + // the i-th argument is equivalent to the i-th tuple element of a tuple + // composed of those arguments. + }; +}; + +template +struct nullary_function_traits { + using traits = function_traits; + using result_type = typename traits::result_type; +}; + +template +struct unary_function_traits { + using traits = function_traits; + using result_type = typename traits::result_type; + using arg1_t = typename traits::template arg<0>::type; +}; + +template +struct binary_function_traits { + using traits = function_traits; + using result_type = typename traits::result_type; + using arg1_t = typename traits::template arg<0>::type; + using arg2_t = typename traits::template arg<1>::type; +}; + + +// Traits for calling with c10::guts::invoke, where member_functions have a first argument of ClassType +template +struct invoke_traits : public function_traits{ +}; + +template +struct invoke_traits : public invoke_traits{ +}; + +template +struct invoke_traits : public invoke_traits{ +}; + +template +struct invoke_traits : + public function_traits { +}; + +template +struct invoke_traits : + public function_traits { +}; diff --git a/phivenv/Lib/site-packages/torch/include/ATen/detail/HIPHooksInterface.h b/phivenv/Lib/site-packages/torch/include/ATen/detail/HIPHooksInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..e13c56da0303208ee336570d0b9bf7c1f5a28d31 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/detail/HIPHooksInterface.h @@ -0,0 +1,67 @@ +#pragma once + +#include +#include +#include + +#include + +// NB: Class must live in `at` due to limitations of Registry.h. +namespace at { + +// The HIPHooksInterface is an omnibus interface for any HIP functionality +// which we may want to call into from CPU code (and thus must be dynamically +// dispatched, to allow for separate compilation of HIP code). See +// CUDAHooksInterface for more detailed motivation. +struct TORCH_API HIPHooksInterface : AcceleratorHooksInterface { + // This should never actually be implemented, but it is used to + // squelch -Werror=non-virtual-dtor + ~HIPHooksInterface() override = default; + + void init() const override { + TORCH_CHECK(false, "Cannot initialize HIP without ATen_hip library."); + } + + const Generator& getDefaultGenerator( + [[maybe_unused]] DeviceIndex device_index = -1) const override { + TORCH_CHECK(false, "Cannot initialize HIP without ATen_hip library."); + } + + virtual bool hasHIP() const { + return false; + } + + virtual c10::DeviceIndex current_device() const { + return -1; + } + + bool isPinnedPtr(const void* /*data*/ ) const override { + return false; + } + + Allocator* getPinnedMemoryAllocator() const override { + TORCH_CHECK(false, "Pinned memory requires HIP."); + } + + virtual int getNumGPUs() const { + return 0; + } + + bool hasPrimaryContext(DeviceIndex /*device_index*/ ) const override { + TORCH_CHECK(false, "Cannot check primary context without ATen_hip library."); + } +}; + +// NB: dummy argument to suppress "ISO C++11 requires at least one argument +// for the "..." in a variadic macro" +struct TORCH_API HIPHooksArgs {}; + +TORCH_DECLARE_REGISTRY(HIPHooksRegistry, HIPHooksInterface, HIPHooksArgs); +#define REGISTER_HIP_HOOKS(clsname) \ + C10_REGISTER_CLASS(HIPHooksRegistry, clsname, clsname) + +namespace detail { +TORCH_API const HIPHooksInterface& getHIPHooks(); + +} // namespace detail +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/detail/HPUHooksInterface.h b/phivenv/Lib/site-packages/torch/include/ATen/detail/HPUHooksInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..8ce324efffe8720ba4f92a450427f9f7f1698eb5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/detail/HPUHooksInterface.h @@ -0,0 +1,57 @@ +#pragma once + +#include +#include + +#include +#include +#include + +namespace at { + +struct TORCH_API HPUHooksInterface : AcceleratorHooksInterface { + ~HPUHooksInterface() override = default; + + void init() const override { + TORCH_CHECK(false, "Cannot initialize HPU without HPU backend"); + } + + virtual bool hasHPU() const { + return false; + } + + Device getDeviceFromPtr(void* /*data*/) const override { + TORCH_CHECK( + false, "Cannot get device of pointer on HPU without HPU backend"); + } + + bool isPinnedPtr(const void*) const override { + return false; + } + + Allocator* getPinnedMemoryAllocator() const override { + TORCH_CHECK( + false, + "You should register `HPUHooksInterface` for HPU before call `getPinnedMemoryAllocator`."); + } + + bool hasPrimaryContext( + [[maybe_unused]] DeviceIndex device_index) const override { + TORCH_CHECK( + false, + "You should register `HPUHooksInterface` for HPU before call `hasPrimaryContext`."); + } +}; + +struct TORCH_API HPUHooksArgs {}; + +TORCH_DECLARE_REGISTRY(HPUHooksRegistry, HPUHooksInterface, HPUHooksArgs); +#define REGISTER_HPU_HOOKS(clsname) \ + C10_REGISTER_CLASS(HPUHooksRegistry, clsname, clsname) + +namespace detail { + +TORCH_API const at::HPUHooksInterface& getHPUHooks(); + +} // namespace detail +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/detail/IPUHooksInterface.h b/phivenv/Lib/site-packages/torch/include/ATen/detail/IPUHooksInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..69af0b12c97ecef00cb6bdd12413f209afb86b0e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/detail/IPUHooksInterface.h @@ -0,0 +1,43 @@ +#pragma once + +#include + +#include +#include +#include + +namespace at { + +struct TORCH_API IPUHooksInterface : AcceleratorHooksInterface { + ~IPUHooksInterface() override = default; + + void init() const override { + TORCH_CHECK(false, "Cannot initialize IPU without ATen_ipu library."); + } + + bool hasPrimaryContext(DeviceIndex /*device_index*/) const override { + TORCH_CHECK(false, "Cannot initialize IPU without ATen_ipu library."); + return false; + } + + const Generator& getDefaultGenerator( + [[maybe_unused]] DeviceIndex device_index = -1) const override { + TORCH_CHECK(false, "Cannot initialize IPU without ATen_ipu library."); + } + + Generator getNewGenerator( + DeviceIndex /*device_index*/ = -1) const override { + TORCH_CHECK(false, "Cannot initialize IPU without ATen_ipu library."); + } +}; + +struct TORCH_API IPUHooksArgs {}; + +TORCH_DECLARE_REGISTRY(IPUHooksRegistry, IPUHooksInterface, IPUHooksArgs); +#define REGISTER_IPU_HOOKS(clsname) \ + C10_REGISTER_CLASS(IPUHooksRegistry, clsname, clsname) + +namespace detail { +TORCH_API const IPUHooksInterface& getIPUHooks(); +} // namespace detail +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/detail/MAIAHooksInterface.h b/phivenv/Lib/site-packages/torch/include/ATen/detail/MAIAHooksInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..ff3c338cc02e8db3401d78a43c875c594b24828a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/detail/MAIAHooksInterface.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include + +#include + +// NB: Class must live in `at` due to limitations of Registry.h. +namespace at { + +struct TORCH_API MAIAHooksInterface : AcceleratorHooksInterface { + // This should never actually be implemented, but it is used to + // squelch -Werror=non-virtual-dtor + ~MAIAHooksInterface() override = default; + + void init() const override { + TORCH_CHECK(false, "Cannot initialize MAIA without ATen_maia library."); + } + + bool hasPrimaryContext(DeviceIndex /*device_index*/) const override { + TORCH_CHECK(false, "Cannot initialize MAIA without ATen_maia library."); + return false; + } + + virtual std::string showConfig() const { + TORCH_CHECK(false, "Cannot query detailed MAIA version information."); + } +}; + +// NB: dummy argument to suppress "ISO C++11 requires at least one argument +// for the "..." in a variadic macro" +struct TORCH_API MAIAHooksArgs {}; + +TORCH_DECLARE_REGISTRY(MAIAHooksRegistry, MAIAHooksInterface, MAIAHooksArgs); +#define REGISTER_MAIA_HOOKS(clsname) \ + C10_REGISTER_CLASS(MAIAHooksRegistry, clsname, clsname) + +namespace detail { +TORCH_API const MAIAHooksInterface& getMAIAHooks(); +} // namespace detail + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/detail/MPSHooksInterface.h b/phivenv/Lib/site-packages/torch/include/ATen/detail/MPSHooksInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..c33977e639233117922e471579ff7c66d6c29983 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/detail/MPSHooksInterface.h @@ -0,0 +1,125 @@ +// Copyright © 2022 Apple Inc. + +#pragma once + +#include + +#include +#include +#include + +#include + +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter") +namespace at { + +struct TORCH_API MPSHooksInterface : AcceleratorHooksInterface { + // this fails the implementation if MPSHooks functions are called, but + // MPS backend is not present. + #define FAIL_MPSHOOKS_FUNC(func) \ + TORCH_CHECK(false, "Cannot execute ", func, "() without MPS backend."); + + ~MPSHooksInterface() override = default; + + // Initialize the MPS library state + void init() const override { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual bool hasMPS() const { + return false; + } + virtual bool isOnMacOSorNewer(unsigned major = 13, unsigned minor = 0) const { + FAIL_MPSHOOKS_FUNC(__func__); + } + const Generator& getDefaultGenerator( + [[maybe_unused]] DeviceIndex device_index = -1) const override { + FAIL_MPSHOOKS_FUNC(__func__); + } + Generator getNewGenerator( + [[maybe_unused]] DeviceIndex device_index) const override { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual Allocator* getMPSDeviceAllocator() const { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual void deviceSynchronize() const { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual void commitStream() const { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual void* getCommandBuffer() const { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual void* getDispatchQueue() const { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual void emptyCache() const { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual size_t getCurrentAllocatedMemory() const { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual size_t getDriverAllocatedMemory() const { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual size_t getRecommendedMaxMemory() const { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual void setMemoryFraction(double /*ratio*/) const { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual void profilerStartTrace(const std::string& mode, bool waitUntilCompleted) const { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual void profilerStopTrace() const { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual uint32_t acquireEvent(bool enable_timing) const { + FAIL_MPSHOOKS_FUNC(__func__); + } + Device getDeviceFromPtr(void* data) const override { + TORCH_CHECK(false, "Cannot get device of pointer on MPS without ATen_mps library. "); + } + virtual void releaseEvent(uint32_t event_id) const { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual void recordEvent(uint32_t event_id) const { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual void waitForEvent(uint32_t event_id) const { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual void synchronizeEvent(uint32_t event_id) const { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual bool queryEvent(uint32_t event_id) const { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const { + FAIL_MPSHOOKS_FUNC(__func__); + } + bool hasPrimaryContext(DeviceIndex device_index) const override { + FAIL_MPSHOOKS_FUNC(__func__); + } + bool isPinnedPtr(const void* data) const override { + return false; + } + Allocator* getPinnedMemoryAllocator() const override { + FAIL_MPSHOOKS_FUNC(__func__); + } + #undef FAIL_MPSHOOKS_FUNC +}; + +struct TORCH_API MPSHooksArgs {}; + +TORCH_DECLARE_REGISTRY(MPSHooksRegistry, MPSHooksInterface, MPSHooksArgs); +#define REGISTER_MPS_HOOKS(clsname) \ + C10_REGISTER_CLASS(MPSHooksRegistry, clsname, clsname) + +namespace detail { +TORCH_API const MPSHooksInterface& getMPSHooks(); + +} // namespace detail +} // namespace at +C10_DIAGNOSTIC_POP() diff --git a/phivenv/Lib/site-packages/torch/include/ATen/detail/MTIAHooksInterface.h b/phivenv/Lib/site-packages/torch/include/ATen/detail/MTIAHooksInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..4f034cc4040cfed7337669613d271fc42f6dcd18 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/detail/MTIAHooksInterface.h @@ -0,0 +1,164 @@ +#pragma once + +#include +#include + +#include +#include + +#include + +#include +#include + +#include +namespace at { +class Context; +} + +namespace at { +constexpr const char* MTIA_HELP = + "The MTIA backend requires MTIA extension for PyTorch;" + "this error has occurred because you are trying " + "to use some MTIA's functionality without MTIA extension included."; + +struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface { +// this fails the implementation if MTIAHooks functions are called, but +// MTIA backend is not present. +#define FAIL_MTIAHOOKS_FUNC(func) \ + TORCH_CHECK(false, "Cannot execute ", func, "() without MTIA backend."); + + ~MTIAHooksInterface() override = default; + + void init() const override { + // Avoid logging here, since MTIA needs init devices first then it will know + // how many devices are available. Make it as no-op if mtia extension is not + // dynamically loaded. + return; + } + + virtual bool hasMTIA() const { + return false; + } + + DeviceIndex deviceCount() const override { + return 0; + } + + virtual void deviceSynchronize(c10::DeviceIndex /*device_index*/) const { + FAIL_MTIAHOOKS_FUNC(__func__); + } + + virtual std::string showConfig() const { + FAIL_MTIAHOOKS_FUNC(__func__); + } + + bool hasPrimaryContext(DeviceIndex /*device_index*/) const override { + return false; + } + + void setCurrentDevice(DeviceIndex /*device*/) const override { + FAIL_MTIAHOOKS_FUNC(__func__); + } + + DeviceIndex getCurrentDevice() const override { + FAIL_MTIAHOOKS_FUNC(__func__); + return -1; + } + + DeviceIndex exchangeDevice(DeviceIndex /*device*/) const override { + FAIL_MTIAHOOKS_FUNC(__func__); + return -1; + } + + DeviceIndex maybeExchangeDevice(DeviceIndex /*device*/) const override { + FAIL_MTIAHOOKS_FUNC(__func__); + return -1; + } + + virtual c10::Stream getCurrentStream(DeviceIndex /*device*/) const { + FAIL_MTIAHOOKS_FUNC(__func__); + return c10::Stream::unpack3(-1, 0, c10::DeviceType::MTIA); + } + + virtual int64_t getCurrentRawStream(DeviceIndex /*device*/) const { + FAIL_MTIAHOOKS_FUNC(__func__); + return -1; + } + + virtual c10::Stream getDefaultStream(DeviceIndex /*device*/) const { + FAIL_MTIAHOOKS_FUNC(__func__); + return c10::Stream::unpack3(-1, 0, c10::DeviceType::MTIA); + } + + virtual void setCurrentStream(const c10::Stream& /*stream*/ ) const { + FAIL_MTIAHOOKS_FUNC(__func__); + } + + bool isPinnedPtr(const void* /*data*/) const override { + return false; + } + + Allocator* getPinnedMemoryAllocator() const override { + FAIL_MTIAHOOKS_FUNC(__func__); + return nullptr; + } + + virtual PyObject* memoryStats(DeviceIndex /*device*/) const { + FAIL_MTIAHOOKS_FUNC(__func__); + return nullptr; + } + + virtual PyObject* getDeviceCapability(DeviceIndex /*device*/) const { + FAIL_MTIAHOOKS_FUNC(__func__); + return nullptr; + } + + virtual PyObject* getDeviceProperties(DeviceIndex device) const { + FAIL_MTIAHOOKS_FUNC(__func__); + return nullptr; + } + + virtual void emptyCache() const { + FAIL_MTIAHOOKS_FUNC(__func__); + } + + + virtual void recordMemoryHistory( + const std::optional& /*enabled*/, + const std::string& /*stacks*/, + size_t /*max_entries*/) const { + FAIL_MTIAHOOKS_FUNC(__func__); + } + + virtual PyObject* memorySnapshot(const std::optional& local_path) const { + FAIL_MTIAHOOKS_FUNC(__func__); + return nullptr; + } + + virtual DeviceIndex getDeviceCount() const { + FAIL_MTIAHOOKS_FUNC(__func__); + return 0; + } + + virtual void resetPeakMemoryStats(DeviceIndex /*device*/) const { + FAIL_MTIAHOOKS_FUNC(__func__); + } + + virtual void attachOutOfMemoryObserver(PyObject* observer) const { + FAIL_MTIAHOOKS_FUNC(__func__); + return; + } +}; + +struct TORCH_API MTIAHooksArgs {}; + +TORCH_DECLARE_REGISTRY(MTIAHooksRegistry, MTIAHooksInterface, MTIAHooksArgs); +#define REGISTER_MTIA_HOOKS(clsname) \ + C10_REGISTER_CLASS(MTIAHooksRegistry, clsname, clsname) + +namespace detail { +TORCH_API const MTIAHooksInterface& getMTIAHooks(); +TORCH_API bool isMTIAHooksBuilt(); +} // namespace detail +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/detail/PrivateUse1HooksInterface.h b/phivenv/Lib/site-packages/torch/include/ATen/detail/PrivateUse1HooksInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..1555d147395f7c8500788b866200bda1784614aa --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/detail/PrivateUse1HooksInterface.h @@ -0,0 +1,89 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include + +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter") + +namespace at { + +struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface { +#define FAIL_PRIVATEUSE1HOOKS_FUNC(func) \ + TORCH_CHECK_NOT_IMPLEMENTED( \ + false, \ + "You should register `PrivateUse1HooksInterface`", \ + "by `RegisterPrivateUse1HooksInterface` and implement `", \ + func, \ + "` at the same time for PrivateUse1."); + + ~PrivateUse1HooksInterface() override = default; + + bool isBuilt() const override { + FAIL_PRIVATEUSE1HOOKS_FUNC(__func__); + } + + bool isAvailable() const override { + FAIL_PRIVATEUSE1HOOKS_FUNC(__func__); + } + + const at::Generator& getDefaultGenerator( + c10::DeviceIndex device_index) const override { + FAIL_PRIVATEUSE1HOOKS_FUNC(__func__); + } + + Generator getNewGenerator( + [[maybe_unused]] DeviceIndex device_index = -1) const override { + // TODO(FFFrog): Perserved for BC and will be removed in the future. + if (at::GetGeneratorPrivate().has_value()) + return at::GetGeneratorForPrivateuse1(device_index); + + FAIL_PRIVATEUSE1HOOKS_FUNC(__func__); + } + + at::Device getDeviceFromPtr(void* data) const override { + FAIL_PRIVATEUSE1HOOKS_FUNC(__func__); + } + + bool isPinnedPtr(const void* data) const override { + return false; + } + + Allocator* getPinnedMemoryAllocator() const override { + FAIL_PRIVATEUSE1HOOKS_FUNC(__func__); + } + + bool hasPrimaryContext(DeviceIndex device_index) const override { + FAIL_PRIVATEUSE1HOOKS_FUNC(__func__); + } + + void init() const override {} + virtual void resizePrivateUse1Bytes( + const c10::Storage& storage, + size_t newsize) const { + FAIL_PRIVATEUSE1HOOKS_FUNC(__func__); + } + +#undef FAIL_PRIVATEUSE1HOOKS_FUNC +}; + +struct TORCH_API PrivateUse1HooksArgs {}; + +TORCH_API void RegisterPrivateUse1HooksInterface( + at::PrivateUse1HooksInterface* hook_); + +TORCH_API bool isPrivateUse1HooksRegistered(); + +namespace detail { + +TORCH_API const at::PrivateUse1HooksInterface& getPrivateUse1Hooks(); + +} // namespace detail + +} // namespace at + +C10_DIAGNOSTIC_POP() diff --git a/phivenv/Lib/site-packages/torch/include/ATen/detail/XPUHooksInterface.h b/phivenv/Lib/site-packages/torch/include/ATen/detail/XPUHooksInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..53b876455328a21531313a21149f05c8bbb0bfa9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/detail/XPUHooksInterface.h @@ -0,0 +1,84 @@ +#pragma once + +#include +#include +#include + +#include + +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter") + +namespace at { + +struct TORCH_API XPUHooksInterface : AcceleratorHooksInterface{ + ~XPUHooksInterface() override = default; + + void init() const override { + TORCH_CHECK(false, "Cannot initialize XPU without ATen_xpu library."); + } + + virtual bool hasXPU() const { + return false; + } + + virtual std::string showConfig() const { + TORCH_CHECK( + false, + "Cannot query detailed XPU version without ATen_xpu library."); + } + + virtual int32_t getGlobalIdxFromDevice(const Device& device) const { + TORCH_CHECK(false, "Cannot get XPU global device index without ATen_xpu library."); + } + + const Generator& getDefaultGenerator( + [[maybe_unused]] DeviceIndex device_index = -1) const override { + TORCH_CHECK( + false, "Cannot get default XPU generator without ATen_xpu library."); + } + + Generator getNewGenerator( + [[maybe_unused]] DeviceIndex device_index = -1) const override { + TORCH_CHECK(false, "Cannot get XPU generator without ATen_xpu library."); + } + + virtual DeviceIndex getNumGPUs() const { + return 0; + } + + virtual DeviceIndex current_device() const { + TORCH_CHECK(false, "Cannot get current device on XPU without ATen_xpu library."); + } + + Device getDeviceFromPtr(void* /*data*/) const override { + TORCH_CHECK(false, "Cannot get device of pointer on XPU without ATen_xpu library."); + } + + virtual void deviceSynchronize(DeviceIndex /*device_index*/) const { + TORCH_CHECK(false, "Cannot synchronize XPU device without ATen_xpu library."); + } + + Allocator* getPinnedMemoryAllocator() const override { + TORCH_CHECK(false, "Cannot get XPU pinned memory allocator without ATen_xpu library."); + } + + bool isPinnedPtr(const void* data) const override { + return false; + } + + bool hasPrimaryContext(DeviceIndex device_index) const override { + TORCH_CHECK(false, "Cannot query primary context without ATen_xpu library."); + } +}; + +struct TORCH_API XPUHooksArgs {}; + +TORCH_DECLARE_REGISTRY(XPUHooksRegistry, XPUHooksInterface, XPUHooksArgs); +#define REGISTER_XPU_HOOKS(clsname) \ + C10_REGISTER_CLASS(XPUHooksRegistry, clsname, clsname) + +namespace detail { +TORCH_API const XPUHooksInterface& getXPUHooks(); +} // namespace detail +} // namespace at +C10_DIAGNOSTIC_POP() diff --git a/phivenv/Lib/site-packages/torch/include/ATen/functorch/ADInterpreters.h b/phivenv/Lib/site-packages/torch/include/ATen/functorch/ADInterpreters.h new file mode 100644 index 0000000000000000000000000000000000000000..fc151cedd1b05936922a94df29d771f39884e749 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/functorch/ADInterpreters.h @@ -0,0 +1,38 @@ +#pragma once +#include + +namespace at::functorch { + +// These are the interpreters for our AD transforms +// (grad, vjp and jvp). +// See NOTE: [functorch interpreter stack] for more details. + +struct TORCH_API GradInterpreterPtr { + explicit GradInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Grad); } + TransformType key() const { return base_->key(); } + int64_t level() const { return base_->level(); } + void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack); + void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case); + bool prevGradMode() const { + return std::get(base_->meta()).prevGradMode_; + } + Tensor lift(const Tensor& tensor) const; + private: + const Interpreter* base_; +}; + +struct TORCH_API JvpInterpreterPtr { + explicit JvpInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Jvp); } + TransformType key() const { return base_->key(); } + int64_t level() const { return base_->level(); } + void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack); + void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case); + bool prevFwdGradMode() const { + return std::get(base_->meta()).prevFwdGradMode_; + } + Tensor lift(const Tensor& tensor) const; + private: + const Interpreter* base_; +}; + +} // namespace at::functorch diff --git a/phivenv/Lib/site-packages/torch/include/ATen/functorch/BatchRulesHelper.h b/phivenv/Lib/site-packages/torch/include/ATen/functorch/BatchRulesHelper.h new file mode 100644 index 0000000000000000000000000000000000000000..f7db9b158d6fa7386987f54949ec9a31cec5595e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/functorch/BatchRulesHelper.h @@ -0,0 +1,481 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. +#pragma once + +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +// This file contains helper functions for batching rules. + +namespace at::functorch { + +TORCH_API Tensor reshape_dim_into(int64_t src, int64_t dst, const Tensor& x); +TORCH_API Tensor reshape_dim_outof(int64_t src, int64_t size1, const Tensor& x); + +TORCH_API Tensor reshape_dim_outof_symint(int64_t src, const c10::SymInt& size1, const Tensor& x); + +Tensor moveBatchDimToFront(Tensor tensor, std::optional maybe_batch_dim); +int64_t rankWithoutBatchDim(const Tensor& tensor, std::optional maybe_batch_dim); +int64_t numelWithoutBatchDim(const Tensor& tensor, std::optional maybe_batch_dim); +std::optional valIfNonempty(std::optional maybe_empty, int64_t new_val); +int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical_dim); +VmapDimVector getPhysicalDims(const Tensor& tensor, bool has_batch_dim, IntArrayRef logical_dims); + +void vmapIncompatibleInplaceError(const char* schema_name); + +Tensor maybePadToLogicalRank(const Tensor& tensor, std::optional has_bdim, int64_t logical_rank); + +void check_randomness(RandomnessType randomness); +void check_randomness(RandomnessType randomness, bool any_tensor_bdim); + +inline Tensor ensure_has_bdim(const Tensor& tensor, bool has_bdim, c10::SymInt batch_size) { + if (has_bdim) { + return tensor; + } + const auto sizes = tensor.sym_sizes(); + SymDimVector expanded_shape; + expanded_shape.reserve(sizes.size()); + expanded_shape.emplace_back(std::move(batch_size)); + expanded_shape.insert(expanded_shape.end(), sizes.begin(), sizes.end()); + return tensor.expand_symint(expanded_shape); +} + +#define VMAP_SUPPORT(op, batch_rule) \ + m.impl(#op, op ## _generated_plumbing); + +#define VMAP_SUPPORT2(op, overload, batch_rule) \ + m.impl(#op "." #overload, op ## _ ## overload ## _generated_plumbing); + +#define OP_DECOMPOSE(op) m.impl(#op, static_cast(native::op)); +#define OP_DECOMPOSE2(op, overload) m.impl(#op"."#overload, static_cast(native::op)); + +// DO NOT USE ME DIRECTLY! Use BASIC_UNARY_BATCH_RULE to save yourself some pain +template +struct BasicUnaryBatchRuleHelper; + +template +struct BasicUnaryBatchRuleHelper> { + static std::tuple> apply( + const Tensor& tensor, + std::optional batch_dim, + T... extra_args) { + return std::make_tuple(Func(tensor, std::forward(extra_args)...), batch_dim); + } +}; + +// USAGE: BASIC_UNARY_BATCH_RULE(at::sin) +// INCORRECT USAGE: BASIC_UNARY_BATCH_RULE(&at::sin) +// It is important that this macro is not passed a function pointer!! +#define BASIC_UNARY_BATCH_RULE(fn) SINGLE_ARG(\ + BasicUnaryBatchRuleHelper<\ + decltype(&fn),\ + &fn,\ + c10::guts::function_traits::parameter_types>::apply) + +#define UNARY_POINTWISE(op) \ + VMAP_SUPPORT(op, BASIC_UNARY_BATCH_RULE(ATEN_FN(op))); + +template +struct VariadicBdimsBatchRuleHelper; + +template +struct VariadicBdimsBatchRuleHelper> { + static std::tuple> apply( + const Tensor& tensor, + std::optional batch_dim, + T... extra_args) { + auto tensor_ = moveBatchDimToFront(tensor, batch_dim); + return std::make_tuple(Func(tensor_, std::forward(extra_args)...), 0); + } +}; + +// USAGE: VARIADIC_BDIMS_BATCH_RULE(at::cholesky_inverse) +// INCORRECT USAGE: VARIADIC_BDIMS_BATCH_RULE(&at::cholesky_inverse) +// It is important that this macro is not passed a function pointer!! +#define VARIADIC_BDIMS_BATCH_RULE(fn) SINGLE_ARG(\ + VariadicBdimsBatchRuleHelper<\ + decltype(&fn),\ + &fn,\ + c10::guts::function_traits::parameter_types>::apply) + +#define VARIADIC_BDIMS(op) \ + VMAP_SUPPORT(op, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(op))); + +#define VARIADIC_BDIMS2(op, overload) \ + VMAP_SUPPORT2(op, overload, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN2(op, overload))); + +template +void boxed_tensor_inputs_batch_rule(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + const auto& schema = op.schema(); + const auto num_returns = schema.returns().size(); + const auto num_arguments = schema.arguments().size(); + + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "boxed_tensor_inputs_batch_rule"); + + int64_t cur_level = maybe_layer->layerId(); + + auto orig_arguments = torch::jit::last(*stack, num_arguments); + if (std::none_of(orig_arguments.begin(), orig_arguments.end(), ivalueParticipatesInCurrentLevel)) { + op.callBoxed(stack); + return; + } + + auto arguments = torch::jit::pop(*stack, num_arguments); + std::vector>> tensor_inputs; + std::vector tensor_pos; + for (const auto idx : c10::irange(0, num_arguments)) { + const auto& ivalue = arguments[idx]; + if (ivalue.isTensor()) { + auto [tensor_value, tensor_bdim] = unwrapTensorAtLevel(ivalue.toTensor(), cur_level); + tensor_inputs.emplace_back(std::move(tensor_value), tensor_bdim); + tensor_pos.push_back(static_cast(idx)); + } + } + Func(tensor_inputs); + + size_t tensor_idx = 0; + TORCH_INTERNAL_ASSERT(!tensor_pos.empty()); + for (const auto arg_idx : c10::irange(0, num_arguments)) { + if (tensor_idx >= tensor_pos.size() || (int64_t)arg_idx != tensor_pos[tensor_idx]) { + torch::jit::push(stack, arguments[arg_idx]); + } else { + TORCH_INTERNAL_ASSERT(tensor_idx < tensor_inputs.size()); + torch::jit::push(stack, tensor_inputs[tensor_idx].first); + tensor_idx++; + } + } + + op.callBoxed(stack); + const auto returns = torch::jit::pop(*stack, num_returns); + for (const auto& ret : returns) { + if (ret.isTensor()) { + torch::jit::push(stack, makeBatched(ret.toTensor(), 0, cur_level)); + } else { + TORCH_INTERNAL_ASSERT(false, "This boxed batching rule does not currently support ops that return non-tensor values"); + } + } +} + +inline void handle_pointwise_ops(std::vector>> &tensor_inputs) { + int64_t out_logical_rank = 0; + for (auto& tensor_input : tensor_inputs) { + int64_t cur_logical_rank = rankWithoutBatchDim(tensor_input.first, tensor_input.second); + out_logical_rank = std::max(out_logical_rank, cur_logical_rank); + } + for (auto& tensor_input: tensor_inputs) { + tensor_input.first = moveBatchDimToFront(tensor_input.first, tensor_input.second); + tensor_input.first = maybePadToLogicalRank(tensor_input.first, tensor_input.second, out_logical_rank); + } +} + +#define POINTWISE_BOXED(op) \ + m.impl(#op, torch::CppFunction::makeFromBoxedFunction>()); + +#define POINTWISE_BOXED2(op, overload) \ + m.impl(#op "." #overload, torch::CppFunction::makeFromBoxedFunction>()); + +inline void handle_variadic_bdims(std::vector>> &tensor_inputs) { + for (auto & tensor_input : tensor_inputs) { + tensor_input.first = moveBatchDimToFront(tensor_input.first, tensor_input.second); + } +} + +#define VARIADIC_BDIMS_BOXED(op) \ + m.impl(#op, torch::CppFunction::makeFromBoxedFunction>()); + +using UnpackedBatchedTensor = std::tuple>; + +inline void find_and_unpack_tensors( + const torch::jit::Stack* stack, + int64_t num_args, + int64_t cur_level, + SmallVector* tensors, + SmallVector* tensors_pos, + int64_t* batch_size) { + + int64_t computed_batch_size = -1; + int64_t args_begin = static_cast(stack->size()) - num_args; + + for (const auto idx : c10::irange(0, num_args)) { + const auto& ivalue = (*stack)[args_begin + idx]; + if (!ivalue.isTensor()) { + continue; + } + auto unpacked = unwrapTensorAtLevel(ivalue.toTensor(), cur_level); + const auto& [tensor_value, tensor_bdim] = unpacked; + if (tensor_bdim.has_value()) { + auto candidate_batch_size = tensor_value.size(*tensor_bdim); + if (computed_batch_size == -1) { + computed_batch_size = candidate_batch_size; + } + TORCH_INTERNAL_ASSERT(candidate_batch_size == computed_batch_size); + } + + tensors->push_back(std::move(unpacked)); + tensors_pos->push_back(idx); + } + TORCH_INTERNAL_ASSERT(computed_batch_size > -1); + *batch_size = computed_batch_size; +} + +inline void boxed_existing_bdim_all_batch_rule( + const c10::OperatorHandle& op, torch::jit::Stack* stack) { + const auto& schema = op.schema(); + const auto num_returns = schema.returns().size(); + const auto num_arguments = static_cast(schema.arguments().size()); + + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + const auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "boxed_existing_bdim_all_batch_rule"); + + const auto arguments = torch::jit::last(stack, num_arguments); + if (std::none_of(arguments.begin(), arguments.end(), ivalueParticipatesInCurrentLevel)) { + op.callBoxed(stack); + return; + } + + int64_t args_begin = static_cast(stack->size()) - num_arguments; + SmallVector tensor_inputs; + SmallVector tensor_pos; + int64_t batch_size = 0; + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + int64_t cur_level = maybe_layer->layerId(); + + find_and_unpack_tensors( + stack, num_arguments, cur_level, + &tensor_inputs, &tensor_pos, &batch_size); + + // for each tensor, ensure it has a bdim and reshape it. + for (const auto tensor_idx : c10::irange(0, tensor_inputs.size())) { + const auto& [value, bdim] = tensor_inputs[tensor_idx]; + auto value_ = ensure_has_bdim(value, bdim.has_value(), batch_size); + (*stack)[args_begin + tensor_pos[tensor_idx]] = reshape_dim_into(bdim.value_or(0), 0, value_); + } + + op.callBoxed(stack); + + for (const auto idx : c10::irange(args_begin, args_begin + num_returns)) { + const auto& ret = (*stack)[idx]; + TORCH_INTERNAL_ASSERT(ret.isTensor(), + "This boxed batching rule does not currently support ops that return non-tensor values"); + (*stack)[idx] = makeBatched(reshape_dim_outof(0, batch_size, ret.toTensor()), 0, cur_level); + } +} + +// Use when all tensors arguments accept one (normal) batch dim. +// This batching rule expands the batch dim on all Tensors, reshapes it into +// dim 0, calls the op, and then reshapes the batch dim out of dim 0. +// This is not the most efficient thing; if there are alternatives, plese try +// to use them. Use this only as a last resort. +#define EXISTING_BDIM_ALL_BOXED(op) \ + m.impl(#op, torch::CppFunction::makeFromBoxedFunction()); + +template +inline void boxed_all_tensors_have_optional_bdim( + const c10::OperatorHandle& op, torch::jit::Stack* stack) { + const auto& schema = op.schema(); + const auto num_returns = schema.returns().size(); + const auto num_arguments = schema.arguments().size(); + + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "boxed_all_tensors_have_optional_bdim"); + int64_t cur_level = maybe_layer->layerId(); + + const auto arguments = torch::jit::last(stack, num_arguments); + if (std::none_of(arguments.begin(), arguments.end(), ivalueParticipatesInCurrentLevel)) { + op.callBoxed(stack); + return; + } + + int64_t args_begin = static_cast(stack->size() - num_arguments); + SmallVector tensor_inputs; + SmallVector tensor_pos; + int64_t batch_size = 0; + + find_and_unpack_tensors( + stack, static_cast(num_arguments), cur_level, + &tensor_inputs, &tensor_pos, &batch_size); + + std::optional is_no_batch_dim_case; + + for (const auto tensor_idx : c10::irange(0, tensor_inputs.size())) { + const auto& value = std::get<0>(tensor_inputs[tensor_idx]); + auto bdim = std::get<1>(tensor_inputs[tensor_idx]); + const auto logical_rank = rankWithoutBatchDim(value, bdim); + + if (!is_no_batch_dim_case.has_value()) { + is_no_batch_dim_case = (logical_rank == feature_rank); + } + auto value_ = ensure_has_bdim(value, bdim.has_value(), batch_size); + if (!bdim.has_value()) { + bdim = 0; + } + if (*is_no_batch_dim_case) { + TORCH_INTERNAL_ASSERT(logical_rank == feature_rank); + value_ = moveBatchDimToFront(value_, bdim); + if (tensor_idx == contig_tensor_index) { + value_ = value_.contiguous(); + } + (*stack)[args_begin + tensor_pos[tensor_idx]] = std::move(value_); + continue; + } + TORCH_INTERNAL_ASSERT(logical_rank == feature_rank + 1); + value_ = reshape_dim_into(*bdim, 0, value_); + if (tensor_idx == contig_tensor_index) { + value_ = value_.contiguous(); + } + (*stack)[args_begin + tensor_pos[tensor_idx]] = std::move(value_); + } + + op.callBoxed(stack); + + for (const auto idx : c10::irange(args_begin, args_begin + num_returns)) { + const auto& ret = (*stack)[idx]; + TORCH_INTERNAL_ASSERT(ret.isTensor(), + "This boxed batching rule does not currently support ops that return non-tensor values"); + if (*is_no_batch_dim_case) { + (*stack)[idx] = makeBatched(ret.toTensor(), 0, cur_level); + } else { + (*stack)[idx] = makeBatched(reshape_dim_outof(0, batch_size, ret.toTensor()), 0, cur_level); + } + } +} + +// Useful for many NN operators. +// The operator must satisfy the following: +// - All arguments must accept an optional batch dim. +// - All arguments must be the same rank +#define ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED(feature_rank, op) \ + m.impl(#op, torch::CppFunction::makeFromBoxedFunction>()); + +#define ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(feature_rank, op, contig_tensor_index) \ + m.impl(#op, \ + torch::CppFunction::makeFromBoxedFunction<\ + boxed_all_tensors_have_optional_bdim<\ + feature_rank, \ + contig_tensor_index>\ + >()); + +template +struct ExistingBdimBatchRuleHelper; + +template +struct ExistingBdimBatchRuleHelper> { + static std::tuple> apply( + const Tensor& self, + std::optional self_bdim, + T... extra_args) { + auto self_ = reshape_dim_into(*self_bdim, 0, self); + auto out = Func(self_, std::forward(extra_args)...); + return std::make_tuple(reshape_dim_outof_symint(0, self.sym_sizes()[*self_bdim], out), 0); + } +}; + +// USAGE: EXISTING_BDIM_BATCH_RULE(at::cholesky_inverse) +// INCORRECT USAGE: EXISTING_BDIM_BATCH_RULE(&at::cholesky_inverse) +// It is important that this macro is not passed a function pointer!! +#define EXISTING_BDIM_BATCH_RULE(fn) SINGLE_ARG(\ + ExistingBdimBatchRuleHelper<\ + decltype(&fn),\ + &fn,\ + c10::guts::function_traits::parameter_types>::apply) + + +#define EXISTING_BDIM(op) \ + VMAP_SUPPORT(op, EXISTING_BDIM_BATCH_RULE(ATEN_FN(op))); + +#define EXISTING_BDIM2(op, overload) \ + VMAP_SUPPORT2(op, overload, EXISTING_BDIM_BATCH_RULE(ATEN_FN2(op, overload))); + +#define INVOKE(object,ptrToMember) ((object).*(ptrToMember)) + + +template +Tensor& unary_inplace_batch_rule(Tensor& self, std::optional, ExtraArgs... extra_args) { + INVOKE(self, Method)(std::forward(extra_args)...); + return self; +} + +inline int64_t get_bdim_size4( + const Tensor& a_value, std::optional a_bdim, + const Tensor& b_value, std::optional b_bdim, + const Tensor& c_value, std::optional c_bdim, + const Tensor& d_value, std::optional d_bdim) { + if (a_bdim) + return a_value.size(*a_bdim); + if (b_bdim) + return b_value.size(*b_bdim); + if (c_bdim) + return c_value.size(*c_bdim); + if (d_bdim) + return d_value.size(*d_bdim); + TORCH_INTERNAL_ASSERT(false); +} + +inline int64_t get_bdim_size3( + const Tensor& a_value, std::optional a_bdim, + const Tensor& b_value, std::optional b_bdim, + const Tensor& c_value, std::optional c_bdim) { + if (a_bdim) + return a_value.size(*a_bdim); + if (b_bdim) + return b_value.size(*b_bdim); + if (c_bdim) + return c_value.size(*c_bdim); + TORCH_INTERNAL_ASSERT(false); +} + +inline int64_t get_bdim_size2( + const Tensor& a_value, std::optional a_bdim, + const Tensor& b_value, std::optional b_bdim) { + if (a_bdim) + return a_value.size(*a_bdim); + if (b_bdim) + return b_value.size(*b_bdim); + TORCH_INTERNAL_ASSERT(false); +} + +inline c10::SymInt get_bdim_size2_symint( + const Tensor& a_value, std::optional a_bdim, + const Tensor& b_value, std::optional b_bdim) { + if (a_bdim) + return a_value.sym_size(*a_bdim); + if (b_bdim) + return b_value.sym_size(*b_bdim); + TORCH_INTERNAL_ASSERT(false); +} + +// [start, start + 1, ..., stop - 1] +inline VmapDimVector range(int64_t start, int64_t stop) { + TORCH_INTERNAL_ASSERT(stop >= start); + VmapDimVector dims; + dims.reserve(stop - start); + for (int64_t i = start; i < stop; i++) { + dims.emplace_back(i); + } + return dims; +} +std::tuple _binary_pointwise_helper( + const Tensor& tensor, std::optional tensor_batch_dim, const Tensor& other, std::optional other_batch_dim, + bool do_type_promotion=true); + +} // namespace at::functorch diff --git a/phivenv/Lib/site-packages/torch/include/ATen/functorch/BatchedFallback.h b/phivenv/Lib/site-packages/torch/include/ATen/functorch/BatchedFallback.h new file mode 100644 index 0000000000000000000000000000000000000000..ab4fbc662aa3e0f28bc4e15432e56377a471a196 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/functorch/BatchedFallback.h @@ -0,0 +1,81 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include +#include +#include + +namespace at::functorch { + +// This file contains code for the vmap fallback (also known as the +// BatchedTensor fallback or the Batched fallback). This code runs +// when an operation doesn't have a batching rule implemented. + +// If an operator doesn't have a batching rule implemented then we fallback +// to this implementation. The fallback doesn't work on out= variants or +// view operations; that is, it works for out-of-place operations and +// in-place non-view operations. +// +// For out-of-place operations, the fallback effectively takes all of the +// BatchedTensors in `stack`, slices them, and runs `op` on all of the +// corresponding slices to produce slices of the outputs. The output slices +// then get `torch.stack`ed to create the +// final returns. +// +// The performance of the fallback is not very good because it introduces an +// extra copy from stacking the sliced outputs. Because of this, we prefer to +// write batching rules for operators whenever possible. +void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack); +void batchedNestedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack); + +void vmapErrorFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack); + +// The vmap fallback emits a warning by default, but it may be disabled if +// the user finds it to be too annoying. +TORCH_API bool isVmapFallbackWarningEnabled(); +TORCH_API void setVmapFallbackWarningEnabled(bool enabled); + +// Used for testing. The vmap fallback is enabled by default. When it is disabled, +// it raises an error. +TORCH_API bool isVmapFallbackEnabled(); +TORCH_API void setVmapFallbackEnabled(bool enabled); + +template A vector_to_result(const std::vector& buffer) { + return buffer[0].to(); +} +template std::tuple vector_to_result(const std::vector& buffer) { + return std::make_tuple(buffer[0].to(), buffer[1].to()); +} +template std::tuple vector_to_result(const std::vector& buffer) { + return std::make_tuple(buffer[0].to(), buffer[1].to(), buffer[2].to()); +} + +// slow_fallback is a way to call the vmap fallback inside some boxed kernel. +// There is probably some better way to metaprogram this. +template +Ret slow_fallback(const c10::OperatorHandle& op, ArrayRef args) { + std::vector stack(args.begin(), args.end()); + batchedTensorForLoopFallback(op, &stack); + return vector_to_result(stack); +} + +template +std::tuple slow_fallback(const c10::OperatorHandle& op, ArrayRef args) { + std::vector stack(args.begin(), args.end()); + batchedTensorForLoopFallback(op, &stack); + return vector_to_result(stack); +} + +template +std::tuple slow_fallback(const c10::OperatorHandle& op, ArrayRef args) { + std::vector stack(args.begin(), args.end()); + batchedTensorForLoopFallback(op, &stack); + return vector_to_result(stack); +} + + +} // namespace at::functorch diff --git a/phivenv/Lib/site-packages/torch/include/ATen/functorch/BatchedTensorImpl.h b/phivenv/Lib/site-packages/torch/include/ATen/functorch/BatchedTensorImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..c926dffa7ccf712401b67c0c7fb0533c9ddac87a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/functorch/BatchedTensorImpl.h @@ -0,0 +1,170 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +#include +#include +#include + +namespace at::functorch { + +using Tensor = at::Tensor; + +// We assume this in a few other places in the codebase, +// but there isn't a centralized definition. +constexpr int64_t kVmapMaxTensorDims = 64; + +// The valid vmap levels range from [0, 64). This effectively means that we +// support a maximum of 64 nested vmaps. +constexpr int64_t kVmapNumLevels = 64; + +// Store this number of elements of BatchDims on the stack. Most people will +// probably use <= 5 nested vmaps, but adjust this number as necessary. +constexpr int64_t kBatchDimsStackSize = 5; + +// A BatchedTensorImpl holds an underlying Tensor and a single batch dim +// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a +// BatchedTensorImpl. +// +// The batch dimensions are treated as being "private"; they are not user-visible. +// For example, in the following Tensor, +// bt = BatchedTensorImpl(ones(2, 3, 5, 7), lvl=1, dim=0) +// dimension 0 is batch dimension. +// +// bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public) +// dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7) tensor. +struct TORCH_API BatchedTensorImpl : public c10::TensorImpl { + explicit BatchedTensorImpl(at::DispatchKeySet key_set, Tensor value, int64_t dim, int64_t level); + + // Returns batch dimension of this tensor + int64_t bdim() const { return bdim_; } + + // Returns batch dimension of this tensor + int64_t level() const { return level_; } + + // BatchedTensorImpl wraps a Tensor + const Tensor& value() const { return value_; } + + // Given a public dimension index, return the dimension index in the underlying + // value() tensor. + // For example, if we have + // bt = BatchedTensorImpl(ones(2, 3, 5, 7), lvl=1, dim=0) + // bt.actualDim(0) -> 1 + // bt.actualDim(1) -> 2 + // bt.actualDim(2) -> 3 + // bt.actualDim(3) -> Error + int64_t actualDim(int64_t dim, bool wrap_dim = true) const; + + IntArrayRef sizes_custom() const override; + SymIntArrayRef sym_sizes_custom() const override; + int64_t size_custom(int64_t d) const override; + c10::SymInt sym_size_custom(int64_t d) const override; + // We have to override this because we opted into CustomStrides + IntArrayRef strides_custom() const override; + SymIntArrayRef sym_strides_custom() const override; + // Override a bunch of methods inherited from TensorImpl to return error messages. + bool is_contiguous_custom(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const override; + void set_size(int64_t dim, int64_t new_size) override; + void set_stride(int64_t dim, int64_t new_stride) override; + c10::intrusive_ptr shallow_copy_and_detach( + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change) const override; + c10::intrusive_ptr shallow_copy_and_detach( + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const override; + void shallow_copy_from(const c10::intrusive_ptr& impl) override; +#ifdef DEBUG + bool has_storage() const override; +#endif + + void refreshTensorMetadata(); + + // Used in torchdim. torchdim uses non-lexical BatchedTensor; the way it + // accomplishes this is a hack where it is able to modify the levels of + // BatchedTensor to match the level of the current vmap transform. + void _unsafe_set_level(int64_t level) { + level_ = level; + } + + // Used in batching rule for in-place view operations that can change + // the index of the bdim (think squeeze_, unsqueeze_) + void unsafe_set_bdim(int64_t bdim) { + // NB: you MUST call refreshTensorMetadata after doing this. + bdim_ = bdim; + } + private: + // see NOTE: [BatchedTensorImpl levels invariant] + void checkInvariants() const; + const char* tensorimpl_type_name() const override; + + Tensor value_; + + int64_t level_; + int64_t bdim_; +}; + +// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a +// BatchedTensorImpl. +inline bool isBatchedTensor(const Tensor& tensor) { + return tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::FuncTorchBatched) || + tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::BatchedNestedTensor); +} + +// It is unsafe to call this on a Tensor that is not backed by a +// BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible. +inline BatchedTensorImpl* unsafeGetBatchedImpl(const Tensor& tensor) { + return static_cast(tensor.unsafeGetTensorImpl()); +} + +inline BatchedTensorImpl* maybeGetBatchedImpl(const Tensor& tensor) { + if (!isBatchedTensor(tensor)) { + return nullptr; + } + return unsafeGetBatchedImpl(tensor); +} + +// Returns a bitset. If bit i is set, then that means dim i is a batchdim. +inline std::bitset createBatchDimBitset(int64_t dim) { + std::bitset is_bdim; + is_bdim.set(dim); + return is_bdim; +} + +// Creates a bitset for the given level +inline std::bitset createVmapLevelsBitset(int64_t level) { + std::bitset result; + result.set(level); + return result; +} + +// Use this to construct a BatchedTensor from a regular Tensor +TORCH_API Tensor makeBatched(Tensor tensor, int64_t dim, int64_t level); + +// Adds a batch dim to `tensor`, returning a BatchedTensor +TORCH_API Tensor addBatchDim(Tensor tensor, int64_t dim, int64_t level); + +// Certain dispatch keys must be propagated to the BatchedTensor (or, in general, +// any wrapper Tensor subclasses). This is because there are methods on Tensor +// that skip dispatch and check for the presence of a dispatch key (e.g. is_cpu()). +// TODO: should probably contain more (or all?) backend keys +constexpr DispatchKeySet kKeysToPropagateToWrapper({ + DispatchKey::Negative, + DispatchKey::Conjugate, + DispatchKey::XLA, + DispatchKey::CUDA, + DispatchKey::CPU, + DispatchKey::PrivateUse1, +}); + +inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) { + auto key_set = tensor.unsafeGetTensorImpl()->key_set(); + return key_set & kKeysToPropagateToWrapper; +} + +} // namespace at::functorch diff --git a/phivenv/Lib/site-packages/torch/include/ATen/functorch/BatchingMetaprogramming.h b/phivenv/Lib/site-packages/torch/include/ATen/functorch/BatchingMetaprogramming.h new file mode 100644 index 0000000000000000000000000000000000000000..9f58df839d442afb9f979841770edc8f80427582 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/functorch/BatchingMetaprogramming.h @@ -0,0 +1,126 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include +#include + +// This file contains template metaprogramming things that are used for our +// batching rules. +// +// See NOTE: [vmap plumbing] for more details on why this is necessary. +// The plumbing has a bunch of metaprogramming hacks for determining the signature +// of a batching rule from the signature of the operator, many of which use the +// helper functions in this file. + +namespace at::functorch { + +// Metaprogramming things +template using typelist = c10::guts::typelist::typelist; +template using head_t = c10::guts::typelist::head_t; +template using concat_t = c10::guts::typelist::concat_t; +template class debug_t; + +// tail operation +template +struct tail final { + static_assert(c10::guts::false_t::value, + "In typelist::tail, the T argument must be typelist<...>."); +}; +template +struct tail> final { + using type = typelist; +}; +template using tail_t = typename tail::type; + +template +struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext { + using type = Next; +}; +template +struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext, Next, Tail> { + using type = Tail; +}; +template +struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext, Next, Tail> { + using type = Tail; +}; +template +struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext, Next, Tail> { + using type = Tail; +}; +template +struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext, std::optional, Next, Tail> { + using type = Tail; +}; +template +struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext&, std::optional, Next, Tail> { + using type = Tail; +}; +template +struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext&, std::optional, Next, Tail> { + using type = Tail; +}; +template +struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext, std::optional, Next, Tail> { + using type = Tail; +}; +template struct RemoveBatchDimAfterTensor { + using first = head_t; + using next = tail_t; + using second = head_t; + using tail = tail_t; + + using type = concat_t< + typelist, + typename RemoveBatchDimAfterTensor< + typename IfFirstIsTensorAndSecondisBatchDimThenTailElseNext::type + >::type + >; +}; +template struct RemoveBatchDimAfterTensor> { + using type = typelist; +}; +template <> struct RemoveBatchDimAfterTensor> { + using type = typelist<>; +}; +template using remove_batch_dim_after_tensor_t = typename RemoveBatchDimAfterTensor::type; + +template struct UnpackSingleItemTuple { + using type = T; +}; +template struct UnpackSingleItemTuple> { + using type = T; +}; +template using unpack_single_item_tuple_t = typename UnpackSingleItemTuple::type; + +template struct BuildFunctionHelper; +template struct BuildFunctionHelper> { + using type = Return(Args...); +}; +template +struct BuildFunction { + using type = typename BuildFunctionHelper>::type; +}; +template using build_function_t = typename BuildFunction::type; + + +template struct ToOperatorType { + using batch_rule_return_type = typename c10::guts::function_traits::return_type; + using batch_rule_parameter_types = typename c10::guts::function_traits::parameter_types; + + using operator_parameter_types = remove_batch_dim_after_tensor_t; + using operator_return_type = + unpack_single_item_tuple_t< + c10::guts::typelist::to_tuple_t< + remove_batch_dim_after_tensor_t< + c10::guts::typelist::from_tuple_t>>>; + + using type = build_function_t; +}; +template using to_operator_t = typename ToOperatorType::type; + +} // namespace at::functorch diff --git a/phivenv/Lib/site-packages/torch/include/ATen/functorch/DynamicLayer.h b/phivenv/Lib/site-packages/torch/include/ATen/functorch/DynamicLayer.h new file mode 100644 index 0000000000000000000000000000000000000000..da21067e4d0ca94d45cd29a3a475e69602f72f36 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/functorch/DynamicLayer.h @@ -0,0 +1,124 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Forward declared +namespace c10 { struct AutogradMetaInterface; } + +namespace at::functorch { + +// This file contains the implementation of functorch's interpreter stack. +// See NOTE: [functorch interpreter stack] first before reading on. +// +// NB: the functorch interpreter stack is also referred to as: +// - the "dynamic layer stack" -- an older name for "interpreter" was +// "dynamic layer". +// - the "functorch mode stack". You can think of each functorch transform as a +// "mode" (in the same sense as torch_dispatch mode or torch_function mode), +// and functorch being an implementation of a "mode stack" where the modes +// may be arbitrary composed. + +// DynamicLayer is basically the same thing as an Interpreter. +// It represents a functorch transform and it holds an Interpreter, +// which contains metadata related to the transform and instructions on +// how to perform the transform. +// +// TODO: we can excise DynamicLayer in favor of Interpreter, +// But I am going to leave it for now as a compatiblity shim to avoid +// needing to refactor a lot of callsites... +struct TORCH_API DynamicLayer { + explicit DynamicLayer( + TransformType transform_type, + int64_t layerId, + std::optional batchSize = std::nullopt, + std::optional randomness = std::nullopt, + std::optional prev_grad_mode = std::nullopt, + std::optional pre_fwd_grad_mode = std::nullopt, + std::optional functionalize_add_back_views = std::nullopt); + + TransformType key() const; + int64_t layerId() const; + + const Interpreter& interpreter() const { return interpreter_; } + Interpreter& interpreter() { return interpreter_; } + + // Only valid for vmap + c10::SymInt batchSize() const; + RandomnessType randomness() const; + + private: + Interpreter interpreter_; +}; + +TORCH_API int64_t initAndPushDynamicLayer( + TransformType transform_type, + std::optional batch_size = std::nullopt, + std::optional randomness = std::nullopt, + std::optional prev_grad_mode = std::nullopt, + std::optional prev_fwd_grad_mode = std::nullopt, + std::optional functionalize_add_back_views = std::nullopt); +TORCH_API DynamicLayer popDynamicLayerAndDeleteMetadata(); +TORCH_API std::optional maybeCurrentDynamicLayer(); +TORCH_API const std::vector& getDynamicLayerStack(); +TORCH_API void setDynamicLayerStack(const std::vector& stack); +TORCH_API void setDynamicLayerFrontBackKeysIncluded(bool included); + +// NOTE: [Life handles and lexically scoped transforms] +// functorch transforms are lexically scoped. +// Given a level, we store a "life handle" that is a boolean that tells us if the +// transform with that level is active or not. +// +// functorch's TensorWrapper (for grad transforms) stores a life handle. +// If a TensorWrapper escapes from the scope of the transform, then somehow +// it must know it escaped; it can tell by querying the life handle. +TORCH_API const std::shared_ptr& getLifeHandleForLevel(int64_t level); + +// Returns if an operator is in-place. An operator is inplace if: +// 1. The first argument is a Tensor and it is being written to +// 2. The first argument is being returned +// 3. No other arguments are aliased +// Here is an example of an in-place operator: +// add_(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) +TORCH_API bool isInplaceOp(const c10::FunctionSchema& schema); + +// Given the indices of unwrapped inputs and the schema, this returns the indices of any outputs that should remain unwrapped +TORCH_API std::optional findAliasedOutput(const FunctionSchema& schema, const int64_t immutable_input); + +TORCH_API Tensor unwrapIfDead(const Tensor& tensor); +TORCH_API bool isDeadTensorWrapper(const Tensor& tensor); + +// Pretty printers +TORCH_API std::ostream& operator<<(std::ostream& os, const DynamicLayer& layer); +TORCH_API std::ostream& operator<<(std::ostream& os, const std::vector& dynamicLayerStack); + +// While a functorch transform is active, torch.autograd.function._SingleLevelFunction +// is disabled by default. The following two APIs are APIs for enabling +// it. These are not user-facing APIs. We can delete this in the future, but +// it is useful for debugging when something goes wrong with the +// autograd.Function <> functorch interaction, which uses _SingleLevelFunction, +// because it leads to loud errors if something is incorrect. +TORCH_API void setSingleLevelAutogradFunctionAllowed(bool allowed); +TORCH_API bool getSingleLevelAutogradFunctionAllowed(); + +// While a functorch grad transform is active, Tensor.requires_grad_() gets +// disabled. These two functions are the mechanism to controlling that. +TORCH_API void setInplaceRequiresGradAllowed(bool allowed); +TORCH_API bool getInplaceRequiresGradAllowed(); + +TORCH_API DynamicLayer popDynamicLayer(); +TORCH_API int64_t pushDynamicLayer(DynamicLayer&& layer); + +} // namespace at::functorch diff --git a/phivenv/Lib/site-packages/torch/include/ATen/functorch/FunctionalizeInterpreter.h b/phivenv/Lib/site-packages/torch/include/ATen/functorch/FunctionalizeInterpreter.h new file mode 100644 index 0000000000000000000000000000000000000000..5ae0bcdccdf5fc5c061b542d175f677eace9a4c2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/functorch/FunctionalizeInterpreter.h @@ -0,0 +1,22 @@ +#pragma once +#include + +namespace at::functorch { + +// This is the interpreter that handles the functionalize() transform. +// See NOTE: [functorch interpreter stack] for more details. + +struct FunctionalizeInterpreterPtr { + explicit FunctionalizeInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Functionalize); } + TransformType key() const { return base_->key(); } + int64_t level() const { return base_->level(); } + void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack); + void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case); + bool functionalizeAddBackViews() const { + return std::get(base_->meta()).functionalizeAddBackViews_; + } + private: + const Interpreter* base_; +}; + +} // namespace at::functorch diff --git a/phivenv/Lib/site-packages/torch/include/ATen/functorch/Interpreter.h b/phivenv/Lib/site-packages/torch/include/ATen/functorch/Interpreter.h new file mode 100644 index 0000000000000000000000000000000000000000..32ef50c9feba977a1d1d55e3402b2827d9cb5571 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/functorch/Interpreter.h @@ -0,0 +1,351 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace at::functorch { + +// NOTE: [functorch interpreter stack] +// +// functorch's dispatching system uses a stack of interpreters. +// Historically we've referred to this as the "DynamicLayerStack". +// +// An interpreter is something that reads in the code it is passed +// and then executes it. We have a different interpreter per-transform: +// the "VmapInterpreter" is responsible for reading in operators (like aten::mv) +// and executing the batched version of it (the batching rule for aten::mv). +// +// Concretely, each interpreter is responsible for two things: +// +// 1) process(ophandle, stack) +// Given an operator handle and a stack of arguments, the interpreter is +// responsible for figuring out how to execute the operation under the semantics +// of the interpreter. For e.g. VmapInterpreter, this is figuring out how to call +// the batching rule. +// +// The batching rules are stored as kernels on the FuncTorchBatched key, so the way +// VmapInterpreter calls the batching rule is roughly: (A) exclude all +// dispatch keys aside from the Batched key, (B) redispatch so we get to the +// Batched key. +// +// 2) sendToNextInterpreter(ophandle, stack) +// The VmapInterpreter, when it sees aten::mv, will process it into a call to +// aten::mm. It then needs to send the call to aten::mm to the next interpreter +// in the interpreter stack. +// +// The VmapInterpreter just does this via a call to ophandle.callBoxed(stack) +// and most Interpreters will implement it this way. + +enum class RandomnessType { + Error, // always errors when calling a random function + Same, // randomness appears the same across batches + Different, // randomness appears different across batches + END +}; + +enum class TransformType { + Torch, // Unused + Vmap, + Grad, // reverse-mode AD, aka vjp + Jvp, // forward-mode AD + Functionalize, +}; + +std::ostream& operator<<(std::ostream& os, const TransformType& t); + +// NOTE: [Interpreter "subclassing" design] +// +// How are various Interpreters for different transforms (vmap, grad, ...) +// implemented? +// +// Accessing interpreters is in the hot-path of functorch so we have a constraint +// that this code must be as fast as possible. +// +// As a result, we stay away from virtual methods and this causes our code +// to look a little funny. +// +// `Interpreter` is the struct for Interpreters. It holds ALL of the +// relevant information (what type of interpreter it is and the metadata). +// Metadata for each interpreter is represented as a Union (std::variant) +// of all possible metadata (VmapInterpreterMeta, GradInterpreterMeta, ...). +// +// Given an Interpreter, how do I get a "VmapInterpreter"? You may wish to do this +// if you want to access the metadata fields (like batchSize and randomness). +// +// Each type of interpreter (e.g. Vmap) has a convenience struct +// (e.g. VmapInterpreterPtr) associated with it. +// +// Construct the convenience struct with VmapInterpreterPtr(Interpreter*), +// and then one can access methods on VmapInterpreterPtr like so: +// >>> VmapInterpreterPtr(&interpreter).batchSize() +// +// Finally, Interpreter::process switches on the type of the interpreter +// and calls one of {Transform}Intepreter::processImpl under the hood. +// Same for Interpreter::sendToNextInterpreter :) + +struct VmapInterpreterMeta { + explicit VmapInterpreterMeta(c10::SymInt batchSize, RandomnessType randomness) : + batchSize_(std::move(batchSize)), randomness_(randomness) {} + + c10::SymInt batchSize_; + RandomnessType randomness_; + + VmapInterpreterMeta() = default; + VmapInterpreterMeta(const VmapInterpreterMeta&) = default; + VmapInterpreterMeta(VmapInterpreterMeta&&) = default; + VmapInterpreterMeta& operator=(const VmapInterpreterMeta&) = default; + VmapInterpreterMeta& operator=(VmapInterpreterMeta&&) = default; + ~VmapInterpreterMeta() = default; + + template + friend void to_json(T& json_j, const VmapInterpreterMeta& json_t) { + if (json_t.batchSize_.is_heap_allocated()) { + throw std::runtime_error("Serialization for heap-allocated SymInt is not implemented yet"); + } + json_j["batchSize"] = json_t.batchSize_.as_int_unchecked(); + json_j["randomness"] = static_cast(json_t.randomness_); + } + + template + friend void from_json(const T& json_j, VmapInterpreterMeta& json_t) { + json_t.batchSize_ = c10::SymInt(SymInt::Unchecked::UNCHECKED, json_j["batchSize"]); + json_t.randomness_ = static_cast(json_j["randomness"]); + } +}; + +struct GradInterpreterMeta { + explicit GradInterpreterMeta(bool prevGradMode): prevGradMode_(prevGradMode) {} + GradInterpreterMeta() = default; + GradInterpreterMeta(const GradInterpreterMeta&) = default; + GradInterpreterMeta(GradInterpreterMeta&&) = default; + GradInterpreterMeta& operator=(const GradInterpreterMeta&) = default; + GradInterpreterMeta& operator=(GradInterpreterMeta&&) = default; + ~GradInterpreterMeta() = default; + + bool prevGradMode_; + template + friend void to_json(T& json_j, const GradInterpreterMeta& json_t) { + json_j["prevGradMode"] = json_t.prevGradMode_; + } + + template + friend void from_json(const T& json_j, GradInterpreterMeta& json_t) { + json_t.prevGradMode_ = json_j["prevGradMode"]; + } +}; + +struct JvpInterpreterMeta { + explicit JvpInterpreterMeta(bool prevFwdGradMode) : prevFwdGradMode_(prevFwdGradMode) {} + JvpInterpreterMeta() = default; + JvpInterpreterMeta(const JvpInterpreterMeta&) = default; + JvpInterpreterMeta(JvpInterpreterMeta&&) = default; + JvpInterpreterMeta& operator=(const JvpInterpreterMeta&) = default; + JvpInterpreterMeta& operator=(JvpInterpreterMeta&&) = default; + ~JvpInterpreterMeta() = default; + + bool prevFwdGradMode_; + template + friend void to_json(T& json_j, const JvpInterpreterMeta& json_t) { + json_j["prevFwdGradMode"] = json_t.prevFwdGradMode_; + } + + template + friend void from_json(const T& json_j, JvpInterpreterMeta& json_t) { + json_t.prevFwdGradMode_ = json_j["prevFwdGradMode"]; + } +}; + +struct FunctionalizeInterpreterMeta { + explicit FunctionalizeInterpreterMeta(bool functionalizeAddBackViews) : + functionalizeAddBackViews_(functionalizeAddBackViews) {} + FunctionalizeInterpreterMeta() = default; + FunctionalizeInterpreterMeta(const FunctionalizeInterpreterMeta&) = default; + FunctionalizeInterpreterMeta(FunctionalizeInterpreterMeta&&) = default; + FunctionalizeInterpreterMeta& operator=(const FunctionalizeInterpreterMeta&) = default; + FunctionalizeInterpreterMeta& operator=(FunctionalizeInterpreterMeta&&) = default; + ~FunctionalizeInterpreterMeta() = default; + + bool functionalizeAddBackViews_; + template + friend void to_json(T& json_j, const FunctionalizeInterpreterMeta& json_t) { + json_j["functionalizeAddBackViews"] = json_t.functionalizeAddBackViews_; + } + + template + friend void from_json(const T& json_j, FunctionalizeInterpreterMeta& json_t) { + json_t.functionalizeAddBackViews_ = json_j["functionalizeAddBackViews"]; + } +}; + +typedef std::variant< + int64_t, + GradInterpreterMeta, + JvpInterpreterMeta, + VmapInterpreterMeta, + FunctionalizeInterpreterMeta +> InterpreterMeta; + + +struct Interpreter { + // factory functions + static Interpreter Vmap(int64_t level, c10::SymInt batchSize, RandomnessType randomness) { + return Interpreter(TransformType::Vmap, level, VmapInterpreterMeta(std::move(batchSize), randomness)); + } + static Interpreter Grad(int64_t level, bool prevGradMode) { + return Interpreter(TransformType::Grad, level, GradInterpreterMeta(prevGradMode)); + } + static Interpreter Jvp(int64_t level, bool prevFwdGradMode) { + return Interpreter(TransformType::Jvp, level, JvpInterpreterMeta(prevFwdGradMode)); + } + static Interpreter Functionalize(int64_t level, bool functionalizeAddBackViews) { + return Interpreter(TransformType::Functionalize, level, FunctionalizeInterpreterMeta(functionalizeAddBackViews)); + } + + // methods + TransformType key() const { return type_; } + int64_t level() const { return level_; } + const InterpreterMeta& meta() const { return meta_; } + + void process(const c10::OperatorHandle& op, torch::jit::Stack* stack); + void sendToNextInterpreter(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case); + + void saveLocalDispatchKeySet(c10::impl::LocalDispatchKeySet keyset) { + TORCH_INTERNAL_ASSERT(!savedLocalDispatchKeySet_.has_value()); + savedLocalDispatchKeySet_ = keyset; + } + void clearSavedLocalDispatchKeySet() { + TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value()); + savedLocalDispatchKeySet_ = std::nullopt; + } + c10::impl::LocalDispatchKeySet getSavedLocalDispatchKeySet() const { + TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value()); + return *savedLocalDispatchKeySet_; + } + + // An Interpreter is alive if we are currently inside the ongoing transform + // for the interpreter. For example, vmap(f)(x); inside of f, the vmap's + // corresponding Interpreter is alive, even when it is not on the DynamicLayerStack. + bool is_alive() const { + return *is_alive_; + } + const std::shared_ptr& is_alive_ptr() const { + return is_alive_; + } + void set_is_alive(bool alive) { + *is_alive_ = alive; + } + + // Please don't use this + explicit Interpreter() = default; + + template + friend void to_json(T& json_j, const Interpreter& json_t) { + json_j["type"] = static_cast(json_t.type_); + json_j["level"] = json_t.level_; + if (json_t.savedLocalDispatchKeySet_) { + json_j["savedLocalDispatchKeySet"] = { + {"included", json_t.savedLocalDispatchKeySet_->included_.raw_repr()}, + {"excluded", json_t.savedLocalDispatchKeySet_->excluded_.raw_repr()} + }; + } else { + json_j["savedLocalDispatchKeySet"] = nlohmann::json(); + } + json_j["is_alive"] = *json_t.is_alive_; + std::visit([&](auto&& arg) { + using V = std::decay_t; + if constexpr (std::is_same_v) { + json_j["meta"] = {{"Torch", arg}}; + } else if constexpr (std::is_same_v) { + json_j["meta"] = {{"Grad", arg}}; + } else if constexpr (std::is_same_v) { + json_j["meta"] = {{"Jvp", arg}}; + } else if constexpr (std::is_same_v) { + json_j["meta"] = {{"Vmap", arg}}; + } else if constexpr (std::is_same_v) { + json_j["meta"] = {{"Functionalize", arg}}; + } else { + static_assert(false && sizeof(V), "unknown variant case"); + } + }, json_t.meta_); + } + + template + friend void from_json(const T& json_j, Interpreter& json_t) { + json_t.type_ = static_cast(json_j["type"]); + json_t.level_ = json_j["level"]; + auto savedLocalDispatchKeySet = json_j["savedLocalDispatchKeySet"]; + if (savedLocalDispatchKeySet.is_null()) { + json_t.savedLocalDispatchKeySet_ = std::nullopt; + } else { + c10::impl::PODLocalDispatchKeySet pod; + pod.set_included(DispatchKeySet::from_raw_repr(savedLocalDispatchKeySet["included"].template get())); + pod.set_excluded(DispatchKeySet::from_raw_repr(savedLocalDispatchKeySet["excluded"].template get())); + json_t.savedLocalDispatchKeySet_ = c10::impl::LocalDispatchKeySet(pod); + } + json_t.is_alive_ = std::make_shared(json_j["is_alive"]); + auto meta = json_j["meta"]; + if (meta.contains("Torch")) { + json_t.meta_.emplace(meta["Torch"].template get()); + } else if (meta.contains("Grad")) { + json_t.meta_.emplace(meta["Grad"].template get()); + } else if (meta.contains("Jvp")) { + json_t.meta_.emplace(meta["Jvp"].template get()); + } else if (meta.contains("Vmap")) { + json_t.meta_.emplace(meta["Vmap"].template get()); + } else if (meta.contains("Functionalize")) { + json_t.meta_.emplace(meta["Functionalize"].template get()); + } else { + throw std::runtime_error("unknown interpreter metadata type"); + } + } + + std::string serialize() const { + return nlohmann::json(*this).dump(); + } + + static Interpreter deserialize(const std::string& serialized) { + return nlohmann::json::parse(serialized).get(); + } + + private: + explicit Interpreter(TransformType type, int64_t level, InterpreterMeta meta): + type_(type), level_(level), is_alive_(std::make_shared(false)), meta_(std::move(meta)) {} + + // fields + TransformType type_{}; + int64_t level_{}; + std::optional savedLocalDispatchKeySet_; + std::shared_ptr is_alive_; + InterpreterMeta meta_; +}; + +// Applies the following for-loop: +// for i in range(begin, end): +// args[i] = func(args[i]) +void foreachTensorInplace(std::vector& args, int64_t begin, int64_t end, + std::function func); + +// Applies the following for-loop: +// for i in range(begin, end): +// if use_flag_relative[i] == 1: <-- treats use_flag_relative as a bitset +// args[i] = func(args[i], i - begin, true) +// args[i] = func(args[i], i - begin) +void foreachTensorInplaceWithFlag(std::vector& args, int64_t begin, int64_t end, + const std::bitset<64> use_flag_relative, const std::function& func); + +std::vector findUnwrappedInputs(std::vector& args, int64_t begin, int64_t end); + +DispatchKeySet keysToExcludeWhenEnteringDynamicLayer(TransformType key); + +void setup_dispatch_key_tls(TransformType key, DispatchKeySet include); + +void sanityCheckStack(const c10::OperatorHandle& op, torch::jit::Stack* stack); + +} // namespace at::functorch diff --git a/phivenv/Lib/site-packages/torch/include/ATen/functorch/LegacyVmapTransforms.h b/phivenv/Lib/site-packages/torch/include/ATen/functorch/LegacyVmapTransforms.h new file mode 100644 index 0000000000000000000000000000000000000000..597eb9fd6ec033e4ec90c692eaa608efac95b1f7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/functorch/LegacyVmapTransforms.h @@ -0,0 +1,187 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include + +namespace at::functorch { + +// This files contains the legacy (now-deprecated) batching rule API. +// Please try to use the new-style batching rule API (see writing_batch_rules.md) + +// This file contains abstractions used for transforming *logical* vmap arguments +// into *physical* arguments. (Keep reading for definitions of these terms). + +// NOTE: [Logical vs physical args] +// Consider the following vmap. +// vmap(vmap(func, in_dims=(2,)), in_dims=(0,))(torch.ones(2, 3, 4)) +// This would produce a BatchedTensor wrapping a Tensor of size [2, 3, 4], +// with batch dims 0 and 2: +// BatchedTensor(ones(2, 3, 4), bdims=[(lvl=1,dim=0),(lvl=2,dim=2)]) +// +// We say the *logical* view of the tensor has size [3] -- tensors inside +// `func` appear to have size [3]. +// However, the *physical* underlying tensor (the one passed to vmap) has size +// [2, 3, 4]. +// +// This notion of logical vs physical also extends to non-tensor arguments. +// Consider the previous tensor; let's assume the user called +// `torch.sum(tensor, dim=0)` inside of `func`. Then the logical +// dimension they are reducing over is dim 0 but the physical dim is dim 1 +// (the first non-batch dimension) + +// Forward declared; see NOTE: [What is a VmapPhysicalView?] +struct VmapPhysicalView; + +// Most PyTorch operators take 4 or fewer inputs. +constexpr int64_t kVmapTransformStaticInputSize = 4; +using VmapPhysicalViewVec = SmallVector; + +// Pytorch generally advertises good performance for <= 5 dims. +// (see ATen/core/DimVector.h). We add a few extra dims (~3) for vmap +// dimensions to get 8. Adjust this number as necessary +constexpr int64_t kVmapStaticDimVecSize = 8; +using VmapDimVector = SmallVector; +using VmapSymDimVector = SmallVector; + +// NOTE: [What is an VmapTransform?] +// An *VmapTransform* converts logical views of tensors to physical views. +// +// Batching rules use VmapTransforms to convert logical arguments to +// physical arguments, then call one or more at:: operator that handles the +// physical arguments, and then converts the physical result back to a logical +// argument. + +// VmapTransform for operators that take tensors with multiple batch dims. +// Given one or more logical views on Tensors, `logicalToPhysical` +// permutes all of the batch dims to the front of the tensor, aligns +// and expands the batch dims to match each other (according to their `level`), +// and returns a VmapPhysicalView on the tensor(s). +struct TORCH_API MultiBatchVmapTransform { + static VmapPhysicalView logicalToPhysical(const Tensor& logical_tensor); + static VmapPhysicalViewVec logicalToPhysical(ITensorListRef logical_tensors); +}; + +// VmapTransform for operators that broadcast all inputs. +// Given some logical views on Tensors, `logicalToPhysical`: +// - permutes all of the batch dims to the front of the tensors +// - aligns all the batch dims to the collective levels of all of the tensors. +// If a tensor does not have a batch dim for a vmap level, then it receives +// a size-one dimension for said level. +// - aligns the non-batch dims to have the same dimensionality, adding extra +// size-1 dimensions in between the batch dimensions and the non-batch dimensions +// so that the batch dimensions are lined up from the right. +// +// For example: given inputs of size (B, 2) and (B, 3, 2) where B is the batch +// dimension, BroadcastingVmapTransform returns VmapPhysicalViews that wrap tensors +// of size (B, 1, 2) and (B, 3, 2). +// +// Given inputs of size (B, 2) and (2,), BroadcastingVmapTransform returns +// VmapPhysicalViews wrapping tensors of size (B, 2) and (1, 2). We don't +// actually *need* to return a tensor of size (1, 2) for the second tensor +// because the broadcasting operation takes care of that for us, but we do +// it anyways to keep things simple. +struct TORCH_API BroadcastingVmapTransform { + static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors); +}; + +// Forward declared, if you're reading this file head to toe, don't worry about +// it yet. +struct VmapPhysicalToLogicalMap; + +// NOTE: [What is a VmapPhysicalView?] +// VmapPhysicalView represents a physical view on a Tensor. +// +// One can use it to further convert logical dimension indices, logical shapes, +// and more to their physical variants, or convert a new (physical) tensor into +// a logical BatchedTensor. (TODO(rzou): some of these are not yet implemented). +// +// VmapPhysicalView stores a physical tensor with all of its batch dimensions at +// the front and some levels that correspond to said batch dimensions. +// +// The levels bitset specifies which vmap levels correspond to the batch +// dimensions at the front of the tensor. In particular, the number of set bits +// corresponds to the number of batch dimensions on `tensor` and the rightmost +// bit of `levels` specifies the maximum number of nested vmaps we are in at +// this point in time. +// For example, given: +// physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5, 6), levels={1, 3}) +// +// Rightmost bit of `levels` is 3 indicating the number of nested vmaps less +// than or equal to 3. +// bitset: 010100 +// ^ +// | +// levels: 012345 +struct TORCH_API VmapPhysicalView { + VmapPhysicalView(Tensor&& tensor, std::bitset levels) + : levels_(levels), tensor_(std::move(tensor)) { + // TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor)); + } + + Tensor& tensor() { return tensor_; } + const Tensor& tensor() const { return tensor_; } + + // Maps logical dim indices to physical dim indices. Also does dim wrapping. + // + // For example, given: + // physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5), levels={1, 3}) + // + // Then physical_view.getPhysicalDims({0, 1}) returns {2, 3}. + // This is because the size of levels tell us that the first two dimensions + // of `tensor_` are batch dimensions, so a logical dim of `n` is actually + // a physical dim of `n + 2`. + VmapDimVector getPhysicalDims(IntArrayRef logical_dims) const; + int64_t getPhysicalDim(int64_t logical_dim) const; + + // Returns a VmapPhysicalToLogicalMap object. This can be used for + // mapping a physical tensor to a new logical tensor (BatchedTensor) + VmapPhysicalToLogicalMap getPhysicalToLogicalMap() const; + + // Maps a logical shape to a physical shape by pre-pending the batch + // sizes to the logical shape. + VmapDimVector getPhysicalShape(IntArrayRef logical_shape) const; + SymDimVector getPhysicalShape(c10::SymIntArrayRef logical_shape) const; + + int64_t numBatchDims() const; + + private: + int64_t numLogicalDims() const; + + std::bitset levels_; + Tensor tensor_; +}; + +// Convenience struct used for mapping a physical tensor (a non-BatchedTensor) +// to a logical one (BatchedTensor). It holds some levels that are used to do the +// mapping and assumes that the batch dimensions in the physical tensor all +// occur at the front of the tensor. +struct TORCH_API VmapPhysicalToLogicalMap { + VmapPhysicalToLogicalMap(std::bitset levels): levels_(levels) {} + + // Maps a physical tensor to a new logical tensor (BatchedTensor). + // Assumes that all of the "batch dimensions" are at the front + // of the physical tensor. For example, given: + // - x = rank-4 Tensor with size 2, 3, 5, 7 + // - levels = (2, 4) + // Returns: + // - BatchedTensor(x, bdims=[(dim=0,lvl=2), (dim=1, lvl=4)]) + Tensor apply(const Tensor& physical_tensor) const; + + // Given a vector of physical tensors, + // 1. maps each tensor to a new logical tensor. Assumes that all of the + // "batch dimensions" are at the front of the physical tensors. + // 2. stores the new logical tensors back into the passed-in vector. This is + // to avoid additional dynamic allocations. + void applyInplace(std::vector& physical_tensors) const; + + std::bitset levels_; +}; + + +} // namespace at::functorch diff --git a/phivenv/Lib/site-packages/torch/include/ATen/functorch/Macros.h b/phivenv/Lib/site-packages/torch/include/ATen/functorch/Macros.h new file mode 100644 index 0000000000000000000000000000000000000000..b99be8781c127d5d8c49fdc1b7b80027c9383e48 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/functorch/Macros.h @@ -0,0 +1,3 @@ +#pragma once + +#define SINGLE_ARG(...) __VA_ARGS__ diff --git a/phivenv/Lib/site-packages/torch/include/ATen/functorch/PlumbingHelper.h b/phivenv/Lib/site-packages/torch/include/ATen/functorch/PlumbingHelper.h new file mode 100644 index 0000000000000000000000000000000000000000..cd31ffb4553af64b7ab789a9e32d46f95bb8f18e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/functorch/PlumbingHelper.h @@ -0,0 +1,63 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. +#pragma once +#include +#include +#include + +// NOTE: [vmap plumbing] +// +// Here's how "batching rules" work. +// - we register kernels to the Batched key +// - these kernels have the same signatures as the original operators. +// For example, at::sin(Tensor self) accepts a Tensor, and the batched kernel +// must also accept a Tensor +// - However, it is more natural for users to write a batching rule like the +// following: sin_batch_rule(Tensor self, std::optional self_bdim) +// - There is some codegenerated layer (the "plumbing") that wraps the user +// defined batching rule (e.g. sin_batch_rule) in a kernel that can be +// registered to the Batched key. +// +// The plumbing is responsible for wrapping a batching rule into a form that may +// be registered as the kernel for the batched key. + +namespace at::functorch { + +void vmap_check_escaped(const std::optional &layer, const char* what); + +// Create a BatchedTensor given a tensor, bdim, and level +TORCH_API Tensor makeBatched(Tensor tensor, std::optional bdim, int64_t level); + +// Given a Tensor that may or may not be a BatchedTensor, unwrap it. +// If `tensor` is not a BatchedTensor, or is a BatchedTensor but the level +// doesn't match, then this returns (tensor, std::nullopt). +// Otherwise, it returns (unwrap(tensor), bdim). +TORCH_API std::tuple> unwrapTensorAtLevel(const Tensor& tensor, int64_t level); + +// Creates a vector of BatchedTensor +TORCH_API std::vector makeBatchedVector(std::vector tensors, std::optional bdim, int64_t level); + +// Returns True if ANY tensor in tensors is batched at level +TORCH_API bool isBatchedAtLevel(ITensorListRef tensors, int64_t level); +TORCH_API bool isBatchedAtLevel(const c10::List>& maybe_tensors, int64_t level); +TORCH_API bool isBatchedAtLevel(const Tensor& tensor, int64_t level); +TORCH_API bool isBatchedAtLevel(const std::optional& maybe_tensor, int64_t level); + +// Convenience helper. Returns true if any tensor is batched at level +TORCH_API bool areAnyBatchedAtLevel(ArrayRef> maybe_tensors, int64_t level); + +inline bool ivalueParticipatesInCurrentLevel(const IValue& ivalue) { + if (ivalue.isTensor()) { + auto maybe_level = maybeCurrentDynamicLayer(); + TORCH_INTERNAL_ASSERT(maybe_level.has_value()); + auto current_level = maybe_level->layerId(); + return isBatchedAtLevel(ivalue.toTensor(), current_level); + } + // TODO: should really check this + return false; +} + +} // namespace at::functorch diff --git a/phivenv/Lib/site-packages/torch/include/ATen/functorch/TensorWrapper.h b/phivenv/Lib/site-packages/torch/include/ATen/functorch/TensorWrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..6e1f7c8cb058be127f95e4d2cb09891b99a9f771 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/functorch/TensorWrapper.h @@ -0,0 +1,103 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include + +namespace at::functorch { + +// NOTE: [functorch's TensorWrapper] +// +// Taking better suggestions for a name. TensorWrapper is the wrapper Tensor +// Subclass for functorch's grad-based transforms (grad, vjp, jvp). It is +// analogous to how vmap uses BatchedTensor as the wrapper Tensor subclass. +// +// If you're familiar with the Tensor-Variable merge, TensorWrapper is effectively +// another Variable. +// +// Consider grad(grad(torch.sin))(x). This wraps `x` as TensorWrapper(TensorWrapper(x)). +// The reason why is so that each TensorWrapper can hold its own AutogradMeta and +// participate in a **separate** autograd graph. +// +// There are alternative designs we could have chosen (e.g. each grad transform +// stores a weak map of Tensor -> AutogradMeta); the benefit of the TensorWrapper +// design is that we can re-use existing VariableType kernels (i.e. Autograd kernels) +// without much modification. Since a TensorWrapper looks like a regular Tensor, +// the VariableType kernel can pull out the AutogradMeta struct from where it +// expects and extend the autograd graph + +struct TORCH_API TensorWrapper : public c10::TensorImpl { + explicit TensorWrapper( + c10::DispatchKeySet key_set, + Tensor value, + int64_t level, + std::shared_ptr is_alive, + bool is_immutable = false, // if true, this came from an operation that aliases an immutable tensor + bool use_value_sizes_strides = true); + + void refreshMetadata(); + + const Tensor& value() const { + return value_; + } + std::optional level() const { + if (is_alive()) { + return level_; + } + return {}; + } + bool is_immutable() const { + return is_immutable_; + } + bool is_alive() const; + + // Overrides necessary for autograd + c10::intrusive_ptr shallow_copy_and_detach( + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change) const override; + c10::intrusive_ptr shallow_copy_and_detach( + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const override; + void shallow_copy_from(const c10::intrusive_ptr& impl) override; + + private: + const char* tensorimpl_type_name() const override; + Tensor value_; + int64_t level_; + bool is_immutable_; + + // TensorWrapper receives a boolean flag on whether or not the Grad Interpreter + // that created it is still alive or not. + // If the Grad Interpreter is no longer alive then it attempts to behave like + // a regular Tensor. + // + // When we exit the level, this wrapper may be marked as "not alive". + // Wrappers that are not alive: + // 1) May still have autograd metadata on them + // 2) Forward dispatches to the underlying value() + std::shared_ptr is_alive_; +}; + +// There are two variants of makeTensorWrapper: one that accepts a level +// and one that accepts an Interpreter. +// +// The one that accepts a level tries to automatically get the life handle from the +// interpreter on the DynamicLayerStack. +// It needs to be used with caution: if the interpreter is not on the +// DynamicLayerStack, then we won't be able to find the life handle. +// +// In practice this isn't a problem: when we're constructing TensorWrapper in +// Python, the corresponding interpreter is on the stack. +TORCH_API Tensor makeTensorWrapper(const Tensor& tensor, int64_t level, bool is_immutable=false); +TORCH_API Tensor makeTensorWrapper(const Tensor& tensor, const Interpreter& interpreter, bool is_immutable=false); +TORCH_API TensorWrapper* maybeGetTensorWrapper(const Tensor& tensor); +TORCH_API void dumpTensor(std::ostream & ss, const Tensor& tensor); +TORCH_API void dumpTensorCout(const Tensor& tensor); + +} // namespace at::functorch diff --git a/phivenv/Lib/site-packages/torch/include/ATen/functorch/VmapInterpreter.h b/phivenv/Lib/site-packages/torch/include/ATen/functorch/VmapInterpreter.h new file mode 100644 index 0000000000000000000000000000000000000000..8a2539e24faeae1308dc8376bfc3a2b15d438179 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/functorch/VmapInterpreter.h @@ -0,0 +1,25 @@ +#pragma once +#include + +namespace at::functorch { + +// This is the interpreter that handles the functionalize() transform. +// See NOTE: [functorch interpreter stack] for more details. + +struct VmapInterpreterPtr { + explicit VmapInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Vmap); } + TransformType key() const { return base_->key(); } + int64_t level() const { return base_->level(); } + void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack); + void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case); + c10::SymInt batchSize() const { + return std::get(base_->meta()).batchSize_; + } + RandomnessType randomness() const { + return std::get(base_->meta()).randomness_; + } + private: + const Interpreter* base_; +}; + +} // namespace at::functorch diff --git a/phivenv/Lib/site-packages/torch/include/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h b/phivenv/Lib/site-packages/torch/include/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h new file mode 100644 index 0000000000000000000000000000000000000000..fe3c35ecb1748d39cec85933d05f2adb01ef62ee --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h @@ -0,0 +1,31 @@ +#pragma once + +#include +#include + +// Use of c10::hip namespace here makes hipification easier, because +// I don't have to also fix namespaces. Sorry! +namespace c10::hip { + +// Takes a valid HIPAllocator (of any sort) and turns it into +// an allocator pretending to be a CUDA allocator. See +// Note [Masquerading as CUDA] +class HIPAllocatorMasqueradingAsCUDA final : public Allocator { + Allocator* allocator_; +public: + explicit HIPAllocatorMasqueradingAsCUDA(Allocator* allocator) + : allocator_(allocator) {} + DataPtr allocate(size_t size) override { + DataPtr r = allocator_->allocate(size); + r.unsafe_set_device(Device(c10::DeviceType::CUDA, r.device().index())); + return r; + } + DeleterFnPtr raw_deleter() const override { + return allocator_->raw_deleter(); + } + void copy_data(void* dest, const void* src, std::size_t count) const final { + allocator_->copy_data(dest, src, count); + } +}; + +} // namespace c10::hip diff --git a/phivenv/Lib/site-packages/torch/include/ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h b/phivenv/Lib/site-packages/torch/include/ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h new file mode 100644 index 0000000000000000000000000000000000000000..4811b0d5e45e984bea140496c9ae10684d16e040 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h @@ -0,0 +1,18 @@ +#pragma once + +#include +#include +#include + +namespace c10 { +// forward declaration +class DataPtr; +namespace hip { +namespace HIPCachingAllocatorMasqueradingAsCUDA { + +C10_HIP_API Allocator* get(); +C10_HIP_API void recordStreamMasqueradingAsCUDA(const DataPtr& ptr, HIPStreamMasqueradingAsCUDA stream); + +} // namespace HIPCachingAllocatorMasqueradingAsCUDA +} // namespace hip +} // namespace c10 diff --git a/phivenv/Lib/site-packages/torch/include/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h b/phivenv/Lib/site-packages/torch/include/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h new file mode 100644 index 0000000000000000000000000000000000000000..b6d8e0a537d4422ec786d63436f59484ba8f2ff9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h @@ -0,0 +1,383 @@ +#pragma once + +#include + +// The includes of HIPGuard.h +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +// Use of c10::hip namespace here makes hipification easier, because +// I don't have to also fix namespaces. Sorry! +namespace c10 { namespace hip { + +// Note [Masquerading as CUDA] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// c10_hip is very easy to understand: it is HIPified from c10_cuda, +// and anywhere you said CUDA, the source code now says HIP. HIPified +// PyTorch is much harder to understand: it is HIPified from regular +// PyTorch, yes, but NO source-to-source translation from CUDA to +// HIP occurs; instead, anywhere we see "CUDA", it actually means "HIP". +// For example, when you use HIPified PyTorch, you say x.cuda() to +// move a tensor onto ROCm device. We call this situation "HIP +// masquerading as CUDA". +// +// This leads to a very awkward situation when we want to call c10_hip +// code from PyTorch, since c10_hip is expecting things to be called +// HIP, but PyTorch is calling them CUDA (masquerading as HIP). To +// fix this impedance mismatch, we have MasqueradingAsCUDA variants +// for all c10_hip classes. These translate between the "HIP" and "CUDA +// masquerading as HIP" worlds. For example, +// HIPGuardImplMasqueradingAsCUDA (this file) provides something like a +// HIPGuardImpl, but it reports its DeviceType as CUDA (e.g., type() +// returns CUDA, getDevice() reports the current HIP device as a CUDA +// device.) +// +// We should be able to delete all of these classes entirely once +// we switch PyTorch to calling a HIP a HIP. +// +// When you add a new MasqueradingAsCUDA class/function, you need to +// also update the rewrite rules in torch/utils/hipify/cuda_to_hip_mappings.py +// +// +// +// By the way, note that the cpp file associated with this also +// *overwrites* the entry in the DeviceGuardImpl registry for CUDA with +// this HIP implementation. + +struct HIPGuardImplMasqueradingAsCUDA final : public c10::impl::DeviceGuardImplInterface { + static constexpr c10::DeviceType static_type = c10::DeviceType::CUDA; + HIPGuardImplMasqueradingAsCUDA() {} + HIPGuardImplMasqueradingAsCUDA(c10::DeviceType t) { + TORCH_INTERNAL_ASSERT(t == c10::DeviceType::CUDA); + } + c10::DeviceType type() const override { + return c10::DeviceType::CUDA; + } + Device exchangeDevice(Device d) const override { + TORCH_INTERNAL_ASSERT(d.is_cuda()); + Device old_device = getDevice(); + if (old_device.index() != d.index()) { + C10_HIP_CHECK(hipSetDevice(d.index())); + } + return old_device; + } + Device getDevice() const override { + int device; + C10_HIP_CHECK(hipGetDevice(&device)); + return Device(c10::DeviceType::CUDA, device); + } + void setDevice(Device d) const override { + TORCH_INTERNAL_ASSERT(d.is_cuda()); + C10_HIP_CHECK(hipSetDevice(d.index())); + } + void uncheckedSetDevice(Device d) const noexcept override { + C10_HIP_CHECK_WARN(hipSetDevice(d.index())); + } + Stream getStream(Device d) const override { + return getCurrentHIPStreamMasqueradingAsCUDA(d.index()).unwrap(); + } + Stream getDefaultStream(Device d) const override { + return getDefaultHIPStreamMasqueradingAsCUDA(d.index()); + } + Stream getNewStream(Device d, int priority = 0) const override { + return getStreamFromPoolMasqueradingAsCUDA(priority, d.index()); + } + Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false) const override { + return getStreamFromPoolMasqueradingAsCUDA(isHighPriority, d.index()); + } + Stream exchangeStream(Stream s) const override { + HIPStreamMasqueradingAsCUDA cs(s); + auto old_stream = getCurrentHIPStreamMasqueradingAsCUDA(s.device().index()); + setCurrentHIPStreamMasqueradingAsCUDA(cs); + return old_stream.unwrap(); + } + DeviceIndex deviceCount() const noexcept override { + int deviceCnt; + hipError_t _err; + _err = hipGetDeviceCount(&deviceCnt); + if(_err != hipErrorNoDevice && _err != hipSuccess) + C10_HIP_CHECK(_err); + return deviceCnt; + } + + // Event-related functions + // Note: hipEventCreateWithFlags should be called on the same device as + // the recording stream's device. + void createEvent( + hipEvent_t* hip_event, + const EventFlag flag) const { + // Maps PyTorch's Event::Flag to HIP flag + auto hip_flag = hipEventDefault; + switch (flag) { + case EventFlag::PYTORCH_DEFAULT: + hip_flag = hipEventDisableTiming; + break; + case EventFlag::BACKEND_DEFAULT: + hip_flag = hipEventDefault; + break; + default: + TORCH_CHECK(false, "HIP event received unknown flag"); + } + + C10_HIP_CHECK(hipEventCreateWithFlags(hip_event, hip_flag)); + } + + void destroyEvent( + void* event, + const DeviceIndex device_index) const noexcept override { + if (!event) return; + auto hip_event = static_cast(event); + int orig_device; + C10_HIP_CHECK_WARN(hipGetDevice(&orig_device)); + C10_HIP_CHECK_WARN(hipSetDevice(device_index)); + C10_HIP_CHECK_WARN(hipEventDestroy(hip_event)); + C10_HIP_CHECK_WARN(hipSetDevice(orig_device)); + } + + void record(void** event, + const Stream& stream, + const DeviceIndex device_index, + const EventFlag flag) const override { + TORCH_CHECK(device_index == -1 || device_index == stream.device_index(), + "Event device index ", + device_index, + " does not match recording stream's device index ", + stream.device_index(), + "."); + + hipEvent_t hip_event = static_cast(*event); + HIPStreamMasqueradingAsCUDA hip_stream{stream}; + + // Moves to stream's device to record + const auto orig_device = getDevice(); + setDevice(stream.device()); + + // Creates the event (lazily) + if (!hip_event) createEvent(&hip_event, flag); + C10_HIP_CHECK(hipEventRecord(hip_event, hip_stream)); + // Makes the void* point to the (possibly just allocated) HIP event + *event = hip_event; + + // Resets device + setDevice(orig_device); + } + + void block( + void* event, + const Stream& stream) const override { + if (!event) return; + hipEvent_t hip_event = static_cast(event); + HIPStreamMasqueradingAsCUDA hip_stream{stream}; + const auto orig_device = getDevice(); + setDevice(stream.device()); + C10_HIP_CHECK(hipStreamWaitEvent( + hip_stream, + hip_event, + /*flags (must be zero)=*/ 0)); + setDevice(orig_device); + } + + bool queryEvent(void* event) const override { + if (!event) return true; + hipEvent_t hip_event = static_cast(event); + const hipError_t err = hipEventQuery(hip_event); + if (err != hipErrorNotReady) C10_HIP_CHECK(err); + else { + // ignore and clear the error if not ready + (void)hipGetLastError(); + } + return (err == hipSuccess); + } + + // Stream-related functions + bool queryStream(const Stream& stream) const override { + HIPStreamMasqueradingAsCUDA hip_stream{stream}; + return hip_stream.query(); + } + + void synchronizeStream(const Stream& stream) const override { + HIPStreamMasqueradingAsCUDA hip_stream{stream}; + hip_stream.synchronize(); + } + + void synchronizeEvent(void* event) const override { + if (!event) + return; + hipEvent_t hip_event = static_cast(event); + C10_HIP_CHECK(hipEventSynchronize(hip_event)); + } + + // Note: synchronizeDevice can be safely called from any device + void synchronizeDevice(const c10::DeviceIndex device_index) const override { + int orig_device{-1}; + C10_HIP_CHECK(hipGetDevice(&orig_device)); + C10_HIP_CHECK(hipSetDevice(device_index)); + C10_HIP_CHECK(hipDeviceSynchronize()); + C10_HIP_CHECK(hipSetDevice(orig_device)); + } + + void recordDataPtrOnStream( + const c10::DataPtr& data_ptr, + const Stream& stream) const override { + HIPStreamMasqueradingAsCUDA hip_stream{stream}; + HIPCachingAllocatorMasqueradingAsCUDA::recordStreamMasqueradingAsCUDA(data_ptr, hip_stream); + } + + double elapsedTime(void* event1, void* event2, const DeviceIndex device_index) + const override { + TORCH_CHECK( + event1 && event2, + "Both events must be recorded before calculating elapsed time."); + int orig_device; + C10_HIP_CHECK(hipGetDevice(&orig_device)); + C10_HIP_CHECK(hipSetDevice(device_index)); + hipEvent_t hip_event1 = static_cast(event1); + hipEvent_t hip_event2 = static_cast(event2); + float time_ms = 0; + // raise hipErrorNotReady if either event is recorded but not yet completed + C10_HIP_CHECK(hipEventElapsedTime(&time_ms, hip_event1, hip_event2)); + C10_HIP_CHECK(hipSetDevice(orig_device)); + return static_cast(time_ms); + } +}; + +// All of the guards which have HIPGuardImpl burned in need to also have +// variants using HIPGuardImplMasqueradingAsCUDA. + +/// This code is all a direct copy from c10/cuda/HIPGuardMasqueradingAsCUDA.h, but with +/// the correct InlineDeviceGuard burned in. Sorry about the +/// copy-pasting. + +struct HIPGuardMasqueradingAsCUDA { + explicit HIPGuardMasqueradingAsCUDA() = delete; + explicit HIPGuardMasqueradingAsCUDA(DeviceIndex device_index) : guard_(device_index) {} + explicit HIPGuardMasqueradingAsCUDA(Device device) : guard_(device) {} + + HIPGuardMasqueradingAsCUDA(const HIPGuardMasqueradingAsCUDA&) = delete; + HIPGuardMasqueradingAsCUDA& operator=(const HIPGuardMasqueradingAsCUDA&) = delete; + HIPGuardMasqueradingAsCUDA(HIPGuardMasqueradingAsCUDA&& other) = delete; + HIPGuardMasqueradingAsCUDA& operator=(HIPGuardMasqueradingAsCUDA&& other) = delete; + + void set_device(Device device) { guard_.set_device(device); } + void reset_device(Device device) { guard_.reset_device(device); } + void set_index(DeviceIndex device_index) { guard_.set_index(device_index); } + Device original_device() const { return guard_.original_device(); } + Device current_device() const { return guard_.current_device(); } + + private: + c10::impl::InlineDeviceGuard guard_; +}; + +struct OptionalHIPGuardMasqueradingAsCUDA { + explicit OptionalHIPGuardMasqueradingAsCUDA() : guard_() {} + explicit OptionalHIPGuardMasqueradingAsCUDA(std::optional device_opt) : guard_(device_opt) {} + explicit OptionalHIPGuardMasqueradingAsCUDA(std::optional device_index_opt) : guard_(device_index_opt) {} + + OptionalHIPGuardMasqueradingAsCUDA(const OptionalHIPGuardMasqueradingAsCUDA&) = delete; + OptionalHIPGuardMasqueradingAsCUDA& operator=(const OptionalHIPGuardMasqueradingAsCUDA&) = delete; + OptionalHIPGuardMasqueradingAsCUDA(OptionalHIPGuardMasqueradingAsCUDA&& other) = delete; + OptionalHIPGuardMasqueradingAsCUDA& operator=(OptionalHIPGuardMasqueradingAsCUDA&& other) = delete; + + void set_device(Device device) { guard_.set_device(device); } + void reset_device(Device device) { guard_.reset_device(device); } + void set_index(DeviceIndex device_index) { guard_.set_index(device_index); } + std::optional original_device() const { return guard_.original_device(); } + std::optional current_device() const { return guard_.current_device(); } + void reset() { guard_.reset(); } + +private: + c10::impl::InlineOptionalDeviceGuard guard_; +}; + +struct HIPStreamGuardMasqueradingAsCUDA { + explicit HIPStreamGuardMasqueradingAsCUDA() = delete; + explicit HIPStreamGuardMasqueradingAsCUDA(Stream stream) : guard_(stream) {} + HIPStreamGuardMasqueradingAsCUDA(const HIPStreamGuardMasqueradingAsCUDA&) = delete; + HIPStreamGuardMasqueradingAsCUDA& operator=(const HIPStreamGuardMasqueradingAsCUDA&) = delete; + HIPStreamGuardMasqueradingAsCUDA(HIPStreamGuardMasqueradingAsCUDA&& other) = delete; + HIPStreamGuardMasqueradingAsCUDA& operator=(HIPStreamGuardMasqueradingAsCUDA&& other) = delete; + + void reset_stream(Stream stream) { guard_.reset_stream(stream); } + + HIPStreamMasqueradingAsCUDA original_stream() const { + return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, guard_.original_stream()); + } + HIPStreamMasqueradingAsCUDA current_stream() const { + return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, guard_.current_stream()); + } + + Device current_device() const { return guard_.current_device(); } + Device original_device() const { return guard_.original_device(); } + +private: + c10::impl::InlineStreamGuard guard_; +}; + +struct OptionalHIPStreamGuardMasqueradingAsCUDA { + explicit OptionalHIPStreamGuardMasqueradingAsCUDA() : guard_() {} + explicit OptionalHIPStreamGuardMasqueradingAsCUDA(Stream stream) : guard_(stream) {} + explicit OptionalHIPStreamGuardMasqueradingAsCUDA(std::optional stream_opt) : guard_(stream_opt) {} + + OptionalHIPStreamGuardMasqueradingAsCUDA(const OptionalHIPStreamGuardMasqueradingAsCUDA&) = delete; + OptionalHIPStreamGuardMasqueradingAsCUDA& operator=(const OptionalHIPStreamGuardMasqueradingAsCUDA&) = delete; + OptionalHIPStreamGuardMasqueradingAsCUDA(OptionalHIPStreamGuardMasqueradingAsCUDA&& other) = delete; + OptionalHIPStreamGuardMasqueradingAsCUDA& operator=(OptionalHIPStreamGuardMasqueradingAsCUDA&& other) = delete; + + void reset_stream(Stream stream) { guard_.reset_stream(stream); } + + std::optional original_stream() const { + auto r = guard_.original_stream(); + if (r.has_value()) { + return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, r.value()); + } else { + return std::nullopt; + } + } + + std::optional current_stream() const { + auto r = guard_.current_stream(); + if (r.has_value()) { + return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, r.value()); + } else { + return std::nullopt; + } + } + + void reset() { guard_.reset(); } + +private: + c10::impl::InlineOptionalStreamGuard guard_; +}; + +struct HIPMultiStreamGuardMasqueradingAsCUDA { + explicit HIPMultiStreamGuardMasqueradingAsCUDA(ArrayRef streams) + : guard_(unwrapStreams(streams)) {} + + HIPMultiStreamGuardMasqueradingAsCUDA(const HIPMultiStreamGuardMasqueradingAsCUDA&) = delete; + HIPMultiStreamGuardMasqueradingAsCUDA& operator=(const HIPMultiStreamGuardMasqueradingAsCUDA&) = delete; + HIPMultiStreamGuardMasqueradingAsCUDA(HIPMultiStreamGuardMasqueradingAsCUDA&& other) = delete; + HIPMultiStreamGuardMasqueradingAsCUDA& operator=(HIPMultiStreamGuardMasqueradingAsCUDA&& other) = delete; + +private: + c10::impl::InlineMultiStreamGuard guard_; + + static std::vector unwrapStreams(ArrayRef hipStreams) { + std::vector streams; + streams.reserve(hipStreams.size()); + for (const HIPStreamMasqueradingAsCUDA& hipStream : hipStreams) { + streams.push_back(hipStream); + } + return streams; + } +}; + +}} // namespace c10::hip diff --git a/phivenv/Lib/site-packages/torch/include/ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h b/phivenv/Lib/site-packages/torch/include/ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h new file mode 100644 index 0000000000000000000000000000000000000000..fc39ec1681e538f3d01450e9611896ad04d41dab --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h @@ -0,0 +1,135 @@ +#pragma once + +#include + +// Use of c10::hip namespace here makes hipification easier, because +// I don't have to also fix namespaces. Sorry! +namespace c10 { namespace hip { + +// See Note [Masquerading as CUDA] for motivation + +class HIPStreamMasqueradingAsCUDA { +public: + + enum Unchecked { UNCHECKED }; + + explicit HIPStreamMasqueradingAsCUDA(Stream stream) + : HIPStreamMasqueradingAsCUDA(UNCHECKED, stream) { + // We did the coercion unchecked; check that it was right. + TORCH_CHECK(stream.device().is_cuda() /* !!! */); + } + + explicit HIPStreamMasqueradingAsCUDA(Unchecked, Stream stream) + // Unsafely coerce the "CUDA" stream into a HIP stream + : stream_( + HIPStream( + Stream( + Stream::UNSAFE, + Device(c10::DeviceType::HIP, stream.device_index()), + stream.id()) + ) + ) {} + + // New constructor, just for this. Does NOT coerce. + explicit HIPStreamMasqueradingAsCUDA(HIPStream stream) : stream_(stream) {} + + bool operator==(const HIPStreamMasqueradingAsCUDA& other) const noexcept { + return stream_ == other.stream_; + } + + bool operator!=(const HIPStreamMasqueradingAsCUDA& other) const noexcept { + return stream_ != other.stream_; + } + + operator hipStream_t() const { return stream_.stream(); } + + operator Stream() const { + // Unsafely coerce HIP stream into a "CUDA" stream + return Stream(Stream::UNSAFE, device(), id()); + } + + DeviceIndex device_index() const { return stream_.device_index(); } + + // Unsafely coerce HIP device into CUDA device + c10::DeviceType device_type() const { return c10::DeviceType::CUDA; } + + Device device() const { + // Unsafely coerce HIP device into CUDA device + return Device(c10::DeviceType::CUDA, stream_.device_index()); + } + + StreamId id() const { return stream_.id(); } + bool query() const { return stream_.query(); } + void synchronize() const { stream_.synchronize(); } + int priority() const { return stream_.priority(); } + hipStream_t stream() const { return stream_.stream(); } + + Stream unwrap() const { + // Unsafely coerce HIP stream into "CUDA" stream + return Stream(Stream::UNSAFE, device(), id()); + } + + c10::StreamData3 pack3() const noexcept { + // Unsafely coerce HIP stream into "CUDA" stream before packing + return unwrap().pack3(); + } + + static HIPStreamMasqueradingAsCUDA unpack3(StreamId stream_id, + DeviceIndex device_index, + c10::DeviceType device_type) { + // NB: constructor manages CUDA->HIP translation for us + return HIPStreamMasqueradingAsCUDA(Stream::unpack3( + stream_id, device_index, device_type)); + } + + static std::tuple priority_range() { return HIPStream::priority_range(); } + + // New method, gets the underlying HIPStream + HIPStream hip_stream() const { return stream_; } + +private: + HIPStream stream_; +}; + +HIPStreamMasqueradingAsCUDA +inline getStreamFromPoolMasqueradingAsCUDA(const bool isHighPriority = false, DeviceIndex device = -1) { + return HIPStreamMasqueradingAsCUDA(getStreamFromPool(isHighPriority, device)); +} + +HIPStreamMasqueradingAsCUDA +inline getStreamFromPoolMasqueradingAsCUDA(const int priority, DeviceIndex device = -1) { + return HIPStreamMasqueradingAsCUDA(getStreamFromPool(priority, device)); +} + +HIPStreamMasqueradingAsCUDA +inline getStreamFromExternalMasqueradingAsCUDA(hipStream_t ext_stream, DeviceIndex device) { + return HIPStreamMasqueradingAsCUDA(getStreamFromExternal(ext_stream, device)); +} + +inline HIPStreamMasqueradingAsCUDA getDefaultHIPStreamMasqueradingAsCUDA(DeviceIndex device_index = -1) { + return HIPStreamMasqueradingAsCUDA(getDefaultHIPStream(device_index)); +} + +inline HIPStreamMasqueradingAsCUDA getCurrentHIPStreamMasqueradingAsCUDA(DeviceIndex device_index = -1) { + return HIPStreamMasqueradingAsCUDA(getCurrentHIPStream(device_index)); +} + +inline void setCurrentHIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA stream) { + setCurrentHIPStream(stream.hip_stream()); +} + +inline std::ostream& operator<<(std::ostream& stream, const HIPStreamMasqueradingAsCUDA& s) { + stream << s.hip_stream() << " (masquerading as CUDA)"; + return stream; +} + +}} // namespace c10::hip + +namespace std { + template <> + struct hash { + size_t operator()(c10::hip::HIPStreamMasqueradingAsCUDA s) const noexcept { + return std::hash{}(s.unwrap()); + } + }; +} // namespace std diff --git a/phivenv/Lib/site-packages/torch/include/ATen/metal/Context.h b/phivenv/Lib/site-packages/torch/include/ATen/metal/Context.h new file mode 100644 index 0000000000000000000000000000000000000000..cf199668eb061f0b7d7a4d7099c5ba813162ba66 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/metal/Context.h @@ -0,0 +1,32 @@ +#ifndef MetalContext_h +#define MetalContext_h + +#include + +#include + +namespace at::metal { + +struct MetalInterface { + virtual ~MetalInterface() = default; + virtual bool is_metal_available() const = 0; + virtual at::Tensor& metal_copy_(at::Tensor& self, const at::Tensor& src) + const = 0; +}; + +extern std::atomic g_metal_impl_registry; + +class MetalImplRegistrar { + public: + explicit MetalImplRegistrar(MetalInterface*); +}; + +at::Tensor& metal_copy_(at::Tensor& self, const at::Tensor& src); + +} // namespace at::metal + +namespace at::native { +bool is_metal_available(); +} // namespace at::native + +#endif /* MetalContext_h */ diff --git a/phivenv/Lib/site-packages/torch/include/ATen/miopen/Descriptors.h b/phivenv/Lib/site-packages/torch/include/ATen/miopen/Descriptors.h new file mode 100644 index 0000000000000000000000000000000000000000..74d3253557602e63ee179d6bf7b2d731d782db96 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/miopen/Descriptors.h @@ -0,0 +1,169 @@ +#pragma once + +#include + +#include +#include +#include +#include + +namespace at { namespace native { + +inline int dataSize(miopenDataType_t dataType) +{ + switch (dataType) { + case miopenHalf: return 2; + case miopenFloat: return 4; + case miopenBFloat16: return 2; + default: return 8; + } +} + +template +struct DescriptorDeleter { + void operator()(T* x) { + if (x != nullptr) { + MIOPEN_CHECK(dtor(x)); + } + } +}; + +// A generic class for wrapping MIOpen descriptor types. All you need +// is to give the underlying type the Descriptor_t points to (usually, +// if it's miopenTensorDescriptor_t it points to miopenTensorStruct), +// the constructor and the destructor. Subclasses are responsible +// for defining a set() function to actually set the descriptor. +// +// Descriptors default construct to a nullptr, and have a descriptor +// initialized the first time you call set() or any other initializing +// function. +template +// NOLINTNEXTLINE(bugprone-exception-escape) +class TORCH_CUDA_CPP_API Descriptor { + public: + // Use desc() to access the underlying descriptor pointer in + // a read-only fashion. Most client code should use this. + // If the descriptor was never initialized, this will return + // nullptr. + T* desc() const { return desc_.get(); } + T* desc() { return desc_.get(); } + + // Use mut_desc() to access the underlying descriptor pointer + // if you intend to modify what it points to (e.g., using + // miopenSetFooDescriptor). This will ensure that the descriptor + // is initialized. Code in this file will use this function. + T* mut_desc() { init(); return desc_.get(); } +protected: + void init() { + if (desc_ == nullptr) { + T* raw_desc = nullptr; + MIOPEN_CHECK(ctor(&raw_desc)); + desc_.reset(raw_desc); + } + } +private: + std::unique_ptr> desc_; +}; + +class TORCH_CUDA_CPP_API TensorDescriptor : public Descriptor< + miopenTensorDescriptor, + &miopenCreateTensorDescriptor, + &miopenDestroyTensorDescriptor> { + public: + TensorDescriptor() = default; + explicit TensorDescriptor(const at::Tensor &t, size_t pad = 0) { + set(t, pad); + } + + void set(const at::Tensor &t, size_t pad = 0); + void set(miopenDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad = 0); + + void print(); + +private: + void set(miopenDataType_t dataType, int dim, int* size, int* stride) { + MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, stride)); + } +}; + +std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d); + +class TORCH_CUDA_CPP_API FilterDescriptor : public Descriptor< + miopenTensorDescriptor, + &miopenCreateTensorDescriptor, + &miopenDestroyTensorDescriptor> { + public: + void set(const at::Tensor &t, int64_t pad = 0) { + set(t, at::MemoryFormat::Contiguous, pad); + } + + void set(const at::Tensor &t, const at::MemoryFormat memory_format, int64_t pad = 0); + +private: + void set(miopenDataType_t dataType, int dim, int* size, int* stride) { + MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, stride)); + } +}; + +struct TORCH_CUDA_CPP_API ConvolutionDescriptor + : public Descriptor< + miopenConvolutionDescriptor, + &miopenCreateConvolutionDescriptor, + &miopenDestroyConvolutionDescriptor> { + void set(miopenDataType_t dataType, miopenConvolutionMode_t c_mode, int dim, int* pad, int* stride, int * upscale /* aka dilation */, int groups, bool benchmark, bool deterministic) { + MIOPEN_CHECK(miopenInitConvolutionNdDescriptor(mut_desc(), dim, pad, stride, upscale, c_mode)); + MIOPEN_CHECK(miopenSetConvolutionGroupCount(mut_desc(), groups)); + MIOPEN_CHECK(miopenSetConvolutionAttribute(mut_desc(), MIOPEN_CONVOLUTION_ATTRIB_DETERMINISTIC, deterministic ? 1 : 0)); + if (benchmark) { + MIOPEN_CHECK(miopenSetConvolutionFindMode(mut_desc(), miopenConvolutionFindModeNormal)); + } + } +}; + +// NOLINTNEXTLINE(bugprone-exception-escape) +struct TORCH_CUDA_CPP_API DropoutDescriptor + : public Descriptor< + miopenDropoutDescriptor, + &miopenCreateDropoutDescriptor, + &miopenDestroyDropoutDescriptor> { + void set(miopenHandle_t handle, float dropout, void* states, size_t stateSizeInBytes, + unsigned long long seed, bool use_mask, bool state_evo, miopenRNGType_t rng_mode) { + MIOPEN_CHECK(miopenSetDropoutDescriptor(mut_desc(), handle, dropout, states, stateSizeInBytes, seed, use_mask, state_evo, rng_mode)); + } + + void restore(miopenHandle_t handle, float dropout, void* states, size_t stateSizeInBytes, + unsigned long long seed, bool use_mask, bool state_evo, miopenRNGType_t rng_mode) { + MIOPEN_CHECK(miopenRestoreDropoutDescriptor(mut_desc(), handle, dropout, states, stateSizeInBytes, seed, use_mask, state_evo, rng_mode)); + } +}; + +struct TORCH_CUDA_CPP_API RNNDescriptor + : public Descriptor +{ + void set(int64_t hidden_size, int64_t num_layers, miopenRNNInputMode_t input_mode, miopenRNNDirectionMode_t direction, miopenRNNMode_t rnn_mode, + miopenRNNBiasMode_t bias_mode, miopenRNNAlgo_t algorithm, miopenDataType_t datatype) { + MIOPEN_CHECK(miopenSetRNNDescriptor(mut_desc(), hidden_size, num_layers, input_mode, direction, rnn_mode, bias_mode, algorithm, datatype)); + } + + void setWithDropout(DropoutDescriptor& dropout_desc, int64_t hidden_size, int64_t num_layers, miopenRNNInputMode_t input_mode, miopenRNNDirectionMode_t direction, + miopenRNNMode_t rnn_mode, miopenRNNBiasMode_t bias_mode, miopenRNNAlgo_t algorithm, miopenDataType_t datatype) { + MIOPEN_CHECK(miopenSetRNNDescriptor_V2(mut_desc(), hidden_size, num_layers, dropout_desc.mut_desc(), input_mode, direction, rnn_mode, bias_mode, algorithm, datatype)); + } +}; + +union Constant +{ + float f; + double d; + Constant(miopenDataType_t dataType, double value) { + if (dataType == miopenHalf || dataType == miopenFloat || dataType == miopenBFloat16) { + f = static_cast(value); + } else { + d = value; + } + } +}; + +}} // namespace diff --git a/phivenv/Lib/site-packages/torch/include/ATen/miopen/Exceptions.h b/phivenv/Lib/site-packages/torch/include/ATen/miopen/Exceptions.h new file mode 100644 index 0000000000000000000000000000000000000000..044ae3222aa83e512c796fc2b903b2a111285015 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/miopen/Exceptions.h @@ -0,0 +1,41 @@ +#pragma once + +#include +#include +#include +#include + +namespace at { namespace native { + +class miopen_exception : public std::runtime_error { +public: + miopenStatus_t status; + miopen_exception(miopenStatus_t status, const char* msg) + : std::runtime_error(msg) + , status(status) {} + miopen_exception(miopenStatus_t status, const std::string& msg) + : std::runtime_error(msg) + , status(status) {} +}; + +inline void MIOPEN_CHECK(miopenStatus_t status) +{ + if (status != miopenStatusSuccess) { + if (status == miopenStatusNotImplemented) { + throw miopen_exception(status, std::string(miopenGetErrorString(status)) + + ". This error may appear if you passed in a non-contiguous input."); + } + throw miopen_exception(status, miopenGetErrorString(status)); + } +} + +inline void HIP_CHECK(hipError_t error) +{ + if (error != hipSuccess) { + std::string msg("HIP error: "); + msg += hipGetErrorString(error); + throw std::runtime_error(msg); + } +} + +}} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/miopen/Handle.h b/phivenv/Lib/site-packages/torch/include/ATen/miopen/Handle.h new file mode 100644 index 0000000000000000000000000000000000000000..6ee016afbb6a1ca9c6201633ebd8368e2ceceb6c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/miopen/Handle.h @@ -0,0 +1,9 @@ +#pragma once + +#include +#include + +namespace at::native { + +TORCH_CUDA_CPP_API miopenHandle_t getMiopenHandle(); +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/miopen/Types.h b/phivenv/Lib/site-packages/torch/include/ATen/miopen/Types.h new file mode 100644 index 0000000000000000000000000000000000000000..31e33cda39e64d8cb07c38474434e0f543d46837 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/miopen/Types.h @@ -0,0 +1,13 @@ +#pragma once + +#include +#include +#include + +namespace at::native { + +TORCH_CUDA_CPP_API miopenDataType_t getMiopenDataType(const at::Tensor& tensor); + +int64_t miopen_version(); + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/miopen/Utils.h b/phivenv/Lib/site-packages/torch/include/ATen/miopen/Utils.h new file mode 100644 index 0000000000000000000000000000000000000000..30f8e228165664c6e358838df3c26d4074ccd173 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/miopen/Utils.h @@ -0,0 +1,18 @@ +#pragma once + +#include +#include +#include + +namespace at { namespace native { + +// This function makes tensors which have zero stride contiguous, by +// setting the strides to 1. +inline Tensor contiguousIfZeroInStrides(const Tensor& t) { + for (auto s : t.strides()) { + if (s == 0) return t.contiguous(); + } + return t; +} + +}} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/miopen/miopen-wrapper.h b/phivenv/Lib/site-packages/torch/include/ATen/miopen/miopen-wrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..715fd6ed269031b942bfb5ebe9c6a29f0f768216 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/miopen/miopen-wrapper.h @@ -0,0 +1,21 @@ +#pragma once + +#include +#include + +#if MIOPEN_VERSION_MAJOR > 3 || (MIOPEN_VERSION_MAJOR == 3 && MIOPEN_VERSION_MINOR >= 4) +// miopen 3.4 moved find mode from private header to public header +#else +// from miopen_internal.h +extern "C" { + +typedef enum +{ + miopenConvolutionFindModeNormal = 1, /*!< Normal mode */ +} miopenConvolutionFindMode_t; + +miopenStatus_t miopenSetConvolutionFindMode( + miopenConvolutionDescriptor_t convDesc, + miopenConvolutionFindMode_t findMode); +} +#endif diff --git a/phivenv/Lib/site-packages/torch/include/ATen/mps/EmptyTensor.h b/phivenv/Lib/site-packages/torch/include/ATen/mps/EmptyTensor.h new file mode 100644 index 0000000000000000000000000000000000000000..2d0bcc09013132f3d3eb262d1e77bbdcc1ceb85f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/mps/EmptyTensor.h @@ -0,0 +1,28 @@ +// Copyright © 2022 Apple Inc. + +#pragma once +#include + +namespace at::detail { + +C10_EXPORT TensorBase empty_mps( + IntArrayRef size, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt, + std::optional memory_format_opt); +C10_EXPORT TensorBase empty_mps(IntArrayRef size, const TensorOptions& options); + +C10_EXPORT TensorBase empty_strided_mps( + IntArrayRef size, + IntArrayRef stride, + ScalarType dtype, + std::optional device_opt); + +C10_EXPORT TensorBase empty_strided_mps( + IntArrayRef size, + IntArrayRef stride, + const TensorOptions& options); + +} // namespace at::detail diff --git a/phivenv/Lib/site-packages/torch/include/ATen/mps/IndexKernels.h b/phivenv/Lib/site-packages/torch/include/ATen/mps/IndexKernels.h new file mode 100644 index 0000000000000000000000000000000000000000..9399c3fbfafe75d5f7ea5b4681c1def9919f98b4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/mps/IndexKernels.h @@ -0,0 +1,220 @@ +#pragma once + +namespace at::mps { + +static const char* SCATTER_OPS_TEMPLATE = R"METAL_SCATTER( +template +Y cast(const X x); + +template<> +{1} cast<{1}, {0}>(const {0} x) {{ + return {2}; +}} + +kernel void scatter_kernel_n(uint linear_index [[thread_position_in_grid]], + constant void * src_ [[buffer(0)]], + device void * dst_ [[buffer(1)]], + constant uint32_t * size [[buffer(2)]], + constant uint32_t * stride [[buffer(3)]], + constant uint32_t & numel [[buffer(4)]], + constant int32_t & ndim [[buffer(5)]]) {{ + if (linear_index >= numel) return; + + constant {0} * src = (constant {0} *)src_; + device {1} * dst = (device {1} *)dst_; + + uint64_t dst_offs = 0; + auto dst_idx = linear_index; + for(int dim = ndim - 1; dim >= 0; --dim) {{ + dst_offs += stride[dim] * (dst_idx % size[dim]); + dst_idx /= size[dim]; + }} + + dst[dst_offs] = cast<{1}>(src[linear_index]); +}} + +kernel void scatter_kernel_4(uint linear_index [[thread_position_in_grid]], + constant void * src_ [[buffer(0)]], + device void * dst_ [[buffer(1)]], + constant packed_uint4 & size [[buffer(2)]], + constant packed_uint4 & stride [[buffer(3)]], + constant uint32_t & numel [[buffer(4)]]) {{ + if (linear_index >= numel) return; + + constant {0} * src = (constant {0} *)src_; + device {1} * dst = (device {1} *)dst_; + + packed_uint4 local_index; + local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0]; + local_index.y = linear_index / (size[3] * size[2]) % size[1]; + local_index.z = linear_index / size[3] % size[2]; + local_index.w = linear_index % size[3]; + + const packed_uint4 strided_index = local_index * stride; + dst[strided_index.x + strided_index.y + strided_index.z + strided_index.w] = cast<{1}>(src[linear_index]); +}} + +kernel void scatter_kernel_3(uint linear_index [[thread_position_in_grid]], + constant void * src_ [[buffer(0)]], + device void * dst_ [[buffer(1)]], + constant packed_uint3 & size [[buffer(2)]], + constant packed_uint3 & stride [[buffer(3)]], + constant uint32_t & numel [[buffer(4)]]) {{ + if (linear_index >= numel) return; + + constant {0} * src = (constant {0} *)src_; + device {1} * dst = (device {1} *)dst_; + + packed_uint3 local_index; + local_index.x = linear_index / (size[2] * size[1]) % size[0]; + local_index.y = linear_index / size[2] % size[1]; + local_index.z = linear_index % size[2]; + + const packed_uint3 strided_index = local_index * stride; + dst[strided_index.x + strided_index.y + strided_index.z] = cast<{1}>(src[linear_index]); +}} + +kernel void scatter_kernel_2(uint linear_index [[thread_position_in_grid]], + constant void * src_ [[buffer(0)]], + device void * dst_ [[buffer(1)]], + constant packed_uint2 & size [[buffer(2)]], + constant packed_uint2 & stride [[buffer(3)]], + constant uint32_t & numel [[buffer(4)]]) {{ + if (linear_index >= numel) return; + + constant {0} * src = (constant {0} *)src_; + device {1} * dst = (device {1} *)dst_; + + packed_uint2 local_index; + local_index.x = linear_index / size[1] % size[0]; + local_index.y = linear_index % size[1]; + + const packed_uint2 strided_index = local_index * stride; + dst[strided_index.x + strided_index.y] = cast<{1}>(src[linear_index]); +}} + +kernel void scatter_kernel_1(uint linear_index [[thread_position_in_grid]], + constant void * src_ [[buffer(0)]], + device void * dst_ [[buffer(1)]], + constant int & size [[buffer(2)]], + constant int & stride [[buffer(3)]], + constant uint32_t & numel [[buffer(4)]]) {{ + if (linear_index >= numel) return; + + constant {0} * src = (constant {0} *)src_; + device {1} * dst = (device {1} *)dst_; + + const int local_index = linear_index % size; + const int strided_index = local_index * stride; + dst[strided_index] = cast<{1}>(src[linear_index]); +}} +)METAL_SCATTER"; + +static const char* GATHER_OPS_TEMPLATE = R"METAL_GATHER( +template +Y cast(const X x); + +template<> +{1} cast<{1}, {0}>(const {0} x) {{ + return {2}; +}} + +kernel void gather_kernel_n(uint linear_index [[thread_position_in_grid]], + constant void * src_ [[buffer(0)]], + device void * dst_ [[buffer(1)]], + constant uint32_t * size [[buffer(2)]], + constant uint32_t * stride [[buffer(3)]], + constant uint32_t & numel [[buffer(4)]], + constant int32_t & ndim [[buffer(5)]]) {{ + if (linear_index >= numel) return; + + constant {0} * src = (constant {0} *)src_; + device {1} * dst = (device {1} *)dst_; + + uint64_t src_offs = 0; + auto src_idx = linear_index; + for(int dim = ndim - 1; dim >= 0; --dim) {{ + src_offs += stride[dim] * (src_idx % size[dim]); + src_idx /= size[dim]; + }} + + dst[linear_index] = cast<{1}>(src[src_offs]); +}} + +kernel void gather_kernel_4(uint linear_index [[thread_position_in_grid]], + constant void * src_ [[buffer(0)]], + device void * dst_ [[buffer(1)]], + constant packed_uint4 & size [[buffer(2)]], + constant packed_uint4 & stride [[buffer(3)]], + constant uint32_t & numel [[buffer(4)]]) {{ + if (linear_index >= numel) return; + + constant {0} * src = (constant {0} *)src_; + device {1} * dst = (device {1} *)dst_; + + packed_uint4 local_index; + local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0]; + local_index.y = linear_index / (size[3] * size[2]) % size[1]; + local_index.z = linear_index / size[3] % size[2]; + local_index.w = linear_index % size[3]; + + const packed_uint4 strided_index = local_index * stride; + dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z + strided_index.w]); +}} + +kernel void gather_kernel_3(uint linear_index [[thread_position_in_grid]], + constant void * src_ [[buffer(0)]], + device void * dst_ [[buffer(1)]], + constant packed_uint3 & size [[buffer(2)]], + constant packed_uint3 & stride [[buffer(3)]], + constant uint32_t & numel [[buffer(4)]]) {{ + if (linear_index >= numel) return; + + constant {0} * src = (constant {0} *)src_; + device {1} * dst = (device {1} *)dst_; + + packed_uint3 local_index; + local_index.x = linear_index / (size[2] * size[1]) % size[0]; + local_index.y = linear_index / size[2] % size[1]; + local_index.z = linear_index % size[2]; + + const packed_uint3 strided_index = local_index * stride; + dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z]); +}} + +kernel void gather_kernel_2(uint linear_index [[thread_position_in_grid]], + constant void * src_ [[buffer(0)]], + device void * dst_ [[buffer(1)]], + constant packed_uint2 & size [[buffer(2)]], + constant packed_uint2 & stride [[buffer(3)]], + constant uint32_t & numel [[buffer(4)]]) {{ + if (linear_index >= numel) return; + + constant {0} * src = (constant {0} *)src_; + device {1} * dst = (device {1} *)dst_; + + packed_uint2 local_index; + local_index.x = linear_index / size[1] % size[0]; + local_index.y = linear_index % size[1]; + + const packed_uint2 strided_index = local_index * stride; + dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y]); +}} + +kernel void gather_kernel_1(uint linear_index [[thread_position_in_grid]], + constant void * src_ [[buffer(0)]], + device void * dst_ [[buffer(1)]], + constant int & size [[buffer(2)]], + constant int & stride [[buffer(3)]], + constant uint32_t & numel [[buffer(4)]]) {{ + if (linear_index >= numel) return; + + constant {0} * src = (constant {0} *)src_; + device {1} * dst = (device {1} *)dst_; + + const int local_index = linear_index % size; + const int strided_index = local_index * stride; + dst[linear_index] = cast<{1}>(src[strided_index]); +}} +)METAL_GATHER"; +} // namespace at::mps diff --git a/phivenv/Lib/site-packages/torch/include/ATen/mps/MPSAllocator.h b/phivenv/Lib/site-packages/torch/include/ATen/mps/MPSAllocator.h new file mode 100644 index 0000000000000000000000000000000000000000..9ea2edb3b88d177c876979d0e1c354a7fb0f686f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/mps/MPSAllocator.h @@ -0,0 +1,437 @@ +// Copyright © 2022 Apple Inc. + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +// this implementation is based on CUDACachingAllocator. +// It utilizes Metal Heaps to improve the performance with buffer allocation. +// Do not include this header. Use MPSAllocatorInterface.h instead. +// TODO: Unify the logic with CUDACachingAllocator and remove redundant code. +namespace at::mps::HeapAllocator { + +static const size_t kMaxSmallAlloc = MB(1); // largest "small" allocation is 1 MiB +static const size_t kMinLargeAlloc = MB(10); // allocations between 1 and 10 MiB may use kLargeHeap +static const size_t kRoundLarge = MB(2); // round up large allocations to 2 MiB +static const size_t kSmallHeap = MB(8); // "small" allocations are packed in 8 MiB heaps +static const size_t kLargeHeap = MB(32); // "large" allocations may be packed in 32 MiB heaps +static const size_t kXLargeHeapD = + MB(128); // "extra large" allocations on Discrete devices may be packed in 128 MiB heaps +static const size_t kXLargeHeapU = + MB(1024); // "extra large" allocations on Unified devices may be packed in 1 GiB heaps +static const size_t kMaxScalarAlloc = (sizeof(int64_t)); // largest "scalar" allocation + +// buffer pools could be customized with a combination of usage flags +enum UsageFlags : uint32_t { + PRIVATE = 0, + SMALL = (1 << 0), // small heaps have sizes of kSmallHeap, and large ones kLargeHeap + SHARED = (1 << 1), // shared pools allocated on devices with unified memory; otherwise, private between host/device + MANAGED = (1 << 2), // managed storage mode + HAZARD = (1 << 3), // enables Automatic Hazard Tracking for the resources allocated on the pool + SCALAR = (1 << 4), // used to import CPU scalar values to GPU and use them in MPS Stream +}; +// debug verbosity flags +enum DebugVerbosity : uint32_t { + SILENT = 0, + PROFILING = (1 << 0), // print generic profiling data for total system memory usage + ALLOCATIONS = (1 << 1), // print buffer allocations + RECYCLES = (1 << 2), // print buffer recycling + RELEASES = (1 << 3), // print buffer releases + LARGE_ONLY = (1 << 4), // only log large buffer pool transactions +}; + +struct HeapBlock; + +struct BufferBlock { + id buffer; + void* cpu_ptr = nullptr; // stores the pointer to CPU mapping of a Shared MTLBuffer + size_t size; // size after alignment + size_t requested_size; // requested size (before alignment) + // buffer shape is used for retrieving base of views in cached graphs + std::vector shape; + bool in_use = false; + HeapBlock* heap; + id_t buf_id; + // counter to candidate least recently used buffers for garbage collection + uint32_t gc_count = 0; + uint32_t use_count = 0; + // counter to assign unique ids to buffer blocks + static uint64_t buffer_counter; + // Metal events used to sync GPU/CPU operations on the shared-storage buffers + MPSEventPtr event; + + BufferBlock(size_t Size, size_t RequestedSize = 0, const id Buffer = nullptr, HeapBlock* Heap = nullptr) + : buffer(Buffer), size(Size), requested_size(RequestedSize), heap(Heap), buf_id(Buffer ? ++buffer_counter : 0) {} + + static bool Comparator(const BufferBlock* a, const BufferBlock* b) { + return (a->size != b->size) ? a->size < b->size : (uintptr_t)a->buffer < (uintptr_t)b->buffer; + } + static size_t alignUp(size_t Size, size_t Alignment) { + assert(((Alignment - 1) & Alignment) == 0); + return ((Size + Alignment - 1) & ~(Alignment - 1)); + } + uint32_t retainCount() const { + return [buffer retainCount]; + } +}; +typedef bool (*BufferComparison)(const BufferBlock*, const BufferBlock*); + +struct BufferPool; +struct AllocParams { + AllocParams(size_t Alloc_Size, size_t Requested_Size, BufferPool* Pool) + : search_key(Alloc_Size), pool(Pool), requested_size(Requested_Size) {} + size_t size() const { + return search_key.size; + } + + BufferBlock search_key; + BufferPool* pool; + BufferBlock* buffer_block = nullptr; + size_t requested_size; + // true if we exceed the low watermark limit. In this case + // we apply strategies to relieve the pressure before allocation. + bool has_memory_pressure = false; + // true if we're allocating on a unified memory device + bool has_unified_memory = true; +}; + +struct HeapBlock { + id heap; + struct { + size_t total, available; + } size; + BufferPool* pool; + unsigned int n_buffers = 0; + id_t heap_id; + // indicates if we split this heap to sub-allocate 'several' buffers (otherwise single buffer) + bool is_split; + // counter to assign unique ids to heap blocks + static uint64_t heap_counter; + + HeapBlock(size_t Size, const id Heap = nullptr, BufferPool* Pool = nullptr) + : heap(Heap), + size({.total = Size, .available = Size}), + pool(Pool), + heap_id(Heap ? ++heap_counter : 0), + is_split(true) {} + + static MTLResourceOptions getOptions(uint32_t usage) { + // TODO: check the caching performance of write-combined mode + MTLResourceOptions options = MTLResourceCPUCacheModeDefaultCache; + + if (usage & UsageFlags::MANAGED) + options |= MTLResourceStorageModeManaged; + else if (usage & UsageFlags::SHARED) + options |= MTLResourceStorageModeShared; + else + options |= MTLResourceStorageModePrivate; + + options |= + (usage & UsageFlags::HAZARD) ? MTLResourceHazardTrackingModeTracked : MTLResourceHazardTrackingModeUntracked; + + return options; + } + + static HeapBlock* createHeapBlock(AllocParams& params, id device, uint32_t usage) { + HeapBlock* heapBlock = nullptr; + bool is_split = true; + const size_t size = params.size(); + MTLHeapDescriptor* d = [MTLHeapDescriptor new]; + if (d) { + const size_t kXLargeHeap = params.has_unified_memory ? kXLargeHeapU : kXLargeHeapD; + if (size <= kMaxSmallAlloc) { + d.size = kSmallHeap; + } else if (size < kMinLargeAlloc) { + d.size = kLargeHeap; + } else if (size < kXLargeHeap / 2 && !params.has_memory_pressure) { + d.size = kXLargeHeap; + } else { + d.size = kRoundLarge * ((size + kRoundLarge - 1) / kRoundLarge); + is_split = false; + } + d.storageMode = (usage & UsageFlags::SHARED) ? MTLStorageModeShared : MTLStorageModePrivate; + d.cpuCacheMode = MTLCPUCacheModeDefaultCache; + // this automatically handles Metal buffer access synchronizations at the + // cost of slightly lower performance. + d.hazardTrackingMode = + (usage & UsageFlags::HAZARD) ? MTLHazardTrackingModeTracked : MTLHazardTrackingModeUntracked; + d.resourceOptions = getOptions(usage); + d.type = MTLHeapTypeAutomatic; + id heap = [device newHeapWithDescriptor:d]; + if (heap) { + [heap setPurgeableState:MTLPurgeableStateNonVolatile]; + const size_t heap_size = heapAvailableSize(heap); + heapBlock = new HeapBlock(heap_size, heap, params.pool); + if (heapBlock) { + heapBlock->is_split = is_split; + } + } + [d release]; + } + return heapBlock; + } + static bool Comparator(const HeapBlock* a, const HeapBlock* b) { + return (a->size.available != b->size.available) ? a->size.available < b->size.available + : (uintptr_t)a->heap < (uintptr_t)b->heap; + } + static NSUInteger heapAvailableSize(id heap, size_t Alignment = vm_page_size) { + return [heap maxAvailableSizeWithAlignment:Alignment]; + } + NSUInteger Size() { + return [heap size]; + } + id newMTLBuffer(size_t length, uint32_t usage) { + id buf = [heap newBufferWithLength:length options:getOptions(usage)]; + if (buf) { + updateAvailableSize(); + n_buffers++; + } + return buf; + } + // returns the retainCount before releasing the buffer + uint32_t releaseMTLBuffer(id& buffer) { + const uint32_t retainCount = [buffer retainCount]; + [buffer release]; + buffer = nil; + updateAvailableSize(); + n_buffers--; + return retainCount; + } + // returns the retainCount before releasing the heap + uint32_t releaseMTLHeap() { + const uint32_t retainCount = [heap retainCount]; + TORCH_INTERNAL_ASSERT(!n_buffers); // assert if heap isn't empty + [heap setPurgeableState:MTLPurgeableStateEmpty]; + [heap release]; + heap = nil; + size.available = 0; + return retainCount; + } + uint32_t retainCount() const { + return [heap retainCount]; + } + void updateAvailableSize() { + size.available = heapAvailableSize(heap); + } +}; +typedef bool (*HeapComparison)(const HeapBlock*, const HeapBlock*); + +struct BufferPool { + enum class Kind { + PRIVATE_SMALL, + PRIVATE_LARGE, + SHARED_SMALL, + SHARED_LARGE, + SCALAR, + }; + + BufferPool(const id Device, uint32_t Usage) + : device(Device), usage(Usage), heaps(HeapBlock::Comparator), available_buffers(BufferBlock::Comparator) {} + + const id device; + // usage flags to customize the pool for various purposes (see UsageFlags enum) + const uint32_t usage; + // total number of buffers in the pool + uint32_t n_buffers = 0; + // total allocations size on this pool + size_t allocated_size = 0; + // total memory available in the pool + size_t available_size = 0; + // list of heaps ordered by their "available" (not total) memory size + std::set heaps; + // list of only "available" buffers in the pool (i.e., buffers not in-use) + std::set available_buffers; + // list of buffers that are in a state of "limbo" where they've already been freed + // from PyTorch-side, but were not returned to pool due to still being + // in-use by command buffers with retainCount > 1. In this state, the buffer is + // neither ready to be recycled, nor could be returned to pool as available. + // These buffers will be returned to pool once the command buffer's + // completionHandler callbacks are called. + std::unordered_set buffers_pending_free; + // list of heaps pending size update + std::unordered_set heaps_pending_update; +}; + +class MPSHeapAllocatorImpl { + public: + explicit MPSHeapAllocatorImpl() + : m_device(at::mps::MPSDevice::getInstance()->device()), + m_max_buffer_size([m_device maxBufferLength]), + m_stream(getDefaultMPSStream()), + m_event_pool(getMPSEventPool()) { + init_allocator(); + } + ~MPSHeapAllocatorImpl() { + emptyCache(); + } + // interface exposed to at::Allocator + id malloc(size_t size, uint32_t usage); + // frees a buffer and returns it into buffer pool + void free(void* ptr); + // releases all the cached buffers and their associated heaps + void emptyCache(); + // free inactive buffers that are pending to be freed + void freeInactiveBuffers(); + // returns true if buffer was allocated from the shared pool + bool isSharedBuffer(const void* ptr); + // get the requested unaligned size of an MTLBuffer + ssize_t getUnalignedBufferSize(const void* ptr); + // set the shape of a base tensor from a view tensor + void setBufferShape(const void* ptr, const IntArrayRef& shape); + // retrieve the shape of a base tensor from a view tensor + IntArrayRef getBufferShape(const void* ptr); + // get the unique ID of the buffer + id_t getBufferId(const void* ptr); + // allocate a buffer from a specialized pool to import CPU scalars into GPU + id allocScalarBufferWithValue(void* value, size_t size); + // returns a CPU-mapping of the input buffer and its retainCount, + // if only it has Shared storage-mode and allocated on MPSAllocator + std::pair getSharedBufferPtr(const void* buffer); + // records events for a list of MTLBuffers (list is used to lock the mutex once) + // returns true if records any event (given if passed buffers exist and are shared-storage) + bool recordEvents(c10::ArrayRef buffers); + // waits for the event to signal the completion of GPU execution + // on the passed shared buffers (list is used to lock the mutex once) + // returns true if actually waited on any event + bool waitForEvents(c10::ArrayRef buffers); + // this indicates how far (in Megabytes) the current total allocations are from the + // low watermark limit which is used to detect if we're under memory pressure + // This returns zero if we've reached the low watermark limit + ssize_t getLowWatermarkValue(); + // (see m_low_watermark_ratio for description) + void setLowWatermarkRatio(double ratio); + // (see m_high_watermark_ratio for description) + void setHighWatermarkRatio(double ratio); + // (see m_low_watermark_limit for description) + size_t getLowWatermarkLimit() const { + return m_low_watermark_limit; + } + // (see m_max_total_allowed_size for description) + size_t getHighWatermarkLimit() const { + return m_max_total_allowed_size; + } + // (see m_total_allocated_memory for description) + size_t getTotalAllocatedMemory() const { + return m_total_allocated_memory; + } + // (see m_current_allocated_memory for description) + size_t getCurrentAllocatedMemory() const { + return m_current_allocated_memory; + } + // total GPU memory allocated in the process by Metal driver; including + // implicit allocations from MPS/MPSGraph frameworks and MPSHeapAllocatorImpl. + size_t getDriverAllocatedMemory() const { + return current_allocated_size(); + } + // recommended Max memory for Metal + size_t getRecommendedMaxMemory() const { + return max_device_size(); + } + // (see enum DebugVerbosity for description) + uint32_t getDebugVerbosity() const { + return m_debug_verbosity; + } + // returns the device that we allocate from + inline id Device() const { + return m_device; + } + + inline std::string format_size(uint64_t size) const; + + private: + // (see m_high_watermark_ratio for description) + constexpr static double default_high_watermark_ratio = 1.7; + // we set the allowed upper bound to twice the size of recommendedMaxWorkingSetSize. + constexpr static double default_high_watermark_upper_bound = 2.0; + // (see m_low_watermark_ratio for description) + // on unified memory, we could allocate beyond the recommendedMaxWorkingSetSize + constexpr static double default_low_watermark_ratio_unified = 1.4; + constexpr static double default_low_watermark_ratio_discrete = 1.0; + + const id m_device; + std::recursive_mutex m_mutex; + // allocated buffers by device pointer + ska::flat_hash_map m_allocated_buffers; + // using a container for pools to simplify iterating them + ska::flat_hash_map> m_pools; + // total memory allocated by HeapAllocator (including blocks in pools) + size_t m_total_allocated_memory = 0; + // currently active memory allocations in use (i.e., blocks not in pools) + size_t m_current_allocated_memory = 0; + // max buffer size allowed by Metal + size_t m_max_buffer_size = 0; + // maximum total size allowed to be allocated + size_t m_max_total_allowed_size = 0; + // high watermark ratio is a hard limit for the total allowed allocations + // 0. : disables high watermark limit (may cause system failure if system-wide OOM occurs) + // 1. : recommended maximum allocation size (i.e., device.recommendedMaxWorkingSetSize) + // >1.: allows limits beyond the device.recommendedMaxWorkingSetSize + // e.g., value 0.95 means we allocate up to 95% of recommended maximum + // allocation size; beyond that, the allocations would fail with OOM error. + double m_high_watermark_ratio; + // low watermark ratio is a soft limit to attempt limiting memory allocations up to the lower watermark + // level by garbage collection or committing command buffers more frequently (a.k.a, adaptive commit). + // Value between 0 to m_high_watermark_ratio (setting 0.0 disables adaptive commit and garbage collection) + // e.g., value 0.9 means we 'attempt' to limit allocations up to 90% of recommended maximum + // allocation size. + double m_low_watermark_ratio; + // low watermark size limit (in Bytes) at the time we initialize the allocator + size_t m_low_watermark_limit; + // use "PYTORCH_DEBUG_MPS_ALLOCATOR" env-var to set debug verbosity + uint32_t m_debug_verbosity; + // default MPS stream + MPSStream* m_stream; + // we hold a reference to MPSEventPool so it could get destroyed after MPSAllocator + std::shared_ptr m_event_pool; + + void init_allocator(); + void init_buffer_pools(); + HeapBlock* get_free_heap(AllocParams& params); + bool get_free_buffer(AllocParams& params); + BufferBlock* get_allocated_buffer_block(const void* ptr); + BufferBlock* alloc_buffer_block(size_t size, uint32_t usage); + bool alloc_buffer(AllocParams& params); + void free_buffer(BufferBlock* buffer_block); + // returns true if the container heap is also released + bool release_buffer(BufferBlock* buffer_block, bool remove_empty_heap = true); + void release_buffers(BufferPool& pool); + bool release_available_cached_buffers(AllocParams& params); + bool release_cached_buffers(); + // free unused cached blocks to reclaim GPU memory if memory pressure is high + void garbage_collect_cached_buffers(AllocParams& params); + // returns the suitable buffer pool type for the usage or + // requested/allocated sizes + BufferPool& get_pool(size_t requested_size, size_t aligned_size, uint32_t usage); + // returns the aligned allocation size that is optimized + // for the buffers to get reused frequently + size_t get_allocation_size(size_t size, uint32_t usage) const; + // maximum size of device memory available for allocation in current process + // Note: the recommendedMaxWorkingSetSize is typically 75% of the total system memory. + size_t max_device_size() const { + return [m_device recommendedMaxWorkingSetSize]; + } + // there are implicit allocations from MPS backend, so we need to query the 'device' for + // total allocated size instead of manually tracking in MPSAllocator + size_t current_allocated_size() const { + return [m_device currentAllocatedSize]; + } + + bool trigger_memory_callbacks(BufferBlock* buffer_block, IMpsAllocatorCallback::EventType event) const { + for (const auto& name : MPSAllocatorCallbacksRegistry()->Keys()) { + MPSAllocatorCallbacksRegistry()->Create(name)->executeMPSAllocatorCallback( + buffer_block ? buffer_block->buffer : nullptr, event); + } + return true; + } +}; + +} // namespace at::mps::HeapAllocator diff --git a/phivenv/Lib/site-packages/torch/include/ATen/mps/MPSAllocatorInterface.h b/phivenv/Lib/site-packages/torch/include/ATen/mps/MPSAllocatorInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..0bbec055f4bed49ea37f80858645301149a91f18 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/mps/MPSAllocatorInterface.h @@ -0,0 +1,68 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include +#include +#include + +#define MB(x) (x * 1048576UL) + +namespace at::mps { + +// this is a public interface to access MPSAllocator. +// Do not declare methods that would depend on MPS or Metal frameworks. +class IMPSAllocator : public c10::Allocator { + public: + // see the comments in MPSAllocator.h for the description of these methods. + virtual void emptyCache() const = 0; + virtual void freeInactiveBuffers() const = 0; + virtual ssize_t getUnalignedBufferSize(const void* ptr) const = 0; + virtual IntArrayRef getBufferShape(const void* ptr) const = 0; + virtual id_t getBufferId(const void* ptr) const = 0; + virtual void setBufferShape(const void* ptr, const IntArrayRef& shape) + const = 0; + virtual bool isSharedBuffer(const void* ptr) const = 0; + virtual bool isSharedStorageSupported() const = 0; + virtual c10::DataPtr allocScalarBufferWithValue(void* value, size_t size) + const = 0; + virtual std::string formatSize(size_t size) const = 0; + virtual void setLowWatermarkRatio(double ratio) const = 0; + virtual void setHighWatermarkRatio(double ratio) const = 0; + virtual ssize_t getLowWatermarkValue() const = 0; + virtual size_t getLowWatermarkLimit() const = 0; + virtual size_t getHighWatermarkLimit() const = 0; + virtual size_t getTotalAllocatedMemory() const = 0; + virtual size_t getCurrentAllocatedMemory() const = 0; + virtual size_t getDriverAllocatedMemory() const = 0; + virtual size_t getRecommendedMaxMemory() const = 0; + virtual std::pair getSharedBufferPtr( + const void* ptr) const = 0; + virtual bool recordEvents(c10::ArrayRef buffers) const = 0; + virtual bool waitForEvents(c10::ArrayRef buffers) const = 0; +}; + +class IMpsAllocatorCallback { + public: + enum class EventType { + ALLOCATED, // buffer got allocated to be used immediately + RECYCLED, // buffer pulled from free list to be reused + FREED, // buffer put to free list for future recycling + RELEASED, // buffer memory released + ALLOCATION_FAILED // buffer allocation failed + }; + virtual ~IMpsAllocatorCallback() = default; + virtual void executeMPSAllocatorCallback(void* ptr, EventType event) = 0; +}; + +// MPS allocator will execute every registered callback when a block of memory +// is freed. +TORCH_DECLARE_REGISTRY(MPSAllocatorCallbacksRegistry, IMpsAllocatorCallback); +#define REGISTER_MPS_ALLOCATOR_CALLBACK(name, ...) \ + C10_REGISTER_CLASS(MPSAllocatorCallbacksRegistry, name, __VA_ARGS__) + +IMPSAllocator* getIMPSAllocator(bool sharedAllocator = false); + +bool isMPSPinnedPtr(const void* data); + +} // namespace at::mps diff --git a/phivenv/Lib/site-packages/torch/include/ATen/mps/MPSDevice.h b/phivenv/Lib/site-packages/torch/include/ATen/mps/MPSDevice.h new file mode 100644 index 0000000000000000000000000000000000000000..f15bf5e3f91b3ed240f7ecf8d5a9b4c524c0b893 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/mps/MPSDevice.h @@ -0,0 +1,78 @@ +// Copyright © 2022 Apple Inc. + +#pragma once +#include +#include +#include +#include + +#ifdef __OBJC__ +#include +#include +typedef id MTLDevice_t; +#else +typedef void* MTLDevice_t; +#endif + +namespace at::mps { + +// Helper enum to check if a MPSGraph op is supported in a given macOS version +enum class MacOSVersion : uint32_t { + MACOS_VER_13_1_PLUS = 0, + MACOS_VER_13_2_PLUS, + MACOS_VER_13_3_PLUS, + MACOS_VER_14_0_PLUS, + MACOS_VER_14_4_PLUS, + MACOS_VER_15_0_PLUS, + MACOS_VER_15_1_PLUS, + MACOS_VER_15_2_PLUS, +}; + +//----------------------------------------------------------------- +// MPSDevice +// +// MPSDevice is a singleton class that returns the default device +//----------------------------------------------------------------- + +class TORCH_API MPSDevice { + public: + /** + * MPSDevice should not be cloneable. + */ + MPSDevice(MPSDevice& other) = delete; + /** + * MPSDevice should not be assignable. + */ + void operator=(const MPSDevice&) = delete; + /** + * Gets single instance of the Device. + */ + static MPSDevice* getInstance(); + /** + * Returns the single device. + */ + MTLDevice_t device() { + return _mtl_device; + } + /** + * Returns whether running on Ventura or newer + */ + bool isMacOS13Plus(MacOSVersion version) const; + + ~MPSDevice(); + + private: + static MPSDevice* _device; + MTLDevice_t _mtl_device; + MPSDevice(); +}; + +TORCH_API bool is_available(); +TORCH_API bool is_macos_13_or_newer(MacOSVersion version); +TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false); + +inline Device getDeviceFromPtr(void* ptr) { + return {c10::DeviceType::MPS, 0}; +} + +} // namespace at::mps diff --git a/phivenv/Lib/site-packages/torch/include/ATen/mps/MPSEvent.h b/phivenv/Lib/site-packages/torch/include/ATen/mps/MPSEvent.h new file mode 100644 index 0000000000000000000000000000000000000000..101a37fc1bd4a2fb574104df2b14145dbf833899 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/mps/MPSEvent.h @@ -0,0 +1,105 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include +#include +#include + +namespace at::mps { + +// NOTE: don't create instances of this class directly. +// Use MPSEventPool to acquire instances of MPSEvent. +class MPSEvent { + public: + explicit MPSEvent(id_t ID, MPSStream* stream, bool enable_timing); + ~MPSEvent(); + + // records an event on the stream + void record(bool needsLock, bool syncEvent = false); + // makes all future work submitted to the stream wait for this event. + bool wait(bool needsLock, bool syncEvent = false); + // schedules a notifyListener callback for the event. + bool notify(bool needsLock, MTLSharedEventNotificationBlock block); + // checks if events are already signaled. + bool query() const; + // blocks the CPU thread until all the GPU work that were scheduled + // prior to recording this event are completed. + bool synchronize(); + // resets this event with new parameters in case it gets reused from the event + // pool + void reset(MPSStream* stream, bool enable_timing); + // returns the unique ID of the event instance + id_t getID() const { + return m_id; + } + // returns the completion timestamp of the event + uint64_t getCompletionTime() const { + return m_completion_time; + } + // if already recorded, waits for cpu_sync_cv to be signaled + void waitForCpuSync(); + + private: + id_t m_id; + // enables measuring the completion time of the notifyListener of this event + bool m_enable_timing; + uint64_t m_signalCounter = 0; + MPSStream* m_stream = nullptr; + MTLSharedEvent_t m_event = nullptr; + MTLSharedEventListener* m_listener = nullptr; + // used to sync the events created on this Stream with CPU + std::mutex m_cpu_sync_mutex{}; + std::condition_variable m_cpu_sync_cv{}; + // CondVar predicate to sync the events created on this Stream with CPU + bool m_cpu_sync_completed = false; + // used to compute elapsed time + uint64_t m_completion_time = 0; + + void recordLocked(bool syncEvent); + bool waitLocked(bool syncEvent); + bool notifyLocked(MTLSharedEventNotificationBlock block); + void notifyCpuSync(); + static uint64_t getTime() { + return clock_gettime_nsec_np(CLOCK_MONOTONIC_RAW); + } +}; + +typedef std::unique_ptr> MPSEventPtr; + +class MPSEventPool { + public: + explicit MPSEventPool(MPSStream* default_stream); + ~MPSEventPool(); + + MPSEventPtr acquireEvent(bool enable_timing, MPSStream* stream); + void emptyCache(); + + // these are mainly used for MPSHooks and torch.mps.Event() bindings + id_t acquireEvent(bool enable_timing); + void releaseEvent(id_t event_id); + void recordEvent(id_t event_id, bool syncEvent); + void waitForEvent(id_t event_id, bool syncEvent); + void synchronizeEvent(id_t event_id); + bool queryEvent(id_t event_id); + // returns elapsed time between two recorded events in milliseconds + double elapsedTime(id_t start_event_id, id_t end_event_id); + + private: + MPSStream* m_default_stream = nullptr; + std::recursive_mutex m_mutex; + std::stack> m_pool{}; + // dictionary to associate event IDs with event objects + // used to retain in-use events out of the pool + // for torch.mps.Event() bindings. + std::unordered_map m_in_use_events{}; + uint64_t m_event_counter = 0; + std::function m_default_deleter; + + MPSEvent* getInUseEvent(id_t event_id, bool locked = true); +}; + +// shared_ptr is used to get MPSEventPool destroyed after dependent instances +std::shared_ptr getMPSEventPool(); + +} // namespace at::mps diff --git a/phivenv/Lib/site-packages/torch/include/ATen/mps/MPSGeneratorImpl.h b/phivenv/Lib/site-packages/torch/include/ATen/mps/MPSGeneratorImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..4cab0b7f7879dd62c2c62ff4c2fd1cbd22766e10 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/mps/MPSGeneratorImpl.h @@ -0,0 +1,61 @@ +// Copyright © 2022 Apple Inc. + +#pragma once + +#include +#include +#include +#include + +namespace at { +namespace mps::detail { + +constexpr uint32_t PHILOX_STATE_N = 7; +struct rng_data_pod { + std::array state{1}; + uint64_t seed = default_rng_seed_val; +}; + +TORCH_API const Generator& getDefaultMPSGenerator(); +TORCH_API Generator +createMPSGenerator(uint64_t seed_val = default_rng_seed_val); + +} // namespace mps::detail + +struct TORCH_API MPSGeneratorImpl : public c10::GeneratorImpl { + // Constructors + MPSGeneratorImpl(uint64_t seed_in = default_rng_seed_val); + ~MPSGeneratorImpl() override = default; + + // MPSGeneratorImpl methods + std::shared_ptr clone() const; + void set_current_seed(uint64_t seed) override; + void set_offset(uint64_t offset) override; + uint64_t get_offset() const override; + uint64_t current_seed() const override; + uint64_t seed() override; + void set_state(const c10::TensorImpl& new_state) override; + c10::intrusive_ptr get_state() const override; + void update_philox_counters(); + + void set_engine(at::Philox4_32 engine) { + engine_ = engine; + } + at::Philox4_32 engine() { + return engine_; + } + uint32_t* state_data() { + return data_.state.data(); + } + static DeviceType device_type() { + return DeviceType::MPS; + } + + private: + mps::detail::rng_data_pod data_; + at::Philox4_32 engine_; + + MPSGeneratorImpl* clone_impl() const override; +}; + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/mps/MPSGuardImpl.h b/phivenv/Lib/site-packages/torch/include/ATen/mps/MPSGuardImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..e1c5b2dbb193777e6053af2f4b5ca17b54a2a54f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/mps/MPSGuardImpl.h @@ -0,0 +1,182 @@ +// Copyright © 2022 Apple Inc. + +#pragma once +#include +#include +#include +#include +#include +#include + +#ifdef __OBJC__ +#include +#include +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::mps { + +typedef MPSEvent* mpsEvent_t; + +// TODO: Move the MPSGuardImpl to inherit from NoOpDeviceGuardImpl +// https://github.com/pytorch/pytorch/issues/77170 +struct TORCH_API MPSGuardImpl final + : public c10::impl::DeviceGuardImplInterface { + static constexpr c10::DeviceType static_type = c10::DeviceType::MPS; + + // constructor + MPSGuardImpl() {} + explicit MPSGuardImpl(c10::DeviceType t) { + TORCH_CHECK( + t == DeviceType::MPS, + "MPSGuardImpl initialized with non-MPS DeviceType: ", + t); + } + + // returns the type + c10::DeviceType type() const override { + return c10::DeviceType::MPS; + } + + Device exchangeDevice(Device d) const override { + return Device(c10::DeviceType::MPS, 0); + } + + Device getDevice() const override { + return Device(c10::DeviceType::MPS, 0); + } + + std::optional uncheckedGetDevice() const noexcept { + return Device(c10::DeviceType::MPS, 0); + } + + void setDevice(Device d) const override { + TORCH_CHECK(d.is_mps(), "Expected a MPS device, but got ", d); + } + + void uncheckedSetDevice(Device d) const noexcept override { + // TODO: Currently setting only device 0 + } + + Stream getStream(Device d) const override { + return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0)); + } + + Stream getNewStream(Device, int priority = 0) const override { + (void)priority; + return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0)); + } + + Stream getDefaultStream(Device d) const override { + return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0)); + } + + // NB: These do NOT set the current device + Stream exchangeStream(Stream s) const override { + return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0)); + } + DeviceIndex deviceCount() const noexcept override { + if (at::hasMPS()) { + // TODO: extend it for multi-device case + return 1; + } else { + return 0; + } + } + + // Event-related functions + void createEvent(mpsEvent_t* event, const EventFlag flag) const; + + void destroyEvent(void* event, const DeviceIndex device_index) + const noexcept override; + + void record( + void** event, + const Stream& stream, + const DeviceIndex device_index, + const EventFlag flag) const override; + + void block(void* event, const Stream& stream) const override; + + bool queryEvent(void* event) const override; + + void synchronizeEvent(void* event) const override; + + double elapsedTime(void* event1, void* event2, const DeviceIndex device_index) + const override; + + void synchronizeDevice(const DeviceIndex device_index) const override; +}; + +/// A variant of OptionalDeviceGuard that is specialized for MPS. +struct OptionalMPSGuard { + explicit OptionalMPSGuard() : guard_() {} + + explicit OptionalMPSGuard(std::optional device_opt) + : guard_(device_opt) {} + + /// Set the current MPS device to the passed device index, if it is not + /// nullopt + explicit OptionalMPSGuard(std::optional device_index_opt) + : guard_(device_index_opt) {} + + // Copy is not allowed + OptionalMPSGuard(const OptionalMPSGuard&) = delete; + OptionalMPSGuard& operator=(const OptionalMPSGuard&) = delete; + OptionalMPSGuard(OptionalMPSGuard&& other) = delete; + OptionalMPSGuard& operator=(OptionalMPSGuard&& other) = delete; + + /// Sets the MPS device to the given device, initializing the guard if it + /// is not already initialized. Errors if the given device is not a MPS + /// device. + void set_device(Device device) { + guard_.set_device(device); + } + + /// Sets the MPS device to the given device, initializing the guard if it is + /// not already initialized. Errors if the given device is not a MPS device. + void reset_device(Device device) { + guard_.reset_device(device); + } + + /// Sets the MPS device to the given device index, initializing the guard if + /// it is not already initialized. + void set_index(DeviceIndex device_index) { + guard_.set_index(device_index); + } + + /// Returns the device that was set immediately prior to initialization of the + /// guard, or nullopt if the guard is uninitialized. + std::optional original_device() const { + return guard_.original_device(); + } + + /// Returns the most recent device that was set using this device guard, + /// either from construction, or via set_device, if the guard is initialized, + /// or nullopt if the guard is uninitialized. + std::optional current_device() const { + return guard_.current_device(); + } + + /// Restore the original MPS device, resetting this guard to uninitialized + /// state. + void reset() { + guard_.reset(); + } + + private: + c10::impl::InlineOptionalDeviceGuard guard_; +}; + +C10_REGISTER_GUARD_IMPL(MPS, MPSGuardImpl) + +} // namespace at::mps diff --git a/phivenv/Lib/site-packages/torch/include/ATen/mps/MPSHooks.h b/phivenv/Lib/site-packages/torch/include/ATen/mps/MPSHooks.h new file mode 100644 index 0000000000000000000000000000000000000000..e2088f43d99ce66aaa72f003f66c168006747ef2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/mps/MPSHooks.h @@ -0,0 +1,71 @@ +// Copyright © 2022 Apple Inc. + +#pragma once + +#include +#include +#include +#include + +namespace at::mps { + +// The real implementation of MPSHooksInterface +struct MPSHooks : public at::MPSHooksInterface { + MPSHooks(at::MPSHooksArgs) {} + void init() const override; + + // MPSDevice interface + bool hasMPS() const override; + bool isOnMacOSorNewer(unsigned major, unsigned minor) const override; + + Device getDeviceFromPtr(void* data) const override; + + // MPSGeneratorImpl interface + const Generator& getDefaultGenerator( + DeviceIndex device_index = -1) const override; + Generator getNewGenerator(DeviceIndex device_index = -1) const override; + + // MPSStream interface + void deviceSynchronize() const override; + void commitStream() const override; + void* getCommandBuffer() const override; + void* getDispatchQueue() const override; + + // MPSAllocator interface + Allocator* getMPSDeviceAllocator() const override; + void emptyCache() const override; + size_t getCurrentAllocatedMemory() const override; + size_t getDriverAllocatedMemory() const override; + size_t getRecommendedMaxMemory() const override; + void setMemoryFraction(double ratio) const override; + bool isPinnedPtr(const void* data) const override; + Allocator* getPinnedMemoryAllocator() const override; + + // MPSProfiler interface + void profilerStartTrace(const std::string& mode, bool waitUntilCompleted) + const override; + void profilerStopTrace() const override; + + // MPSEvent interface + uint32_t acquireEvent(bool enable_timing) const override; + void releaseEvent(uint32_t event_id) const override; + void recordEvent(uint32_t event_id) const override; + void waitForEvent(uint32_t event_id) const override; + void synchronizeEvent(uint32_t event_id) const override; + bool queryEvent(uint32_t event_id) const override; + double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) + const override; + + bool isBuilt() const override { + return true; + } + bool isAvailable() const override { + return hasMPS(); + } + bool hasPrimaryContext(DeviceIndex device_index) const override { + // When MPS is available, it is always in use for the one device. + return true; + } +}; + +} // namespace at::mps diff --git a/phivenv/Lib/site-packages/torch/include/ATen/mps/MPSProfiler.h b/phivenv/Lib/site-packages/torch/include/ATen/mps/MPSProfiler.h new file mode 100644 index 0000000000000000000000000000000000000000..59a3d50076f028ff94611600d2fda42a02f45866 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/mps/MPSProfiler.h @@ -0,0 +1,467 @@ +// Copyright © 2022 Apple Inc. + +#pragma once + +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +#ifndef __OBJC__ +typedef void* MTLCaptureManager; +#endif + +namespace at::mps { + +namespace Profiler { + +struct BaseInfo { + // profiling info types + enum class Type { + GRAPH, + KERNEL, + COPY, + CPU_FALLBACK, + }; + + BaseInfo(Type infoType, uint64_t Id, const uintptr_t Handle) + : type(infoType), profileId(Id), handle(Handle) {} + virtual ~BaseInfo() = default; + + // type of profiling info + Type type; + // unique profile ID for execution instances of operations or copies + uint64_t profileId; + // ID generated by os_signpost + // since it's possible to use event and interval-based signposts at the + // same time, we need separate IDs for each. + os_signpost_id_t eventSignpostId = 0, intervalSignpostId = 0; + // accumulated GPU time in ms (obtained from CompletionHandler's "GPUEndTime - + // GPUStartTime") + std::atomic totalGpuTime{0.0}; + // accumulated Scheduling time in ms (obtained from CompletionHandler's + // "KernelEndTime - KernelStartTime") + std::atomic totalSchedulingTime{0.0}; + // indicates if the operation or copy execution has completed + std::atomic_bool completed{false}; + // handle used to identify the profile info's instance (usually the pointer) + const uintptr_t handle; + + virtual const std::string toString( + double gpuTime = 0, + double schedulingTime = 0) const; + // builds a string for a tensor (format: Device:ScalarType[tensor.sizes()]) + static std::string buildTensorString( + const Tensor& tensor, + bool includeBufferId = false); + static uint64_t getTime() { + return clock_gettime_nsec_np(CLOCK_MONOTONIC_RAW); + } +}; + +struct OperationInfo : BaseInfo { + OperationInfo( + const void* Handle, + bool IsGraph, + uint64_t Id, + const std::string& StrKey) + : BaseInfo(IsGraph ? Type::GRAPH : Type::KERNEL, Id, uintptr_t(Handle)), + strKey(StrKey) {} + + uint64_t runCount = 0; + std::string strKey; + + const std::string toString(double gpuTime = 0, double schedulingTime = 0) + const override; + + // builds a string for a kernel + static std::string buildKernelString( + const std::string& kernelName, + const TensorList& tensors, + bool includeBufferId = false) { + std::stringstream kernelStr; + kernelStr << kernelName; + for (const Tensor& tensor : tensors) { + kernelStr << ":" << BaseInfo::buildTensorString(tensor, includeBufferId); + } + return kernelStr.str(); + } +}; + +struct CpuFbInfo : BaseInfo { + CpuFbInfo(uint64_t Id, const std::string& OpName) + : BaseInfo(Type::CPU_FALLBACK, Id, 0), opName(OpName) {} + + uint64_t runCount = 0; + // the current and total overhead of copies in bytes required to convert the + // Op's input tensors from MPS to CPU and then output from CPU back to MPS + size_t currentCopyOverhead = 0; + size_t totalCopyOverhead = 0; + std::string opName; + std::string strKey; + uint64_t startTime = 0; + + const std::string toString(double gpuTime = 0, double schedulingTime = 0) + const override; + + void updateCopyOverhead(const TensorList& tensors) { + currentCopyOverhead = 0; + for (const Tensor& tensor : tensors) { + if (tensor.defined()) { + currentCopyOverhead += tensor.nbytes(); + } + } + totalCopyOverhead += currentCopyOverhead; + } +}; + +struct CopyInfo : BaseInfo { + enum class Kind { + MPS_TO_MPS, + MPS_TO_CPU, + CPU_TO_MPS, + }; + + CopyInfo( + const void* Handle, + size_t Length, + uint64_t Id, + bool IsNonBlocking, + bool UsesBlitter) + : BaseInfo(Type::COPY, Id, uintptr_t(Handle)), + kind(Kind::MPS_TO_MPS), + length(Length), + isNonBlocking(IsNonBlocking), + usesBlitter(UsesBlitter) {} + + Kind kind; + size_t length; + bool isNonBlocking; + bool usesBlitter; + std::string srcStrKey; + std::string dstStrKey; + // for copies that don't use blitters, we measure CPU time + uint64_t startTime = 0; + + const std::string toString(double gpuTime = 0, double schedulingTime = 0) + const override; + + static std::string buildTensorString( + const void* buffer, + const OptionalTensorRef tensor, + bool includeBufferId = false); + + static bool isStorageOnMPS( + const void* buffer, + const OptionalTensorRef tensor) { + if (tensor.has_value()) { + return tensor->device().type() == at::kMPS; + } + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(buffer); + // getUnalignedBufferSize() returns -1 if input buffer is not on MPS device + return getIMPSAllocator()->getUnalignedBufferSize(buffer) >= 0; + } + + static Kind getCopyKind( + const void* srcBuffer, + const void* dstBuffer, + const OptionalTensorRef srcTensor, + const OptionalTensorRef dstTensor) { + const bool isSrcOnMPS = isStorageOnMPS(srcBuffer, srcTensor); + const bool isDstOnMPS = isStorageOnMPS(dstBuffer, dstTensor); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isSrcOnMPS || isDstOnMPS); + if (isSrcOnMPS && !isDstOnMPS) { + return Kind::MPS_TO_CPU; + } else if (!isSrcOnMPS && isDstOnMPS) { + return Kind::CPU_TO_MPS; + } + return Kind::MPS_TO_MPS; + } +}; + +struct CopyStat : CopyInfo { + explicit CopyStat(std::string CopyKindStr) + : CopyInfo(nullptr, 0, 0, false, false), + kindStr(std::move(CopyKindStr)) {} + // total number of copies + size_t totalCount = 0; + // number of Scalar copies (i.e., less than sizeof(int64)) + size_t scalarsCount = 0; + // number of blocking copies (i.e., require syncing to GPU) + size_t blockingCount = 0; + // number of copies that used memcpy(), instead of Metal Blit Encoder + size_t memcpyCount = 0; + // accumulated GPU time in ms for the scalar copies + std::atomic scalarsGpuTime{0.0}; + // copy kind in string type + std::string kindStr; +}; + +class MPSProfiler { + public: + // lower 16 bits used for profiler options + enum ProfileOptions : uint32_t { + OPTIONS_NONE = 0, + // ALL_* means, all signpost types (RUN_OPERATION|BLIT_COPY|CPU_FALLBACK, + // etc.) (used for convenience to not compute bit flags by OR-ing manually) + // trace all signpost types using events + ALL_SIGNPOST_EVENTS = (1 << 0), + // trace all signpost types using intervals + ALL_SIGNPOST_INTERVALS = (1 << 1), + // always wait for command buffer to finish executing after each commit + WAIT_UNTIL_COMPLETED = (1 << 2), + // for interval-based signposts, include the scheduling portion of + // Graph/Kernel/Copy executions as well. + // if flag is disable, only "GPU run time" is included in interval, + // and not schedule time. + INCLUDE_SCHEDULE_INTERVAL = (1 << 3), + + // use these if you need to trace signposts types individually (rarely + // required) trace signpost using intervals + USE_INTERVALS = (1 << 4), + // trace signpost by emitting events + USE_EVENTS = (1 << 5), + // used for sanity check (Change this when new option added) + OPTIONS_COUNT = (USE_EVENTS << 1) - 1, + }; + + // when adding new types, #define the type string in MPSProfiler.mm as well. + // upper 16 bits used for event types + enum SignpostTypes : uint32_t { + SIGNPOST_NONE = 0, + // trace signposts for PyTorch operation executions + RUN_OPERATION = (1 << 16), + // trace signposts for blitter copies + BLIT_COPY = (1 << 17), + // trace signposts for ops that fall back on CPU + CPU_FALLBACK = (1 << 18), + // used for sanity check (Change this when new type added) + SIGNPOST_COUNT = (CPU_FALLBACK << 1) - 1, + }; + + enum LogOptions : uint32_t { + LOG_NONE = 0, + + // Info logging options during execution + // ------------------------------------- + // prints operation info (id/key/run_count) during execution + OPERATION_INFO = (1 << 0), + // prints copy info (src/dst tensors/buffers, size, etc.) during execution + COPY_INFO = (1 << 1), + // prints CPU Fallback info (id/runCount/opName/copyOverhead) during + // execution + CPU_FALLBACK_INFO = (1 << 2), + + // Profiling Statistics logging options when process terminates + // ------------------------------------------------------------ + // prints all stats (OPERATION_STATS, COPY_STATS, CPU_FALLBACK_STATS) before + // process terminates this is convenient to not combine following stats bit + // flags manually + ALL_STATS = (1 << 3), + // prints operation stats (GPU times, run count, etc.) before process + // terminates + OPERATION_STATS = (1 << 4), + // prints copies stats (GPU times, copy kinds, sizes, etc.) before process + // terminates + COPY_STATS = (1 << 5), + // prints CPU Fallback stats (CPU times, run times, size of MPS<->CPU copies + // for tensors, etc.) before process terminates + CPU_FALLBACK_STATS = (1 << 6), + + // Metadata format options when logging the info + // --------------------------------------------- + // if enabled, includes GPU run time in metadata (i.e., + // GPUEndTime-GPUStartTime from Metal Command Buffers) (e.g., [GPU=0.324 + // ms]) + INCLUDE_GPU_TIME = (1 << 7), + // if enabled, includes GPU scheduling time in metadata separately + // (i.e., KernelEndTime-KernelStartTime from Metal Command Buffers) + // e.g., [GPU=0.324 ms, KRNL=0.036 ms] + INCLUDE_KERNEL_TIME = (1 << 8), + // if enabled, includes the unique buffer ID in metadata for the storage + // of a tensor that was allocated on MPSAllocator. This is useful (along + // with the EV "PYTORCH_DEBUG_MPS_ALLOCATOR") to identify buffers that are + // involved with various operations. + INCLUDE_BUFFER_ID = (1 << 9), + + // used for sanity check (Change this when new option added) + LOG_COUNT = (INCLUDE_BUFFER_ID << 1) - 1, + }; + + explicit MPSProfiler(); + ~MPSProfiler(); + + // the handle is either "MPSGraph*" or "id" for Metal + // Kernels the beginProfile*() functions return a profileId which is unique + // per graph/kernel/copy + uint64_t beginProfileKernel( + const void* handle, + const std::string& strKey, + bool isGraph); + uint64_t beginProfileKernel( + const void* handle, + const std::string& kernelName, + const TensorList& tensors); + uint64_t beginProfileCopy( + const void* srcBuffer, + const void* dstBuffer, + const OptionalTensorRef srcTensor, + const OptionalTensorRef dstTensor, + size_t length, + bool isNonBlocking, + bool usesBlitter = true); + uint64_t beginProfileCPUFallback( + const std::string& opName, + const TensorList& tensors); + void beginProfileGPUInterval(const void* handle); + + void endProfileCopy(uint64_t profileId, SyncType syncType); + void endProfileKernel(const void* handle, SyncType syncType = SyncType::NONE); + void endProfileCPUFallback(const std::string& opName); + + // these are used to hook into Python bindings for torch.mps.profiler module. + // this enables generating OS Signpost traces from MPSProfiler on-demand + // during runtime (instead of environment variables). + // The "mode" could be either "interval", "event", or both "interval,event" + // for interval-based and/or event-based signpost tracing. + void StartTrace(const std::string& mode, bool waitUntilCompleted); + void StopTrace(); + + // Abstractions for GPU trace capturing + bool isCaptureEnabled() const; + bool isCapturing() const; + void startCapture(const std::string& name, MPSStream* stream = nullptr); + void stopCapture(MPSStream* stream = nullptr); + + // convenience functions to indicate whether signpost tracing or + // logging are enabled for the SignpostTypes + bool isOperationProfilingEnabled() const { + return (m_signpost_types & SignpostTypes::RUN_OPERATION) || + (m_log_options & + (LogOptions::OPERATION_INFO | LogOptions::OPERATION_STATS)); + } + bool isCopyProfilingEnabled() const { + return (m_signpost_types & SignpostTypes::BLIT_COPY) || + (m_log_options & (LogOptions::COPY_INFO | LogOptions::COPY_STATS)); + } + bool isCPUFallbackProfilingEnabled() const { + return (m_signpost_types & SignpostTypes::CPU_FALLBACK) || + (m_log_options & + (LogOptions::CPU_FALLBACK_INFO | LogOptions::CPU_FALLBACK_STATS)); + } + bool isSignpostTracingEnabled() const { + return (m_signpost_types != SignpostTypes::SIGNPOST_NONE); + } + + private: + // indicates what type of signpost types are enabled and traced by MPS + // profiler. + uint32_t m_signpost_types = 0; + uint32_t m_profile_options = 0; + uint32_t m_log_options = 0; + uint64_t m_kernel_counter = 0; + uint64_t m_graph_counter = 0; + uint64_t m_cpu_fb_counter = 0; + uint64_t m_copy_counter = 0; + // technically, it's possible to trace both events and intervals at the same + // time so we use separate os_log categories for them + os_log_t m_os_log_events; + os_log_t m_os_log_intervals; + // stats logging could run either from destructor or signal handler + // so this is used to check if logging has already started. + std::atomic_bool hasLoggedStats{false}; + // indicates there are pending completionHandler callbacks that haven't been + // called yet. + std::atomic_bool hasPendingCompletionHandlers{false}; + // used to capture sigint signal to log profiling stats + static struct sigaction currentSigint, previousSigint; + + // We use the following lists for two reasons: + // 1- for interval-based signposts the "begin" point won't be in same function + // as the "end" point where we need to be able to retrieve signpost's info + // 2- if Operations info need to be logged when process ends using + // LogOptions::OPERATION_INFO. + + // the pointer key for this map is either "MPSGraph*" or + // "id" for Metal Kernels this list is retained and + // could be logged along with aggregate profiling numbers when the process + // ends. + std::unordered_map> + m_op_info_list{}; + // the string key for this map is the op name that we fall back to execute on + // CPU this list is retained and could be logged along with aggregate + // profiling numbers when the process ends. + std::unordered_map> + m_cpu_fb_info_list{}; + // this list contains the info for copies, and its key is the unique profileId + // which is generated from m_copy_counter + // The copyInfo list is not retained. + std::unordered_map> m_copy_info_list{}; + // a short list that contains copy stats + std::unordered_map> + m_copy_stat_list{}; + + mutable MTLCaptureManager* captureManager = nil; + unsigned captureCount = 0; + + void initialize(); + void beginProfileExecution(BaseInfo& info, bool cpuExecution = false); + void endProfileExecution( + BaseInfo& info, + os_signpost_id_t event_signpost_id, + os_signpost_id_t interval_signpost_id, + double gpuTime, + double schedulingTime); + void addProfilerScheduledHandler(BaseInfo& info); + void addProfilerCompletedHandler(BaseInfo& info, SyncType syncType); + void emitSignpostEvent( + SignpostTypes signpost_type, + os_signpost_id_t signpost_id, + const std::string& msg) const; + void beginSignpostInterval( + SignpostTypes signpost_type, + os_signpost_id_t signpost_id, + const std::string& msg) const; + void endSignpostInterval( + SignpostTypes signpost_type, + os_signpost_id_t signpost_id) const; + + void updateCopyStats( + const CopyInfo& copyInfo, + double gpuTime, + double schedulingTime); + // returns true if logging the profiling info "during the execution" is + // enabled + bool isProfileInfoLoggingEnabled( + BaseInfo::Type infoType, + bool isExecutionEnded); + // logs all the profiling stats that are enabled + void logProfilingStats(); + // logs kernel profiling stats when the process ends. + void logOperationsProfilingStats(std::FILE* f) const; + // logs CPU Fallback profiling stats when the process ends. + void logCPUFallbackProfilingStats(std::FILE* f) const; + // logs copy profiling stats when the process ends. + void logCopyProfilingStats(std::FILE* f) const; + + os_signpost_id_t generateSignpostId( + os_signpost_type_t signpostType, + const void* ptr = nullptr); + static SignpostTypes getSignpostType(BaseInfo::Type infoType); + static void handleIntSignal(int signal); +}; + +} // namespace Profiler + +Profiler::MPSProfiler& getMPSProfiler(); + +} // namespace at::mps diff --git a/phivenv/Lib/site-packages/torch/include/ATen/mps/MPSStream.h b/phivenv/Lib/site-packages/torch/include/ATen/mps/MPSStream.h new file mode 100644 index 0000000000000000000000000000000000000000..ed26ed0c211c6513429126c893dd3d21a33a8f66 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/mps/MPSStream.h @@ -0,0 +1,158 @@ +// Copyright © 2022 Apple Inc. + +#pragma once + +#include +#include + +#include +#include +#include +#include + +#ifdef __OBJC__ +#include +#include +#include +#include +typedef MPSCommandBuffer* MPSCommandBuffer_t; +typedef id MTLCommandQueue_t; +typedef id MTLComputeCommandEncoder_t; +typedef id MTLSharedEvent_t; +typedef id MTLDevice_t; +typedef id MTLBuffer_t; +#else +#include +typedef void* MPSCommandBuffer_t; +typedef void* MPSGraph; +typedef void* MPSGraphExecutionDescriptor; +typedef void* MPSGraphCompilationDescriptor; +typedef void* MTLCommandQueue_t; +typedef void* MTLComputeCommandEncoder_t; +typedef void* MTLSharedEvent_t; +typedef void* MTLDevice_t; +typedef void* MTLBuffer_t; +typedef void* MTLCommandBufferHandler; +typedef void* NSDictionary; +#define nil NULL +#endif + +namespace at::mps { + +//----------------------------------------------------------------- +// MPSStream +//----------------------------------------------------------------- + +enum class SyncType { + NONE, // no commit to command buffer + COMMIT, // commit and flush the command buffer + COMMIT_AND_WAIT, // flush and wait for command buffer execution to finish + COMMIT_AND_CONTINUE, // commit and continue with a new underlying command buffer + COMMIT_ADAPTIVE, // commit adaptively based on available memory +}; + +class TORCH_API MPSStream { + public: + enum Unchecked { UNCHECKED }; + + /// Construct a MPSStream from a Stream. This construction is checked, + /// and will raise an error if the Stream is not, in fact, a MPS stream. + explicit MPSStream(Stream stream); + + ~MPSStream(); + + MTLCommandQueue_t commandQueue() const { + return _commandQueue; + } + + dispatch_queue_t queue() const { + return _serialQueue; + } + + MPSCommandBuffer_t commandBuffer(); + MTLComputeCommandEncoder_t commandEncoder(); + void endKernelCoalescing(); + void synchronize(SyncType syncType); + void fill(MTLBuffer_t buffer, uint8_t value, size_t length, size_t offset, SyncType syncType = SyncType::NONE); + void copy(MTLBuffer_t srcBuffer, + MTLBuffer_t dstBuffer, + size_t length, + size_t srcOffset, + size_t dstOffset, + uint64_t profileId, + SyncType syncType = SyncType::NONE); + void copy_and_sync(MTLBuffer_t srcBuffer, + MTLBuffer_t dstBuffer, + size_t length, + size_t srcOffset, + size_t dstOffset, + bool non_blocking, + uint64_t profileId); + void executeMPSGraph(MPSGraph* mpsGraph, + NSDictionary* feeds, + NSDictionary* results, + SyncType syncType = SyncType::NONE); + void addCompletedHandler(MTLCommandBufferHandler block); + + /// Get the MPS device index that this stream is associated with. + c10::DeviceIndex device_index() const { + return _stream.device_index(); + } + + MTLCommandQueue_t stream() const { + return _commandQueue; + } + + MTLDevice_t device() const; + + /// Explicit conversion to Stream. + Stream unwrap() const { + return _stream; + } + + private: + Stream _stream; + MTLCommandQueue_t _commandQueue = nil; + MPSCommandBuffer_t _commandBuffer = nil; + MPSCommandBuffer_t _prevCommandBuffer = nil; + MTLComputeCommandEncoder_t _commandEncoder = nil; + MPSGraphExecutionDescriptor* _executionDescriptor = nil; + MPSGraphCompilationDescriptor* _compilationDescriptor = nil; + dispatch_queue_t _serialQueue = nullptr; + // CommitAndContinue is enabled by default + bool _enableCommitAndContinue = true; + + // use synchronize() to access any of these commit functions outside MPSStream + void commit(); + void commitAndWait(); + void commitAndContinue(); + void flush(); +}; + +/** + * Get the current MPS stream + */ +TORCH_API MPSStream* getCurrentMPSStream(); + +/** + * Get the default MPS stream + */ +TORCH_API MPSStream* getDefaultMPSStream(); + +//----------------------------------------------------------------- +// MPSStreamImpl +//----------------------------------------------------------------- + +class TORCH_API MPSStreamImpl { + public: + /** + * Gets single instance of the MPSStream. + */ + static MPSStream* getInstance(); + + private: + static MPSStream* _stream; + MPSStreamImpl(); +}; + +} // namespace at::mps diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/Activation.h b/phivenv/Lib/site-packages/torch/include/ATen/native/Activation.h new file mode 100644 index 0000000000000000000000000000000000000000..02cc5090f1dd0d033678c525445cff69c06b2ae5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/Activation.h @@ -0,0 +1,73 @@ +#pragma once + +#include +#include +#include + +namespace c10 { +class Scalar; +} + +namespace at { +struct TensorIterator; +struct TensorIteratorBase; +class TensorBase; +} + +namespace at::native { + +using structured_activation_fn = void (*)(TensorIteratorBase&); +using structured_activation_backward_fn = void (*)(TensorIteratorBase&); + +using activation_fn = void (*)(TensorIterator&); +using activation_backward_fn = void (*)(TensorIterator&); +using softplus_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&); +using softplus_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&); +using threshold_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&); +using hardtanh_backward_fn = void (*)(TensorIterator&, const c10::Scalar&, const c10::Scalar&); +using hardsigmoid_fn = void(*)(TensorIteratorBase&); +using hardsigmoid_backward_fn = void(*)(TensorIteratorBase&); +using hardswish_fn = void(*)(TensorIterator&); +using hardswish_backward_fn = void(*)(TensorIterator&); +using shrink_fn = void (*)(TensorIteratorBase&, const c10::Scalar&); +using softshrink_fn = void (*)(TensorIteratorBase&, const c10::Scalar&); +using shrink_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&); +using elu_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&, const c10::Scalar&); +using elu_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&, const c10::Scalar&, bool); +using leaky_relu_fn = void (*)(TensorIteratorBase&, const c10::Scalar&); +using leaky_relu_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&); +using log_sigmoid_cpu_fn = void (*)(TensorBase&, TensorBase&, const TensorBase&); +using gelu_fn = void (*)(TensorIteratorBase&, GeluType); +using gelu_backward_fn = void (*)(TensorIteratorBase&, GeluType); +using glu_jvp_fn = void (*)(TensorIteratorBase&); + +DECLARE_DISPATCH(elu_fn, elu_stub) +DECLARE_DISPATCH(elu_backward_fn, elu_backward_stub) +DECLARE_DISPATCH(softplus_fn, softplus_stub) +DECLARE_DISPATCH(softplus_backward_fn, softplus_backward_stub) +DECLARE_DISPATCH(log_sigmoid_cpu_fn, log_sigmoid_cpu_stub) +DECLARE_DISPATCH(activation_backward_fn, log_sigmoid_backward_stub) +DECLARE_DISPATCH(threshold_fn, threshold_stub) +DECLARE_DISPATCH(gelu_fn, GeluKernel) +DECLARE_DISPATCH(gelu_backward_fn, GeluBackwardKernel) +DECLARE_DISPATCH(hardtanh_backward_fn, hardtanh_backward_stub) +DECLARE_DISPATCH(hardsigmoid_fn, hardsigmoid_stub) +DECLARE_DISPATCH(hardsigmoid_backward_fn, hardsigmoid_backward_stub) +DECLARE_DISPATCH(hardswish_fn, hardswish_stub) +DECLARE_DISPATCH(hardswish_backward_fn, hardswish_backward_stub) +DECLARE_DISPATCH(shrink_fn, hardshrink_stub) +DECLARE_DISPATCH(softshrink_fn, softshrink_stub) +DECLARE_DISPATCH(shrink_backward_fn, shrink_backward_stub) +DECLARE_DISPATCH(leaky_relu_fn, leaky_relu_stub) +DECLARE_DISPATCH(leaky_relu_backward_fn, leaky_relu_backward_stub) +DECLARE_DISPATCH(structured_activation_fn, glu_stub) +DECLARE_DISPATCH(activation_backward_fn, glu_backward_stub) +DECLARE_DISPATCH(glu_jvp_fn, glu_jvp_stub) +DECLARE_DISPATCH(structured_activation_fn, silu_stub) +DECLARE_DISPATCH(structured_activation_backward_fn, silu_backward_stub) +DECLARE_DISPATCH(structured_activation_fn, mish_stub) +DECLARE_DISPATCH(activation_backward_fn, mish_backward_stub) +DECLARE_DISPATCH(activation_fn, prelu_stub) +DECLARE_DISPATCH(activation_backward_fn, prelu_backward_stub) + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/AdaptivePooling.h b/phivenv/Lib/site-packages/torch/include/ATen/native/AdaptivePooling.h new file mode 100644 index 0000000000000000000000000000000000000000..0fce2aa38779471b37137f9d7906dd3a0ceea361 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/AdaptivePooling.h @@ -0,0 +1,49 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace at::native { + +using adaptive_avg_pooling2d_fn = void(*)(Tensor& output, const Tensor& input, IntArrayRef output_size); +using adaptive_avg_pooling2d_backward_fn = void(*)(Tensor& grad_input, const Tensor& grad_output); +DECLARE_DISPATCH(adaptive_avg_pooling2d_fn, adaptive_avg_pool2d_kernel) +DECLARE_DISPATCH(adaptive_avg_pooling2d_backward_fn, adaptive_avg_pool2d_backward_kernel) + +using adaptive_max_pooling2d_fn = void(*)(const Tensor& output, const Tensor& indices, const Tensor& input, IntArrayRef output_size); +using adaptive_max_pooling2d_backward_fn = void(*)(const Tensor& grad_input, const Tensor& grad_output, const Tensor& indices); +DECLARE_DISPATCH(adaptive_max_pooling2d_fn, adaptive_max_pool2d_kernel) +DECLARE_DISPATCH(adaptive_max_pooling2d_backward_fn, adaptive_max_pool2d_backward_kernel) + +using adaptive_avg_pooling3d_fn = void(*)(Tensor& output, const Tensor& input, IntArrayRef output_size); +using adaptive_avg_pooling3d_backward_fn = void(*)(Tensor& grad_input, const Tensor& grad_output); +DECLARE_DISPATCH(adaptive_avg_pooling3d_fn, adaptive_avg_pool3d_kernel) +DECLARE_DISPATCH(adaptive_avg_pooling3d_backward_fn, adaptive_avg_pool3d_backward_kernel) + +using adaptive_max_pooling3d_fn = void(*)(const Tensor& output, const Tensor& indices, const Tensor& input, IntArrayRef output_size); +using adaptive_max_pooling3d_backward_fn = void(*)(const Tensor& grad_input, const Tensor& grad_output, const Tensor& indices); +DECLARE_DISPATCH(adaptive_max_pooling3d_fn, adaptive_max_pool3d_kernel) +DECLARE_DISPATCH(adaptive_max_pooling3d_backward_fn, adaptive_max_pool3d_backward_kernel) + +inline int64_t start_index(int64_t a, int64_t b, int64_t c) { + return (a / b) * c + ((a % b) * c) / b; +} + +inline int64_t end_index(int64_t a, int64_t b, int64_t c) { + return 1 + ((a + 1) * c - 1) / b; +} + +inline void adaptive_pool_empty_output_check(const Tensor& gradOutput_, const char* arg_name) { + int64_t ndim = gradOutput_.ndimension(); + for (const auto i : c10::irange(1, ndim)) { + TORCH_CHECK(gradOutput_.size(i) > 0, + arg_name, "(): Expected grad_output to have non-zero size for non-batch dimensions, " + "but grad_output has sizes ", gradOutput_.sizes(), " with dimension ", i, + " being empty"); + } +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/AmpKernels.h b/phivenv/Lib/site-packages/torch/include/ATen/native/AmpKernels.h new file mode 100644 index 0000000000000000000000000000000000000000..665a65259218c322a5b19f008a5c273bac61472c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/AmpKernels.h @@ -0,0 +1,28 @@ +#pragma once + +#include +#include + +namespace at { +class Tensor; + +namespace native { + +using _amp_foreach_non_finite_check_and_unscale_cpu__fn = void (*)( + TensorList, + Tensor&, + const Tensor&); + +using _amp_update_scale_cpu__fn = Tensor& (*)( + Tensor&, + Tensor&, + const Tensor&, + double, + double, + int64_t); + +DECLARE_DISPATCH(_amp_foreach_non_finite_check_and_unscale_cpu__fn, _amp_foreach_non_finite_check_and_unscale_cpu_stub) +DECLARE_DISPATCH(_amp_update_scale_cpu__fn, _amp_update_scale_cpu_stub) + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/BatchLinearAlgebra.h b/phivenv/Lib/site-packages/torch/include/ATen/native/BatchLinearAlgebra.h new file mode 100644 index 0000000000000000000000000000000000000000..6714ea1646a80581c9c805c685d10694f6f8641f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/BatchLinearAlgebra.h @@ -0,0 +1,321 @@ +#pragma once + +#include +#include +#include +#include + +// Forward declare TI +namespace at { +class Tensor; +struct TensorIterator; + +namespace native { +enum class TransposeType; +} + +} + +namespace at::native { + +enum class LapackLstsqDriverType : int64_t { Gels, Gelsd, Gelsy, Gelss}; + +#if AT_BUILD_WITH_LAPACK() +// Define per-batch functions to be used in the implementation of batched +// linear algebra operations + +template +void lapackCholesky(char uplo, int n, scalar_t *a, int lda, int *info); + +template +void lapackCholeskyInverse(char uplo, int n, scalar_t *a, int lda, int *info); + +template +void lapackEig(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *w, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, value_t *rwork, int *info); + +template +void lapackGeqrf(int m, int n, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info); + +template +void lapackOrgqr(int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info); + +template +void lapackOrmqr(char side, char trans, int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *c, int ldc, scalar_t *work, int lwork, int *info); + +template +void lapackSyevd(char jobz, char uplo, int n, scalar_t* a, int lda, value_t* w, scalar_t* work, int lwork, value_t* rwork, int lrwork, int* iwork, int liwork, int* info); + +template +void lapackGels(char trans, int m, int n, int nrhs, + scalar_t *a, int lda, scalar_t *b, int ldb, + scalar_t *work, int lwork, int *info); + +template +void lapackGelsd(int m, int n, int nrhs, + scalar_t *a, int lda, scalar_t *b, int ldb, + value_t *s, value_t rcond, int *rank, + scalar_t* work, int lwork, + value_t *rwork, int* iwork, int *info); + +template +void lapackGelsy(int m, int n, int nrhs, + scalar_t *a, int lda, scalar_t *b, int ldb, + int *jpvt, value_t rcond, int *rank, + scalar_t *work, int lwork, value_t* rwork, int *info); + +template +void lapackGelss(int m, int n, int nrhs, + scalar_t *a, int lda, scalar_t *b, int ldb, + value_t *s, value_t rcond, int *rank, + scalar_t *work, int lwork, + value_t *rwork, int *info); + +template +struct lapackLstsq_impl; + +template +struct lapackLstsq_impl { + static void call( + char trans, int m, int n, int nrhs, + scalar_t *a, int lda, scalar_t *b, int ldb, + scalar_t *work, int lwork, int *info, // Gels flavor + int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor + value_t *s, // Gelss flavor + int *iwork // Gelsd flavor + ) { + lapackGels( + trans, m, n, nrhs, + a, lda, b, ldb, + work, lwork, info); + } +}; + +template +struct lapackLstsq_impl { + static void call( + char trans, int m, int n, int nrhs, + scalar_t *a, int lda, scalar_t *b, int ldb, + scalar_t *work, int lwork, int *info, // Gels flavor + int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor + value_t *s, // Gelss flavor + int *iwork // Gelsd flavor + ) { + lapackGelsy( + m, n, nrhs, + a, lda, b, ldb, + jpvt, rcond, rank, + work, lwork, rwork, info); + } +}; + +template +struct lapackLstsq_impl { + static void call( + char trans, int m, int n, int nrhs, + scalar_t *a, int lda, scalar_t *b, int ldb, + scalar_t *work, int lwork, int *info, // Gels flavor + int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor + value_t *s, // Gelss flavor + int *iwork // Gelsd flavor + ) { + lapackGelsd( + m, n, nrhs, + a, lda, b, ldb, + s, rcond, rank, + work, lwork, + rwork, iwork, info); + } +}; + +template +struct lapackLstsq_impl { + static void call( + char trans, int m, int n, int nrhs, + scalar_t *a, int lda, scalar_t *b, int ldb, + scalar_t *work, int lwork, int *info, // Gels flavor + int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor + value_t *s, // Gelss flavor + int *iwork // Gelsd flavor + ) { + lapackGelss( + m, n, nrhs, + a, lda, b, ldb, + s, rcond, rank, + work, lwork, + rwork, info); + } +}; + +template +void lapackLstsq( + char trans, int m, int n, int nrhs, + scalar_t *a, int lda, scalar_t *b, int ldb, + scalar_t *work, int lwork, int *info, // Gels flavor + int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor + value_t *s, // Gelss flavor + int *iwork // Gelsd flavor + ) { + lapackLstsq_impl::call( + trans, m, n, nrhs, + a, lda, b, ldb, + work, lwork, info, + jpvt, rcond, rank, rwork, + s, + iwork); +} + +template +void lapackLuSolve(char trans, int n, int nrhs, scalar_t *a, int lda, int *ipiv, scalar_t *b, int ldb, int *info); + +template +void lapackLu(int m, int n, scalar_t *a, int lda, int *ipiv, int *info); + +template +void lapackLdlHermitian( + char uplo, + int n, + scalar_t* a, + int lda, + int* ipiv, + scalar_t* work, + int lwork, + int* info); + +template +void lapackLdlSymmetric( + char uplo, + int n, + scalar_t* a, + int lda, + int* ipiv, + scalar_t* work, + int lwork, + int* info); + +template +void lapackLdlSolveHermitian( + char uplo, + int n, + int nrhs, + scalar_t* a, + int lda, + int* ipiv, + scalar_t* b, + int ldb, + int* info); + +template +void lapackLdlSolveSymmetric( + char uplo, + int n, + int nrhs, + scalar_t* a, + int lda, + int* ipiv, + scalar_t* b, + int ldb, + int* info); + +template +void lapackSvd(char jobz, int m, int n, scalar_t *a, int lda, value_t *s, scalar_t *u, int ldu, scalar_t *vt, int ldvt, scalar_t *work, int lwork, value_t *rwork, int *iwork, int *info); +#endif + +#if AT_BUILD_WITH_BLAS() +template +void blasTriangularSolve(char side, char uplo, char trans, char diag, int n, int nrhs, scalar_t* a, int lda, scalar_t* b, int ldb); +#endif + +using cholesky_fn = void (*)(const Tensor& /*input*/, const Tensor& /*info*/, bool /*upper*/); +DECLARE_DISPATCH(cholesky_fn, cholesky_stub) + +using cholesky_inverse_fn = Tensor& (*)(Tensor& /*result*/, Tensor& /*infos*/, bool /*upper*/); + +DECLARE_DISPATCH(cholesky_inverse_fn, cholesky_inverse_stub) + +using linalg_eig_fn = void (*)(Tensor& /*eigenvalues*/, Tensor& /*eigenvectors*/, Tensor& /*infos*/, const Tensor& /*input*/, bool /*compute_eigenvectors*/); + +DECLARE_DISPATCH(linalg_eig_fn, linalg_eig_stub) + +using geqrf_fn = void (*)(const Tensor& /*input*/, const Tensor& /*tau*/); +DECLARE_DISPATCH(geqrf_fn, geqrf_stub) + +using orgqr_fn = Tensor& (*)(Tensor& /*result*/, const Tensor& /*tau*/); +DECLARE_DISPATCH(orgqr_fn, orgqr_stub) + +using ormqr_fn = void (*)(const Tensor& /*input*/, const Tensor& /*tau*/, const Tensor& /*other*/, bool /*left*/, bool /*transpose*/); +DECLARE_DISPATCH(ormqr_fn, ormqr_stub) + +using linalg_eigh_fn = void (*)( + const Tensor& /*eigenvalues*/, + const Tensor& /*eigenvectors*/, + const Tensor& /*infos*/, + bool /*upper*/, + bool /*compute_eigenvectors*/); +DECLARE_DISPATCH(linalg_eigh_fn, linalg_eigh_stub) + +using lstsq_fn = void (*)( + const Tensor& /*a*/, + Tensor& /*b*/, + Tensor& /*rank*/, + Tensor& /*singular_values*/, + Tensor& /*infos*/, + double /*rcond*/, + std::string /*driver_name*/); +DECLARE_DISPATCH(lstsq_fn, lstsq_stub) + +using triangular_solve_fn = void (*)( + const Tensor& /*A*/, + const Tensor& /*B*/, + bool /*left*/, + bool /*upper*/, + TransposeType /*transpose*/, + bool /*unitriangular*/); +DECLARE_DISPATCH(triangular_solve_fn, triangular_solve_stub) + +using lu_factor_fn = void (*)( + const Tensor& /*input*/, + const Tensor& /*pivots*/, + const Tensor& /*infos*/, + bool /*compute_pivots*/); +DECLARE_DISPATCH(lu_factor_fn, lu_factor_stub) + +using unpack_pivots_fn = void(*)( + TensorIterator& iter, + const int64_t dim_size, + const int64_t max_pivot); +DECLARE_DISPATCH(unpack_pivots_fn, unpack_pivots_stub) + +using lu_solve_fn = void (*)( + const Tensor& /*LU*/, + const Tensor& /*pivots*/, + const Tensor& /*B*/, + TransposeType /*trans*/); +DECLARE_DISPATCH(lu_solve_fn, lu_solve_stub) + +using ldl_factor_fn = void (*)( + const Tensor& /*LD*/, + const Tensor& /*pivots*/, + const Tensor& /*info*/, + bool /*upper*/, + bool /*hermitian*/); +DECLARE_DISPATCH(ldl_factor_fn, ldl_factor_stub) + +using svd_fn = void (*)( + const Tensor& /*A*/, + const bool /*full_matrices*/, + const bool /*compute_uv*/, + const std::optional& /*driver*/, + const Tensor& /*U*/, + const Tensor& /*S*/, + const Tensor& /*Vh*/, + const Tensor& /*info*/); +DECLARE_DISPATCH(svd_fn, svd_stub) + +using ldl_solve_fn = void (*)( + const Tensor& /*LD*/, + const Tensor& /*pivots*/, + const Tensor& /*result*/, + bool /*upper*/, + bool /*hermitian*/); +DECLARE_DISPATCH(ldl_solve_fn, ldl_solve_stub) +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/BinaryOps.h b/phivenv/Lib/site-packages/torch/include/ATen/native/BinaryOps.h new file mode 100644 index 0000000000000000000000000000000000000000..870c1a54d48a4d36c9fa6456c810c7eb1702b6cc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/BinaryOps.h @@ -0,0 +1,119 @@ +#pragma once + +#include +#include +#include +#include + + +namespace at { +struct TensorIterator; +struct TensorIteratorBase; +} + +namespace at::native { + +inline void alpha_check(const ScalarType dtype, const Scalar& alpha) { + TORCH_CHECK(! alpha.isBoolean() || dtype == ScalarType::Bool, + "Boolean alpha only supported for Boolean results."); + TORCH_CHECK(isFloatingType(dtype) || isComplexType(dtype) + || alpha.isIntegral(true), + "For integral input tensors, argument alpha must not be a floating point number."); + TORCH_CHECK(isComplexType(dtype) || !alpha.isComplex(), + "For non-complex input tensors, argument alpha must not be a complex number.") +} + +// Basic checking for all sub functions. +inline void sub_check(const TensorBase& self, const TensorBase& other) { + TORCH_CHECK(self.scalar_type() != kBool || other.scalar_type() != kBool, + "Subtraction, the `-` operator, with two bool tensors is not supported. " + "Use the `^` or `logical_xor()` operator instead.") + TORCH_CHECK(self.scalar_type() != kBool && other.scalar_type() != kBool, + "Subtraction, the `-` operator, with a bool tensor is not supported. " + "If you are trying to invert a mask, use the `~` or `logical_not()` operator instead."); +} + +inline void sub_check(const TensorBase& self, const Scalar& scalar) { + TORCH_CHECK(self.scalar_type() != kBool || !scalar.isBoolean(), + "Subtraction, the `-` operator, with two bool tensors is not supported. " + "Use the `^` or `logical_xor()` operator instead.") + TORCH_CHECK(self.scalar_type() != kBool && !scalar.isBoolean(), + "Subtraction, the `-` operator, with a bool tensor is not supported. " + "If you are trying to invert a mask, use the `~` or `logical_not()` operator instead."); +} + +using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha); +using structured_binary_fn_double = void(*)(TensorIteratorBase&, double); +using structured_binary_fn = void(*)(TensorIteratorBase&); + +using binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha); +using binary_fn_double = void(*)(TensorIterator&, double); +using binary_fn = void(*)(TensorIterator&); +using binary_clamp_fn_alpha = + void(*)(TensorIterator&, const Scalar& alpha, const Scalar& min_val, const Scalar& max_val); + +// NB: codegenned +DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub) + +DECLARE_DISPATCH(binary_clamp_fn_alpha, add_clamp_stub) +DECLARE_DISPATCH(structured_binary_fn_alpha, sub_stub) +DECLARE_DISPATCH(structured_binary_fn, mul_stub) +DECLARE_DISPATCH(structured_binary_fn, div_true_stub) +DECLARE_DISPATCH(structured_binary_fn, div_floor_stub) +DECLARE_DISPATCH(structured_binary_fn, div_trunc_stub) +DECLARE_DISPATCH(structured_binary_fn, atan2_stub) +DECLARE_DISPATCH(structured_binary_fn, remainder_stub) +DECLARE_DISPATCH(structured_binary_fn, bitwise_and_stub) +DECLARE_DISPATCH(structured_binary_fn, bitwise_or_stub) +DECLARE_DISPATCH(structured_binary_fn, bitwise_xor_stub) +DECLARE_DISPATCH(structured_binary_fn, lshift_stub) +DECLARE_DISPATCH(structured_binary_fn, rshift_stub) +DECLARE_DISPATCH(binary_fn, logical_xor_stub) +DECLARE_DISPATCH(binary_fn, logical_and_stub) +DECLARE_DISPATCH(binary_fn, logical_or_stub) +DECLARE_DISPATCH(structured_binary_fn, lt_stub) +DECLARE_DISPATCH(structured_binary_fn, le_stub) +DECLARE_DISPATCH(structured_binary_fn, gt_stub) +DECLARE_DISPATCH(structured_binary_fn, ge_stub) +DECLARE_DISPATCH(structured_binary_fn, eq_stub) +DECLARE_DISPATCH(structured_binary_fn, ne_stub) +DECLARE_DISPATCH(binary_fn, max_elementwise_stub) +DECLARE_DISPATCH(binary_fn, min_elementwise_stub) +DECLARE_DISPATCH(structured_binary_fn, maximum_stub) +DECLARE_DISPATCH(structured_binary_fn, minimum_stub) +DECLARE_DISPATCH(structured_binary_fn, fmax_stub) +DECLARE_DISPATCH(structured_binary_fn, fmin_stub) +DECLARE_DISPATCH(structured_binary_fn_double, smooth_l1_stub) +DECLARE_DISPATCH(binary_fn_double, huber_stub) +DECLARE_DISPATCH(structured_binary_fn, sigmoid_backward_stub) +DECLARE_DISPATCH(binary_fn_alpha, logit_backward_stub) +DECLARE_DISPATCH(structured_binary_fn, tanh_backward_stub) +DECLARE_DISPATCH(structured_binary_fn, mse_stub) +DECLARE_DISPATCH(structured_binary_fn, fmod_stub) +DECLARE_DISPATCH(structured_binary_fn, logaddexp_stub) +DECLARE_DISPATCH(structured_binary_fn, logaddexp2_stub) +DECLARE_DISPATCH(structured_binary_fn, gcd_stub) +DECLARE_DISPATCH(structured_binary_fn, lcm_stub) +DECLARE_DISPATCH(structured_binary_fn, hypot_stub) +DECLARE_DISPATCH(structured_binary_fn, igamma_stub) +DECLARE_DISPATCH(structured_binary_fn, igammac_stub) +DECLARE_DISPATCH(structured_binary_fn, nextafter_stub) +DECLARE_DISPATCH(structured_binary_fn, heaviside_stub) +DECLARE_DISPATCH(structured_binary_fn, copysign_stub) +DECLARE_DISPATCH(structured_binary_fn, xlogy_stub) +DECLARE_DISPATCH(structured_binary_fn, xlog1py_stub) +DECLARE_DISPATCH(structured_binary_fn, zeta_stub) +DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_t_stub) +DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_u_stub) +DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_v_stub) +DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_w_stub) +DECLARE_DISPATCH(structured_binary_fn, hermite_polynomial_h_stub) +DECLARE_DISPATCH(structured_binary_fn, hermite_polynomial_he_stub) +DECLARE_DISPATCH(structured_binary_fn, laguerre_polynomial_l_stub) +DECLARE_DISPATCH(structured_binary_fn, legendre_polynomial_p_stub) +DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_t_stub) +DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_u_stub) +DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_v_stub) +DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_w_stub) + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/BucketizationUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/BucketizationUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..e39476738697b363f6db3213371cc83616eab6ce --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/BucketizationUtils.h @@ -0,0 +1,173 @@ +#pragma once + +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +namespace at::native { + +// original values given by raw_*. If an original value is not contiguous, will make a contiguous copy to +// the corresponding trimmed_* value. Additionally, if the dtypes of the boundary and input tensor do not +// match, will change them to be a common super type so comparisons are done between the same types. +// For any trimmed_* tensor, if its outgoing value matches what it was incoming (typically null), then the +// corresponding raw_* version should be used since it was already contiguous of the right type. +inline void searchsorted_maybe_trim_input_tensors( + Tensor& trimmed_input, + Tensor& trimmed_boundaries, + Tensor& trimmed_sorter, + const Tensor& raw_input, + const Tensor& raw_boundaries, + const Tensor& raw_sorter) { + bool in_is_contiguous = raw_input.is_contiguous(); + bool bd_is_contiguous = raw_boundaries.is_contiguous(); + bool sort_is_contiguous = raw_sorter.is_contiguous(); + + if (!in_is_contiguous) { + TORCH_WARN_ONCE("torch.searchsorted(): input value tensor is non-contiguous, this will lower the performance due " + "to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous input value " + "tensor if possible. This message will only appear once per program."); + trimmed_input = raw_input.contiguous(); + } + if (!bd_is_contiguous) { + TORCH_WARN_ONCE("torch.searchsorted(): boundary tensor is non-contiguous, this will lower the performance due " + "to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous boundary " + "tensor if possible. This message will only appear once per program."); + trimmed_boundaries = raw_boundaries.contiguous(); + } + if (!sort_is_contiguous) { + TORCH_WARN_ONCE("torch.searchsorted(): sorter tensor is non-contiguous, this will lower the performance due " + "to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous sorter " + "tensor if possible. This message will only appear once per program."); + trimmed_sorter = raw_sorter.contiguous(); + } + if (raw_input.dtype() != raw_boundaries.dtype()) { + at::native::ResultTypeState state = {}; + state = at::native::update_result_type_state(raw_boundaries, state); + state = at::native::update_result_type_state(raw_input, state); + ScalarType common_stype = at::native::result_type(state); + + TORCH_INTERNAL_ASSERT(common_stype != ScalarType::Undefined); + if (common_stype != raw_input.scalar_type()) { + trimmed_input = in_is_contiguous ? raw_input.to(common_stype) : trimmed_input.to(common_stype); + } + if (common_stype != raw_boundaries.scalar_type()) { + trimmed_boundaries = bd_is_contiguous ? raw_boundaries.to(common_stype) : trimmed_boundaries.to(common_stype); + } + } +} + +/* unused but needed for internal jagged tensor class */ +inline void searchsorted_maybe_trim_input_tensors( + Tensor& trimmed_input, + Tensor& trimmed_boundaries, + const Tensor& raw_input, + const Tensor& raw_boundaries) { + Tensor trimmed_sorter; + Tensor raw_sorter; + return searchsorted_maybe_trim_input_tensors( + trimmed_input, + trimmed_boundaries, + trimmed_sorter, + raw_input, + raw_boundaries, + raw_sorter); +} + +inline bool searchsorted_dims_matched_before_last_dim(const Tensor& boundaries, const Tensor& input) { + if (boundaries.dim() != input.dim()) { + return false; + } + const auto& dims_bd = boundaries.sizes(); + const auto& dims_in = input.sizes(); + for (int64_t dim = 0; dim + 1 < boundaries.dim(); ++dim) { + if (dims_bd[dim] != dims_in[dim]) { + return false; + } + } + return true; +} + +inline Tensor searchsorted_scalar_tensor(const Scalar& scalar, const c10::Device& device) { + auto tensor = c10::scalar_to_tensor(scalar, device); + // This is to adopt the scalar promotion rules defined in native/TypeProperties.h + // So we have the same type promotion rules as binary operations. + tensor.unsafeGetTensorImpl()->set_wrapped_number(true); + return tensor; +} + +inline void searchsorted_pre_check( + const Tensor& boundaries, + const Tensor& input, + const Tensor& output, + const bool out_int32, + const bool right, + const std::optional side_opt, + const Tensor& sorter) { + if (side_opt) { + const std::string_view side = *side_opt; + TORCH_CHECK(side == "left" || side == "right", "torch.searchsorted(): side can only be 'left' or 'right' but ", + "got ", side); + + // assume the user has not explicitly set (right=False, side="right") + TORCH_CHECK(!right || side == "right", "torch.searchsorted(): side and right can't be set to opposites, got side " + "of ", side, " while right was True"); + } + + TORCH_CHECK(boundaries.device() == input.device(), "torch.searchsorted(): boundaries and input value tensors ", + "should have same device type, but got boundaries tensor device type ", boundaries.device(), " and input value ", + "tensor device type ", input.device()); + + if (sorter.defined()) { + TORCH_CHECK(sorter.device() == boundaries.device(), "torch.searchsorted(): sorter and boundary tensors should ", + "have same device type, but got sorter tensor device type ", sorter.device(), " and input value tensor ", + "device type ", boundaries.device()); + + TORCH_CHECK(sorter.sizes() == boundaries.sizes(), "torch.searchsorted(): boundary and sorter must have the same " + "size, but got boundary tensor ", boundaries.sizes(), "and got sorter tensor ", sorter.sizes()); + + TORCH_CHECK(sorter.scalar_type() == ScalarType::Long, "torch.searchsorted(): sorter must be a tensor of long ", + "dtype but got dtype ", sorter.scalar_type()); + + if (sorter.numel() > 0) { + auto minmax = sorter.aminmax(); + int64_t vmin = std::get<0>(minmax).item().toLong(); + int64_t vmax = std::get<1>(minmax).item().toLong(); + TORCH_CHECK(vmin >= 0 && vmax < sorter.sizes().back(), "torch.searchsorted(): sorter index out of range"); + } + } + + TORCH_CHECK(input.dim() > 0 || (input.dim() == 0 && input.numel() == 1 && boundaries.dim() == 1), + "torch.searchsorted(): input value can be a scalar only when boundaries tensor dimension is 1, but we got ", + "boundaries tensor dim(", boundaries.dim(), ") and input value's dim(", input.dim(), ") numel(", + input.numel(), ")"); + + TORCH_CHECK(boundaries.dim() != 0, "torch.searchsorted(): boundaries tensor should have positive dimension, but ", + "got 0 dimension"); + + TORCH_CHECK(boundaries.dim() == 1 || searchsorted_dims_matched_before_last_dim(boundaries, input), + "torch.searchsorted(): boundaries tensor should be 1 dimension or the first N-1 dimensions of boundaries tensor ", + "and input value tensor must match, but we got boundaries tensor ", boundaries.sizes(), " and input value tensor ", + input.sizes()); + + ScalarType output_dtype = output.scalar_type(); + TORCH_CHECK( + (output_dtype == ScalarType::Long && !out_int32) || + (output_dtype == ScalarType::Int && out_int32), + "torch.searchsorted(): output tensor's dtype is wrong, it can only be Int(int32) or Long(int64) depending on ", + "whether out_int32 flag is True, but we got output tensor's dtype ", output_dtype, + " and out_int32 flag is ", (out_int32 ? "True" : "False")); + + if (out_int32) { + TORCH_CHECK(boundaries.sizes().back() < INT_MAX, + "torch.searchsorted(): the size of boundaries' last dimension should be less than ", INT_MAX, ", but we got ", + boundaries.sizes().back()); + } +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/CPUBlas.h b/phivenv/Lib/site-packages/torch/include/ATen/native/CPUBlas.h new file mode 100644 index 0000000000000000000000000000000000000000..64bbefea8db2f5d76fb93ed515b5a1025cd63910 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/CPUBlas.h @@ -0,0 +1,304 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + + +namespace at::native::cpublas { + +namespace internal { +void normalize_last_dims( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + int64_t *lda, int64_t *ldb, int64_t *ldc); +} // namespace internal + +using gemm_fn = void(*)( + at::ScalarType type, + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + const Scalar& alpha, + const void *a, int64_t lda, + const void *b, int64_t ldb, + const Scalar& beta, + void *c, int64_t ldc); + +DECLARE_DISPATCH(gemm_fn, gemm_stub) + +using gemm_no_downcast_fn = void(*)( + at::ScalarType type, + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + const Scalar& alpha, + const void *a, int64_t lda, + const void *b, int64_t ldb, + const Scalar& beta, + void *c, int64_t ldc); + +DECLARE_DISPATCH(gemm_no_downcast_fn, gemm_no_downcast_stub) + +template +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + at::opmath_type alpha, + const scalar_t *a, int64_t lda, + const scalar_t *b, int64_t ldb, + at::opmath_type beta, + scalar_t *c, int64_t ldc) { + internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); + gemm_stub( + kCPU, c10::CppTypeToScalarType::value, + transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + double alpha, + const double *a, int64_t lda, + const double *b, int64_t ldb, + double beta, + double *c, int64_t ldc); + +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + float alpha, + const float *a, int64_t lda, + const float *b, int64_t ldb, + float beta, + float *c, int64_t ldc); + +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + float alpha, + const at::BFloat16 *a, int64_t lda, + const at::BFloat16 *b, int64_t ldb, + float beta, + at::BFloat16 *c, int64_t ldc); + +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + const float alpha, + const at::BFloat16 *a, int64_t lda, + const at::BFloat16 *b, int64_t ldb, + const float beta, + float *c, int64_t ldc); + +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + float alpha, + const at::Half *a, int64_t lda, + const at::Half *b, int64_t ldb, + float beta, + at::Half *c, int64_t ldc); + +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + const float alpha, + const at::Half *a, int64_t lda, + const at::Half *b, int64_t ldb, + const float beta, + float *c, int64_t ldc); + +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + c10::complex alpha, + const c10::complex *a, int64_t lda, + const c10::complex *b, int64_t ldb, + c10::complex beta, + c10::complex *c, int64_t ldc); + +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + c10::complex alpha, + const c10::complex *a, int64_t lda, + const c10::complex *b, int64_t ldb, + c10::complex beta, + c10::complex *c, int64_t ldc); + +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + int64_t alpha, + const int64_t *a, int64_t lda, + const int64_t *b, int64_t ldb, + int64_t beta, + int64_t *c, int64_t ldc); + +template +void gemm_batched( + TransposeType transa, TransposeType transb, + int64_t batch_size, int64_t m, int64_t n, int64_t k, + scalar_t alpha, + const scalar_t * const *a, int64_t lda, + const scalar_t * const *b, int64_t ldb, + const scalar_t beta, + scalar_t * const *c, int64_t ldc); + +template +void gemm_batched_with_stride( + TransposeType transa, TransposeType transb, + int64_t batch_size, int64_t m, int64_t n, int64_t k, + scalar_t alpha, + const scalar_t *a, int64_t lda, int64_t batch_stride_a, + const scalar_t *b, int64_t ldb, int64_t batch_stride_b, + scalar_t beta, + scalar_t *c, int64_t ldc, int64_t batch_stride_c); + +using axpy_fn = void(*)(at::ScalarType type, int64_t n, const Scalar& a, const void *x, int64_t incx, void *y, int64_t incy); + +DECLARE_DISPATCH(axpy_fn, axpy_stub) + +template +void axpy(int64_t n, scalar_t a, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy){ + if(n == 1) + { + incx = 1; + incy = 1; + } + axpy_stub( + kCPU, c10::CppTypeToScalarType::value, + n, a, x, incx, y, incy); +} + +void axpy(int64_t n, double a, const double *x, int64_t incx, double *y, int64_t incy); +void axpy(int64_t n, float a, const float *x, int64_t incx, float *y, int64_t incy); +void axpy(int64_t n, c10::complex a, const c10::complex *x, int64_t incx, c10::complex *y, int64_t incy); +void axpy(int64_t n, c10::complex a, const c10::complex *x, int64_t incx, c10::complex *y, int64_t incy); + +using copy_fn = void(*)(at::ScalarType type, int64_t n, const void *x, int64_t incx, void *y, int64_t incy); + +DECLARE_DISPATCH(copy_fn, copy_stub) + +template +void copy(int64_t n, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy) { + if(n == 1) + { + incx = 1; + incy = 1; + } + copy_stub( + kCPU, c10::CppTypeToScalarType::value, + n, x, incx, y, incy); +} + +void copy(int64_t n, const double *x, int64_t incx, double *y, int64_t incy); +void copy(int64_t n, const float *x, int64_t incx, float *y, int64_t incy); +void copy(int64_t n, const c10::complex *x, int64_t incx, c10::complex *y, int64_t incy); +void copy(int64_t n, const c10::complex *x, int64_t incx, c10::complex *y, int64_t incy); + +// Batch-reduce GEMM +// Operates by the following formula: +// C = SUM(A[i] x B[i]) + C if add_C is true, i = 0 to batch size +// A Base pointer to a tensor A. +// B Base pointer to a tensor B. +// C Pointer to a tensor C (accumulation buffer). +// Note only batch size 1 is used currently +TORCH_API void brgemm( + int64_t M, + int64_t N, + int64_t K, + int64_t ld_a, + int64_t ld_b, + int64_t ld_c, + const bool add_C, + const at::Half* A, + const at::Half* B, + float* C, + bool is_vnni = true); + +TORCH_API void brgemm( + int64_t M, + int64_t N, + int64_t K, + int64_t ld_a, + int64_t ld_b, + int64_t ld_c, + const bool add_C, + const at::BFloat16* A, + const at::BFloat16* B, + float* C, + bool is_vnni = true); + +TORCH_API void brgemm( + int64_t M, + int64_t N, + int64_t K, + int64_t ld_a, + int64_t ld_b, + int64_t ld_c, + const bool add_C, + const float* A, + const float* B, + float* C, + bool is_vnni = false); + +TORCH_API void brgemm( + int64_t M, + int64_t N, + int64_t K, + int64_t ld_a, + int64_t ld_b, + int64_t ld_c, + const bool add_C, + const unsigned char* A, + const unsigned char* B, + int32_t* C, + bool is_vnni = true); + +TORCH_API void brgemm( + int64_t M, + int64_t N, + int64_t K, + int64_t ld_a, + int64_t ld_b, + int64_t ld_c, + const bool add_C, + const unsigned char* A, + const signed char* B, + int32_t* C, + bool is_vnni = true); + +TORCH_API void brgemm( + int64_t M, + int64_t N, + int64_t K, + int64_t ld_a, + int64_t ld_b, + int64_t ld_c, + const bool add_C, + const signed char* A, + const signed char* B, + int32_t* C, + bool is_vnni = true); + +// Release brgemm hardware context +TORCH_API void brgemm_release(bool is_vnni = true); + +// Pack B matrix to get better performance if needed +TORCH_API void pack( + int64_t K, + int64_t N, + int64_t ld_in, + int64_t ld_out, + ScalarType dt_in, + ScalarType dt_out, + const void* in, + void* out); + +// Whether pack is supported in the platform. +TORCH_API bool could_pack(ScalarType dt_in); + +} // namespace at::native::cpublas diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/CPUFallback.h b/phivenv/Lib/site-packages/torch/include/ATen/native/CPUFallback.h new file mode 100644 index 0000000000000000000000000000000000000000..5bffefe5a4223ac7756914457a525d9cc67c7ba5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/CPUFallback.h @@ -0,0 +1,46 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace at::native { + +// This function implements a boxed fallback to CPU. +// External backends can add their own custom logging on top if it to customize their own CPU fallbacks. +TORCH_API void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views = false, + c10::DispatchKey cpu_dispatch_key = c10::DispatchKey::CPU); + +// This is a helper function that backends can use to directly call their boxed CPU fallback +// TODO: update and add a usage example after https://github.com/pytorch/pytorch/pull/58092 lands. +template +struct _call_fallback_fn final {}; + +template +struct _call_fallback_fn final { + static ReturnType call(typename c10::maybe_keep_symint::type... args) { + auto op = c10::Dispatcher::singleton() + // TODO: figure out how to make compiler happy without dynamic casts + .findSchemaOrThrow((const char*) Op::name, (const char*) Op::overload_name) + //.findSchemaOrThrow("a", "b") + .typed::type...)>(); + return c10::impl::BoxedKernelWrapper::type...)>::call( + c10::BoxedKernel::makeFromFunction(), + op, + c10::DispatchKeySet(), // we know that the cpu_fallback doesn't use the dispatch keyset. + // TODO: get std::forward<> to work + args... + ); + } +}; + +template +using call_fallback_fn_symint = _call_fallback_fn; + +template +using call_fallback_fn = _call_fallback_fn; + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/CanUse32BitIndexMath.h b/phivenv/Lib/site-packages/torch/include/ATen/native/CanUse32BitIndexMath.h new file mode 100644 index 0000000000000000000000000000000000000000..983ff7fe26e332a979ece32d42889081e6c56fcf --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/CanUse32BitIndexMath.h @@ -0,0 +1,13 @@ +#pragma once +#include +#include + +namespace at { +class TensorBase; +} + +namespace at::native { + +TORCH_API bool canUse32BitIndexMath(const at::TensorBase &t, int64_t max_elem=std::numeric_limits::max()); + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/ComplexHelper.h b/phivenv/Lib/site-packages/torch/include/ATen/native/ComplexHelper.h new file mode 100644 index 0000000000000000000000000000000000000000..c75664d5df43866c84bb2499e34ee838a32e3a61 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/ComplexHelper.h @@ -0,0 +1,97 @@ +#pragma once + +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include + +#include +#endif + +// WARNING: this header contains non-inline functions and should be only +// included from ONE cpp file + +namespace at::native { + +// View tensor with new dtype, storage offset, sizes and strides +inline Tensor view_tensor( + const Tensor &tensor, ScalarType dtype, + c10::SymInt offset, SymIntArrayRef sizes, SymIntArrayRef strides) { + Storage storage = tensor.storage(); + auto key_set = tensor.key_set().remove(DispatchKey::Conjugate); + auto new_tensor = detail::make_tensor( + c10::TensorImpl::VIEW, std::move(storage), key_set, scalarTypeToTypeMeta(dtype)); + auto * impl = new_tensor.unsafeGetTensorImpl(); + impl->set_sizes_and_strides(sizes, strides, offset); + return new_tensor; +} + +inline SymDimVector computeStrideForViewAsReal(SymIntArrayRef oldstride) { + SymDimVector res(oldstride.size() + 1); + for (const auto i : c10::irange(oldstride.size())) { + res[i] = oldstride[i] * 2; + } + res.back() = 1; + return res; +} + +inline Tensor _view_as_real_physical(const Tensor& self) { + TORCH_CHECK(self.is_complex(), "view_as_real is only supported for complex tensors"); + auto old_sizes = self.sym_sizes(); + SymDimVector new_sizes(old_sizes.size() + 1); + std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin()); + // last dimension will always have two elements containing the real and imag vals + new_sizes.back() = 2; + auto new_strides = computeStrideForViewAsReal(self.sym_strides()); + auto new_storage_offset = self.sym_storage_offset() * 2; + const auto float_type = c10::toRealValueType(self.scalar_type()); + auto real_tensor = view_tensor(self, float_type, std::move(new_storage_offset), new_sizes, new_strides); + return real_tensor; +} + +// expects as input a complex tensor and returns back a tensor +// with corresponding real dtype containing the complex values +// in the last two dimensions +Tensor view_as_real(const Tensor& self) { + TORCH_CHECK(!self.is_conj(), "view_as_real doesn't work on unresolved conjugated tensors. To resolve the conjugate tensor so you can view it as real, use self.resolve_conj(); however, be warned that the resulting tensor will NOT alias the original."); + return _view_as_real_physical(self); +} + +inline SymDimVector computeStrideForViewAsComplex(SymIntArrayRef oldstride) { + const auto dim = oldstride.size(); + TORCH_CHECK(dim > 0 && oldstride[dim - 1] == 1, "Tensor must have a last dimension with stride 1"); + + SymDimVector res(dim - 1); + for (const auto i : c10::irange(res.size())) { + TORCH_CHECK(oldstride[i] % 2 == 0, "Tensor must have a stride divisible by 2 for all but last dimension"); + res[i] = oldstride[i] / 2; + } + return res; +} + +// expects as input a float or double tensor with last dimension of size 2 +// and returns back a tensor with corresponding complex dtype +Tensor view_as_complex(const Tensor& self) { + TORCH_CHECK( + self.scalar_type() == kFloat || self.scalar_type() == kDouble || self.scalar_type() == kHalf, + "view_as_complex is only supported for half, float and double tensors, but got a tensor of scalar type: ", self.scalar_type()); + + auto old_sizes = self.sym_sizes(); + TORCH_CHECK(!old_sizes.empty(), "Input tensor must have one or more dimensions"); + TORCH_CHECK(old_sizes[old_sizes.size()-1] == 2, "Tensor must have a last dimension of size 2"); + SymDimVector new_sizes(old_sizes.begin(), old_sizes.end() - 1); + + const auto new_strides = computeStrideForViewAsComplex(self.sym_strides()); + const auto complex_type = c10::toComplexType(self.scalar_type()); + + TORCH_CHECK(self.sym_storage_offset() % 2 == 0, "Tensor must have a storage_offset divisible by 2"); + const auto new_storage_offset = self.sym_storage_offset() / 2; + + return view_tensor(self, complex_type, new_storage_offset, new_sizes, new_strides); +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/CompositeRandomAccessor.h b/phivenv/Lib/site-packages/torch/include/ATen/native/CompositeRandomAccessor.h new file mode 100644 index 0000000000000000000000000000000000000000..27aefd57376f4468da7f628cec608c4e4837c4b5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/CompositeRandomAccessor.h @@ -0,0 +1,34 @@ +#pragma once + +#include + +namespace at::native { + +struct TupleInfoCPU { + template + using tuple = std::tuple; + + template + static constexpr auto tie(Types&... args) noexcept { + return std::tie(args...); + } +}; + +template +using CompositeRandomAccessorCPU = + CompositeRandomAccessor; + +template +void swap( + references_holder rh1, + references_holder rh2 +) { + return std::swap(rh1.data(), rh2.data()); +} + +template +auto get(references_holder rh) -> decltype(std::get(rh.data())) { + return std::get(rh.data()); +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/CompositeRandomAccessorCommon.h b/phivenv/Lib/site-packages/torch/include/ATen/native/CompositeRandomAccessorCommon.h new file mode 100644 index 0000000000000000000000000000000000000000..5db76a15575c4542004ef6fbedbdc3fb73d1f3fb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/CompositeRandomAccessorCommon.h @@ -0,0 +1,263 @@ +#include + +#pragma once + +namespace at::native { + +namespace { + +// operator_brackets_proxy is used in +// CompositeRandomAccessor in place of operator[]. +// For some iterators, references returned by operator[] +// could become invalid, operator_brackets_proxy tries to +// resolve that by making accessor[n] to be equivalent to +// *(accessor + n). +template +class operator_brackets_proxy { + using reference = typename std::iterator_traits::reference; + using value_type = typename std::iterator_traits::value_type; + +public: + C10_HOST_DEVICE + operator_brackets_proxy(Accessor const& accessor) + : accessor(accessor) + {} + + C10_HOST_DEVICE + operator reference() { + return *accessor; + } + + C10_HOST_DEVICE + reference operator*() { + return *accessor; + } + + C10_HOST_DEVICE + operator_brackets_proxy& operator=(value_type const& val) { + *accessor = val; + return *this; + } + +private: + Accessor accessor; +}; + +} + +// references_holder is used as a surrogate for the +// references type from std::iterator_traits in CompositeRandomAccessor. +// It is assumed in CompositeRandomAccessor that +// References = tuple, +// Values = tuple by default, +// but they could be anything as long as References could be +// cast to Values. +// If you plan to use it with STL, for example, you will need to +// define 'swap` and `get`(aka std::get) methods. +template +class references_holder { +public: + using values = Values; + using references = References; + + C10_HOST_DEVICE + references_holder(references refs) + : refs{std::move(refs)} + {} + + C10_HOST_DEVICE + operator references() { + return refs; + } + + C10_HOST_DEVICE + operator values() { + return refs; + } + + C10_HOST_DEVICE + references_holder& operator=(values vals) { + refs = vals; + return *this; + } + + C10_HOST_DEVICE + references& data() { + return refs; + } + +protected: + references refs; +}; + +// CompositeRandomAccessor is essentially a simplified version of +// a random access iterator over two random access iterators. +// TupleInfo should contain a variadic type `tuple`, and a method `tie`, +// which constructs a tuple of references from a variadic list of arguments. +template +class CompositeRandomAccessor { + using self_type = CompositeRandomAccessor; + + using key_accessor_value_type = + typename std::iterator_traits::value_type; + using value_accessor_value_type = + typename std::iterator_traits::value_type; + using key_accessor_reference_type = + typename std::iterator_traits::reference; + using value_accessor_reference_type = + typename std::iterator_traits::reference; + + using composite_value_type = typename TupleInfo::template tuple< + key_accessor_value_type, + value_accessor_value_type>; + using composite_reference = typename TupleInfo::template tuple< + key_accessor_reference_type, + value_accessor_reference_type>; + +public: + using value_type = composite_value_type; + using reference = references_holder; + // Note that CompositeRandomAccessor does not hold key and values + // in a specific datastructure, which means that a pointer to a (key, value) + // is not defined. Hence we just use a pointer type of the KeyAccessor. + using pointer = typename std::iterator_traits::pointer; + using difference_type = typename std::iterator_traits::difference_type; + using iterator_category = std::random_access_iterator_tag; + + C10_HOST_DEVICE + CompositeRandomAccessor() = default; + + C10_HOST_DEVICE + CompositeRandomAccessor(KeyAccessor keys, ValueAccessor values) + : keys(keys), values(values) + {} + + // Pointer-like operations { + C10_HOST_DEVICE + reference operator*() const { + return TupleInfo::tie(*keys, *values); + } + + // operator->() is supposed to return a pointer type. + // Since CompositeRandomAccessor does not hold pointers to pairs, + // we just return a pointer to a key. + C10_HOST_DEVICE + auto* operator->() const { + return keys.operator->(); + } + + C10_HOST_DEVICE + reference operator[](difference_type idx) { + return operator_brackets_proxy( + CompositeRandomAccessor(keys + idx, values + idx) + ); + } + // } + + // Prefix/postfix increment/decrement { + C10_HOST_DEVICE + CompositeRandomAccessor& operator++() { + ++keys; + ++values; + return *this; + } + + C10_HOST_DEVICE + CompositeRandomAccessor operator++(int) { + CompositeRandomAccessor copy(*this); + ++*this; + return copy; + } + + C10_HOST_DEVICE + CompositeRandomAccessor& operator--() { + --keys; + --values; + return *this; + } + + C10_HOST_DEVICE + CompositeRandomAccessor operator--(int) { + CompositeRandomAccessor copy(*this); + --*this; + return copy; + } + // } + + // Arithmetic operations { + C10_HOST_DEVICE + CompositeRandomAccessor& operator+=(difference_type offset) { + keys += offset; + values += offset; + return *this; + } + + C10_HOST_DEVICE + CompositeRandomAccessor operator+(difference_type offset) const { + return CompositeRandomAccessor(keys + offset, values + offset); + } + + C10_HOST_DEVICE + friend CompositeRandomAccessor operator+( + difference_type offset, + const CompositeRandomAccessor& accessor + ) { + return accessor + offset; + } + + C10_HOST_DEVICE + CompositeRandomAccessor& operator-=(difference_type offset) { + keys -= offset; + values -= offset; + return *this; + } + + C10_HOST_DEVICE + CompositeRandomAccessor operator-(difference_type offset) const { + return CompositeRandomAccessor(keys - offset, values - offset); + } + + C10_HOST_DEVICE + difference_type operator-(const CompositeRandomAccessor& other) const { + return keys - other.keys; + } + // } + + // Comparison operators { + C10_HOST_DEVICE + bool operator==(const CompositeRandomAccessor& other) const { + return keys == other.keys; + } + + C10_HOST_DEVICE + bool operator!=(const CompositeRandomAccessor& other) const { + return keys != other.keys; + } + + C10_HOST_DEVICE + bool operator<(const CompositeRandomAccessor& other) const { + return keys < other.keys; + } + + C10_HOST_DEVICE + bool operator<=(const CompositeRandomAccessor& other) const { + return keys <= other.keys; + } + + C10_HOST_DEVICE + bool operator>(const CompositeRandomAccessor& other) const { + return keys > other.keys; + } + + C10_HOST_DEVICE + bool operator>=(const CompositeRandomAccessor& other) const { + return keys >= other.keys; + } + // } + +protected: + KeyAccessor keys; + ValueAccessor values; +}; + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/ConvUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/ConvUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..12ac0f531392650b29910a822a3ed10edf032161 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/ConvUtils.h @@ -0,0 +1,455 @@ +#pragma once +#include +#include +#include +#include +#include +#include + +#include + +namespace at::native { + +using conv_depthwise2d_backward_fn = std::tuple(*)( + const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, std::array); +DECLARE_DISPATCH(conv_depthwise2d_backward_fn, conv_depthwise2d_backward_stub) +using conv_depthwise3d_backward_fn = std::tuple(*)( + const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, std::array); +DECLARE_DISPATCH(conv_depthwise3d_backward_fn, conv_depthwise3d_backward_stub) +using cudnn_convolution_backward_fn = std::tuple(*)( + const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, int64_t, bool, bool, bool, std::array); +DECLARE_DISPATCH(cudnn_convolution_backward_fn, cudnn_convolution_backward_stub) +using mps_convolution_backward_fn = std::tuple(*)( + const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, int64_t, std::array); +DECLARE_DISPATCH(mps_convolution_backward_fn, mps_convolution_backward_stub) +using cudnn_convolution_transpose_backward_fn = std::tuple(*)( + const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, int64_t, bool, bool, bool, std::array); +DECLARE_DISPATCH(cudnn_convolution_transpose_backward_fn, cudnn_convolution_transpose_backward_stub) +using miopen_convolution_backward_fn = std::tuple(*)( + const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, int64_t, bool, bool, std::array); +DECLARE_DISPATCH(miopen_convolution_backward_fn, miopen_convolution_backward_stub) +using miopen_convolution_transpose_backward_fn = std::tuple(*)( + const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, int64_t, bool, bool, std::array); +DECLARE_DISPATCH(miopen_convolution_transpose_backward_fn, miopen_convolution_transpose_backward_stub) +using miopen_depthwise_convolution_backward_fn = std::tuple(*)( + const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, int64_t, bool, bool, std::array); +DECLARE_DISPATCH(miopen_depthwise_convolution_backward_fn, miopen_depthwise_convolution_backward_stub) +using mkldnn_convolution_backward_fn = std::tuple(*)( + const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, int64_t, std::array); +DECLARE_DISPATCH(mkldnn_convolution_backward_fn, mkldnn_convolution_backward_stub) +using mkldnn_convolution_transpose_fn = Tensor(*)(const Tensor&, const Tensor&, const std::optional&, + IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, int64_t); +DECLARE_DISPATCH(mkldnn_convolution_transpose_fn, mkldnn_convolution_transpose_stub) +using mkldnn_convolution_transpose_backward_fn = std::tuple(*)( + const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, int64_t, std::array); +DECLARE_DISPATCH(mkldnn_convolution_transpose_backward_fn, mkldnn_convolution_transpose_backward_stub) +using slow_conv_dilated2d_backward_fn = std::tuple(*)( + const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, std::array); +DECLARE_DISPATCH(slow_conv_dilated2d_backward_fn, slow_conv_dilated2d_backward_stub) +using slow_conv_dilated3d_backward_fn = std::tuple(*)( + const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, std::array); +DECLARE_DISPATCH(slow_conv_dilated3d_backward_fn, slow_conv_dilated3d_backward_stub) +using slow_conv_transpose2d_backward_fn = std::tuple(*)( + const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, std::array); +DECLARE_DISPATCH(slow_conv_transpose2d_backward_fn, slow_conv_transpose2d_backward_stub) +using slow_conv_transpose3d_backward_fn = std::tuple(*)( + const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, std::array); +DECLARE_DISPATCH(slow_conv_transpose3d_backward_fn, slow_conv_transpose3d_backward_stub) + +namespace { + bool is_cudnnv8_heuristic_mode_b() { + static const bool is_cudnnv8_heuristic_mode_b = c10::utils::check_env("TORCH_CUDNN_USE_HEURISTIC_MODE_B") == true; + return is_cudnnv8_heuristic_mode_b; + } +} + +inline bool cudnnv8_enabled_check_debug() { + static bool cudnnv8_flag = c10::utils::check_env("TORCH_CUDNN_V8_API_DISABLED") != true; + static bool cudnnv8_debug = c10::utils::check_env("TORCH_CUDNN_V8_API_DEBUG") == true; + static uint8_t cudnnv8_debugcount = 0; + if (cudnnv8_debug == 1 && cudnnv8_debugcount < 10) { + TORCH_WARN("TORCH_CUDNN_V8_DEBUG ON, V8 ON: ", cudnnv8_flag, " TORCH_CUDNN_USE_HEURISTIC_MODE B: ", is_cudnnv8_heuristic_mode_b()); + cudnnv8_debugcount++; + } + return cudnnv8_flag == 1; +} + +inline bool cudnnv8_use_heur_mode_b() { + return is_cudnnv8_heuristic_mode_b(); +} + +// Keep in sync with py::enum_ in Module.cpp +enum class ConvBackend { + CudaDepthwise2d, + CudaDepthwise3d, + Cudnn, + CudnnTranspose, + Empty, + Miopen, + MiopenDepthwise, + MiopenTranspose, + Mkldnn, + MkldnnTranspose, + MkldnnEmpty, + NnpackSpatial, + Overrideable, + Slow2d, + Slow3d, + SlowDilated2d, + SlowDilated3d, + SlowTranspose2d, + SlowTranspose3d, + Winograd3x3Depthwise, + Xnnpack2d, + Mps, + MpsTranspose, +}; + +// Overload for selecting the convolution backend from the full set of convolution inputs. +// This overload is exposed to python for testing, etc. +TORCH_API ConvBackend select_conv_backend( + const Tensor& input, const Tensor& weight, const std::optional& bias_opt, + SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef dilation, + bool transposed, SymIntArrayRef output_padding, c10::SymInt groups, const at::OptionalSymIntArrayRef bias_sizes_opt); + +TORCH_API at::MemoryFormat _determine_backend_memory_format(const Tensor& input, + const Tensor& weight, + const ConvBackend backend); + +// --------------------------------------------------------------------- +// +// Math +// +// --------------------------------------------------------------------- + +constexpr int input_batch_size_dim = 0; // also grad_input +constexpr int input_channels_dim = 1; +constexpr int output_batch_size_dim = 0; // also grad_output +constexpr int output_channels_dim = 1; +constexpr int weight_output_channels_dim = 0; +constexpr int weight_input_channels_dim = 1; + +// Often written as 2 + max_dim (extra dims for batch size and channels) +constexpr int max_dim = 3; + +// --------------------------------------------------------------------- +// +// Checking +// +// --------------------------------------------------------------------- + +// Used on pad, stride and dilation +static void check_args(CheckedFrom c, IntArrayRef args, size_t expected_size, const char* arg_name) +{ + TORCH_CHECK(args.size() <= expected_size, + "Too many ", arg_name, " values (", args.size(), ") supplied, expecting ", + expected_size, " (while checking arguments for ", c, ")"); + TORCH_CHECK(args.size() >= expected_size, + "Not enough ", arg_name, " values (", args.size(), ") supplied, expecting ", + expected_size, " (while checking arguments for ", c, ")"); + + auto num_negative_values = std::count_if(args.begin(), args.end(), [](int x){return x < 0;}); + if (num_negative_values > 0){ + std::stringstream ss; + ss << arg_name << " should be greater than zero but got ("; + std::copy(args.begin(), args.end() - 1, std::ostream_iterator(ss,", ")); + ss << args.back() << ")" << " (while checking arguments for " << c << ")"; + TORCH_CHECK(false, ss.str()); + } +} + + +// NOTE [ Convolution checks ] +// +// NB: For many call sites, it is not strictly necessary to check all of +// these relationships (for example, for forward convolution, we compute +// the size of output ourselves, so we don't actually need to check +// output. However, writing a single function that does everything +// means we get to reuse it for both forwards and all backwards +// variants, even when the set of "real" inputs varies. The magic of +// relational computing! +// +// (There is one downside, which is that it is slightly harder to write +// error messages which are able to distinguish between real inputs +// (which the user can change) and computed inputs (which the user can +// only indirectly affect). It would be an interesting exercise to +// come up with a general framework to handle such situations.) +inline void convolution_shape_check( + CheckedFrom c, + const TensorGeometryArg& input, const TensorGeometryArg& weight, const TensorGeometryArg& output, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups) +{ + check_args(c, padding, input->dim() - 2, "padding"); + check_args(c, stride, padding.size(), "stride"); + check_args(c, dilation, padding.size(), "dilation"); + + // Input + checkDimRange(c, input, 3, 6 /* exclusive */); + checkSize_symint(c, input, input_channels_dim, weight->size(1) * groups); + + // Weight + checkSameDim(c, input, weight); + + // TODO: check that output->size() matches output_sizes + // TODO: check that weight matches output->sizes() + checkSameDim(c, input, output); +} + +// NB: conv_output_size and conv_input_size are not bijections, +// as conv_output_size loses information; this is why conv_input_size +// takes an extra output_padding argument to resolve the ambiguity. + +template +inline std::vector _conv_output_size( + ArrayRef input_size, ArrayRef weight_size, + ArrayRef padding, ArrayRef stride, ArrayRef dilation = ArrayRef() +) { + // ASSERT(input_size.size() > 2) + // ASSERT(input_size.size() == weight_size.size()) + bool has_dilation = !dilation.empty(); + auto dim = input_size.size(); + std::vector output_size(dim); + output_size[0] = input_size[input_batch_size_dim]; + output_size[1] = weight_size[weight_output_channels_dim]; + for (const auto d : c10::irange(2, dim)) { + auto dilation_ = has_dilation ? dilation[d - 2] : 1; + auto kernel = dilation_ * (weight_size[d] - 1) + 1; + output_size[d] = (input_size[d] + (2 * padding[d - 2]) - kernel) / stride[d - 2] + 1; + } + return output_size; +} + +inline std::vector conv_output_size( + IntArrayRef input_size, IntArrayRef weight_size, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef() +) { + return _conv_output_size(input_size, weight_size, padding, stride, dilation); +} + +inline std::vector conv_output_size( + SymIntArrayRef input_size, SymIntArrayRef weight_size, + SymIntArrayRef padding, SymIntArrayRef stride, SymIntArrayRef dilation = SymIntArrayRef() +) { + return _conv_output_size(input_size, weight_size, padding, stride, dilation); +} + +template +std::vector _conv_input_size( + ArrayRef output_size, ArrayRef weight_size, + ArrayRef padding, ArrayRef output_padding, ArrayRef stride, ArrayRef dilation, T groups +) { + // ASSERT(output_size.size() > 2) + // ASSERT(output_size.size() == weight_size.size()) + auto dim = output_size.size(); + std::vector input_size(dim); + input_size[0] = output_size[output_batch_size_dim]; + input_size[1] = weight_size[weight_input_channels_dim] * groups; + for (const auto d : c10::irange(2, dim)) { + auto kernel = (weight_size[d] - 1) * dilation[d - 2] + 1; + input_size[d] = (output_size[d] - 1) * stride[d - 2] - (padding[d - 2] * 2) + + kernel + output_padding[d - 2]; + } + return input_size; +} + +inline std::vector conv_input_size( + SymIntArrayRef output_size, SymIntArrayRef weight_size, + SymIntArrayRef padding, SymIntArrayRef output_padding, SymIntArrayRef stride, SymIntArrayRef dilation, c10::SymInt groups +) { + return _conv_input_size(output_size, weight_size, padding, output_padding, stride, dilation, std::move(groups)); +} + +inline std::vector conv_input_size( + IntArrayRef output_size, IntArrayRef weight_size, + IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups +) { + return _conv_input_size(output_size, weight_size, padding, output_padding, stride, dilation, groups); +} + +template +std::vector _conv_weight_size( + ArrayRef input_size, ArrayRef output_size, + ArrayRef padding, ArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups +) { + auto dim = input_size.size(); + std::vector weight_size(dim); + weight_size[0] = output_size[1]; + weight_size[1] = input_size[1] / groups; + for (const auto d : c10::irange(2, dim)) { + auto kernel = input_size[d] - (output_size[d] - 1) * stride[d - 2] + + padding[d - 2] * 2 - output_padding[d - 2]; + weight_size[d] = (kernel - 1) / dilation[d - 2] + 1; + } + return weight_size; +} + +inline std::vector conv_weight_size( + SymIntArrayRef input_size, SymIntArrayRef output_size, + SymIntArrayRef padding, SymIntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups +) { + return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups); +} + +inline std::vector conv_weight_size( + IntArrayRef input_size, IntArrayRef output_size, + IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups +) { + return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups); +} + +inline Tensor reshape_bias(int64_t dim, const Tensor& bias) { + std::vector shape(dim, 1); + shape[1] = -1; + return bias.reshape(shape); +} + +inline at::MemoryFormat cudnn_conv_suggest_memory_format(const at::Tensor& input, const at::Tensor& weight) { + // disable NHWC for float64 input. + if (!at::detail::getCUDAHooks().compiledWithCuDNN() || + input.scalar_type() == at::kDouble || + weight.scalar_type() == at::kDouble) { + return at::MemoryFormat::Contiguous; + } + long cudnn_version = at::detail::getCUDAHooks().versionCuDNN(); + auto input_memory_format = input.suggest_memory_format(); + auto weight_memory_format = weight.suggest_memory_format(); + auto weight_ndim = weight.ndimension(); + + bool can_use_cudnn_channels_last_2d = (cudnn_version >= 7603) && (weight_ndim == 4) && ( + (input_memory_format == at::MemoryFormat::ChannelsLast) || + (weight_memory_format == at::MemoryFormat::ChannelsLast) + ); + if (can_use_cudnn_channels_last_2d) { + return at::MemoryFormat::ChannelsLast; + } + + bool can_use_cudnn_channels_last_3d = (cudnn_version >= 8005) && (weight_ndim == 5) && ( + (input_memory_format == at::MemoryFormat::ChannelsLast3d) || + (weight_memory_format == at::MemoryFormat::ChannelsLast3d) + ); + if (can_use_cudnn_channels_last_3d) { + return at::MemoryFormat::ChannelsLast3d; + } + + return at::MemoryFormat::Contiguous; +} + +// controls whether emptyCache will be called following cudnn conv benchmarking +TORCH_API void _cudnn_set_conv_benchmark_empty_cache(bool enable); +TORCH_API bool _cudnn_get_conv_benchmark_empty_cache(); + + +inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) { + + // disable NHWC for float64 input. + if (!at::detail::getCUDAHooks().compiledWithMIOpen() || + input.scalar_type() == at::kDouble || + weight.scalar_type() == at::kDouble) { + return false; + } + + bool can_use_miopen_channels_last_2d = false; + // TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen + // See #64427 + static std::optional PYTORCH_MIOPEN_SUGGEST_NHWC = c10::utils::check_env("PYTORCH_MIOPEN_SUGGEST_NHWC"); + + auto input_memory_format = input.suggest_memory_format(); + auto weight_memory_format = weight.suggest_memory_format(); + + can_use_miopen_channels_last_2d = PYTORCH_MIOPEN_SUGGEST_NHWC && *PYTORCH_MIOPEN_SUGGEST_NHWC && ( + ( (input_memory_format == at::MemoryFormat::ChannelsLast) || + (weight_memory_format == at::MemoryFormat::ChannelsLast) ) + ); + + bool can_use_miopen_channels_last_3d = false; + + return can_use_miopen_channels_last_2d || can_use_miopen_channels_last_3d; +} + +inline bool mkldnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) { + + // disable NHWC for float64 input. + if (input.scalar_type() == at::kDouble || + weight.scalar_type() == at::kDouble) { + return false; + } + + // disable NHWC for MkldnnCPU tensor. + if (input.is_mkldnn() || weight.is_mkldnn()) { + return false; + } + + auto input_memory_format = input.suggest_memory_format(); + auto weight_memory_format = weight.suggest_memory_format(); + + bool can_use_mkldnn_channels_last_2d = + (input_memory_format == at::MemoryFormat::ChannelsLast) || + (weight_memory_format == at::MemoryFormat::ChannelsLast); + + bool can_use_mkldnn_channels_last_3d = + (input_memory_format == at::MemoryFormat::ChannelsLast3d) || + (weight_memory_format == at::MemoryFormat::ChannelsLast3d); + + return can_use_mkldnn_channels_last_2d || can_use_mkldnn_channels_last_3d; +} + +inline bool thnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) { + + auto input_memory_format = input.suggest_memory_format(); + auto weight_memory_format = weight.suggest_memory_format(); + + bool can_use_thnn_channels_last_2d = input.device().is_cpu() && ( + (input_memory_format == at::MemoryFormat::ChannelsLast) || ( + weight_memory_format == at::MemoryFormat::ChannelsLast)); + + return can_use_thnn_channels_last_2d; +} + +inline bool xpu_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) { + + // check layout only for xpu tensor. + if (!input.is_xpu() || !weight.is_xpu()) { + return false; + } + if (!input.defined() || input.is_sparse()) { + // suggest channels_first + return false; + } + + auto is_channel_last = [](const at::Tensor& t) { + auto fmt = t.suggest_memory_format(); + return fmt == at::MemoryFormat::ChannelsLast || fmt == at::MemoryFormat::ChannelsLast3d; + }; + return is_channel_last(input) || is_channel_last(weight); +} + +inline bool mps_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) { + + // check layout only for mps tensor. + if (!input.is_mps() || !weight.is_mps()) { + return false; + } + if (!input.defined() || input.is_sparse()) { + // suggest channels_first + return false; + } + + auto fmt = input.suggest_memory_format(); + return fmt == at::MemoryFormat::ChannelsLast || fmt == at::MemoryFormat::ChannelsLast3d; +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/ConvolutionMM3d.h b/phivenv/Lib/site-packages/torch/include/ATen/native/ConvolutionMM3d.h new file mode 100644 index 0000000000000000000000000000000000000000..6db6f69d96a67c04ef0e689d88c1cf40392d9e18 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/ConvolutionMM3d.h @@ -0,0 +1,14 @@ +#include + +namespace at::native { + +std::tuple slow_conv3d_backward_cpu( + const Tensor& grad_output, + const Tensor& self, + const Tensor& weight, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + std::array output_mask); + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/Copy.h b/phivenv/Lib/site-packages/torch/include/ATen/native/Copy.h new file mode 100644 index 0000000000000000000000000000000000000000..2aeae0097c817cd8b642cf8da67b124b6e19a6f3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/Copy.h @@ -0,0 +1,20 @@ +#pragma once + +#include + +namespace at { + +class Tensor; +struct TensorIterator; +class TensorBase; + +namespace native { + +using copy_fn = void (*)(TensorIterator&, bool non_blocking); + +DECLARE_DISPATCH(copy_fn, copy_stub) + +TORCH_API void copy_ignoring_overlaps(const TensorBase &dst, const TensorBase &src); + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/Cross.h b/phivenv/Lib/site-packages/torch/include/ATen/native/Cross.h new file mode 100644 index 0000000000000000000000000000000000000000..7230219704b07fbd6a48f5faac7ab9e74f047720 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/Cross.h @@ -0,0 +1,14 @@ +#pragma once + +#include + +namespace at { +class Tensor; + +namespace native { + +using cross_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const int64_t d); + +DECLARE_DISPATCH(cross_fn, cross_stub) + +}} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/DilatedConvolutionUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/DilatedConvolutionUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..687700b8d2a133d4355a65495a44e2bf534ae0bd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/DilatedConvolutionUtils.h @@ -0,0 +1,229 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#define TORCH_CHECK_DIM_SIZE(T, DIM, DIM_SIZE, SIZE) \ + TORCH_CHECK( \ + T.dim() == DIM && T.size(DIM_SIZE) == SIZE, \ + "Need " #T " of dimension ", \ + DIM, \ + " and " #T ".size[", \ + DIM_SIZE, \ + "] == ", \ + SIZE, \ + " but got input to be of shape ", \ + T.sizes()) + +namespace at::native::internal { +namespace { +inline bool all_positive(IntArrayRef& arr) { + return std::all_of( + arr.begin(), arr.end(), [](int64_t item) { return item > 0; }); +} + +inline bool all_nonnegative(std::vector& arr) { + return std::all_of( + arr.begin(), arr.end(), [](int64_t item) { return item >= 0; }); +} + +} // namespace + +// calculate the rear part of output tensor sizes +template +std::vector get_output_size( + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride_size, + IntArrayRef pad_size, + IntArrayRef dilation_size) { + std::vector sizes; + for (const auto index : c10::irange(dim)) { + sizes.push_back( + div_rtn( + input.size(index + input.dim() - dim) + 2 * pad_size[index] - + (dilation_size[index] * (kernel_size[index] - 1) + 1), + stride_size[index]) + + 1); + } + return sizes; +} + +// calculate the sizes of output tensor +template +std::vector get_output_size( + const Tensor& input, + const Tensor& weight, + IntArrayRef kernel_size, + IntArrayRef stride_size, + IntArrayRef pad_size, + IntArrayRef dilation_size) { + auto output_size = get_output_size( + input, kernel_size, stride_size, pad_size, dilation_size); + output_size.insert(output_size.begin(), weight.size(0)); + if (input.dim() == dim + 2) { + output_size.insert(output_size.begin(), input.size(0)); + } + return output_size; +} +/* + slow_conv_dilated_shape_check - check user-input to dilated convolution + forward and backward functions. +*/ +template +void slow_conv_dilated_shape_check( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + const Tensor& grad_output, + IntArrayRef kernel_size, + IntArrayRef stride_size, + IntArrayRef pad_size, + IntArrayRef dilation_size) { + /* + When the following tensors are defined: + + bias, grad_weight, grad_output + + then these are assumed to be contiguous without checking + because of these tensors are made contiguous by calling + .contiguous() method or by resizing of zero-sized tensors in + forward/backward functions. + + When grad_weight is defined then it is assumed without + checking to have the same shape as weight, see backward + functions. + */ + // Check size arguments + TORCH_CHECK( + kernel_size.size() == dim, + "kernel sizes length should be ", + dim, + ", but got ", + kernel_size.size()); + TORCH_CHECK( + stride_size.size() == dim, + "strides length should be ", + dim, + ", but got ", + stride_size.size()); + TORCH_CHECK( + dilation_size.size() == dim, + "dilations length should be ", + dim, + ", but got ", + dilation_size.size()); + TORCH_CHECK( + pad_size.size() == dim, + "pads length should be ", + dim, + ", but got ", + pad_size.size()); + + TORCH_CHECK( + all_positive(kernel_size), + "kernel size should be greater than zero, but got ", + kernel_size); + TORCH_CHECK( + all_positive(stride_size), + "stride should be greater than zero, but got ", + stride_size); + TORCH_CHECK( + all_positive(dilation_size), + "dilation should be greater than zero, but got ", + dilation_size); + + // check input + TORCH_CHECK(input.defined(), "input must be defined"); + bool is_batch = input.dim() == dim + 2; + int64_t n = (is_batch ? 2 : 1); + int64_t ndim = n + dim; + if (!is_batch) { + // input dim has to be dim + 1 if not batched + TORCH_CHECK( + input.dim() == dim + 1, + "input must be 4D or 5D tensor but got ", + input.dim(), + "D tensor"); + } + + // check output sizes + auto output_size = get_output_size( + input, kernel_size, stride_size, pad_size, dilation_size); + + TORCH_CHECK( + all_nonnegative(output_size), + "calculated output size ", + output_size, + " is too small (all sizes must be non-negative)"); + + // check weight + TORCH_CHECK(weight.defined(), "weight must be defined"); + TORCH_CHECK( + weight.dim() == dim + 2, + "weight must be ", + dim + 2, + "D tensor but got ", + weight.dim(), + "D tensor dim=", + dim); + TORCH_CHECK( + weight.sizes().slice(2) == kernel_size, + "weight[2:] shape ", + weight.sizes().slice(2), + " must be equal to kernel_size ", + kernel_size); + + TORCH_CHECK_DIM_SIZE(input, input.dim(), (is_batch ? 1 : 0), weight.size(1)); + + // check bias when present + if (bias.defined()) { + TORCH_CHECK( + bias.dim() == 1, + "bias must be 1D tensor but got ", + bias.dim(), + "D tensor"); + TORCH_CHECK_DIM_SIZE(bias, 1, 0, weight.size(0)); + } + + // check grad_output when present + if (grad_output.defined()) { + TORCH_CHECK( + grad_output.dim() == ndim, + "grad_output must be ", + ndim, + "D tensor but got ", + grad_output.dim(), + "D tensor"); + if (is_batch) { + TORCH_CHECK( + grad_output.size(0) == input.size(0), + "grad_output.size(0)=", + grad_output.size(0), + " must be input.size(0)=", + input.size(0)); + } + TORCH_CHECK( + grad_output.size(n - 1) == weight.size(0), + "grad_output.size(", + n - 1, + ")=", + grad_output.size(n - 1), + " must be weight.size(0)=", + weight.size(0)); + TORCH_CHECK( + grad_output.sizes().slice(n) == output_size, + "grad_output[", + n, + ":] shape", + grad_output.sizes().slice(n), + " must be equal to output size ", + output_size); + } +} + +} // namespace at::native::internal diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/DispatchStub.h b/phivenv/Lib/site-packages/torch/include/ATen/native/DispatchStub.h new file mode 100644 index 0000000000000000000000000000000000000000..83f0883d0f59ecdd118fc77205bcd2279181c89e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/DispatchStub.h @@ -0,0 +1,495 @@ +#pragma once + +#include +#include + +#include +#include +#include + +// Implements instruction set specific function dispatch. +// +// Kernels that may make use of specialized instruction sets (e.g. AVX2) are +// compiled multiple times with different compiler flags (e.g. -mavx2). A +// DispatchStub contains a table of function pointers for a kernel. At runtime, +// the fastest available kernel is chosen based on the features reported by +// cpuinfo. +// +// Example: +// +// In native/MyKernel.h: +// using fn_type = void(*)(const Tensor& x); +// DECLARE_DISPATCH(fn_type, stub) +// +// In native/MyKernel.cpp +// DEFINE_DISPATCH(stub); +// +// In native/cpu/MyKernel.cpp: +// namespace { +// // use anonymous namespace so that different cpu versions won't conflict +// void kernel(const Tensor& x) { ... } +// } +// REGISTER_DISPATCH(stub, &kernel); +// +// To call: +// stub(kCPU, tensor); +// +// TODO: CPU instruction set selection should be folded into whatever +// the main dispatch mechanism is. +// +// Supported device types for registration: +// - CPU: Central Processing Unit +// - CUDA: NVIDIA GPUs +// - HIP: AMD GPUs +// - MPS: Apple Silicon GPUs (Metal Performance Shaders) +// - MTIA: Meta Training and Inference Devices +// - XPU: Intel GPUs +// - HPU: Reserved for HPU (Intel Gaudi) device types +// - PrivateUse1: Reserved for private/custom device types +// +// If you want to update the list of supported devices, add a new dispatch_ptr +// member in DispatchStubImpl.h and update the get_call_ptr switch. +// As well you will need to update the inlined list in 'is_device_supported` +// +// +// ignore warnings about DispatchStub::DEFAULT, AVX, AVX2 defined elsewhere +C10_CLANG_DIAGNOSTIC_PUSH() +C10_CLANG_DIAGNOSTIC_IGNORE("-Wundefined-var-template") + +namespace at::native { + +enum class CPUCapability { + DEFAULT = 0, +#if defined(HAVE_VSX_CPU_DEFINITION) + VSX = 1, +#elif defined(HAVE_ZVECTOR_CPU_DEFINITION) + ZVECTOR = 1, +#elif defined(HAVE_SVE256_CPU_DEFINITION) && defined(HAVE_ARM_BF16_CPU_DEFINITION) + SVE256 = 1, +#else + AVX2 = 1, + AVX512 = 2, +#endif + NUM_OPTIONS +}; + +// Enum for error types +enum class ErrorType { + MissingDeviceKernel, + DeviceNotSupported +}; + +// Alias for the return type using std::variant +using DispatchResult = std::variant; + +CPUCapability get_cpu_capability(); + +template +struct DispatchStub; + +/** + * The sole purpose of this class is to outline methods that don't need to be + * specialized or otherwise inlined and duplicated (by the compiler due to + * template expansion), since it causes size bloat if there are a significant + * number of specialization of the DispatchStub<> class. + */ +struct TORCH_API DispatchStubImpl { + + // The DispatchStubImpl::try_get_call_ptr() method is used to get the call + // pointer for a given device type. If the call pointer is not found, + // DispatchStubImpl::try_get_call_ptr() returns an ErrorType. + // The main difference between try_get_call_ptr() and get_call_ptr() is that + // try_get_call_ptr() will return the ErrorType and not raise an exception. + DispatchResult try_get_call_ptr( + c10::DeviceType device_type + , void *DEFAULT +#ifdef HAVE_AVX512_CPU_DEFINITION + , void *AVX512 +#endif +#ifdef HAVE_AVX2_CPU_DEFINITION + , void *AVX2 +#endif +#ifdef HAVE_VSX_CPU_DEFINITION + , void *VSX +#endif +#ifdef HAVE_ZVECTOR_CPU_DEFINITION + , void *ZVECTOR +#endif +#ifdef HAVE_SVE256_CPU_DEFINITION + , void *SVE256 +#endif + ); + + // Analogous to try_get_call_ptr(), but it will return the ErrorType and not + // raise an exception. + DispatchResult try_choose_cpu_impl( + void *DEFAULT +#ifdef HAVE_AVX512_CPU_DEFINITION + , void *AVX512 +#endif +#ifdef HAVE_AVX2_CPU_DEFINITION + , void *AVX2 +#endif +#ifdef HAVE_VSX_CPU_DEFINITION + , void *VSX +#endif +#ifdef HAVE_ZVECTOR_CPU_DEFINITION + , void *ZVECTOR +#endif +#ifdef HAVE_SVE256_CPU_DEFINITION + , void *SVE256 +#endif + ); + + + void* get_call_ptr( + c10::DeviceType device_type + , void *DEFAULT +#ifdef HAVE_AVX512_CPU_DEFINITION + , void *AVX512 +#endif +#ifdef HAVE_AVX2_CPU_DEFINITION + , void *AVX2 +#endif +#ifdef HAVE_VSX_CPU_DEFINITION + , void *VSX +#endif +#ifdef HAVE_ZVECTOR_CPU_DEFINITION + , void *ZVECTOR +#endif +#ifdef HAVE_SVE256_CPU_DEFINITION + , void *SVE256 +#endif + ); + + /** + * The CPU Dispatch actual method is chosen in decreasing order of preference by + * DispatchStubImpl::choose_cpu_impl() in case none is found by + * DispatchStubImpl::get_call_ptr() in cpu_dispatch_ptr. + */ + void* choose_cpu_impl( + void *DEFAULT +#ifdef HAVE_AVX512_CPU_DEFINITION + , void *AVX512 +#endif +#ifdef HAVE_AVX2_CPU_DEFINITION + , void *AVX2 +#endif +#ifdef HAVE_VSX_CPU_DEFINITION + , void *VSX +#endif +#ifdef HAVE_ZVECTOR_CPU_DEFINITION + , void *ZVECTOR +#endif +#ifdef HAVE_SVE256_CPU_DEFINITION + , void *SVE256 +#endif + ); + + // Fixing dispatch error in Windows debug builds. + // See https://github.com/pytorch/pytorch/issues/22681 for more details. + #if defined(_MSC_VER) && defined(_DEBUG) + std::atomic cpu_dispatch_ptr; + void* cuda_dispatch_ptr; + void* hip_dispatch_ptr; + void* mps_dispatch_ptr; + void* mtia_dispatch_ptr; + #if defined(USE_XPU) + void* xpu_dispatch_ptr; + #endif + void* hpu_dispatch_ptr; + void* privateuse1_dispatch_ptr; + #else + std::atomic cpu_dispatch_ptr{nullptr}; + void* cuda_dispatch_ptr = nullptr; + void* hip_dispatch_ptr = nullptr; + void* mps_dispatch_ptr = nullptr; + void* mtia_dispatch_ptr = nullptr; + #if defined(USE_XPU) + void* xpu_dispatch_ptr = nullptr; + #endif + void* hpu_dispatch_ptr = nullptr; + void* privateuse1_dispatch_ptr = nullptr; + #endif +}; + +template +struct DispatchStub { + using FnPtr = rT (*) (Args...); + + DispatchStub() = default; + DispatchStub(const DispatchStub&) = delete; + DispatchStub& operator=(const DispatchStub&) = delete; + +private: + FnPtr get_call_ptr(const c10::DeviceType device_type) { + return reinterpret_cast( + impl.get_call_ptr(device_type + , reinterpret_cast(DEFAULT) +#ifdef HAVE_AVX512_CPU_DEFINITION + , reinterpret_cast(AVX512) +#endif +#ifdef HAVE_AVX2_CPU_DEFINITION + , reinterpret_cast(AVX2) +#endif +#ifdef HAVE_VSX_CPU_DEFINITION + , reinterpret_cast(VSX) +#endif +#ifdef HAVE_ZVECTOR_CPU_DEFINITION + , reinterpret_cast(ZVECTOR) +#endif +#ifdef HAVE_SVE256_CPU_DEFINITION + , reinterpret_cast(SVE256) +#endif + ) + ); + } + +public: + template + rT operator()(c10::DeviceType device_type, ArgTypes&&... args) { + FnPtr call_ptr = get_call_ptr(device_type); + return (*call_ptr)(std::forward(args)...); + } + + void set_cuda_dispatch_ptr(FnPtr fn_ptr) { + impl.cuda_dispatch_ptr = reinterpret_cast(fn_ptr); + } + + #if defined(USE_XPU) + void set_xpu_dispatch_ptr(FnPtr fn_ptr){ + impl.xpu_dispatch_ptr = reinterpret_cast(fn_ptr); + } + #endif + + void set_hpu_dispatch_ptr(FnPtr fn_ptr) { + impl.hpu_dispatch_ptr = reinterpret_cast(fn_ptr); + } + + void set_hip_dispatch_ptr(FnPtr fn_ptr) { + impl.hip_dispatch_ptr = reinterpret_cast(fn_ptr); + } + + void set_mps_dispatch_ptr(FnPtr fn_ptr) { + impl.mps_dispatch_ptr = reinterpret_cast(fn_ptr); + } + + void set_mtia_dispatch_ptr(FnPtr fn_ptr) { + impl.mtia_dispatch_ptr = reinterpret_cast(fn_ptr); + } + + void set_privateuse1_dispatch_ptr(FnPtr fn_ptr) { + impl.privateuse1_dispatch_ptr = reinterpret_cast(fn_ptr); + } + + // Returns true if the dispatcher has a kernel registered for this device + // type. + bool is_device_supported(const c10::DeviceType device_type) { + auto result = impl.try_get_call_ptr(device_type + , reinterpret_cast(DEFAULT) +#ifdef HAVE_AVX512_CPU_DEFINITION + , reinterpret_cast(AVX512) +#endif +#ifdef HAVE_AVX2_CPU_DEFINITION + , reinterpret_cast(AVX2) +#endif +#ifdef HAVE_VSX_CPU_DEFINITION + , reinterpret_cast(VSX) +#endif +#ifdef HAVE_ZVECTOR_CPU_DEFINITION + , reinterpret_cast(ZVECTOR) +#endif +#ifdef HAVE_SVE256_CPU_DEFINITION + , reinterpret_cast(SVE256) +#endif + ); + if (std::holds_alternative(result)){ + return false; + } + return true; + } + + static TORCH_API FnPtr DEFAULT; +#ifdef HAVE_AVX512_CPU_DEFINITION + static TORCH_API FnPtr AVX512; +#endif +#ifdef HAVE_AVX2_CPU_DEFINITION + static TORCH_API FnPtr AVX2; +#endif +#ifdef HAVE_VSX_CPU_DEFINITION + static TORCH_API FnPtr VSX; +#endif +#ifdef HAVE_ZVECTOR_CPU_DEFINITION + static TORCH_API FnPtr ZVECTOR; +#endif +#ifdef HAVE_SVE256_CPU_DEFINITION + static TORCH_API FnPtr SVE256; +#endif +private: + DispatchStubImpl impl; +}; + +namespace { +template +struct RegisterCUDADispatch { + RegisterCUDADispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) { + stub.set_cuda_dispatch_ptr(value); + } +}; + +template +struct RegisterXPUDispatch { + RegisterXPUDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value){ + stub.set_xpu_dispatch_ptr(value); + } +}; + +template +struct RegisterHPUDispatch { + RegisterHPUDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value){ + stub.set_hpu_dispatch_ptr(value); + } +}; + +template +struct RegisterMPSDispatch { + RegisterMPSDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) { + stub.set_mps_dispatch_ptr(value); + } +}; + +template +struct RegisterHIPDispatch { + RegisterHIPDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) { + // TODO: make this point at hip_dispatch_ptr + stub.set_cuda_dispatch_ptr(value); + } +}; + +template +struct RegisterMTIADispatch { + RegisterMTIADispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) { + stub.set_mtia_dispatch_ptr(value); + } +}; + +template +struct RegisterPRIVATEUSE1Dispatch { + RegisterPRIVATEUSE1Dispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) { + stub.set_privateuse1_dispatch_ptr(value); + } +}; + +} // anonymous namespace +// Compiler will complain if you put things like std::tuple in +// the `fn` argument of DECLARE_DISPATCH. Some possible workarounds, e.g., +// adding parentheses and using helper struct to get rid of the parentheses, do +// not work with MSVC. So do a `using`-declaration if you need to pass in such +// `fn`, e.g., grid_sampler_2d_backward_cpu_kernel in GridSampleKernel.h. +#define DECLARE_DISPATCH(fn, name) \ + struct name##_DECLARE_DISPATCH_type : DispatchStub { \ + name##_DECLARE_DISPATCH_type() = default; \ + name##_DECLARE_DISPATCH_type(const name##_DECLARE_DISPATCH_type&) = delete; \ + name##_DECLARE_DISPATCH_type& operator=(const name##_DECLARE_DISPATCH_type&) = delete; \ + name##_DECLARE_DISPATCH_type(name##_DECLARE_DISPATCH_type&&) = delete; \ + name##_DECLARE_DISPATCH_type& operator=(name##_DECLARE_DISPATCH_type&&) = delete; \ + ~name##_DECLARE_DISPATCH_type() = default; \ + }; \ + extern TORCH_API struct name##_DECLARE_DISPATCH_type name; + +#define DEFINE_DISPATCH(name) struct name##_DECLARE_DISPATCH_type name + +#define REGISTER_ARCH_DISPATCH(name, arch, fn) \ + template <> name##_DECLARE_DISPATCH_type::FnPtr TORCH_API DispatchStub::arch = fn; + +#ifdef HAVE_AVX512_CPU_DEFINITION +#define REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX512, fn) +#else +#define REGISTER_AVX512_DISPATCH(name, fn) +#endif + +#ifdef HAVE_AVX2_CPU_DEFINITION +#define REGISTER_AVX2_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX2, fn) +#else +#define REGISTER_AVX2_DISPATCH(name, fn) +#endif + +#ifdef HAVE_VSX_CPU_DEFINITION +#define REGISTER_VSX_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, VSX, fn) +#else +#define REGISTER_VSX_DISPATCH(name, fn) +#endif + +#ifdef HAVE_ZVECTOR_CPU_DEFINITION +#define REGISTER_ZVECTOR_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, ZVECTOR, fn) +#else +#define REGISTER_ZVECTOR_DISPATCH(name, fn) +#endif + +#ifdef HAVE_SVE256_CPU_DEFINITION +#define REGISTER_SVE256_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, SVE256, fn) +#else +#define REGISTER_SVE256_DISPATCH(name, fn) +#endif + +// Macro to register the same kernel for all CPU arch types. This is useful +// if a kernel does not benefit from being recompiled across different arch types. +#define REGISTER_ALL_CPU_DISPATCH(name, fn) \ + REGISTER_ARCH_DISPATCH(name, DEFAULT, fn) \ + REGISTER_AVX512_DISPATCH(name, fn) \ + REGISTER_AVX2_DISPATCH(name, fn) \ + REGISTER_VSX_DISPATCH(name, fn) \ + REGISTER_ZVECTOR_DISPATCH(name, fn) \ + REGISTER_SVE256_DISPATCH(name, fn) + +#define REGISTER_NO_CPU_DISPATCH(name) \ + REGISTER_ALL_CPU_DISPATCH(name, nullptr) + +#define REGISTER_CUDA_DISPATCH(name, fn) \ + static RegisterCUDADispatch name ## __register(name, fn); + +#define REGISTER_XPU_DISPATCH(name, fn) \ + static RegisterXPUDispatch name ## __register(name, fn); + +#define REGISTER_HPU_DISPATCH(name, fn) \ + static RegisterHPUDispatch name ## __register(name, fn); + +#define REGISTER_HIP_DISPATCH(name, fn) \ + static RegisterHIPDispatch name ## __register(name, fn); + +#define REGISTER_MPS_DISPATCH(name, fn) \ + static RegisterMPSDispatch name ## __register(name, fn); + +#define REGISTER_MTIA_DISPATCH(name, fn) \ + static RegisterMTIADispatch name ## __register(name, fn); + +#define REGISTER_PRIVATEUSE1_DISPATCH(name, fn) \ + static RegisterPRIVATEUSE1Dispatch name ## __register(name, fn); + +// NB: This macro must be used in an actual 'cu' file; if you try using +// it from a 'cpp' file it will not work! +#if defined(__CUDACC__) +#define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn) +#elif defined(__HIPCC__) +// TODO: cut this over to HIP dispatch once we stop pretending that CUDA +// is HIP in the PyTorch HIPify build. +#define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn) +// #define REGISTER_DISPATCH(name, fn) REGISTER_HIP_DISPATCH(name, fn) +#elif defined(__OBJC__) && defined(USE_MPS) +// NB: this macro must be used from a 'mm' file in order to dispatch a MPS kernel +#define REGISTER_DISPATCH(name, fn) REGISTER_MPS_DISPATCH(name, fn) +#elif defined(CPU_CAPABILITY) +// REGISTER_DISPATCH now dispatches an AVX512 kernel to nullptr but registers other dispatches. +// ALSO_REGISTER_AVX512_DISPATCH should be used for ensuring AVX512 dispatch, among others. +// ALSO_REGISTER_SVE256_DISPATCH should be used for ensuring SVE256 dispatch, among others. +#ifdef CPU_CAPABILITY_AVX512 +#define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, ((void*)(fn) ? nullptr : nullptr)) +#else +#define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn) +#endif +#define ALSO_REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn) +#define ALSO_REGISTER_SVE256_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn) +#endif +} // namespace at::native + +C10_CLANG_DIAGNOSTIC_POP() diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/Distance.h b/phivenv/Lib/site-packages/torch/include/ATen/native/Distance.h new file mode 100644 index 0000000000000000000000000000000000000000..e9532dcaca0f8afdfcf5a5e2c181841a1dce9169 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/Distance.h @@ -0,0 +1,20 @@ +#pragma once + +#include + +namespace at { +class Tensor; + +namespace native { + +using pdist_forward_fn = void(*)(Tensor&, const Tensor&, const double p); +using pdist_backward_fn = void(*)(Tensor&, const Tensor&, const Tensor&, const double p, const Tensor&); +using cdist_fn = void(*)(Tensor&, const Tensor&, const Tensor&, const double p); +using cdist_backward_fn = void(*)(Tensor&, const Tensor&, const Tensor&, const Tensor&, const double p, const Tensor&); + +DECLARE_DISPATCH(pdist_forward_fn, pdist_forward_stub) +DECLARE_DISPATCH(pdist_backward_fn, pdist_backward_stub) +DECLARE_DISPATCH(cdist_fn, cdist_stub) +DECLARE_DISPATCH(cdist_backward_fn, cdist_backward_stub) + +}} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/DistributionTemplates.h b/phivenv/Lib/site-packages/torch/include/ATen/native/DistributionTemplates.h new file mode 100644 index 0000000000000000000000000000000000000000..6bef490a3a72a55c4f5a90ce751b05e5b3c472d0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/DistributionTemplates.h @@ -0,0 +1,394 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#include +#include +#endif + +namespace at::native::templates { + +// ==================================================== Random ======================================================== + +// The purpose of `update_from` and `update_to` is to find the closest valid int64_t number that can be used as actual `from`. +// The current implementation of `random_` uses uint64_t arithmetics and casts the result to the target dtype(scalar_t). +// This casting can result in generating numbers that happen to be greater or equal to `to` value. For instance: +// +// auto actual = torch::empty({3, 3}, torch::half); +// actual.random_(0, 65504); +// +// If random's uint64_t arithmetics produces 65503 as a random value after casting to torch::half it becomes 65504 +// and violates the requirement that random value must be less than `to`. To resolve this issue `update_from` and `update_to` +// moves `from` to the right and `to` to the left to the next closest value that won't go outside [from, to) after casting to +// the target dtype. For `to` = 65504 it moves left for (1 << (log2(to) - 11 + 1)) = 32 and becomes 65472, which is previous +// available number for torch::half dtype. +template +int64_t update_from(int64_t from) { + static_assert( + std::is_floating_point_v || + std::is_same_v || + std::is_same_v, "scalar_t must be floating-point type"); + const auto from_plus_1 = static_cast(static_cast(from + 1)); + if (from_plus_1 < from) { + int64_t from_ = std::abs(from + 1); + int n = 0; + while (from_ >>= 1) ++n; + // NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult) + from = from_plus_1 + (1LL << (n - std::numeric_limits::digits + 1)); + } + return from; +} + +template +int64_t update_to(int64_t to) { + static_assert( + std::is_floating_point_v || + std::is_same_v || + std::is_same_v, "scalar_t must be floating-point type"); + const auto to_minus_1 = static_cast(static_cast(to - 1)); + if (to_minus_1 >= to) { + int64_t to_ = std::abs(to - 1); + int n = 0; + while (to_ >>= 1) ++n; + // NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult) + to = to_minus_1 - (1LL << (n - std::numeric_limits::digits + 1)); + } + return to; +} + +// Return earlier for not invoking kernel. +// See https://github.com/pytorch/pytorch/issues/103418 for more details +#define CHECK_EMPTY_AND_RETURN(tensor) \ + if (tensor.numel() == 0) { \ + return tensor; \ + } + +template class random_kernel, typename RNG> +at::Tensor& random_impl(at::Tensor& self, std::optional generator) { + CHECK_EMPTY_AND_RETURN(self); + auto iter = at::TensorIterator::borrowing_nullary_op(self); + random_kernel()(iter, generator); + return self; +} + +#define CHECK_OUT_OF_BOUNDS(var, name, min, max, dtype) \ + TORCH_CHECK(var >= min && var <= max, name , " is out of bounds for ", dtype); \ + +#define WARN_OUT_OF_BOUNDS(var, name, digits, dtype) \ + if (var < -(1LL << digits) || var > (1LL << digits)) { \ + TORCH_WARN(name , " is out of bounds [-(2^", digits, "), 2^", digits, "]. ", \ + "Due to precision limitations ", dtype, " can support discrete uniform distribution only within this range. ", \ + "This warning will become an error in version 1.7 release, please fix the code in advance"); \ + } + +inline void check_from_to_in_range(int64_t from, int64_t to_inc, caffe2::TypeMeta dtype) { + const auto scalar_type = typeMetaToScalarType(dtype); + if (isFloatingType(scalar_type)) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "check_random_fp_bounds", [&] { + const auto min = static_cast(std::numeric_limits::lowest()); + const auto max = static_cast(std::numeric_limits::max()); + CHECK_OUT_OF_BOUNDS(from, "from", min, max, dtype); + CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", min, max, dtype); + + constexpr auto digits = std::numeric_limits::digits; + WARN_OUT_OF_BOUNDS(from, "from", digits, dtype); + WARN_OUT_OF_BOUNDS(to_inc, "to - 1", digits, dtype); + }); + } else if (scalar_type == kUInt64) { + // When you do a comparison between int64_t and uint64_t, the usual + // arithmetic conversions say that the int64_t value is promoted to + // unsigned. But this conversion wraps around: if I had -1 as my int64_t, + // then it will promote to 0xFFFFFFFFFFFFFFFF in uint64_t. This is never + // the right thing to do. + CHECK_OUT_OF_BOUNDS(from, "from", 0, INT64_MAX, dtype); + CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", 0, INT64_MAX, dtype); + } else if (isIntegralType(scalar_type, /*includeBool=*/true)) { + AT_DISPATCH_V2(scalar_type, "check_random_integral_bounds", AT_WRAP([&]() { + const auto min = static_cast(std::numeric_limits::lowest()); + const auto max = static_cast(std::numeric_limits::max()); + CHECK_OUT_OF_BOUNDS(from, "from", min, max, dtype); + CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", min, max, dtype); + }), AT_EXPAND(AT_INTEGRAL_TYPES), kUInt16, kUInt32, kBool); + } else { + TORCH_CHECK(false, "check_random_bounds handles only integral, floating-point and boolean types"); + } +} + +template class random_from_to_kernel, typename RNG> +at::Tensor& random_from_to_impl(at::Tensor& self, int64_t from, std::optional to_opt, std::optional generator) { + uint64_t range = 0; + auto iter = at::TensorIterator::borrowing_nullary_op(self); + if (to_opt.has_value()) { + // [from, to) + int64_t to = *to_opt; + TORCH_CHECK(from < to, "random_ expects 'from' to be less than 'to', but got from=", from, " >= to=", to); + if (isFloatingType(iter.dtype())) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "random_update_from_to", [&] { + from = update_from(from); + to = update_to(to); + TORCH_CHECK(from < to, "random_ expects 'from' casted to dtype to be less than 'to' casted to dtype, but got from=", from, " >= to=", to); + }); + } + check_from_to_in_range(from, to - 1, self.dtype()); + CHECK_EMPTY_AND_RETURN(self); + range = static_cast(to) - static_cast(from); + random_from_to_kernel()(iter, range, from, generator); + } else if (from != std::numeric_limits::lowest()) { + // [from, std::numeric_limits::max()] + int64_t to_inc = 0; + if (isFloatingType(iter.dtype())) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "random_from_to_range_calc", [&] { + constexpr int64_t scalar_t_max = static_cast(1) << std::numeric_limits::digits; + to_inc = scalar_t_max > std::numeric_limits::max() ? std::numeric_limits::max() : static_cast(scalar_t_max); + from = update_from(from); + TORCH_CHECK(from < to_inc, "random_ expects 'from' casted to dtype to be less than or equal to 'to_inc' casted to dtype, but got from=", from, " > to_inc=", to_inc); + }); + } else if (isIntegralType(iter.dtype(), /*includeBool=*/true)) { + AT_DISPATCH_V2(self.scalar_type(), "random_from_to_range_calc", AT_WRAP([&] { + if constexpr (std::is_same_v) { + to_inc = static_cast(true); + } else { + to_inc = static_cast(std::numeric_limits::max()); + } + }), AT_EXPAND(AT_INTEGRAL_TYPES_V2), kBool); + } else { + TORCH_CHECK(false, "random_from_to_impl handles only integral, floating-point and boolean types"); + } + check_from_to_in_range(from, to_inc, self.dtype()); + CHECK_EMPTY_AND_RETURN(self); + range = static_cast(to_inc) - static_cast(from) + 1; + random_from_to_kernel()(iter, range, from, generator); + } else { + // [std::numeric_limits::lowest(), std::numeric_limits::max()] + // range = 2^64 + CHECK_EMPTY_AND_RETURN(self); + random_from_to_kernel()(iter, generator); + } + return self; +} + +// ==================================================== Normal ======================================================== + +#define CHECK_NORMAL_TENSOR_STD(std) \ + do { \ + TORCH_CHECK( \ + !std.is_complex(), \ + "normal expects standard deviation to be non-complex"); \ + TORCH_CHECK( \ + std.numel() == 0 || std.is_meta() || std.min().ge(0).item(), \ + "normal expects all elements of std >= 0.0"); \ + } while (0) + +#define CHECK_NORMAL_STD(std) \ + TORCH_CHECK(std >= 0.0, "normal expects std >= 0.0, but found std ", std); + +template class normal_kernel, typename RNG> +Tensor& normal_impl_(Tensor& self, double mean, double std, std::optional gen) { + CHECK_NORMAL_STD(std); + CHECK_EMPTY_AND_RETURN(self); + + if (self.is_complex()) { + auto float_tensor = at::view_as_real(self); + // variance for normal distribution of the real and imaginary values + // is half of the input variance + normal_kernel()(float_tensor, mean, std/(std::sqrt(2)), gen); + } else { + normal_kernel()(self, mean, std, gen); + } + return self; +} + +template class normal_kernel, typename RNG> +Tensor& normal_out_impl(Tensor& output, const Tensor& mean, double std, std::optional gen) { + CHECK_NORMAL_STD(std); + auto std_tensor = at::empty_like(output, MemoryFormat::Contiguous); + auto shape = at::infer_size(mean.sizes(), std_tensor.sizes()); + at::native::resize_output(output, shape); + normal_impl_(output, 0, std, gen); + output.add_(mean); + return output; +} + +template class normal_kernel, typename RNG> +Tensor& normal_out_impl(Tensor& output, double mean, const Tensor& std, std::optional gen) { + CHECK_NORMAL_TENSOR_STD(std); + auto mean_tensor = at::full({}, mean, output.options()); + auto shape = at::infer_size(mean_tensor.sizes(), std.sizes()); + at::native::resize_output(output, shape); + normal_impl_(output, 0, 1, gen); + // CUDA NB: addcmul_out copies the tensor to be added into the output. + // The previous function here was addcmul_out(output, mean_tensor, output, std, 1); + // The third argument is not a constant reference and hence the samples in output are overwritten. + // Consequently, the computation performed is mean_tensor + mean_tensor * std instead of mean_tensor + output * std + output.mul_(std).add_(mean_tensor); + return output; +} + +template class normal_kernel, typename RNG> +Tensor& normal_out_impl(Tensor& output, const Tensor& mean, const Tensor& std, std::optional gen) { + CHECK_NORMAL_TENSOR_STD(std); + auto shape = at::infer_size(mean.sizes(), std.sizes()); + at::native::resize_output(output, shape); + normal_impl_(output, 0, 1, gen); + // CUDA NB: addcmul_out copies the tensor to be added into the output. + // The previous function here was addcmul_out(output, mean, output, std, 1); + // The third argument is not a constant reference and hence the samples in output are overwritten. + // Consequently, the computation performed is mean + mean * std instead of mean + output * std + output.mul_(std).add_(mean); + return output; +} + +template class normal_kernel, typename RNG> +Tensor normal_impl(const Tensor& mean, double std, std::optional gen) { + CHECK_NORMAL_STD(std); + Tensor ret = at::empty_like(mean, MemoryFormat::Contiguous); + normal_out_impl(ret, mean, std, gen); + return ret; +} + +template class normal_kernel, typename RNG> +Tensor normal_impl(double mean, const Tensor& std, std::optional gen) { + CHECK_NORMAL_TENSOR_STD(std); + Tensor ret = at::empty_like(std, MemoryFormat::Contiguous); + normal_out_impl(ret, mean, std, gen); + return ret; +} + +template class normal_kernel, typename RNG> +Tensor normal_impl(const Tensor& mean, const Tensor& std, std::optional gen) { + CHECK_NORMAL_TENSOR_STD(std); + auto shape = at::infer_size(mean.sizes(), std.sizes()); + Tensor ret = at::empty(shape, mean.options(), MemoryFormat::Contiguous); + normal_out_impl(ret, mean, std, gen); + return ret; +} + +// ==================================================== Uniform ======================================================= + +template class uniform_kernel, typename RNG> +at::Tensor& uniform_impl_(at::Tensor& self, double from, double to, std::optional generator) { + if (self.is_complex()) { + CHECK_EMPTY_AND_RETURN(self); + auto float_tensor = at::view_as_real(self); + uniform_impl_(float_tensor, from, to, generator); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "check_uniform_bounds", [&] { + [[maybe_unused]] const auto dtype = self.dtype(); + const auto min = static_cast(std::numeric_limits::lowest()); + const auto max = static_cast(std::numeric_limits::max()); + CHECK_OUT_OF_BOUNDS(from, "from", min, max, dtype); + CHECK_OUT_OF_BOUNDS(to, "to", min, max, dtype); + TORCH_CHECK(from <= to, "uniform_ expects to return a [from, to) range, but found from=", from, " > to=", to); + TORCH_CHECK((to - from) <= std::numeric_limits::max(), + "uniform_ expects to-from <= std::numeric_limits<", toString(self.scalar_type()), + ">::max(), but found to=", to, " and from=", from, + " which result in to-from to exceed the limit"); + from = std::min(std::max(from, min), max); + to = std::max(std::min(to, max), min); + }); + CHECK_EMPTY_AND_RETURN(self); + auto iter = at::TensorIterator::borrowing_nullary_op(self); + uniform_kernel()(iter, from, to, generator); + } + return self; +} + +// ================================================== LogNormal ======================================================= + +template class log_normal_kernel, typename RNG> +at::Tensor& log_normal_impl_(at::Tensor& self, double mean, double std, std::optional gen) { + TORCH_CHECK(std > 0.0, "log_normal_ expects std > 0.0, but found std=", std); + CHECK_EMPTY_AND_RETURN(self); + auto iter = TensorIterator::borrowing_nullary_op(self); + log_normal_kernel()(iter, mean, std, gen); + return self; +} + +// =================================================== Geometric ====================================================== + +template class geometric_kernel, typename RNG> +Tensor& geometric_impl_(Tensor& self, double p, std::optional gen) { + TORCH_CHECK(0 < p && p < 1, "geometric_ expects p to be in (0, 1), but got p=", p); + CHECK_EMPTY_AND_RETURN(self); + auto iter = TensorIterator::borrowing_nullary_op(self); + geometric_kernel()(iter, p, gen); + return self; +} + +// ================================================== Exponential ===================================================== + +template class exponential_kernel, typename RNG> +Tensor& exponential_impl_(Tensor& self, double lambda, std::optional gen) { + TORCH_CHECK(lambda > 0.0, "exponential_ expects lambda > 0.0, but found lambda=", lambda); + CHECK_EMPTY_AND_RETURN(self); + auto iter = TensorIterator::borrowing_nullary_op(self); + exponential_kernel()(iter, lambda, gen); + return self; +} + +// ==================================================== Cauchy ======================================================== + +template class cauchy_kernel, typename RNG> +Tensor& cauchy_impl_(Tensor& self, double median, double sigma, std::optional gen) { + // TODO: instead of variable name 'sigma', use 'gamma' or 'scale' + // the variance, squared sigma, is undefined for cauchy distribution + TORCH_CHECK(sigma > 0.0, "cauchy_ expects sigma > 0.0, but found sigma=", sigma); + TORCH_CHECK(at::isFloatingType(self.scalar_type()), "Cauchy distribution is a continuous probability distribution. dtype must be a floating point but you specified ", self.dtype()); + CHECK_EMPTY_AND_RETURN(self); + auto iter = TensorIterator::borrowing_nullary_op(self); + cauchy_kernel()(iter, median, sigma, gen); + return self; +} + +// ==================================================== Bernoulli ===================================================== + +template class bernoulli_tensor_kernel, typename RNG> +Tensor& bernoulli_impl_(Tensor& self, const Tensor& p_, std::optional gen) { + CHECK_EMPTY_AND_RETURN(self); + NoNamesGuard guard; + at::assert_no_internal_overlap(self); + bernoulli_tensor_kernel()(self, p_, gen); + return self; +} + +template class bernoulli_scalar_kernel, typename RNG> +Tensor& bernoulli_impl_(Tensor& self, double p, std::optional gen) { + TORCH_CHECK(0 <= p && p <= 1, "bernoulli_ expects p to be in [0, 1], but got p=", p); + CHECK_EMPTY_AND_RETURN(self); + at::assert_no_internal_overlap(self); + bernoulli_scalar_kernel()(self, p, gen); + return self; +} + +template class bernoulli_tensor_kernel, typename RNG> +Tensor& bernoulli_out_impl(Tensor& result, const Tensor& self, std::optional gen) { + // result.resize_as_(self) requires self to have same dtype as result, so we + // use resize_ instead. + // TODO: Fix resize_as_. See pytorch/pytorch#11665. + result.resize_(self.sizes()); + bernoulli_impl_(result, self, gen); + namedinference::propagate_names(result, self); + return result; +} + +#undef CHECK_OUT_OF_BOUNDS +#undef WARN_OUT_OF_BOUNDS + +} // namespace at::native::templates diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/Distributions.h b/phivenv/Lib/site-packages/torch/include/ATen/native/Distributions.h new file mode 100644 index 0000000000000000000000000000000000000000..e115dfdbeccbb84436e6177cb78cb63c359b0df4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/Distributions.h @@ -0,0 +1,518 @@ +#pragma once + +#include +#include +#include + +// ROCM hcc doesn't work well with using std:: in kernel functions +#if defined(__CUDA_ARCH__) +#include +#define compat_exp c10::cuda::compat::exp +#define compat_ceil c10::cuda::compat::ceil +#define compat_floor c10::cuda::compat::floor +#define compat_log c10::cuda::compat::log +#define compat_pow c10::cuda::compat::pow +#define compat_sqrt c10::cuda::compat::sqrt +#define compat_tan c10::cuda::compat::tan +#define compat_abs c10::cuda::compat::abs +#define compat_log1p c10::cuda::compat::log1p +#elif defined(__HIPCC__) +#include +#define compat_exp c10::hip::compat::exp +#define compat_ceil c10::hip::compat::ceil +#define compat_floor c10::hip::compat::floor +#define compat_log c10::hip::compat::log +#define compat_pow c10::hip::compat::pow +#define compat_sqrt c10::hip::compat::sqrt +#define compat_tan c10::hip::compat::tan +#define compat_abs c10::hip::compat::abs +#define compat_log1p c10::hip::compat::log1p +#else +#define compat_exp std::exp +#define compat_ceil std::ceil +#define compat_floor std::floor +#define compat_log std::log +#define compat_pow std::pow +#define compat_sqrt std::sqrt +#define compat_tan std::tan +#define compat_abs std::abs +#define compat_log1p std::log1p +#endif + +namespace { + +#if !defined(__CUDA_ARCH__) && !defined(__HIPCC__) +// we cannot use std::isnan directly due to some incompatibility of +// gcc constexpr'ing and nvcc +using std::isnan; +#endif + +// Here sampler_t should be function type scalar_t(void). For gpu +// "sampler" is a device function, but since ROCM doesn't have +// equivalent to nvstd::function, we use a template type parameter to +// capture it. +template +struct BaseSampler { + sampler_t sampler; + C10_DEVICE BaseSampler(const sampler_t& sampler): sampler(sampler) {} + C10_DEVICE scalar_t sample() { + return sampler(); + } +}; + +// The function `sample_gamma` is +// is adapted from Numpy's distributions.c implementation. +// It is MIT licensed, so here is the copyright: + +/* Copyright 2005 Robert Kern (robert.kern@gmail.com) + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the + * "Software"), to deal in the Software without restriction, including + * without limitation the rights to use, copy, modify, merge, publish, + * distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to + * the following conditions: + * + * The above copyright notice and this permission notice shall be included + * in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +*/ + +template +C10_DEVICE scalar_t sample_gamma(scalar_t alpha, BaseSampler& standard_uniform, BaseSampler& standard_normal) { + accscalar_t scale = 1.0f; + + // Boost alpha for higher acceptance probability. + if (alpha < 1.0f) { + if (alpha == 0.f) return 0.f; + scale *= compat_pow(1 - standard_uniform.sample(), 1.0f / alpha); + alpha += 1.0f; + } + + // This implements the acceptance-rejection method of Marsaglia and Tsang (2000) + // doi:10.1145/358407.358414 + const accscalar_t d = alpha - 1.0f / 3.0f; + const accscalar_t c = 1.0f / compat_sqrt(9.0f * d); + for (;;) { + accscalar_t x, y; + do { + x = standard_normal.sample(); + y = 1.0f + c * x; + } while (y <= 0); + const accscalar_t v = y * y * y; + const accscalar_t u = 1 - standard_uniform.sample(); + const accscalar_t xx = x * x; + if (u < 1.0f - 0.0331f * xx * xx) + return static_cast(scale * d * v); + if (compat_log(u) < 0.5f * xx + d * (1.0f - v + compat_log(v))) + return static_cast(scale * d * v); + } +} + +/* the functions stirling_approx_tail, binomial_inversion, and btrs are adapted + * from TensorFlow's random_binomial_op.cc implementation. That code is under + * copyright: 2019 The TensorFlow Authors. + * + * It was released under the Apache License, Version 2.0 (the "License"), available at: + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +template +C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) { + const static scalar_t kTailValues[] = { + 0.0810614667953272, + 0.0413406959554092, + 0.0276779256849983, + 0.02079067210376509, + 0.0166446911898211, + 0.0138761288230707, + 0.0118967099458917, + 0.0104112652619720, + 0.00925546218271273, + 0.00833056343336287 + }; + if (k <= 9) { + return kTailValues[static_cast(k)]; + } + scalar_t kp1sq = (k + 1) * (k + 1); + return (1.0 / 12 - (1.0 / 360 - 1.0 / 1260 / kp1sq) / kp1sq) / (k + 1); +} + + +template +C10_DEVICE scalar_t binomial_inversion(scalar_t count, scalar_t prob, BaseSampler& standard_uniform) { + accscalar_t U; + accscalar_t geom_sum = 0; + scalar_t num_geom = 0; + + accscalar_t logprob = compat_log1p(-prob); + + while (true) { + U = standard_uniform.sample(); + accscalar_t geom = compat_ceil(compat_log(U) / logprob); + geom_sum += geom; + if (geom_sum > count) { + break; + } + num_geom = num_geom + 1; + } + return num_geom; +} + +template +C10_DEVICE scalar_t btrs(scalar_t count, scalar_t prob, BaseSampler& standard_uniform) { + scalar_t k; + accscalar_t U, V, us; + + // This is spq in the paper. + const accscalar_t stddev = compat_sqrt(count * prob * (1 - prob)); + + // Other coefficients for Transformed Rejection sampling. + const accscalar_t b = 1.15 + 2.53 * stddev; + const accscalar_t a = -0.0873 + 0.0248 * b + 0.01 * prob; + const accscalar_t c = count * prob + 0.5; + const accscalar_t v_r = 0.92 - 4.2 / b; + const accscalar_t r = prob / (1 - prob); + + const accscalar_t alpha = (2.83 + 5.1 / b) * stddev; + const accscalar_t m = compat_floor((count + 1) * prob); + + while (true) { + U = standard_uniform.sample() - 0.5; + V = standard_uniform.sample(); + + us = 0.5 - compat_abs(U); + k = static_cast(compat_floor((2 * a / us + b) * U + c)); + + // Reject non-sensical answers. + if (k < 0 || k > count) { + continue; + } + // Region for which the box is tight, and we can return our calculated value. + // This should happen 0.86 * v_r times. In the limit as n * p is large, + // the acceptance rate converges to ~79% (and in the lower regime it is ~24%). + if (us >= 0.07 && V <= v_r) { + return k; + } + + // This deviates from Hormann's BTRS algorithm, as there is a log missing. + // For all (u, v) pairs outside of the bounding box, this calculates the + // transformed-reject ratio. + V = compat_log(V * alpha / (a / (us * us) + b)); + accscalar_t upperbound = + ((m + 0.5) * compat_log((m + 1) / (r * (count - m + 1))) + + (count + 1) * compat_log((count - m + 1) / (count - k + 1)) + + (k + 0.5) * compat_log(r * (count - k + 1) / (k + 1)) + + stirling_approx_tail(m) + stirling_approx_tail(count - m) - + stirling_approx_tail(k) - stirling_approx_tail(count - k)); + + if (V <= upperbound) { + return k; + } + } +} + +template +C10_DEVICE scalar_t sample_binomial(scalar_t count, scalar_t prob, BaseSampler& standard_uniform) { + if (count <= 0.0 || prob <= 0.0) { + return 0; + } else if (prob >= 1.0) { + return count; + } else if (prob <= 0.5) { + if (count * prob >= 10.0) { + // btrs + return btrs(count, prob, standard_uniform); + } else { + // binomial inversion + return binomial_inversion(count, prob, standard_uniform); + } + } else if (prob > 0.5) { + scalar_t qprob = 1.0 - prob; + if (count * qprob >= 10.0) { + // btrs + return count - btrs(count, qprob, standard_uniform); + } else { + // count - binomial inversion + return count - binomial_inversion(count, qprob, standard_uniform); + } + } else { + // prob is nan? + return static_cast(NAN); + } +} + +/* + * This function is derived from the implementation of the digamma function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library] in ATen/native/Math.h. + */ +template +C10_DEVICE inline scalar_t digamma_one(scalar_t x) { + constexpr accscalar_t PSI_10 = 2.25175258906672110764; + if (x == 0) { + return INFINITY; + } + accscalar_t additional_summand = 0; + int x_is_integer = x == compat_floor(x); + if (x < 0) { + if (x_is_integer) { + return INFINITY; + } + // it is more standard to write this as recursion, but + // nvcc does not like that + additional_summand = -c10::pi / + compat_tan(c10::pi * x); + x = 1 - x; + } + + // Push x to be >= 10 + accscalar_t result = 0; + while (x < 10) { + result -= 1 / x; + x += 1; + } + if (x == 10) { + return result + PSI_10 + additional_summand; + } + + // Compute asymptotic digamma + static const accscalar_t A[] = { + 8.33333333333333333333E-2, + -2.10927960927960927961E-2, + 7.57575757575757575758E-3, + -4.16666666666666666667E-3, + 3.96825396825396825397E-3, + -8.33333333333333333333E-3, + 8.33333333333333333333E-2, + }; + + accscalar_t y = 0; + if (x < 1.0e17f) { + accscalar_t z = 1.0 / (x * x); + y = z * polevl(z, A, 6); + } + return static_cast( + result + compat_log(x) - (0.5f / x) - y + additional_summand); +} + +// Computes the reparameterized gradient -(d/dalpha cdf(x;alpha)) / pdf(x;alpha) +// for random number x drawn from a standard Gamma distribution Gamma(alpha). +template +C10_HOST_DEVICE scalar_t standard_gamma_grad_one(scalar_t alpha_, scalar_t x_) { + // Use a Taylor series expansion for small x. + accscalar_t x = static_cast(x_); + accscalar_t alpha = static_cast(alpha_); + if (x < 0.8f) { + accscalar_t numer = 1; + accscalar_t denom = alpha; + auto series1 = numer / denom; + auto series2 = numer / (denom * denom); + for (int i = 1; i <= 5; ++i) { + numer *= -x / static_cast(i); + denom += 1; + series1 += numer / denom; + series2 += numer / (denom * denom); + } + const auto pow_x_alpha = compat_pow(x, alpha); + const auto gamma_pdf = compat_pow(x, alpha - 1) * compat_exp(-x); + const auto gamma_cdf = pow_x_alpha * series1; + const auto gamma_cdf_alpha = + (compat_log(x) - digamma_one(alpha)) * + gamma_cdf - + pow_x_alpha * series2; + const auto result = -gamma_cdf_alpha / gamma_pdf; + return isnan(result) ? static_cast( 0.f ) : static_cast(result); + } + + // Use a Rice saddle point expansion for large alpha. + if (alpha > 8.0f) { + if (0.9f * alpha <= x && x <= 1.1f * alpha) { + const auto numer_1 = 1 + 24 * alpha * (1 + 12 * alpha); + const auto numer_2 = 1440 * (alpha * alpha) + 6 * x * (53 - 120 * x) + - 65 * x * x / alpha + alpha * (107 + 3600 * x); + const auto denom = 1244160 * (alpha * alpha) * (alpha * alpha); + return static_cast(numer_1 * numer_2 / denom); + } + const auto denom = compat_sqrt(8 * alpha); + const auto term2 = denom / (alpha - x); + const auto term3 = compat_pow( + x - alpha - alpha * compat_log(x / alpha), + static_cast(-1.5)); + const auto term23 = (x < alpha) ? term2 - term3 : term2 + term3; + const auto term1 = compat_log(x / alpha) * term23 - + compat_sqrt(2 / alpha) * (alpha + x) / ((alpha - x) * (alpha - x)); + const auto stirling = 1 + 1 / (12 * alpha) * (1 + 1 / (24 * alpha)); + const auto numer = x * term1; + return static_cast(-stirling * numer / denom); + } + + // Use a bivariate rational approximation to the reparameterized gradient. + const auto u = compat_log(x / alpha); + const auto v = compat_log(alpha); + static const accscalar_t coef_uv[3][8] = { + {0.16009398, -0.094634809, 0.025146376, -0.0030648343, + 1, 0.32668115, 0.10406089, 0.0014179084}, + {0.53487893, 0.1298071, 0.065735949, -0.0015649758, + 0.16639465, 0.020070113, -0.0035938915, -0.00058392623}, + {0.040121004, -0.0065914022, -0.0026286047, -0.0013441777, + 0.017050642, -0.0021309326, 0.00085092367, -1.5247877e-07}, + }; + accscalar_t coef_v[8]; + for (int i = 0; i < 8; ++ i) { + coef_v[i] = coef_uv[0][i] + u * (coef_uv[1][i] + u * coef_uv[2][i]); + } + const auto p = coef_v[0] + v * (coef_v[1] + v * (coef_v[2] + v * coef_v[3])); + const auto q = coef_v[4] + v * (coef_v[5] + v * (coef_v[6] + v * coef_v[7])); + return static_cast(compat_exp(p / q)); +} + +// Approximate reparameterized gradient of Beta(x,alpha,beta) wrt alpha. +// Assumes x is close to zero and uses a Taylor expansion. +template +C10_DEVICE inline scalar_t _beta_grad_alpha_small(scalar_t x, scalar_t alpha, scalar_t beta) { + const scalar_t factor = digamma_one(alpha) + - digamma_one(alpha + beta) - compat_log(x); + scalar_t numer = 1; + scalar_t series = numer / alpha * (factor + 1 / alpha); + for (int i = 1; i <= 10; ++i) { + scalar_t casted_i = static_cast(i); + numer *= (casted_i - beta) * x / casted_i; + const scalar_t denom = alpha + casted_i; + series += numer / denom * (factor + 1 / denom); + } + const scalar_t result = x * compat_pow(1 - x, -beta) * series; + return isnan(result) ? static_cast( 0.f ) : result; +} + +// Approximate reparameterized gradient of Beta(x,alpha,beta) wrt beta. +// Assumes x is close to zero and uses a Taylor expansion. +template +C10_DEVICE inline scalar_t _beta_grad_beta_small(scalar_t x, scalar_t alpha, scalar_t beta) { + const scalar_t factor = digamma_one(alpha + beta) - digamma_one(beta); + scalar_t numer = 1, betas = 1, dbetas = 0, series = factor / alpha; + for (int i = 1; i <= 8; ++i) { + scalar_t casted_i = static_cast(i); + numer *= -x / casted_i; + dbetas = dbetas * (beta - casted_i) + betas; + betas = betas * (beta - casted_i); + series += numer / (alpha + casted_i) * (dbetas + factor * betas); + } + const scalar_t result = -compat_pow(1 - x, 1 - beta) * series; + return isnan(result) ? static_cast( 0.f ) : result; +} + +// Approximate reparameterized gradient of Beta(x,alpha,beta) wrt alpha. +// Assumes alpha and beta are both large and uses a Rice saddle point expansion. +// To ensure numerical stability, this computation is performed at higher precision. +template +C10_DEVICE inline scalar_t _beta_grad_alpha_mid(accscalar_t x, accscalar_t alpha, accscalar_t beta) { + const accscalar_t total = alpha + beta; + const accscalar_t mean = alpha / total; + const accscalar_t std = compat_sqrt(alpha * beta / (total + 1)) / total; + if (mean - 0.1 * std <= x && x <= mean + 0.1 * std) { + // Avoid the singularity at x = mean. + const accscalar_t poly = 47 * x * (beta * beta) * (beta * beta) + alpha * ( + (43 + 20 * (16 + 27 * beta) * x) * (beta * beta) * beta + alpha * ( + 3 * (59 + 180 * beta - 90 * x) * (beta * beta) + alpha * ( + (453 + 1620 * beta * (1 - x) - 455 * x) * beta + alpha * ( + 8 * (1 - x) * (135 * beta - 11))))); + const accscalar_t prefactor_num = (1 + 12 * alpha) * (1 + 12 * beta) / (total * total); + const accscalar_t prefactor_den = 12960 * alpha * alpha * alpha * beta * beta * (1 + 12 * total); + return prefactor_num / (1 - x) * poly / prefactor_den; + } + const accscalar_t prefactor = -x / compat_sqrt(2 * alpha * beta / total); + const accscalar_t stirling = (1 + 1 / (12 * alpha) + 1 / (288 * alpha * alpha)) + * (1 + 1 / (12 * beta) + 1 / (288 * beta * beta)) + / (1 + 1 / (12 * total) + 1 / (288 * total * total)); + const accscalar_t term1_num = 2 * (alpha * alpha) * (x - 1) + alpha * beta * (x - 1) - x * (beta * beta); + const accscalar_t axbx = alpha * (x - 1) + beta * x; + const accscalar_t term1_den = compat_sqrt(2 * alpha / beta) * compat_pow(total, static_cast(1.5f)) * axbx * axbx; + const accscalar_t term1 = term1_num / term1_den; + const accscalar_t term2 = 0.5f * compat_log(alpha / (total * x)); + const accscalar_t term3_num = compat_sqrt(8 * alpha * beta / total); + const accscalar_t term3_den = beta * x + alpha * (x - 1); + const accscalar_t term3 = term3_num / term3_den; + const accscalar_t term4_base = beta * compat_log(beta / (total * (1 - x))) + + alpha * compat_log(alpha / (total * x)); + const accscalar_t term4 = compat_pow(term4_base, static_cast(-1.5f)); + const accscalar_t term1234 = term1 + term2 * (term3 + (x < mean ? term4 : -term4)); + return static_cast(stirling * prefactor * term1234); +} + +// Computes a scaled reparameterized gradient +// -(d/dalpha cdf(x;alpha,beta)) / pdf(x;alpha,beta) / (1-x) +// for random number x drawn from a Beta distribution Beta(alpha,beta). +// This function inputs total=alpha+beta to make it easy to implement +// Dirichlet reparameterized gradients in terms of Betas. +template +C10_HOST_DEVICE inline scalar_t dirichlet_grad_one(scalar_t x, scalar_t alpha, scalar_t total) { + accscalar_t x_ = static_cast(x); + accscalar_t alpha_ = static_cast(alpha); + accscalar_t total_ = static_cast(total); + + const scalar_t beta = total - alpha; + const accscalar_t beta_ = total_ - alpha_; + const scalar_t boundary = total * x * (1 - x); + + // Use an asymptotic approximation for x close to 0. + if (x <= 0.5f && boundary < 2.5f) { + return _beta_grad_alpha_small(x, alpha, beta); + } + + // Use an asymptotic approximation for x close to 1. + if (x >= 0.5f && boundary < 0.75f) { + return -_beta_grad_beta_small(1 - x, beta, alpha); + } + + // Use an asymptotic approximation when alpha and (total - alpha) are both large. + if (alpha > 6 && beta > 6) { + return _beta_grad_alpha_mid(x_, alpha_, beta_); + } + + // Use a rational correction to an analytic approximation. + static const accscalar_t c[2][3][3][4] = { + {{{1.003668233, -0.01061107488, -0.0657888334, 0.01201642863}, + {0.6336835991, -0.3557432599, 0.05486251648, -0.001465281033}, + {-0.03276231906, 0.004474107445, 0.002429354597, -0.0001557569013}}, + {{0.221950385, -0.3187676331, 0.01799915743, 0.01074823814}, + {-0.2951249643, 0.06219954479, 0.01535556598, 0.001550077057}, + {0.02155310298, 0.004170831599, 0.001292462449, 6.976601077e-05}}, + {{-0.05980841433, 0.008441916499, 0.01085618172, 0.002319392565}, + {0.02911413504, 0.01400243777, -0.002721828457, 0.000751041181}, + {0.005900514878, -0.001936558688, -9.495446725e-06, 5.385558597e-05}}}, + {{{1, -0.02924021934, -0.04438342661, 0.007285809825}, + {0.6357567472, -0.3473456711, 0.05454656494, -0.002407477521}, + {-0.03301322327, 0.004845219414, 0.00231480583, -0.0002307248149}}, + {{0.5925320577, -0.1757678135, 0.01505928619, 0.000564515273}, + {0.1014815858, -0.06589186703, 0.01272886114, -0.0007316646956}, + {-0.007258481865, 0.001096195486, 0.0003934994223, -4.12701925e-05}}, + {{0.06469649321, -0.0236701437, 0.002902096474, -5.896963079e-05}, + {0.001925008108, -0.002869809258, 0.0008000589141, -6.063713228e-05}, + {-0.0003477407336, 6.959756487e-05, 1.097287507e-05, -1.650964693e-06}}}, + }; + const accscalar_t u = compat_log(x_); + const accscalar_t a = compat_log(alpha_) - u; + const accscalar_t b = compat_log(total_) - a; + const accscalar_t pow_u[3] = {1, u, u * u}; + const accscalar_t pow_a[3] = {1, a, a * a}; + accscalar_t p = 0.0; + accscalar_t q = 0.0; + for (int i = 0; i < 3; ++i) { + for (int j = 0; j < 3; ++j) { + const accscalar_t ua = pow_u[i] * pow_a[j]; + p += ua * (c[0][i][j][0] + b * (c[0][i][j][1] + b * (c[0][i][j][2] + b * c[0][i][j][3]))); + q += ua * (c[1][i][j][0] + b * (c[1][i][j][1] + b * (c[1][i][j][2] + b * c[1][i][j][3]))); + } + } + const accscalar_t approx = x_ * (digamma_one(total_) - digamma_one(alpha_)) / beta_; + return static_cast(p / q * approx); +} + +} // namespace diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/EmbeddingBag.h b/phivenv/Lib/site-packages/torch/include/ATen/native/EmbeddingBag.h new file mode 100644 index 0000000000000000000000000000000000000000..2bd72d34b4811817609eb9a4ea4459c02c02f30a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/EmbeddingBag.h @@ -0,0 +1,153 @@ +#include +#include +#include + +#ifdef USE_FBGEMM +#include +#endif + +namespace at::native { + +enum class EmbeddingBagMode { + SUM = 0, + MEAN = 1, + MAX = 2, +}; + +[[maybe_unused]] static bool operator==(int64_t op1, EmbeddingBagMode op2) { + return op1 == static_cast(op2); +} + +[[maybe_unused]] static bool operator!=(int64_t op1, EmbeddingBagMode op2) { + return !(op1 == op2); +} + +void check_arguments( + const Tensor& weight, + const Tensor& indices, + const Tensor& offsets, + const int64_t mode, + const std::optional& per_sample_weights, + bool include_last_offset); + +void make_bag_size_out( + Tensor& bag_size_out, + const Tensor& offsets, + const Tensor& indices, + const int64_t mode, + const bool include_last_offset, + const bool requires_grad); + +void make_max_indices_out( + Tensor& max_indices_out, + const Tensor& weight, + const Tensor& indices, + const Tensor& offsets, + const Tensor& bag_size, + const int64_t mode, + bool include_last_offset); + +void make_offset2bag_out( + Tensor& offset2bag, + Tensor& output, + const Tensor& weight, + const Tensor& indices, + const Tensor& offsets, + const int64_t mode, + const std::optional& per_sample_weights, + const int64_t padding_idx = -1); + +#ifdef USE_FBGEMM + +template +struct _CallbackAndBlockSize { + using TCallback = typename fbgemm::EmbeddingSpMDMKernelSignature::Type; + + int64_t blockSize = -1; + TCallback callback = nullptr; + + static TCallback generateCallback(int64_t block_size) { + return fbgemm::GenerateEmbeddingSpMDM( + block_size, + has_weight, + /* normalize_by_lengths */false, + /* prefetch */16, + /* is_weight_positional */false, + /* use_offsets */true); + } + + _CallbackAndBlockSize() = default; + + explicit _CallbackAndBlockSize(std::optional maybe_block_size) + : blockSize(maybe_block_size.value_or(-1)) + , callback(maybe_block_size.has_value() ? generateCallback(maybe_block_size.value()) : nullptr) + {} +}; + +template +struct _EmbeddingBagKernelCacheImpl : private StorageMixins... { + + _EmbeddingBagKernelCacheImpl() = default; + // use each of the mixins to store corresponding kernel and block size + explicit _EmbeddingBagKernelCacheImpl(std::optional maybe_block_size) + : StorageMixins(maybe_block_size)... + {} + + // this method is thread safe (call sites may call from different threads) + template + typename _CallbackAndBlockSize::TCallback + getCallback(int64_t block_size) const { + // if the cache doesn't store the kernel for the incoming block size + // (so it is different from the one stored in corresponding mixin) + // regenerate the kernel (not writing it into the cache so we avoid locks) + if (block_size != _CallbackAndBlockSize::blockSize) { + return _CallbackAndBlockSize::generateCallback(block_size); + } + // else retrieve the cached kernel from the corresponding mixin + return _CallbackAndBlockSize::callback; + } +}; + +// instantiate the cache with the list of storage mixins +// for each of the 8 _EmbeddingBagKernelCache* usages in the EmbeddingBag.cpp impl file +using _EmbeddingBagKernelCache = _EmbeddingBagKernelCacheImpl< + _CallbackAndBlockSize, + _CallbackAndBlockSize, + _CallbackAndBlockSize, + _CallbackAndBlockSize, + _CallbackAndBlockSize, + _CallbackAndBlockSize, + _CallbackAndBlockSize, + _CallbackAndBlockSize>; +#else +struct _EmbeddingBagKernelCache { + explicit _EmbeddingBagKernelCache(std::optional /* maybe_block_size */) {} +}; +#endif + +void _embedding_bag_cpu_impl_out(Tensor& output, Tensor& offset2bag, + Tensor& bag_size, Tensor* max_indices, + const Tensor &weight, const Tensor &indices, + const Tensor &offsets, const int64_t mode = 0, + const std::optional& per_sample_weights = std::nullopt, + bool include_last_offset = false, + int64_t padding_idx = -1, + _EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr); + +void _embedding_bag_cpu_out( + at::Tensor& output, + at::Tensor& offset2bag, + at::Tensor& bag_size, + at::Tensor* p_max_indices, + const at::Tensor& weight, + const at::Tensor& indices, + const at::Tensor& offsets, + const bool scale_grad_by_freq, + const int64_t mode, + const bool sparse, + const std::optional& per_sample_weights, + const bool include_last_offset, + const std::optional& padding_idx, + _EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr); + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/Fill.h b/phivenv/Lib/site-packages/torch/include/ATen/native/Fill.h new file mode 100644 index 0000000000000000000000000000000000000000..26a67717b6383c24ca47ea49396ba04c98530beb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/Fill.h @@ -0,0 +1,21 @@ +// Functions that fill Tensors with constants. Implementations are in Fill.cpp. + +#pragma once + +#include + +namespace c10 { +class Scalar; +} + +namespace at { +class Tensor; +struct TensorIterator; + +namespace native { + +DECLARE_DISPATCH(void(*)(TensorIterator&, const c10::Scalar&), fill_stub) + +Tensor& fill_out(Tensor& self, const Scalar& value); + +}} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/ForeachUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/ForeachUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..73b40a8434130816c5aac9292d7977ad3722d058 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/ForeachUtils.h @@ -0,0 +1,409 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +#include +#include + +namespace at::native { +namespace { +// Check if tensor list has either a boolean tensor or a integer tensor +inline bool has_integral_tensor(TensorList tensors, const bool includeBool) { + return std::any_of( + tensors.begin(), tensors.end(), [&includeBool](const auto& t) { + return at::isIntegralType(t.scalar_type(), includeBool); + }); +} +// check if tensor list has bool tensors +inline bool has_bool_tensor(TensorList tensors) { + return std::any_of(tensors.begin(), tensors.end(), [](const auto& t) -> bool { + return t.scalar_type() == ScalarType::Bool; + }); +} + +// Check foreach API restrictions +// - Tensor lists must be non-empty. +// - All TensorLists and ScalarLists must have the same number of elements. +// - Corresponding tensors must have the same size. +inline void check_foreach_api_restrictions(TensorList tensors) { + TORCH_CHECK(!tensors.empty(), "Tensor list must have at least one tensor."); +} + +inline void check_foreach_api_restrictions( + TensorList tensors, + ArrayRef scalars) { + check_foreach_api_restrictions(tensors); + TORCH_CHECK( + tensors.size() == scalars.size(), + "Tensor list must have same number of elements as scalar list."); +} + +inline void check_foreach_api_restrictions( + TensorList tensors1, + TensorList tensors2) { + TORCH_CHECK(!tensors1.empty(), "Tensor list must have at least one tensor."); + TORCH_CHECK(!tensors2.empty(), "Tensor list must have at least one tensor."); + TORCH_CHECK( + tensors1.size() == tensors2.size(), + "Tensor lists must have the same number of tensors, got ", + tensors1.size(), + " and ", + tensors2.size()); +} + +inline void check_foreach_api_restrictions( + TensorList tensors1, + TensorList tensors2, + TensorList tensors3) { + TORCH_CHECK(!tensors1.empty(), "Tensor list must have at least one tensor."); + TORCH_CHECK(!tensors2.empty(), "Tensor list must have at least one tensor."); + TORCH_CHECK(!tensors3.empty(), "Tensor list must have at least one tensor."); + TORCH_CHECK( + tensors1.size() == tensors2.size(), + "Tensor lists must have the same number of tensors, got ", + tensors1.size(), + " and ", + tensors2.size()); + TORCH_CHECK( + tensors1.size() == tensors3.size(), + "Tensor lists must have the same number of tensors, got ", + tensors1.size(), + " and ", + tensors3.size()); +} + +inline void check_foreach_api_restrictions( + TensorList tensors1, + TensorList tensors2, + TensorList tensors3, + ArrayRef scalars) { + check_foreach_api_restrictions(tensors1, tensors2, tensors3); + TORCH_CHECK( + tensors1.size() == scalars.size(), + "Tensor list must have same number of elements as scalar list, got ", + tensors1.size(), + " and ", + scalars.size()); +} + +inline void check_foreach_api_restrictions( + TensorList tensors1, + TensorList tensors2, + ArrayRef scalars) { + check_foreach_api_restrictions(tensors1, tensors2); + TORCH_CHECK( + tensors1.size() == scalars.size(), + "Tensor list must have same number of elements as scalar list, got ", + tensors1.size(), + " and ", + scalars.size()); +} + +// Helper function called in check_fast_path_restrictions to check whether all +// corresponding tensors (aligning in index across the tensorLists) share the +// same device and dtype. +inline bool _check_tensors_share_device_and_dtype( + ArrayRef tensorLists, + const bool skip_dtype_check = false) { + const auto expected_dtype = tensorLists[0][0].dtype(); + const auto expected_device = tensorLists[0][0].device(); + + auto is_tensor_okay = [&](const Tensor& tensor) { + return (skip_dtype_check || tensor.dtype() == expected_dtype) && + tensor.device() == expected_device && tensor.layout() == at::kStrided && + tensor.is_non_overlapping_and_dense(); + }; + + for (const auto& tensorList : tensorLists) { + for (const auto& tensor : tensorList) { + if (!is_tensor_okay(tensor)) { + return false; + } + } + } + + return true; +} + +// Helper function called in check_fast_path_restrictions to check if +// corresponding tensors in tensor lists have the same sizes and strides. +inline bool _check_tensors_share_sizes_and_strides( + ArrayRef tensorLists) { + auto is_diff_stride = [](const IntArrayRef& size, + const IntArrayRef& left_stride, + const IntArrayRef& right_stride) -> bool { + const size_t size_size = size.size(); + for (const auto dim : c10::irange(size_size)) { + if (size[dim] == 1) + continue; + if (left_stride[dim] != right_stride[dim]) { + return true; + } + } + return false; + }; + for (const auto i : c10::irange(1, tensorLists.size())) { + for (const auto j : c10::irange(tensorLists[0].size())) { + if (tensorLists[0][j].sizes() != tensorLists[i][j].sizes() || + is_diff_stride( + tensorLists[0][j].sizes(), + tensorLists[0][j].strides(), + tensorLists[i][j].strides())) { + return false; + } + } + } + + return true; +} + +// Helper function called in check_fast_path_restrictions to check whether +// all tensors type promote properly with the scalars in scalarList. This +// function assumes that _check_tensors_share_device_and_dtype has already been +// called so that all corresponding tensors in tensorLists have the same dtype. +// Then, it is sufficient to check the type promotion with just one tensorList. +inline bool _check_tensors_do_type_promotion_with_scalars( + TensorList tensorList, + ArrayRef scalarList = {}, + bool does_op_promote_integer_inputs_to_float = false) { + for (const auto i : c10::irange(tensorList.size())) { + // For division, integer inputs will result in float. + if (does_op_promote_integer_inputs_to_float) { + if (at::isIntegralType( + tensorList[i].scalar_type(), /*includeBool*/ true)) { + return false; + } + } + if (!scalarList.empty()) { + const auto& scalar = + scalarList.size() == 1 ? scalarList[0] : scalarList[i]; + const auto& tensor = tensorList[i]; + // note(mkozuki): This check might be responsible for + // `_foreach_add(bool_tensors, bool_tensors)` being pushed to slow path. + if (tensor.scalar_type() != at::native::result_type(scalar, tensor)) { + return false; + } + } + } + + return true; +} + +// To go via 'fast' path, several conditions must be satisfied +// - All tensors in all lists must have the same dtype. +// - All tensors must be on the same device +// - All tensors must have strided layout +// - All tensors must be non-overlapping and dense +// - Resulting tensor must have the same dtype as the input one + +// [note: what's ``does_op_promote_integer_inputs_to_float=true``?] +// ``does_op_promote_integer_inputs_to_float=true`` means that the result of +// the op will be float even if inputs are integer or boolean, which +// currently fast path does not support. In short, this flag, when +// turned on, gatekeeps the op from going down the fastpath. + +// Please, make sure to call check_foreach_api_restrictions before calling this +// method. There is a set of preconditions that have to be satisfied. +inline bool check_fast_path_restrictions( + ArrayRef tensorLists, + ArrayRef scalarList = {}, + bool does_op_promote_integer_inputs_to_float = false) { + return _check_tensors_share_device_and_dtype(tensorLists) && + _check_tensors_share_sizes_and_strides(tensorLists) && + _check_tensors_do_type_promotion_with_scalars( + tensorLists[0], + scalarList, + does_op_promote_integer_inputs_to_float); +} + +inline std::vector convert_tensor_to_scalar_list( + const Tensor& scalarList_, + int64_t expect_length) { + std::vector scalarList; + TORCH_CHECK( + scalarList_.device() == c10::kCPU, + "Expected scalars to be on CPU, got ", + scalarList_.device(), + " instead."); + TORCH_CHECK( + scalarList_.is_contiguous(), "Expected scalars to be contiguous."); + TORCH_CHECK( + scalarList_.dim() == 1, + "Expected packed scalar Tensor to be of dimension 1. Got ", + scalarList_.dim(), + " instead."); + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( + kComplexHalf, + kHalf, + kBool, + kBFloat16, + scalarList_.scalar_type(), + "convert_tensor_to_scalar_list", + [&]() { + const scalar_t* scalar_data = scalarList_.const_data_ptr(); + TORCH_CHECK( + (expect_length == scalarList_.size(0)), + "Expected length of scalars to match input of length ", + expect_length, + " but got ", + scalarList_.size(0), + " instead."); + for (int64_t i = 0; i < scalarList_.size(0); i++) { + scalarList.emplace_back(scalar_data[i]); + } + }); + return scalarList; +} + +// see: [note: what's ``does_op_promote_integer_inputs_to_float=true``?] +inline bool can_use_fast_route( + ArrayRef tensorLists, + ArrayRef scalarList = {}, + bool does_op_promote_integer_inputs_to_float = false) { + return check_fast_path_restrictions( + tensorLists, scalarList, does_op_promote_integer_inputs_to_float); +} + +// see: [note: what's ``does_op_promote_integer_inputs_to_float=true``?] +inline bool can_use_fast_route( + TensorList tensors1, + TensorList tensors2, + bool does_op_promote_integer_inputs_to_float = false) { + return can_use_fast_route( + {tensors1, tensors2}, {}, does_op_promote_integer_inputs_to_float); +} + +using DeviceDtypeKey = std::pair; +using IndicesT = std::vector; +using nested_optional_tensorvec_t = + std::vector>>; +using TensorsAndIndicesT = std::pair; +using FlatMap = std::unordered_map< + DeviceDtypeKey, + TensorsAndIndicesT, + ParamsHash>; + +inline FlatMap _group_tensors_by_first_tensors_device_and_dtype( + const nested_optional_tensorvec_t& nested_tensorlist, + const bool with_indices) { + FlatMap grouped_tensors_with_indices; + + TORCH_CHECK(!nested_tensorlist.empty()); + TORCH_CHECK(!nested_tensorlist[0].empty()); + const auto num_lists = nested_tensorlist.size(); + const auto num_tensors = nested_tensorlist[0].size(); + + TORCH_CHECK(std::all_of( + nested_tensorlist.cbegin(), + nested_tensorlist.cend(), + [&](const auto& tensorlist) -> bool { + // note(crcrpar): Allow empty tensorlists following + // ref: + // https://github.com/pytorch/pytorch/blob/85885301fd3c6adb8b9dc3cf7afadf6945566684/torch/utils/_foreach_utils.py#L21-L24 + return tensorlist.size() == num_tensors || tensorlist.size() == 0; + })); + + for (const auto& tensor_index : c10::irange(num_tensors)) { + const auto key = [&]() -> DeviceDtypeKey { + const auto t = nested_tensorlist[0][tensor_index]; + TORCH_CHECK( + t.has_value(), + "Tensors of the first list of nested Tensor lists are supposed to be defined but ", + "the ", + tensor_index, + "-th Tensor is not."); + return {t->device(), t->scalar_type()}; + }(); + TORCH_CHECK( + std::all_of( + nested_tensorlist.cbegin(), + nested_tensorlist.cend(), + [&](const auto& tensorlist) -> bool { + if (tensorlist.size() == 0) { + return true; + } + const auto& tensor = tensorlist[tensor_index]; + // note(crcrpar): Currently the scope of this function is + // optimizers so there could be `state_steps` and other scalars + // whose elements are float tensors no matter what the parameter's + // dtype is. + if (!tensor.has_value()) { + return true; + } else { + const auto s = tensor->scalar_type(); + const auto d = tensor->device(); + // Note: `step` or `state_step` is float32 by default. + if (key.first == d) { + return key.second == s || s == at::ScalarType::Float || + s == at::ScalarType::Double; + } else if (d.is_cpu()) { + // note(crcrpar): There are some test cases (e.g. + // TestOptim::test_adam) where state_steps are on CPU and the + // others are on CUDA. Currently a state_step Tensor has the + // dtype of float. + return s == at::ScalarType::Float || + s == at::ScalarType::Double; + } else { + return false; + } + } + }), + "Tensors of the same index must be on the same device and the same dtype except `step` tensors that can be CPU and float32/64 notwithstanding"); + if (!grouped_tensors_with_indices.count(key)) { + grouped_tensors_with_indices.insert( + {key, + TensorsAndIndicesT{ + [&]() -> nested_optional_tensorvec_t { + nested_optional_tensorvec_t nested_tensorvec; + nested_tensorvec.reserve(num_lists); + for (const auto& i : c10::irange(num_lists)) { + std::vector> tensors; + if (!nested_tensorlist[i].empty()) { + // NB: num_tensors is the max possible length for any of + // the inner lists of tensor references. Reserving the max + // trades memory for perf. This should not have significant + // impact. + tensors.reserve(num_tensors); + } + nested_tensorvec.emplace_back(tensors); + } + return nested_tensorvec; + }(), + [&]() -> IndicesT { + if (!with_indices) { + return {}; + } else { + IndicesT indices; + indices.reserve(num_tensors); + return indices; + } + }()}}); + } + for (const auto& list_index : c10::irange(num_lists)) { + if (!nested_tensorlist[list_index].empty()) { + grouped_tensors_with_indices[key].first[list_index].emplace_back( + nested_tensorlist[list_index][tensor_index]); + } + } + if (with_indices) { + grouped_tensors_with_indices[key].second.emplace_back(tensor_index); + } + } + + return grouped_tensors_with_indices; +} + +} // namespace +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/FractionalMaxPooling.h b/phivenv/Lib/site-packages/torch/include/ATen/native/FractionalMaxPooling.h new file mode 100644 index 0000000000000000000000000000000000000000..1ad7a625ba3b9afa22f13aaff642220b2f71f34c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/FractionalMaxPooling.h @@ -0,0 +1,80 @@ +#pragma once +#include +#include +#include + +namespace at::native { + +template +inline std::vector generate_intervals( + scalar_t sample, + int64_t inputSize, + int64_t outputSize, + int64_t poolSize) { + std::vector sequence(outputSize); + if (outputSize > 1) { + scalar_t alpha = static_cast(inputSize - poolSize) / + static_cast(outputSize - 1); + + for (const auto i : c10::irange(outputSize - 1)) { + sequence[i] = + static_cast((i + sample) * alpha) - static_cast(sample * alpha); + } + } + if (outputSize > 0) { + sequence[outputSize - 1] = inputSize - poolSize; + } + return sequence; +} + +template +inline void fractional_max_pool_check_shape( + const Tensor& input, + const Tensor& randomSamples) { + + TORCH_CHECK( + input.scalar_type() == randomSamples.scalar_type(), + "Expect _random_samples to have the same dtype as input"); + + int64_t ndimension = randomSamples.ndimension(); + TORCH_CHECK( + ndimension == 3, + "Expect _random_samples to have 3 dimensions, got ", ndimension); + + int64_t N = randomSamples.size(0); + int64_t C = randomSamples.size(1); + int64_t D = randomSamples.size(2); + + int64_t input_batch = 0, input_channel = 0; + if (ndim == 2) { + // fractional_max_pool2d + if (input.ndimension() == 3) { + input_batch = 1; + input_channel = input.size(0); + } else { + input_batch = input.size(0); + input_channel = input.size(1); + } + } else { + // factional_max_pool3d + if (input.ndimension() == 4) { + input_batch = 1; + input_channel = input.size(0); + } else { + input_batch = input.size(0); + input_channel = input.size(1); + } + } + + TORCH_CHECK( + N >= input_batch, + "Expect _random_samples.size(0) no less then input batch size."); + TORCH_CHECK( + C == input_channel, + "Expect _random_samples.size(1) equals to input channel size."); + TORCH_CHECK( + D == ndim, + "Expect _random_samples.size(2) equals to ", ndim, "; got ", D, "."); +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/FunctionOfAMatrixUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/FunctionOfAMatrixUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..db42959bd0309c8f0ccd5bb4da316f09f7bb1bca --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/FunctionOfAMatrixUtils.h @@ -0,0 +1,20 @@ +#pragma once + +#include +#include + +namespace at { +struct TensorIterator; + +namespace native { + +using _compute_linear_combination_fn = void(*)( + TensorIterator& iter, + int64_t in_stride, + int64_t coeff_stride, + int64_t num_summations +); + +DECLARE_DISPATCH(_compute_linear_combination_fn, _compute_linear_combination_stub) + +}} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/FusedAdagrad.h b/phivenv/Lib/site-packages/torch/include/ATen/native/FusedAdagrad.h new file mode 100644 index 0000000000000000000000000000000000000000..bf2540202c686789cc4e11f7fbcd1b3f858509c4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/FusedAdagrad.h @@ -0,0 +1,20 @@ +#include +#include + +namespace at::native { + +using fused_adagrad_fn = void (*)( + const at::Tensor& param, + const at::Tensor& grad, + const at::Tensor& state_sum, + const at::Tensor& state_step, + const double lr, + const double lr_decay, + const double weight_decay, + const double eps, + const bool maximize, + const float* grad_scale_ptr); + +DECLARE_DISPATCH(fused_adagrad_fn, fused_adagrad_stub) + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/FusedAdam.h b/phivenv/Lib/site-packages/torch/include/ATen/native/FusedAdam.h new file mode 100644 index 0000000000000000000000000000000000000000..530424b5316890c174876a221da3f10631b7b370 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/FusedAdam.h @@ -0,0 +1,27 @@ +#include +#include + +namespace at::native { + +enum class ADAM_MODE : uint8_t { ORIGINAL = 0, ADAMW = 1 }; + +using fused_adam_fn = void (*)( + const at::Tensor& param, + const at::Tensor& grad, + const at::Tensor& exp_avg, + const at::Tensor& exp_avg_sq, + const at::Tensor& max_exp_avg_sq, + const at::Tensor& state_step, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool amsgrad, + const bool maximize, + const float* grad_scale_ptr, + const ADAM_MODE); + +DECLARE_DISPATCH(fused_adam_fn, fused_adam_stub) + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/FusedSGD.h b/phivenv/Lib/site-packages/torch/include/ATen/native/FusedSGD.h new file mode 100644 index 0000000000000000000000000000000000000000..283ccd1f3f2a2210cf510074bb8cee756d7904ef --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/FusedSGD.h @@ -0,0 +1,21 @@ +#include +#include + +namespace at::native { + +using fused_sgd_fn = void (*)( + const at::Tensor& param, + const at::Tensor& grad, + const at::Tensor& momentum_buffer, + const double weight_decay, + const double momentum, + const double lr, + const double dampening, + const bool nesterov, + const bool maximize, + const bool is_first_step, + const float* grad_scale_ptr); + +DECLARE_DISPATCH(fused_sgd_fn, fused_sgd_stub) + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/Gelu.h b/phivenv/Lib/site-packages/torch/include/ATen/native/Gelu.h new file mode 100644 index 0000000000000000000000000000000000000000..111a729e1d764472b0975aac9da4ee92ed7cf694 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/Gelu.h @@ -0,0 +1,33 @@ +#pragma once + +#include +#include + +namespace at::native { +// These constants control the approximation behavior of gelu function. +enum class GeluType { + None, // Baseline Gelu + Tanh, // Tanh Gelu Approximation + END +}; + +inline GeluType get_gelutype_enum(const std::string_view approximate) { + if (approximate == "none") { + return GeluType::None; + } else if (approximate == "tanh") { + return GeluType::Tanh; + } else { + TORCH_CHECK(false, "approximate argument must be either none or tanh."); + } +} + +inline std::string gelutype_to_string(const GeluType type) { + switch(type) { + case GeluType::None: return "none"; + case GeluType::Tanh: return "tanh"; + default: TORCH_CHECK(false, "unknown GELU type: ", static_cast(type)); + } +} + + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/GridSampler.h b/phivenv/Lib/site-packages/torch/include/ATen/native/GridSampler.h new file mode 100644 index 0000000000000000000000000000000000000000..92e61f75539dc9a5a45a966170bb0827de82feda --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/GridSampler.h @@ -0,0 +1,298 @@ +#pragma once + +#include +#include +#include +#include + +#include + +namespace at::native { + +using detail::GridSamplerInterpolation; +using detail::GridSamplerPadding; + +// Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value, +// where we view each pixel as an area between (idx - 0.5) and (idx + 0.5). +// if align_corners: -1 and +1 get sent to the centers of the corner pixels +// -1 --> 0 +// +1 --> (size - 1) +// scale_factor = (size - 1) / 2 +// if not align_corners: -1 and +1 get sent to the image edges +// -1 --> -0.5 +// +1 --> (size - 1) + 0.5 == size - 0.5 +// scale_factor = size / 2 +template +static inline scalar_t grid_sampler_unnormalize(scalar_t coord, int64_t size, + bool align_corners) { + if (align_corners) { + // unnormalize coord from [-1, 1] to [0, size - 1] + return ((coord + 1) / 2) * (size - 1); + } else { + // unnormalize coord from [-1, 1] to [-0.5, size - 0.5] + return ((coord + 1) * size - 1) / 2; + } +} + +// grid_sampler_unnormalize_set_grad works the same as grid_sampler_unnormalize +// except that it also returns the `d output / d input` via pointer argument +// `grad_in`. +// This is useful in the backward pass of grid_sampler. +template +static inline scalar_t grid_sampler_unnormalize_set_grad(scalar_t coord, int64_t size, + bool align_corners, scalar_t *grad_in) { + if (align_corners) { + // unnormalize coord from [-1, 1] to [0, size - 1] + *grad_in = static_cast(size - 1) / 2; + return ((coord + 1) / 2) * (size - 1); + } else { + // unnormalize coord from [-1, 1] to [-0.5, size - 0.5] + *grad_in = static_cast(size) / 2; + return ((coord + 1) * size - 1) / 2; + } +} + +// Clips coordinates to between 0 and clip_limit - 1 +template +static inline scalar_t clip_coordinates(scalar_t in, int64_t clip_limit) { + return std::min(static_cast(clip_limit - 1), std::max(in, static_cast(0))); +} + +// clip_coordinates_set_grad works similarly to clip_coordinates except that +// it also returns the `d output / d input` via pointer argument `grad_in`. +// This is useful in the backward pass of grid_sampler. +template +static inline scalar_t clip_coordinates_set_grad(scalar_t in, int64_t clip_limit, + scalar_t *grad_in) { + // Note that it is important for the gradient calculation that borders + // are considered out of bounds. + if (in <= static_cast(0)) { + *grad_in = static_cast(0); + return static_cast(0); + } else { + scalar_t max = static_cast(clip_limit - 1); + if (in >= max) { + *grad_in = static_cast(0); + return max; + } else { + *grad_in = static_cast(1); + return in; + } + } +} + +// Reflects coordinates until they fall between low and high (inclusive). +// The bounds are passed as twice their value so that half-integer values +// can be represented as ints. +template +static inline scalar_t reflect_coordinates(scalar_t in, int64_t twice_low, + int64_t twice_high) { + if (twice_low == twice_high) { + return static_cast(0); + } + scalar_t min = static_cast(twice_low) / 2; + scalar_t span = static_cast(twice_high - twice_low) / 2; + in = std::fabs(in - min); + // `fmod` returns same sign as `in`, which is positive after the `fabs` above. + scalar_t extra = std::fmod(in, span); + int flips = static_cast(std::floor(in / span)); + if (flips % 2 == 0) { + return extra + min; + } else { + return span - extra + min; + } +} + +// reflect_coordinates_set_grad works similarly to reflect_coordinates except +// that it also returns the `d output / d input` via pointer argument +// `grad_in`. +// This is useful in the backward pass of grid_sampler. +template +static inline scalar_t reflect_coordinates_set_grad(scalar_t in, int64_t twice_low, + int64_t twice_high, scalar_t *grad_in) { + if (twice_low == twice_high) { + *grad_in = static_cast(0); + return static_cast(0); + } + int grad_in_mult_; + scalar_t min = static_cast(twice_low) / 2; + scalar_t span = static_cast(twice_high - twice_low) / 2; + in = in - min; + if (in < static_cast(0)) { + grad_in_mult_ = -1; + in = -in; + } else { + grad_in_mult_ = 1; + } + // `fmod` returns same sign as `in`, which is positive after the `if` above. + scalar_t extra = std::fmod(in, span); + int flips = static_cast(std::floor(in / span)); + if (flips % 2 == 0) { + *grad_in = static_cast(grad_in_mult_); + return extra + min; + } else { + *grad_in = static_cast(-grad_in_mult_); + return span - extra + min; + } +} + +// Mapping the out-of-boundary points back into boundary +// This would only affect padding_mode=border or reflection +template +static inline scalar_t compute_coordinates(scalar_t coord, int64_t size, + GridSamplerPadding padding_mode, + bool align_corners) { + if (padding_mode == GridSamplerPadding::Border) { + // clip coordinates to image borders + coord = clip_coordinates(coord, size); + } else if (padding_mode == GridSamplerPadding::Reflection) { + // reflect coordinates by image borders + if (align_corners) { + coord = reflect_coordinates(coord, 0, 2*(size - 1)); + } else { + coord = reflect_coordinates(coord, -1, 2*size - 1); + } + // clip coordinates to image borders + coord = clip_coordinates(coord, size); + } + return coord; +} + +// Computes the pixel source index value for a grid coordinate +template +static inline scalar_t grid_sampler_compute_source_index( + scalar_t coord, + int64_t size, + GridSamplerPadding padding_mode, + bool align_corners) { + coord = grid_sampler_unnormalize(coord, size, align_corners); + coord = compute_coordinates(coord, size, padding_mode, align_corners); + return coord; +} + +// grid_sampler_compute_source_index_set_grad works similarly to +// grid_sampler_compute_source_index except that it also returns the +// `d output / d input` via pointer argument `grad_in`. +// This is useful in the backward pass of grid_sampler. +template +static inline scalar_t grid_sampler_compute_source_index_set_grad( + scalar_t coord, + int64_t size, + GridSamplerPadding padding_mode, + bool align_corners, + scalar_t *grad_in) { + scalar_t grad_clip, grad_refl; + coord = grid_sampler_unnormalize_set_grad(coord, size, align_corners, grad_in); + if (padding_mode == GridSamplerPadding::Border) { + // clip coordinates to image borders + coord = clip_coordinates_set_grad(coord, size, &grad_clip); + *grad_in = (*grad_in) * grad_clip; + } else if (padding_mode == GridSamplerPadding::Reflection) { + // reflect coordinates by image borders + if (align_corners) { + coord = reflect_coordinates_set_grad(coord, 0, 2*(size - 1), &grad_refl); + } else { + coord = reflect_coordinates_set_grad(coord, -1, 2*size - 1, &grad_refl); + } + // clip coordinates to image borders + coord = clip_coordinates_set_grad(coord, size, &grad_clip); + *grad_in = (*grad_in) * grad_refl * grad_clip; + } + return coord; +} + +static inline bool within_bounds_2d(int64_t h, int64_t w, int64_t H, int64_t W) { + return h >= 0 && h < H && w >= 0 && w < W; +} + +static inline bool within_bounds_3d(int64_t d, int64_t h, int64_t w, int64_t D, int64_t H, int64_t W) { + return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W; +} + +template +static inline scalar_t get_value_bounded( + const scalar_t* data, + scalar_t x, + scalar_t y, + int64_t W, + int64_t H, + int64_t sW, + int64_t sH, + GridSamplerPadding padding_mode, + bool align_corners) { + + x = compute_coordinates(x, W, padding_mode, align_corners); + y = compute_coordinates(y, H, padding_mode, align_corners); + + int64_t ix = static_cast(x); + int64_t iy = static_cast(y); + + if (within_bounds_2d(iy, ix, H, W)) { + return data[iy * sH + ix * sW]; + } + return static_cast(0); +} + +template +static inline void safe_add_2d(scalar_t *data, int64_t h, int64_t w, + int64_t sH, int64_t sW, int64_t H, int64_t W, + scalar_t delta) { + if (within_bounds_2d(h, w, H, W)) { + data[h * sH + w * sW] += delta; + } +} + +template +static inline void safe_add_3d(scalar_t *data, int64_t d, int64_t h, int64_t w, + int64_t sD, int64_t sH, int64_t sW, + int64_t D, int64_t H, int64_t W, + scalar_t delta) { + if (within_bounds_3d(d, h, w, D, H, W)) { + data[d * sD + h * sH + w * sW] += delta; + } +} + +template +static inline void add_value_bounded( + scalar_t* data, + scalar_t x, + scalar_t y, + int64_t W, + int64_t H, + int64_t sW, + int64_t sH, + scalar_t delta, + GridSamplerPadding padding_mode, + bool align_corners) { + + x = compute_coordinates(x, W, padding_mode, align_corners); + y = compute_coordinates(y, H, padding_mode, align_corners); + + int64_t ix = static_cast(x); + int64_t iy = static_cast(y); + + safe_add_2d(data, iy, ix, sH, sW, H, W, delta); +} + +// Calculate the differential of the cubic convolution, i.e. `d coeff / d x` +template +static inline void get_cubic_coefficients_grad( + scalar_t coeffs[4], + scalar_t t) { + + // Must be the same as forward calculation in + // aten/src/ATen/native/UpSample.h:get_cubic_upsample_coefficients + scalar_t A = -0.75; + + scalar_t x; + x = -1 - t; // 1 < x = |-1 - tx| < 2 + coeffs[0] = (-3 * A * x - 10 * A ) * x - 8 * A; + x = -t; // x = |0 - tx| <= 1 + coeffs[1] = (-3 * (A + 2) * x - 2 * (A + 3)) * x; + x = 1 - t; // x = |1 - tx| <= 1 + coeffs[2] = (3 * (A + 2) * x - 2 * (A + 3)) * x; + x = 2 - t; // 1 < x = |2 - tx| < 2 + coeffs[3] = (3 * A * x - 10 * A) * x + 8 * A; +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/GridSamplerUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/GridSamplerUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..2d4208bb47e3193c3d3ad3846df696a54ca4d82f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/GridSamplerUtils.h @@ -0,0 +1,105 @@ +#pragma once + +// See NOTE: [Tensor vs. TensorBase] +// https://github.com/pytorch/pytorch/pull/66979 +#include +#include +#include + +namespace at::native { + +namespace detail { + +enum class GridSamplerInterpolation {Bilinear, Nearest, Bicubic}; +enum class GridSamplerPadding {Zeros, Border, Reflection}; + +} // namespace detail + +using detail::GridSamplerInterpolation; +using detail::GridSamplerPadding; + +// See NOTE [ grid_sampler Native Functions ]. +inline void check_grid_sampler_common( + const TensorBase& input, + const TensorBase& grid +) { + auto input_opt = input.options(); + auto grid_opt = grid.options(); + + TORCH_CHECK( + input.defined(), + "grid_sampler(): expected input to not be undefined"); + TORCH_CHECK( + grid.defined(), + "grid_sampler(): expected grid to not be undefined"); + TORCH_CHECK( + input_opt.device() == grid_opt.device(), + "grid_sampler(): expected input and grid to be on same device, but input " + "is on ", input_opt.device(), " and grid is on ", grid_opt.device()); + TORCH_CHECK( + input_opt.layout() == kStrided && grid_opt.layout() == kStrided, + "grid_sampler(): expected input and grid to have torch.strided layout, but " + "input has ", input_opt.layout(), " and grid has ", grid_opt.layout()); + TORCH_CHECK( + input.size(0) == grid.size(0), + "grid_sampler(): expected grid and input to have same batch size, but got " + "input with sizes ", input.sizes(), " and grid with sizes ", grid.sizes()); + TORCH_CHECK( + grid.size(-1) == input.dim() - 2, + "grid_sampler(): expected grid to have size ", input.dim() - 2, " in last " + "dimension, but got grid with sizes ", grid.sizes()); + + for (const auto i : c10::irange(2, input.dim())) { + TORCH_CHECK(input.size(i) > 0, + "grid_sampler(): expected input to have non-empty spatial dimensions, " + "but input has sizes ", input.sizes(), " with dimension ", i, " being " + "empty"); + } +} + +// See NOTE [ grid_sampler Native Functions ]. +inline void check_grid_sampler_2d( + const TensorBase& input, + const TensorBase& grid +) { + TORCH_CHECK( + input.dim() == 4 && input.dim() == grid.dim(), + "grid_sampler(): expected 4D input and grid with same number of " + "dimensions, but got input with sizes ", input.sizes(), + " and grid with sizes ", grid.sizes()); +} + +// See NOTE [ grid_sampler Native Functions ]. +inline void check_grid_sampler_3d( + const TensorBase& input, + const TensorBase& grid, + int64_t interpolation_mode +) { + TORCH_CHECK( + input.dim() == 5 && input.dim() == grid.dim(), + "grid_sampler(): expected 5D input and grid with same number of " + "dimensions, but got input with sizes ", input.sizes(), + " and grid with sizes ", grid.sizes()); + TORCH_CHECK( + !(input.dim() == 5 && + static_cast(interpolation_mode) == + GridSamplerInterpolation::Bicubic), + "grid_sampler(): bicubic interpolation only supports 4D input"); +} + +// See NOTE [ grid_sampler Native Functions ]. +// cudnn does not support inputs larger than 1024. +inline bool cond_cudnn_grid_sampler( + const TensorBase& input, + const TensorBase& grid +) { + return ( + at::native::cudnn_is_acceptable(input) && + at::native::cudnn_is_acceptable(grid) && + at::native::canUse32BitIndexMath(input) && + at::native::canUse32BitIndexMath(grid) && + input.dim() == 4 && + input.sym_size(1) <= 1024); +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/Histogram.h b/phivenv/Lib/site-packages/torch/include/ATen/native/Histogram.h new file mode 100644 index 0000000000000000000000000000000000000000..d49ad5b2b1b164f2554556bca220c5c622a128e0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/Histogram.h @@ -0,0 +1,16 @@ +#pragma once + +#include +#include + +namespace at::native { + +using histogramdd_fn = void(*)(const Tensor&, const std::optional&, bool, Tensor&, const TensorList&); +using histogramdd_linear_fn = void(*)(const Tensor&, const std::optional&, bool, Tensor&, const TensorList&, bool); +using histogram_select_outer_bin_edges_fn = void(*)(const Tensor& input, const int64_t N, std::vector &leftmost_edges, std::vector &rightmost_edges); + +DECLARE_DISPATCH(histogramdd_fn, histogramdd_stub) +DECLARE_DISPATCH(histogramdd_linear_fn, histogramdd_linear_stub) +DECLARE_DISPATCH(histogram_select_outer_bin_edges_fn, histogram_select_outer_bin_edges_stub) + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/IndexKernel.h b/phivenv/Lib/site-packages/torch/include/ATen/native/IndexKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..e1709eca647f2a878355947bee7e2defab79fcdb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/IndexKernel.h @@ -0,0 +1,41 @@ +#pragma once +#include +#include + +namespace at { +class Tensor; +class TensorBase; +struct TensorIterator; +struct TensorIteratorBase; +} + +namespace c10 { +class Scalar; +} + +namespace at::native { + +using index_fn = void(*)(TensorIteratorBase &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides); +using index_fill_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_dim_size, int64_t self_dim_stride, const Scalar& source); +using index_copy_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_dim_size, int64_t self_dim_stride); +using index_put_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides, bool accumulate); +using put_fn = void(*)(TensorIterator & iter, const TensorBase& self, const bool accumulate); +using take_fn = void(*)(TensorIterator & iter, const TensorBase& input); +using flip_fn = void(*)(TensorIterator &, const bool); +using masked_fill_fn = void(*)(TensorIterator &, const Scalar& scalar); +using masked_select_fn = void(*)(TensorIterator &, int64_t orig_stride); +using masked_scatter_fn = void(*)(TensorIterator &, const TensorBase &); + +DECLARE_DISPATCH(index_fn, index_stub) +DECLARE_DISPATCH(index_fill_fn, index_fill_stub) +DECLARE_DISPATCH(index_copy_fn, index_copy_stub) +DECLARE_DISPATCH(index_put_fn, index_put_stub) +DECLARE_DISPATCH(put_fn, put_stub) +DECLARE_DISPATCH(take_fn, take_stub) +DECLARE_DISPATCH(flip_fn, flip_stub) +DECLARE_DISPATCH(masked_fill_fn, masked_fill_stub) +DECLARE_DISPATCH(masked_select_fn, masked_select_serial_stub) +DECLARE_DISPATCH(masked_select_fn, masked_select_stub) +DECLARE_DISPATCH(masked_scatter_fn, masked_scatter_stub) + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/IndexingUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/IndexingUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..6971b1f1d67ec595c42ec64a0e602bc60fb937f6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/IndexingUtils.h @@ -0,0 +1,164 @@ +#pragma once +#include +#include +#include +#include +#include + +namespace at::native { + +[[noreturn]] +static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask, int64_t maskIdx) { + TORCH_CHECK_INDEX(false, "The shape of the mask ", mask.sizes(), " at index ", maskIdx, + " does not match the shape of the indexed tensor ", self.sizes(), " at index ", idx); +} + +[[maybe_unused]] static std::vector expandTensors( + const Tensor& self, + IOptTensorListRef indices) { + // If indices come in as ByteTensor or BoolTensor (masks), expand them into + // the equivalent indexing by LongTensors + std::vector result; + for (const auto& index_opt : indices) { + if (!index_opt.has_value()) { + result.emplace_back(); + } else { + const auto& index = *index_opt; + if (index.scalar_type() == kByte || index.scalar_type() == kBool) { + if (index.scalar_type() == kByte) { + TORCH_WARN("indexing with dtype torch.uint8 is now deprecated," \ + " please use a dtype torch.bool instead."); + } + // The sizes of the ByteTensor mask or bool tensor must match the sizes of the + // corresponding dimensions in self + for (const auto j : c10::irange(index.dim())) { + int64_t srcIdx = static_cast(result.size() + j); + if (index.size(j) != self.size(srcIdx)) { + invalid_mask(self, srcIdx, index, j); + } + } + // Replace with nonzeros + auto nonzero = index.nonzero(); + for (const auto j : c10::irange(index.dim())) { + result.emplace_back(nonzero.select(1, j)); + } + } else { + result.emplace_back(index); + } + } + } + return result; +} + +[[maybe_unused]] static void checkIndexTensorTypes( + IOptTensorListRef indices, + bool allow_int = false) { + for (const auto& tensor : indices) { + if (tensor.has_value() && tensor->defined()) { + auto scalarType = tensor->scalar_type(); + if (allow_int) { + if (scalarType != kLong && scalarType != kByte && scalarType != kBool && scalarType != kInt) { + TORCH_CHECK_INDEX(false, "tensors used as indices must be long, int, byte or bool tensors"); + } + } else { + if (scalarType != kLong && scalarType != kByte && scalarType != kBool) { + TORCH_CHECK_INDEX(false, "tensors used as indices must be long, byte or bool tensors"); + } + } + } + } +} + +inline torch::List> toListOfOptionalTensors(ArrayRef list) { + torch::List> result; + result.reserve(list.size()); + for (const Tensor& a : list) { + result.push_back(a); + } + return result; +} + +inline torch::List> toListOfOptionalTensors(ArrayRef list) { + torch::List> result; + result.reserve(list.size()); + for (const IValue& a : list) { + result.push_back(a.isTensor() ? std::optional(a.toTensor()) : std::optional()); + } + return result; +} + +[[maybe_unused]] static bool hasContiguousSubspace(TensorList tl) { + // true if all the non-null tensors are adjacent + auto isDefined = [](const Tensor & tensor){ return tensor.defined(); }; + auto isNull = [](const Tensor & tensor){ return !tensor.defined(); }; + auto start = std::find_if(tl.begin(), tl.end(), isDefined); + auto stop = std::find_if(tl.rbegin(), tl.rend(), isDefined); + auto it = std::find_if(start, stop.base(), isNull); + return it == stop.base(); +} + +// Transposes the tensor and indices together so that all the non-null indices +// index the first k dimensions of the tensor. Returns the transposed tensor +// and the reordered indices. For example: +// transposeToFront(tensor, {nullptr, a, nullptr, b}) +// returns +// tensor.permute([1, 3, 0, 2]), {a, b, nullptr, nullptr} +[[maybe_unused]] static std::tuple> transposeToFront( + const Tensor& self, + TensorList indices) { + std::vector dims; + std::vector transposedIndices; + dims.reserve(self.dim()); + for (const auto i : c10::irange(self.dim())) { + if (indices[i].defined()) { + dims.push_back(i); + transposedIndices.emplace_back(indices[i]); + } + } + for (const auto i : c10::irange(self.dim())) { + if (!indices[i].defined()) { + dims.push_back(i); + transposedIndices.emplace_back(); + } + } + return std::make_tuple(self.permute(dims), std::move(transposedIndices)); +} + +inline std::tuple, std::vector> +transposeToFrontAndInvPerm(const Tensor& self, TensorList indices) { + std::vector dims; + std::vector invPerm; + std::vector transposedIndices; + dims.reserve(self.dim()); + invPerm.resize(self.dim()); + for (const auto i : c10::irange(self.dim())) { + if (indices[i].defined()) { + dims.push_back(i); + transposedIndices.emplace_back(indices[i]); + } + } + for (const auto i : c10::irange(self.dim())) { + if (!indices[i].defined()) { + dims.push_back(i); + transposedIndices.emplace_back(); + } + } + for (const auto i : c10::irange(self.dim())) { + invPerm[dims[i]] = i; + } + return std::make_tuple(self.permute(dims), std::move(transposedIndices), std::move(invPerm)); +} + +struct AdvancedIndex { + AdvancedIndex(const Tensor& src, TensorList indices); + + Tensor src; + std::vector indices; + DimVector indexed_sizes; + DimVector indexed_strides; + int64_t dims_before; + int64_t dims_after; +}; + + +} //namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/Lerp.h b/phivenv/Lib/site-packages/torch/include/ATen/native/Lerp.h new file mode 100644 index 0000000000000000000000000000000000000000..5bfc29b25fcad3264f8e6b0c108cbb1169281682 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/Lerp.h @@ -0,0 +1,46 @@ +#pragma once + +#include +#include +#include +#include + +namespace at::native { + +template +C10_HOST_DEVICE C10_ALWAYS_INLINE bool is_lerp_weight_small(scalar_t weight) { + return std::abs(weight) < scalar_t(0.5); +} +template +C10_HOST_DEVICE C10_ALWAYS_INLINE bool is_lerp_weight_small(c10::complex weight) { + // Avoid the sqrt in abs(weight) + return (weight.real() * weight.real() + weight.imag() * weight.imag()) < scalar_t(0.25); +} + +template +C10_HOST_DEVICE C10_ALWAYS_INLINE scalar_t lerp(scalar_t self_, scalar_t end_, weight_t weight_) { + using opmath_t = at::opmath_type; + using opmath_weight_t = at::opmath_type; + + opmath_t self = self_; + opmath_t end = end_; + opmath_weight_t weight = weight_; + + // Conditional for better numeric. This has been discussed in + // https://github.com/pytorch/pytorch/pull/18871 + return is_lerp_weight_small(weight) + ? self + weight * (end - self) + : end - (end - self) * (opmath_t(1) - weight); +} + +using lerp_fn_scalar = void (*)( + at::TensorIteratorBase& iter, + const Scalar& weight); + +using lerp_fn_tensor = void (*)( + at::TensorIteratorBase& iter); + +DECLARE_DISPATCH(lerp_fn_scalar, lerp_kernel_scalar_weight) +DECLARE_DISPATCH(lerp_fn_tensor, lerp_kernel_tensor_weight) + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/LinearAlgebra.h b/phivenv/Lib/site-packages/torch/include/ATen/native/LinearAlgebra.h new file mode 100644 index 0000000000000000000000000000000000000000..5ecef0d72d4790c4f297724ebe01038180ef1612 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/LinearAlgebra.h @@ -0,0 +1,17 @@ +#pragma once + +#include + +namespace c10 { +class Scalar; +} + +namespace at { +struct TensorIterator; +} + +namespace at::native { + +using addr_fn = void (*)(TensorIterator &, const Scalar& beta, const Scalar& alpha); +DECLARE_DISPATCH(addr_fn, addr_stub) +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/LinearAlgebraUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/LinearAlgebraUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..f1cd236c5f4c0ff3cdc3b492124107225718f389 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/LinearAlgebraUtils.h @@ -0,0 +1,624 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#include +#include +#include +#endif + +namespace at::native { + +inline c10::MaybeOwned expect_resolved_conj(const Tensor& tensor) { + if (tensor.is_conj()) { + return c10::MaybeOwned::owned(tensor.resolve_conj()); + } else { + return c10::MaybeOwned::borrowed(tensor); + } +} + +inline DimVector batched_matrix_contiguous_strides( + const IntArrayRef sizes, + const bool f_contig = false) { + // f_contig chooses between the strides of a batch of Fortran (F-contiguous) + // and C-contiguous matrices + auto strides = c10::contiguous_strides(sizes); + auto dim = strides.size(); + + if (f_contig && dim >= 2) { + // Fix the strides of the last two dimensions, so that we return + // C-contiguous batches of F-contiguous matrices. + strides[dim - 1] = std::max(sizes[dim - 2], static_cast(1)); + strides[dim - 2] = 1; + } + return strides; +} + +/* + * Clones a Tensor so that the following conditions hold: + * If we think of a Tensor of having size (B, M, N), where B is any number + * of batch dimensions, then: + * - Each (M, N) matrix is in column major form + * - Let Tensor P have size (B, M, N) and Q have size (B, M', N'). + * Then when laid out in memory, the M by N matrix starting at + * P.data_ptr()[B * M * N] is of the same corresponding batch as the M' by N' + * matrix starting at Q.data_ptr()[B * M' * N']. + */ +inline Tensor cloneBatchedColumnMajor(const Tensor& src) { + // If src is already in batched column major format, then + // this will be efficient (no reordering of the data will occur) + // because the first transpose will make the tensor contiguous, + // and cloning a contiguous tensor is fast. + auto result = src.mT().clone(at::MemoryFormat::Contiguous); + result.transpose_(-2, -1); + return result; +} + +/* + * contig chooses between C-contig (true) and F-contig (false) + */ +inline c10::MaybeOwned borrow_else_clone(const bool cond, const Tensor& borrow, const Tensor& clone, const bool contig) { + return cond ? c10::MaybeOwned::borrowed(borrow) + : c10::MaybeOwned::owned(contig ? clone.clone(MemoryFormat::Contiguous) + : cloneBatchedColumnMajor(clone)); +} + +/* + * This method is designed to be a faster alternative to + * `cloneBatchedColumnMajor` with some additional features, + * namely: + * 1. It uses `copy` instead of `clone` which could be much faster. + * 2. `nrows` parameter used to create inputs with the number of rows larger + * than the original input, which is required for some LAPACK/MAGMA methods. + * 3. `desired_batch_size` is used to create copies with the batch size + * which is either the original batch size of the input, or its larger + * broadcasted shape. + */ +inline Tensor copyBatchedColumnMajor(const Tensor& src, int64_t nrows = -1, + at::OptionalIntArrayRef desired_batch_sizes = std::nullopt) { + nrows = (nrows == -1) ? src.size(-2) : nrows; + auto copy_sizes = desired_batch_sizes.has_value() + ? desired_batch_sizes.value().vec() + : IntArrayRef(src.sizes().data(), src.dim() - 2).vec(); + copy_sizes.insert(copy_sizes.end(), {nrows, src.size(-1)}); + const auto copy_strides = batched_matrix_contiguous_strides(copy_sizes, /*f-contig*/true); + auto copy = at::empty_strided(copy_sizes, copy_strides, src.options()); + copy.narrow(-2, 0, src.size(-2)).copy_(src); + return copy; +} + +/* + * Given batches of matrices with arbitrary batch dim, + * computes the number of batches. + */ +inline int64_t batchCount(const Tensor& batched_matrices) { + int64_t result = 1; + for (int64_t i = 0; i < batched_matrices.ndimension() - 2; i++) { + result *= batched_matrices.size(i); + } + return result; +} + +// Computes the number of elements of a matrix in a batched matrix tensor +inline int64_t matrixStride(const Tensor& batched_matrices) { + return batched_matrices.size(-1) * batched_matrices.size(-2); +} + +// Validates input shapes for operations on batches of square matrices (inverse, cholesky, symeig, eig) +inline void checkIsMatrix(const Tensor& A, const char* const f_name, const char* const arg_name = "A") { + TORCH_CHECK(A.dim() >= 2, f_name, ": The input tensor ", arg_name, " must have at least 2 dimensions."); +} +inline void squareCheckInputs(const Tensor& self, const char* const f_name, const char* const arg_name = "A") { + checkIsMatrix(self, f_name, arg_name); + TORCH_CHECK(self.sym_size(-1) == self.sym_size(-2), + f_name, + ": ", arg_name, " must be batches of square matrices, " + "but they are ", self.sym_size(-2), " by ", self.sym_size(-1), " matrices"); +} + +inline void checkInputsSolver(const Tensor& A, + const Tensor& B, + const bool left, + const char* const f_name) { + squareCheckInputs(A, f_name, "A"); + checkIsMatrix(B, f_name, "B"); + TORCH_CHECK(left ? A.size(-2) == B.size(-2) : A.size(-1) == B.size(-1), + f_name, ": Incompatible shapes of A and B for the equation ", + left ? "AX = B" : "XA = B", + " (", A.size(-2), "x", A.size(-1), " and ", B.size(-2), "x", B.size(-1), ")"); +} + +inline bool is_row_or_column_contiguous(const Tensor& t) { + // This could be made more general, similar to how it's checked in matmul, which would allow to + // ellide the copy with strides such as (6, 12, 1, 3) or (3, 1, 9), but this is quite tricky. + // We choose to be conservative for simplicity + return t.is_contiguous() || t.transpose(-2, -1).is_contiguous(); +} + +inline TransposeType to_transpose_type(const bool contig, const bool conj) { + if (conj) { + if (contig) { TORCH_INTERNAL_ASSERT(false, "Invalid transpose type"); } + else { return TransposeType::ConjTranspose; } + } else { + if (contig) { return TransposeType::NoTranspose; } + else { return TransposeType::Transpose; } + } +} + + +// This function is designed to be used with linear algebra methods that minimize +// L(ax - b) = 0, where L is generally the identity map (`solve`, for example) +// or the L2 norm (`lstsq`). +// It is expected that `a` and `b` are contiguous tensors of column-major matrices +// (so that a.view({-1, a.size(-2), a.size(-1)}) succeeds, same for `b`), +// with the following additional properties: +// +// 1. a.dim() == b.dim() +// 2. a.shape[:-2] broadcasts over b.shape[:-2] +// 3. a.size(i) <= b.size(i) for i=0,..., a.dim() - 3 (only for batch dimensions) +// +// MAGMA/LAPACK modify tensor `a` in-place, and the main goal of this method +// is to be memory efficient, which means that if there exists an index i such that +// a.shape[i] < b.shape[i], 0 <= i <= a.dim() - 3, +// then instead of materializing copies of `a` in the broadcasted shape, we keep +// a buffer copy of `a` along with flags that check whether specific batch dimension +// indices for `a` were already accessed. If they were, we copy the data from the buffer +// into `a`. The number of copies does not exceed +// prod(max(a.shape[:-2], b.shape[:-2]) - a.shape[:-2] + 1) +// and this value is attained by tensors with non-empty batch dimensions. +// +// func_t `f` is a callable that is being supplied with +// scalar_t* a_working_ptr, scalar_t* b_working_ptr, int64_t a_linear_batch_idx. +// a_working_ptr and b_working_ptr can directly be passed to LAPACK/MAGMA routines, +// and a_linear_batch_idx is an index in the 3d representation which corresponds to +// the memory a_working_ptr points to, in other words: +// a_working_ptr == a.view({-1, a.size(-2), a.size(-1)}.select(0, a_linear_batch_idx).data_ptr(); +// a_linear_batch_idx is useful to store metadata related to `a`, such as, for example, +// its rank or singular values (see linalg_lstsq). +template +void batch_iterator_with_broadcasting(const Tensor& a, const Tensor& b, const func_t& f) { + IntArrayRef a_batch_sizes(a.sizes().data(), a.dim() - 2); + IntArrayRef b_batch_sizes(b.sizes().data(), b.dim() - 2); + + auto a_linear_batch_idx = at::arange(batchCount(a)).view(a_batch_sizes); + auto b_linear_batch_idx = at::arange(batchCount(b)).view(b_batch_sizes); + + TensorIterator iter = TensorIteratorConfig() + .set_check_mem_overlap(false) + .check_all_same_dtype(false) + .resize_outputs(false) + .add_output(b_linear_batch_idx) + .add_input(a_linear_batch_idx) + .build(); + + auto m = a.size(-2); + auto n = a.size(-1); + auto a_3d = a.view({batchCount(a), m, n}); + auto b_3d = b.view({batchCount(b), b.size(-2), b.size(-1)}); + + auto a_broadcasts_over_b = (a_batch_sizes != b_batch_sizes); + Tensor a_buffer, a_was_accessed, a_buffer_3d; + std::function check_if_copy_needed_for_a + = [](int64_t /*a_curr_linear_batch_idx*/){}; + if (a_broadcasts_over_b) { + a_buffer = at::empty_strided(a.sizes(), a.strides(), a.options()) + .copy_(a); + a_was_accessed = at::zeros(batchCount(a), at::kBool); + a_buffer_3d = a_buffer.view({batchCount(a), m, n}); + check_if_copy_needed_for_a = [&](int64_t a_curr_linear_batch_idx) { + auto* a_was_accessed_flag = a_was_accessed + .select(0, a_curr_linear_batch_idx) + .data_ptr(); + if (!(*a_was_accessed_flag)) { + *a_was_accessed_flag = true; + } + else { + a_3d.select(0, a_curr_linear_batch_idx) + .copy_(a_buffer_3d.select(0, a_curr_linear_batch_idx)); + } + }; + } + + auto loop = [&](char** data, const int64_t* strides, int64_t nelems) { + auto* b_batch_idx_ptr = data[0]; + auto* a_batch_idx_ptr = data[1]; + + for ([[maybe_unused]] const auto elem : c10::irange(nelems)) { + auto b_curr_linear_batch_idx = + *reinterpret_cast(b_batch_idx_ptr); + auto a_curr_linear_batch_idx = *reinterpret_cast(a_batch_idx_ptr); + + check_if_copy_needed_for_a(a_curr_linear_batch_idx); + + auto* a_working_ptr = a_3d.select(0, a_curr_linear_batch_idx) + .data_ptr(); + auto* b_working_ptr = b_3d.select(0, b_curr_linear_batch_idx) + .data_ptr(); + f(a_working_ptr, b_working_ptr, a_curr_linear_batch_idx); + + b_batch_idx_ptr += strides[0]; + a_batch_idx_ptr += strides[1]; + } + }; + iter.serial_for_each(loop, {0, batchCount(b)}); +} + +// Returns the epsilon value for floating types except half +inline double _get_epsilon(const ScalarType& sc_type) { + switch (sc_type) { + case at::ScalarType::Float: + return static_cast(std::numeric_limits::epsilon()); + case at::ScalarType::Double: + return std::numeric_limits::epsilon(); + default: + TORCH_CHECK(false, "This function doesn't handle types other than float and double"); + } +} + +// Validates input shapes and devices +// for linear solve methods (solve, cholesky_solve, lu_solve, triangular_solve) +inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A, const char* name) { + TORCH_CHECK(self.device() == A.device(), + "Expected b and A to be on the same device, but found b on ", + self.device(), " and A on ", A.device(), " instead."); + + TORCH_CHECK(self.scalar_type() == A.scalar_type(), + "Expected b and A to have the same dtype, but found b of type ", + self.scalar_type(), " and A of type ", A.scalar_type(), " instead."); + + TORCH_CHECK(A.size(-1) == A.size(-2), + "A must be batches of square matrices, " + "but they are ", A.size(-2), " by ", A.size(-1), " matrices"); + + TORCH_CHECK(A.size(-1) == self.size(-2), + "Incompatible matrix sizes for ", name, ": each A " + "matrix is ", A.size(-1), " by ", A.size(-1), + " but each b matrix is ", self.size(-2), " by ", self.size(-1)); +} + +inline void checkFloatingOrComplex(const Tensor& t, const char* const f_name, const bool allow_low_precision_dtypes=true) { + auto dtype = t.scalar_type(); + TORCH_CHECK((at::isFloatingType(dtype) || at::isComplexType(dtype)), + f_name, ": Expected a floating point or complex tensor as input. Got ", dtype); + if (!allow_low_precision_dtypes) { + TORCH_CHECK(dtype == kFloat || dtype == kDouble || dtype == kComplexFloat || dtype == kComplexDouble, + f_name, ": Low precision dtypes not supported. Got ", dtype); + } +} + + +// Checks if all the Tensors in a TensorList are of the same dimensions +inline void checkAllSameDim(TensorList tensors, int64_t dim) { + for (auto &t : tensors) { + TORCH_CHECK(t.dim() == dim, "Tensor dimension is ", t.dim(), ", expected ", dim, " instead."); + } +} + +inline std::tuple, std::vector> _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2) { + // broadcast the batch dimensions of arg1 and arg2. + IntArrayRef arg1_batch_sizes(arg1.sizes().data(), arg1.ndimension() - 2); + IntArrayRef arg2_batch_sizes(arg2.sizes().data(), arg2.ndimension() - 2); + std::vector expand_batch_portion = infer_size(arg1_batch_sizes, arg2_batch_sizes); + + std::vector arg1_expand_size({expand_batch_portion}); + arg1_expand_size.insert(arg1_expand_size.end(), { arg1.size(-2), arg1.size(-1) }); + + std::vector arg2_expand_size({expand_batch_portion}); + arg2_expand_size.insert(arg2_expand_size.end(), { arg2.size(-2), arg2.size(-1) }); + return std::make_tuple(std::move(arg1_expand_size), std::move(arg2_expand_size)); +} + +inline std::tuple _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2, const char* name) { + // If there's no name we assume we don't want to check the errors + if (name != nullptr) { + linearSolveCheckInputs(arg1, arg2, name); + } + + auto [arg1_expand_size, arg2_expand_size] = at::native::_linalg_broadcast_batch_dims(arg1, arg2); + + auto arg1_broadcasted = arg1_expand_size == arg1.sizes() ? arg1 : arg1.expand(arg1_expand_size); + auto arg2_broadcasted = arg2_expand_size == arg2.sizes() ? arg2 : arg2.expand(arg2_expand_size); + return std::make_tuple(arg1_broadcasted, arg2_broadcasted); +} + +inline std::vector broadcast_batch_size(const Tensor& t1, const Tensor& t2, int64_t n_batch_dims) { + IntArrayRef t1_batch_sizes(t1.sizes().data(), n_batch_dims); + IntArrayRef t2_batch_sizes(t2.sizes().data(), n_batch_dims); + auto broadcasted_batch_sizes = infer_size(t1_batch_sizes, t2_batch_sizes); + return broadcasted_batch_sizes; +} + +// Return a permutation with the given axes moved to the end. +inline Tensor _move_to_end(const Tensor& self, IntArrayRef axes) { + const std::vector a = axes.vec(); + const int64_t ndim = self.ndimension(); + std::vector perm; + + for (const auto i : c10::irange(ndim)) { + auto it = std::find(a.begin(), a.end(), i); + if (it == a.end()) { + perm.push_back(i); + } + } + for (auto i : a) { + perm.push_back(i); + } + + TORCH_CHECK((int64_t)perm.size() == ndim, + "duplicate or invalid axis in 'dim' argument for tensor with ndim==", ndim); + + return self.permute(perm); +} + +// parse the "mode" param in linalg_qr: return a tuple of bools (compute_q, reduced) +inline std::tuple _parse_qr_mode(std::string_view mode) { + bool compute_q; + bool reduced; + if (mode == "reduced") { + compute_q = true; + reduced = true; + } else if (mode == "complete") { + compute_q = true; + reduced = false; + } else if (mode == "r") { + compute_q = false; + reduced = true; // this is actually irrelevant in this mode + } else { + TORCH_CHECK(false, "qr received unrecognized mode '", mode, + "' but expected one of 'reduced' (default), 'r', or 'complete'"); + } + return std::make_tuple(compute_q, reduced); +} + +// Function to compute sizes, strides and the extra columns for the Q matrix in the QR Decomposition +inline std::tuple _compute_geometry_for_Q( + const Tensor& input, + bool reduced) { + int64_t m = input.size(-2), n = input.size(-1); + int64_t n_columns_q; + + // We need to compute the required size of Q based on the `reduced` option + DimVector q_sizes(input.sizes()); + if (!reduced && m > n) { + q_sizes[input.dim() - 1] = m; + n_columns_q = m; + } else { + q_sizes[input.dim() - 1] = n; + n_columns_q = std::min(m, n); + } + auto q_strides = batched_matrix_contiguous_strides(q_sizes, /*f-contig*/true); + return std::make_tuple(q_sizes, q_strides, n_columns_q); +} + +inline bool svd_uses_cusolver(const Tensor& A) { + // if cusolver is available, it is used unconditionally + return A.is_cuda() + && at::globalContext().hasCuSOLVER() + && at::globalContext().linalgPreferredBackend() != at::LinalgBackend::Magma; +} + + +// Function used instead of .to so that the original strides are retained +// .to doesn't retain strides and make the output tensor contiguous +inline Tensor same_stride_to(const Tensor& original_tensor, const at::TensorOptions& options) { + auto strided_to = at::empty_strided(original_tensor.sizes(), + original_tensor.strides(), + options); + strided_to.copy_(original_tensor); + return strided_to; +} + +// Creates a dimension permutation array that can be given to `at::permute()`, which will shift +// the two specified dimensions to the end of a tensor, without changing the order of +// the other dimensions. `dim1` will be placed at the very end, and `dim0` will be +// placed just to the left of it. +// +// For instance, given a 4-D tensor, dimensions 1 and 3 can be shifted to the end by +// calling `create_dim_backshift_permutation(1, 3, 4)`. The resulting vector will +// be `vec(0, 2, 1, 3)`. +inline std::vector create_dim_backshift_permutation(int64_t dim0, int64_t dim1, int64_t ndim) { + TORCH_CHECK( + (dim0 != dim1) && (dim0 < ndim) && (dim0 >= 0) && (dim1 < ndim) && (dim1 >= 0), + "duplicate or invalid dimensions"); + std::vector permutation(ndim); + int64_t cur_permuted_dim = 0; + for (const auto dim_ind : c10::irange(ndim)) { + if ((dim_ind != dim0) && (dim_ind != dim1)) { + permutation[cur_permuted_dim++] = dim_ind; + } + } + permutation[cur_permuted_dim++] = dim0; + permutation[cur_permuted_dim] = dim1; + return permutation; +} + +// Creates a dimension permutation array that can be given to `at::permute()`, which +// will reverse a given permutation. +// The reverse permutation array is created by swapping the indices and their +// associated values from the given permutation array. +inline std::vector create_reverse_permutation(std::vector permutation) { + int64_t ndim = permutation.size(); + std::vector reverse_permutation(ndim); + for (const auto dim_ind : c10::irange(ndim)) { + reverse_permutation[permutation[dim_ind]] = dim_ind; + } + return reverse_permutation; +} + +// Compute R-work array size for MAGMA/LAPACK cgesdd/zgesdd +// See https://github.com/Reference-LAPACK/lapack/blob/122506cd8b6ce050a200920c3d4c0b153b150fd8/SRC/cgesdd.f#L186 +inline int64_t computeLRWorkDim(const char jobz, int64_t m, int64_t n) { + auto mn = std::min(m, n); + auto mx = std::max(m, n); + if (jobz == 'N') { +#ifdef __APPLE__ + // According to `vecLib.framework/Headers/clapack.h` Accelerate.framework is based on LAPACK 3.2.1 + return 7 * mn; +#else + // These setting is valid for on LAPACK 3.6+ + return 5 * mn; +#endif + } + if (mx > 10 * mn) { + return 5 * mn * mn + 5 * mn; + } + return std::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn); +} + +// This function checks whether the uplo argument input is valid +// Allowed strings are "u", "U", "l", "L" +inline void checkUplo(const std::string_view uplo) { + // To use std::toupper safely with plain chars (or signed chars), the argument should first be converted to unsigned char + char uplo_uppercase = static_cast(std::toupper(static_cast(uplo[0]))); + TORCH_CHECK(uplo.size() == 1 && (uplo_uppercase == 'U' || uplo_uppercase == 'L'), + "Expected UPLO argument to be 'L' or 'U', but got ", uplo); +} + +inline void checkSameDevice(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") { + TORCH_CHECK( + result.device() == input.device(), + fn_name, + ": Expected ", result_name, " and input tensors to be on the same device, but got ", + result_name, " on ", result.device(), " and input on ", input.device()); +} + +// Check the dtype of result and input tensors (for _out variants). +// Most linear algebra functions have the same dtype for input and output +// (either floating or complex type input), so we can check whether input's dtype can be casted to result's dtype. +// According to https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-does-out-work-in-pytorch +// c10::canCast is used for checking the "safe copy" dtype requirements. +inline void checkLinalgCompatibleDtype(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") { + bool can_cast = c10::canCast(input.scalar_type(), result.scalar_type()); + TORCH_CHECK( + can_cast, + fn_name, + ": Expected ", result_name, " to be safely castable from ", input.scalar_type(), " dtype, but got ", + result_name, " with dtype ", result.scalar_type()); +} + +// Alternatively, we can check whether the specific expected output type (result_type) can be safely casted to out tensor dtype (out_type) +inline void checkLinalgCompatibleDtype(const std::string& fn_name, ScalarType out_type, ScalarType result_type, const std::string& out_name = "result") { + bool can_cast = c10::canCast(result_type, out_type); + TORCH_CHECK( + can_cast, + fn_name, + ": Expected ", out_name, " to be safely castable from ", result_type, " dtype, but got ", + out_name, " with dtype ", out_type); +} + +inline void checkNotComplexTolerance(const Tensor& tol, const std::string_view f_name, const std::string_view tol_name) { + TORCH_CHECK(!at::isComplexType(tol.scalar_type()), + f_name, ": ", tol_name, " tensor of complex type is not supported. Got ", tol.scalar_type()); +} + +/* + Two types of 'other' tensors are supported when solving + a system of linear equations matmul(input, x) = other: + * 1-dimensional (1D) tensor or batch of 1D tensors (vector case) + * 2-dimensional (2D) tensor or batch of 2D tensors (matrix case). + The original torch.solve supported only the matrix case, while NumPy works for both cases. + For the batched input we need to be able to distinguish them. + Let input.shape = (batch_dimensions, m, n), then 'other' is of vector type if other.shape == (batch_dimensions, m). + This rule is compatible with NumPy, see https://github.com/numpy/numpy/blob/v1.20.0/numpy/linalg/linalg.py#L384-L389 +*/ +inline bool linalg_solve_is_vector_rhs(const Tensor& input, const Tensor& other) { + auto expected_batched_rhs_shape = SymIntArrayRef(input.sym_sizes().data(), input.dim() - 1); // input.shape[:-1] + bool vector_case = other.dim() == 1 || (input.dim() - 1 == other.dim() && other.sym_sizes().equals(expected_batched_rhs_shape)); + return vector_case; +} + +/* + Computes linear indices for a tensor with original_shape to access its elements like it was a materialized broadcast tensor. +*/ +inline Tensor get_linear_indices(int64_t numel, IntArrayRef original_shape, IntArrayRef broadcast_shape) { + TensorOptions options = at::TensorOptions().dtype(at::kLong).device(at::kCPU); + return at::arange(numel, options).view(original_shape).broadcast_to(broadcast_shape).contiguous(); +} + +class BroadcastLinearIndices { + private: + Tensor linear_indices_; + bool is_broadcasting_; + + public: + BroadcastLinearIndices( + int64_t numel, + IntArrayRef original_shape, + IntArrayRef broadcast_shape) : is_broadcasting_(!original_shape.equals(broadcast_shape)) { + // The assumption is that the broadcast_shape is a materialized broadcast + // shape of the original_shape. We need to compute the linear indices + // compatible with the original_shape to access the elements in the original + // tensor corresponding to the broadcast tensor. + if (is_broadcasting_) { + linear_indices_ = + get_linear_indices(numel, original_shape, broadcast_shape); + } + } + int64_t operator()(int64_t broadcast_linear_index) { + return is_broadcasting_ + ? linear_indices_.data_ptr()[broadcast_linear_index] + : broadcast_linear_index; + } +}; + +inline bool is_blas_compatible_column_major_order(const Tensor& input) { + IntArrayRef input_strides = input.strides(); + IntArrayRef input_sizes = input.sizes(); + auto ndim = input.dim(); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ndim >= 2); + if (ndim > 3) { + return input.transpose(-2, -1).is_contiguous(); + } + auto leading_dimension = input_strides[ndim - 1]; + auto rows = input_sizes[ndim - 2]; + bool batch_stride_compatible = true; + if (ndim == 3) { + auto cols = input_sizes[ndim - 1]; + batch_stride_compatible = + input_strides[ndim - 3] >= leading_dimension * cols; + } + return (input_strides[ndim - 2] == 1) && + (leading_dimension >= std::max(1, rows)) && + batch_stride_compatible; +} + +inline bool is_blas_compatible_row_major_order(const Tensor& input) { + IntArrayRef input_strides = input.strides(); + IntArrayRef input_sizes = input.sizes(); + auto ndim = input.dim(); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ndim >= 2); + if (ndim > 3) { + return input.is_contiguous(); + } + auto leading_dimension = input_strides[ndim - 2]; + auto cols = input_sizes[ndim - 1]; + bool batch_stride_compatible = true; + if (ndim == 3) { + auto rows = input_sizes[ndim - 2]; + batch_stride_compatible = + input_strides[ndim - 3] >= leading_dimension * rows; + } + return (input_strides[ndim - 1] == 1) && + (leading_dimension >= std::max(1, cols)) && + batch_stride_compatible; +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/LossMulti.h b/phivenv/Lib/site-packages/torch/include/ATen/native/LossMulti.h new file mode 100644 index 0000000000000000000000000000000000000000..6484ae7aeb824695af7559e3abb07b1799a82bb8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/LossMulti.h @@ -0,0 +1,69 @@ +#pragma once +#include +#include +#include +#include + +namespace at::native { + inline void multilabel_margin_loss_shape_check( + int64_t& nframe, + int64_t& dim, + const int64_t& ndims, + const Tensor& input, + const Tensor& target) { + TORCH_CHECK( + (ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0, + "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ", + input.sizes()); + + if (ndims <= 1) { + nframe = 1; + dim = ndims == 0 ? 1 : input.size(0); + TORCH_CHECK( + target.dim() <= 1 && target.numel() == dim, + "inconsistent target size: ", target.sizes(), " for input of size: ", + input.sizes()); + } else { + nframe = input.size(0); + dim = input.size(1); + TORCH_CHECK( + target.dim() == 2 && target.size(0) == nframe && + target.size(1) == dim, + "inconsistent target size: ", target.sizes(), " for input of size: ", + input.sizes()); + } + } + + inline void multi_margin_loss_shape_check( + int64_t& nframe, + int64_t& dim, + const int64_t& ndims, + const Tensor& input, + const Tensor& target, + const std::optional& weight) { + TORCH_CHECK( + (ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0, + "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ", + input.sizes()); + + if (ndims <= 1) { + nframe = 1; + dim = ndims == 0 ? 1 : input.size(0); + } else { + nframe = input.size(0); + dim = input.size(1); + } + + TORCH_CHECK( + target.dim() <= 1 && target.numel() == nframe, + "inconsistent target size, expected ", nframe, " but got ", + target.sizes()); + if (weight && weight->defined()) { + TORCH_CHECK( + weight->dim() <= 1 && weight->numel() == dim, + "inconsistent weight size, expected ", dim, " but got ", + weight->sizes()); + } +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/Math.h b/phivenv/Lib/site-packages/torch/include/ATen/native/Math.h new file mode 100644 index 0000000000000000000000000000000000000000..2b7714c580e04bcf1b8d2002323de4612b11390d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/Math.h @@ -0,0 +1,3927 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion") +#endif + +/* The next function is taken from https://github.com/antelopeusersgroup/antelope_contrib/blob/master/lib/location/libgenloc/erfinv.c. +Below is the copyright. +Output was modified to be inf or -inf when input is 1 or -1. */ + + +/* + Copyright (c) 2014 Indiana University + All rights reserved. + + Written by Prof. Gary L. Pavlis, Dept. of Geol. Sci., + Indiana University, Bloomington, IN + + This software is licensed under the New BSD license: + + Redistribution and use in source and binary forms, + with or without modification, are permitted provided + that the following conditions are met: + + Redistributions of source code must retain the above + copyright notice, this list of conditions and the + following disclaimer. + + Redistributions in binary form must reproduce the + above copyright notice, this list of conditions and + the following disclaimer in the documentation and/or + other materials provided with the distribution. + + Neither the name of Indiana University nor + the names of its contributors may be used to endorse + or promote products derived from this software without + specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND + CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED + WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A + PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL + THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY + DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF + USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER + IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE + USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + POSSIBILITY OF SUCH DAMAGE. +*/ + +namespace { +/* + * This function is derived from the implementation of the i0e function in the + * Cephes Math Library. See note [3-Clause BSD License for the Cephes Math + * Library]. + * + * Computes an approximation of the exponentially scaled zeroth order modified + * Bessel function of the first kind. The approximation is actually two + * (sub)approximations, both using a Chebyshev polynomial expansion. One + * approximates the function over [0, 8], and the other over (8, infinity). This + * function takes the absolute value of all inputs to convert them into the + * domain of the approximation. + */ +jiterator_also_stringify_as(jiterator_code( + template + JITERATOR_HOST_DEVICE T chbevl(T x, const T array[], const int len) { + T b0, b1, b2; + + b0 = array[0]; + b1 = 0; + + for (int i = 1; i < len; ++i) { + b2 = b1; + b1 = b0; + b0 = x * b1 - b2 + array[i]; + } + + return T{0.5} * (b0 - b2); + } + + template + JITERATOR_HOST_DEVICE T calc_i0e(T _x) { + T x = std::fabs(_x); + + if (x <= T{8.0}) { + static const T coefficients[] = { + -4.41534164647933937950E-18, 3.33079451882223809783E-17, + -2.43127984654795469359E-16, 1.71539128555513303061E-15, + -1.16853328779934516808E-14, 7.67618549860493561688E-14, + -4.85644678311192946090E-13, 2.95505266312963983461E-12, + -1.72682629144155570723E-11, 9.67580903537323691224E-11, + -5.18979560163526290666E-10, 2.65982372468238665035E-9, + -1.30002500998624804212E-8, 6.04699502254191894932E-8, + -2.67079385394061173391E-7, 1.11738753912010371815E-6, + -4.41673835845875056359E-6, 1.64484480707288970893E-5, + -5.75419501008210370398E-5, 1.88502885095841655729E-4, + -5.76375574538582365885E-4, 1.63947561694133579842E-3, + -4.32430999505057594430E-3, 1.05464603945949983183E-2, + -2.37374148058994688156E-2, 4.93052842396707084878E-2, + -9.49010970480476444210E-2, 1.71620901522208775349E-1, + -3.04682672343198398683E-1, 6.76795274409476084995E-1}; + + T y = (x / T{2.0}) - T{2.0}; + return chbevl(y, coefficients, int{30}); + } + + // x > 8 + static const T coefficients[] = { + -7.23318048787475395456E-18, -4.83050448594418207126E-18, + 4.46562142029675999901E-17, 3.46122286769746109310E-17, + -2.82762398051658348494E-16, -3.42548561967721913462E-16, + 1.77256013305652638360E-15, 3.81168066935262242075E-15, + -9.55484669882830764870E-15, -4.15056934728722208663E-14, + 1.54008621752140982691E-14, 3.85277838274214270114E-13, + 7.18012445138366623367E-13, -1.79417853150680611778E-12, + -1.32158118404477131188E-11, -3.14991652796324136454E-11, + 1.18891471078464383424E-11, 4.94060238822496958910E-10, + 3.39623202570838634515E-9, 2.26666899049817806459E-8, + 2.04891858946906374183E-7, 2.89137052083475648297E-6, + 6.88975834691682398426E-5, 3.36911647825569408990E-3, + 8.04490411014108831608E-1}; + + return chbevl(T{32.0} / x - T{2.0}, coefficients, int{25}) / std::sqrt(x); + }), + i0e_string) // i0e_string +} + +#define CENTRAL_RANGE 0.7 + +template +inline typename std::enable_if_t, T> +calc_erfinv(T y) { +/* Function to calculate inverse error function. Rational approximation +is used to generate an initial approximation, which is then improved to +full accuracy by two steps of Newton's method. Code is a direct +translation of the erfinv m file in matlab version 2.0. +Author: Gary L. Pavlis, Indiana University +Date: February 1996 +*/ + T x, z, num, dem; /*working variables */ + /* coefficients in rational expansion */ + T a[4] = { T(0.886226899), T(-1.645349621), T(0.914624893), T(-0.140543331) }; + T b[4] = { T(-2.118377725), T(1.442710462), T(-0.329097515), T(0.012229801) }; + T c[4] = { T(-1.970840454), T(-1.624906493), T(3.429567803), T(1.641345311) }; + T d[2] = { T(3.543889200), T(1.637067800) }; + T y_abs = std::abs(y); + if(y_abs > 1.0) return std::numeric_limits::quiet_NaN(); +#ifdef _WIN32 + // error C2039: '_copysign': is not a member of 'std' + if(y_abs == 1.0) return copysign(std::numeric_limits::infinity(), y); +#else + if(y_abs == 1.0) return std::copysign(std::numeric_limits::infinity(), y); +#endif + if(y_abs <= static_cast(CENTRAL_RANGE)) { + z = y * y; + num = (((a[3]*z + a[2])*z + a[1])*z + a[0]); + dem = ((((b[3]*z + b[2])*z + b[1])*z +b[0]) * z + static_cast(1.0)); + x = y * num / dem; + } + else{ + z = std::sqrt(-std::log((static_cast(1.0)-y_abs)/static_cast(2.0))); + num = ((c[3]*z + c[2])*z + c[1]) * z + c[0]; + dem = (d[1]*z + d[0])*z + static_cast(1.0); +#ifdef _WIN32 + // error C2039: '_copysign': is not a member of 'std' + x = copysign(num, y) / dem; +#else + x = std::copysign(num, y) / dem; +#endif + } + /* Two steps of Newton-Raphson correction */ + x = x - (std::erf(x) - y) / ((static_cast(2.0)/static_cast(std::sqrt(c10::pi)))*std::exp(-x*x)); + x = x - (std::erf(x) - y) / ((static_cast(2.0)/static_cast(std::sqrt(c10::pi)))*std::exp(-x*x)); + + return(x); +} + +#undef CENTRAL_RANGE + +/* + * Note [3-Clause BSD License for the Cephes Math Library] + * Code derived from implementations in the Cephes Math Library should mention its derivation and reference + * this note (ex. 'This function is derived from the implementation of X in the Cephes Math Library. See note + * [3-Clause BSD License for the Cephes Math Library]. The license is: + * Copyright (c) 2018, Steven Moshier + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL Steven Moshier BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +/* + * This function is derived from the implementation of the zeta function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + */ +template +C10_HOST_DEVICE inline scalar_t zeta(scalar_t x, scalar_t q) __ubsan_ignore_float_divide_by_zero__ { + using acc_t = at::acc_type; + const acc_t MACHEP = acc_t{1.11022302462515654042E-16}; + constexpr acc_t zero = acc_t{0.0}; + constexpr acc_t half = acc_t{0.5}; + constexpr acc_t one = acc_t{1.0}; + static const acc_t A[] = { + 12.0, + -720.0, + 30240.0, + -1209600.0, + 47900160.0, + -1.8924375803183791606e9, /*1.307674368e12/691*/ + 7.47242496e10, + -2.950130727918164224e12, /*1.067062284288e16/3617*/ + 1.1646782814350067249e14, /*5.109094217170944e18/43867*/ + -4.5979787224074726105e15, /*8.028576626982912e20/174611*/ + 1.8152105401943546773e17, /*1.5511210043330985984e23/854513*/ + -7.1661652561756670113e18 /*1.6938241367317436694528e27/236364091*/ + }; + + acc_t a, b, k, s, t, w; + if (x == one) { + return std::numeric_limits::infinity(); + } + + if (x < one) { + return std::numeric_limits::quiet_NaN(); + } + + if (q <= zero) { + if (q == std::floor(q)) { + return std::numeric_limits::infinity(); + } + if (x != std::floor(x)) { + return std::numeric_limits::quiet_NaN(); + } + } + + s = std::pow(q, -x); + a = q; + int i = 0; + b = zero; + while ((i < 9) || (a <= acc_t{9.0})) { + i += 1; + a += one; + b = ::pow(a, -x); + s += b; + if ((-MACHEP * s < b) && (b < MACHEP * s)) { + return static_cast(s); + } + }; + + w = a; + s += b * w / (x - one); + s -= half * b; + a = one; + k = zero; + for (i = 0; i < 12; i++) { + a *= x + k; + b /= w; + t = a * b / A[i]; + s = s + t; + t = ::fabs(t / s); + if (t < MACHEP) { + return static_cast(s); + } + k += one; + a *= x + k; + b /= w; + k += one; + } + return static_cast(s); +} + +/* + * This function is derived from the implementation of the digamma function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + * + * Evaluates polynomial of degree N: + * + * 2 N + * y = C + C x + C x +...+ C x + * 0 1 2 N + * + * Coefficients are stored in reverse order: + * + * coef[0] = C , ..., coef[N] = C . + * N 0 + */ +template +C10_HOST_DEVICE inline T polevl(const T x, const T A[], size_t len) { + T result = 0; + for (size_t i = 0; i <= len; i++) { + result = result * x + A[i]; + } + return result; +} + +inline double trigamma(double x) __ubsan_ignore_float_divide_by_zero__ { + double sign = +1; + double result = 0; + if (x < 0.5) { + sign = -1; + const double sin_pi_x = sin(c10::pi * x); + result -= (c10::pi * c10::pi) / (sin_pi_x * sin_pi_x); + x = 1 - x; + } + for (int i = 0; i < 6; ++i) { + result += 1 / (x * x); + x += 1; + } + const double ixx = 1 / (x*x); + result += (1 + 1 / (2*x) + ixx * (1./6 - ixx * (1./30 - ixx * (1./42)))) / x; + return sign * result; +} + +inline float trigamma(float x) __ubsan_ignore_float_divide_by_zero__ { + float sign = +1; + float result = 0; + if (x < 0.5f) { + sign = -1; + const float sin_pi_x = sinf(c10::pi * x); + result -= (c10::pi * c10::pi) / (sin_pi_x * sin_pi_x); + x = 1 - x; + } + for (int i = 0; i < 6; ++i) { + result += 1 / (x * x); + x += 1; + } + const float ixx = 1 / (x*x); + result += (1 + 1 / (2*x) + ixx * (1.f/6 - ixx * (1.f/30 - ixx * (1.f/42)))) / x; + return sign * result; +} + +/* + * This function is derived from the implementation of the digamma function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + */ +inline double calc_digamma(double x) { + // [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma + static double PSI_10 = 2.25175258906672110764; + if (x == 0) { + // As per C++ standard for gamma related functions and SciPy, + // If the argument is ±0, ±∞ is returned + return std::copysign(INFINITY, -x); + } + + bool x_is_integer = x == trunc(x); + if (x < 0) { + if (x_is_integer) { + // As per C++ standard for gamma related functions and SciPy, + // If the argument is a negative integer, NaN is returned + return std::numeric_limits::quiet_NaN(); + } + // Extracts the fractional part of x as r, since tan(pi * r) is more numerically + // accurate than tan(pi * x). While these operations are mathematically equivalent + // since both x and r are in radians and tan() has a periodicity of pi, in practice + // the computation of pi * x is a source of error (when |x| > 1). + double q, r; + r = std::modf(x, &q); + return calc_digamma(1 - x) - c10::pi / tan(c10::pi * r); + } + + // Push x to be >= 10 + double result = 0; + while (x < 10) { + result -= 1 / x; + x += 1; + } + if (x == 10) { + return result + PSI_10; + } + + // Compute asymptotic digamma + static const double A[] = { + 8.33333333333333333333E-2, + -2.10927960927960927961E-2, + 7.57575757575757575758E-3, + -4.16666666666666666667E-3, + 3.96825396825396825397E-3, + -8.33333333333333333333E-3, + 8.33333333333333333333E-2, + }; + + double y = 0; + if (x < 1.0e17) { + double z = 1.0 / (x * x); + y = z * polevl(z, A, 6); + } + return result + log(x) - (0.5 / x) - y; +} + +/* + * This function is derived from the implementation of the digamma function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + */ +inline float calc_digamma(float x) { + // See [C++ Standard Reference: Gamma Function] + static float PSI_10 = 2.25175258906672110764f; + if (x == 0) { + // As per C++ standard for gamma related functions and SciPy, + // If the argument is ±0, ±∞ is returned + return std::copysign(INFINITY, -x); + } + + bool x_is_integer = x == truncf(x); + if (x < 0) { + if (x_is_integer) { + // As per C++ standard for gamma related functions and SciPy, + // If the argument is a negative integer, NaN is returned + return std::numeric_limits::quiet_NaN(); + } + // Extracts the fractional part of x as r, since tan(pi * r) is more numerically + // accurate than tan(pi * x). While these operations are mathematically equivalent + // since both x and r are in radians and tan() has a periodicity of pi, in practice + // the computation of pi * x is a source of error (when |x| > 1). + double q, r; + r = std::modf(x, &q); + float pi_over_tan_pi_x = (float)(c10::pi / tan(c10::pi * r)); + return calc_digamma(1 - x) - pi_over_tan_pi_x; + } + + // Push x to be >= 10 + float result = 0; + while (x < 10) { + result -= 1 / x; + x += 1; + } + if (x == 10) { + return result + PSI_10; + } + + // Compute asymptotic digamma + static const float A[] = { + 8.33333333333333333333E-2f, + -2.10927960927960927961E-2f, + 7.57575757575757575758E-3f, + -4.16666666666666666667E-3f, + 3.96825396825396825397E-3f, + -8.33333333333333333333E-3f, + 8.33333333333333333333E-2f, + }; + + float y = 0; + if (x < 1.0e17f) { + float z = 1 / (x * x); + y = z * polevl(z, A, 6); + } + return result + logf(x) - (0.5f / x) - y; +} + +inline c10::BFloat16 calc_digamma(c10::BFloat16 a) { + return calc_digamma(static_cast(a)); +} + +inline c10::Half calc_digamma(c10::Half a) { + return calc_digamma(static_cast(a)); +} + +template +inline C10_HOST_DEVICE scalar_t calc_polygamma(scalar_t x, int n) { + // already blocked if n <= 1 + const auto one = scalar_t{1}; + return ((n % 2) ? one : -one) * + std::exp(std::lgamma(static_cast(n) + one)) * + zeta(static_cast(n + 1), x); +} + +// regularized lower incomplete gamma +// the regularized lower, upper incomplete gamma, as well as their +// helper functions follow SciPy's implementation + +/* References + * [igam1] "The Digital Library of Mathematical Functions", dlmf.nist.gov + * [igam2] Maddock et al., "Incomplete Gamma Functions", + * https://www.boost.org/doc/libs/1_61_0/libs/math/doc/html/math_toolkit/sf_gamma/igamma.html + */ + +/* + * This implementation of the regularized incomplete gamma functions and + * their helper functions are derived from the implementation of SciPy's + * gammainc, Cephes's igam and igamc, and Boost's Lanczos approximations. + * See NOTICE for the licenses. + */ +template +scalar_t ratevl(scalar_t x, const scalar_t num[], int64_t M, + const scalar_t denom[], int64_t N) { + // evaluating rational function, i.e., the ratio of two polynomials + // the coefficients for numerator are given by `num` while coeffs for + // denumerator are given by `denom` + + int64_t i, dir; + scalar_t y, num_ans, denom_ans; + scalar_t absx = std::fabs(x); + const scalar_t *p; + + if (absx > 1) { + /* Evaluate as a polynomial in 1/x. */ + dir = -1; + p = num + M; + y = 1 / x; + } + else { + dir = 1; + p = num; + y = x; + } + + /* Evaluate the numerator */ + num_ans = *p; + p += dir; + for (i = 1; i <= M; i++) { + num_ans = num_ans * y + *p; + p += dir; + } + /* Evaluate the denominator */ + if (absx > 1) { + p = denom + N; + } + else { + p = denom; + } + + denom_ans = *p; + p += dir; + for (i = 1; i <= N; i++) { + denom_ans = denom_ans * y + *p; + p += dir; + } + if (absx > 1) { + i = N - M; + return std::pow(x, i) * num_ans / denom_ans; + } + else { + return num_ans / denom_ans; + } +} + +// SciPy's lanczos implementation is taken from Boost +/* (C) Copyright John Maddock 2006. + * Use, modification and distribution are subject to the + * Boost Software License, Version 1.0. See + * https://www.boost.org/LICENSE_1_0.txt or see NOTICE. + */ +template +static scalar_t lanczos_sum_expg_scaled(scalar_t x) { + // lanczos approximation + static const scalar_t lanczos_sum_expg_scaled_num[13] = { + 0.006061842346248906525783753964555936883222, + 0.5098416655656676188125178644804694509993, + 19.51992788247617482847860966235652136208, + 449.9445569063168119446858607650988409623, + 6955.999602515376140356310115515198987526, + 75999.29304014542649875303443598909137092, + 601859.6171681098786670226533699352302507, + 3481712.15498064590882071018964774556468, + 14605578.08768506808414169982791359218571, + 43338889.32467613834773723740590533316085, + 86363131.28813859145546927288977868422342, + 103794043.1163445451906271053616070238554, + 56906521.91347156388090791033559122686859 + }; + static const scalar_t lanczos_sum_expg_scaled_denom[13] = { + 1., + 66., + 1925., + 32670., + 357423., + 2637558., + 13339535., + 45995730., + 105258076., + 150917976., + 120543840., + 39916800., + 0. + }; + return ratevl(x, lanczos_sum_expg_scaled_num, + sizeof(lanczos_sum_expg_scaled_num) / sizeof(lanczos_sum_expg_scaled_num[0]) - 1, + lanczos_sum_expg_scaled_denom, + sizeof(lanczos_sum_expg_scaled_denom) / sizeof(lanczos_sum_expg_scaled_denom[0]) - 1); +} + +template +static scalar_t _igam_helper_fac(scalar_t a, scalar_t x) { + // compute x^a * exp(-a) / gamma(a) + // corrected from (15) and (16) in [igam2] by replacing exp(x - a) with + // exp(a - x). + + scalar_t ax, fac, res, num, numfac; + static scalar_t MAXLOG = std::is_same_v ? + 7.09782712893383996843E2 : 88.72283905206835; + static scalar_t EXP1 = 2.718281828459045; + static scalar_t lanczos_g = 6.024680040776729583740234375; + + if (std::fabs(a - x) > 0.4 * std::fabs(a)) { + ax = a * std::log(x) - x - std::lgamma(a); + if (ax < -MAXLOG) { + return 0.0; + } + return std::exp(ax); + } + + fac = a + lanczos_g - 0.5; + res = std::sqrt(fac / EXP1) / lanczos_sum_expg_scaled(a); + + if ((a < 200) && (x < 200)) { + res *= std::exp(a - x) * std::pow(x / fac, a); + } + else { + num = x - a - lanczos_g + 0.5; + numfac = num / fac; + res *= std::exp(a * (std::log1p(numfac) - numfac) + x * (0.5 - lanczos_g) / fac); + } + return res; +} + +template +static scalar_t _igam_helper_series(scalar_t a, scalar_t x) { + // Compute igam using DLMF 8.11.4. [igam1] + static scalar_t MACHEP = std::is_same_v ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + static int MAXITER = 2000; + + int i; + scalar_t ans, ax, c, r; + + ax = _igam_helper_fac(a, x); + if (ax == 0.0) { + return 0.0; + } + + /* power series */ + r = a; + c = 1.0; + ans = 1.0; + + for (i = 0; i < MAXITER; i++) { + r += 1.0; + c *= x / r; + ans += c; + if (c <= MACHEP * ans) { + break; + } + } + return (ans * ax / a); +} + +template +static scalar_t _igamc_helper_series(scalar_t a, scalar_t x) { + // Compute igamc using DLMF 8.7.3 [igam1]. This is related to the series in + // _igam_helper_series but extra care is taken to avoid cancellation. + + int n; + scalar_t fac = 1; + scalar_t sum = 0; + scalar_t term, logx; + static scalar_t MAXITER = 2000; + static scalar_t MACHEP = std::is_same_v ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + + for (n = 1; n < MAXITER; n++) { + fac *= -x / n; + term = fac / (a + n); + sum += term; + if (std::fabs(term) <= MACHEP * std::fabs(sum)) { + break; + } + } + + logx = std::log(x); + term = -std::expm1(a * logx - std::lgamma(1+a)); + return term - std::exp(a * logx - std::lgamma(a)) * sum; +} + +template +static scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t x, bool igam) { + // Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1] + static const scalar_t d[25][25] = + {{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2, + 1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4, + 3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6, + 8.296711340953086e-7, -1.7665952736826079e-7, 6.7078535434014986e-9, + 1.0261809784240308e-8, -4.3820360184533532e-9, 9.1476995822367902e-10, + -2.551419399494625e-11, -5.8307721325504251e-11, 2.4361948020667416e-11, + -5.0276692801141756e-12, 1.1004392031956135e-13, 3.3717632624009854e-13, + -1.3923887224181621e-13, 2.8534893807047443e-14, -5.1391118342425726e-16, + -1.9752288294349443e-15}, + {-1.8518518518518519e-3, -3.4722222222222222e-3, 2.6455026455026455e-3, + -9.9022633744855967e-4, 2.0576131687242798e-4, -4.0187757201646091e-7, + -1.8098550334489978e-5, 7.6491609160811101e-6, -1.6120900894563446e-6, + 4.6471278028074343e-9, 1.378633446915721e-7, -5.752545603517705e-8, + 1.1951628599778147e-8, -1.7543241719747648e-11, -1.0091543710600413e-9, + 4.1627929918425826e-10, -8.5639070264929806e-11, 6.0672151016047586e-14, + 7.1624989648114854e-12, -2.9331866437714371e-12, 5.9966963656836887e-13, + -2.1671786527323314e-16, -4.9783399723692616e-14, 2.0291628823713425e-14, + -4.13125571381061e-15}, + {4.1335978835978836e-3, -2.6813271604938272e-3, 7.7160493827160494e-4, + 2.0093878600823045e-6, -1.0736653226365161e-4, 5.2923448829120125e-5, + -1.2760635188618728e-5, 3.4235787340961381e-8, 1.3721957309062933e-6, + -6.298992138380055e-7, 1.4280614206064242e-7, -2.0477098421990866e-10, + -1.4092529910867521e-8, 6.228974084922022e-9, -1.3670488396617113e-9, + 9.4283561590146782e-13, 1.2872252400089318e-10, -5.5645956134363321e-11, + 1.1975935546366981e-11, -4.1689782251838635e-15, -1.0940640427884594e-12, + 4.6622399463901357e-13, -9.905105763906906e-14, 1.8931876768373515e-17, + 8.8592218725911273e-15}, + {6.4943415637860082e-4, 2.2947209362139918e-4, -4.6918949439525571e-4, + 2.6772063206283885e-4, -7.5618016718839764e-5, -2.3965051138672967e-7, + 1.1082654115347302e-5, -5.6749528269915966e-6, 1.4230900732435884e-6, + -2.7861080291528142e-11, -1.6958404091930277e-7, 8.0994649053880824e-8, + -1.9111168485973654e-8, 2.3928620439808118e-12, 2.0620131815488798e-9, + -9.4604966618551322e-10, 2.1541049775774908e-10, -1.388823336813903e-14, + -2.1894761681963939e-11, 9.7909989511716851e-12, -2.1782191880180962e-12, + 6.2088195734079014e-17, 2.126978363279737e-13, -9.3446887915174333e-14, + 2.0453671226782849e-14}, + {-8.618882909167117e-4, 7.8403922172006663e-4, -2.9907248030319018e-4, + -1.4638452578843418e-6, 6.6414982154651222e-5, -3.9683650471794347e-5, + 1.1375726970678419e-5, 2.5074972262375328e-10, -1.6954149536558306e-6, + 8.9075075322053097e-7, -2.2929348340008049e-7, 2.956794137544049e-11, + 2.8865829742708784e-8, -1.4189739437803219e-8, 3.4463580499464897e-9, + -2.3024517174528067e-13, -3.9409233028046405e-10, 1.8602338968504502e-10, + -4.356323005056618e-11, 1.2786001016296231e-15, 4.6792750266579195e-12, + -2.1492464706134829e-12, 4.9088156148096522e-13, -6.3385914848915603e-18, + -5.0453320690800944e-14}, + {-3.3679855336635815e-4, -6.9728137583658578e-5, 2.7727532449593921e-4, + -1.9932570516188848e-4, 6.7977804779372078e-5, 1.419062920643967e-7, + -1.3594048189768693e-5, 8.0184702563342015e-6, -2.2914811765080952e-6, + -3.252473551298454e-10, 3.4652846491085265e-7, -1.8447187191171343e-7, + 4.8240967037894181e-8, -1.7989466721743515e-14, -6.3061945000135234e-9, + 3.1624176287745679e-9, -7.8409242536974293e-10, 5.1926791652540407e-15, + 9.3589442423067836e-11, -4.5134262161632782e-11, 1.0799129993116827e-11, + -3.661886712685252e-17, -1.210902069055155e-12, 5.6807435849905643e-13, + -1.3249659916340829e-13}, + {5.3130793646399222e-4, -5.9216643735369388e-4, 2.7087820967180448e-4, + 7.9023532326603279e-7, -8.1539693675619688e-5, 5.6116827531062497e-5, + -1.8329116582843376e-5, -3.0796134506033048e-9, 3.4651553688036091e-6, + -2.0291327396058604e-6, 5.7887928631490037e-7, 2.338630673826657e-13, + -8.8286007463304835e-8, 4.7435958880408128e-8, -1.2545415020710382e-8, + 8.6496488580102925e-14, 1.6846058979264063e-9, -8.5754928235775947e-10, + 2.1598224929232125e-10, -7.6132305204761539e-16, -2.6639822008536144e-11, + 1.3065700536611057e-11, -3.1799163902367977e-12, 4.7109761213674315e-18, + 3.6902800842763467e-13}, + {3.4436760689237767e-4, 5.1717909082605922e-5, -3.3493161081142236e-4, + 2.812695154763237e-4, -1.0976582244684731e-4, -1.2741009095484485e-7, + 2.7744451511563644e-5, -1.8263488805711333e-5, 5.7876949497350524e-6, + 4.9387589339362704e-10, -1.0595367014026043e-6, 6.1667143761104075e-7, + -1.7562973359060462e-7, -1.2974473287015439e-12, 2.695423606288966e-8, + -1.4578352908731271e-8, 3.887645959386175e-9, -3.8810022510194121e-17, + -5.3279941738772867e-10, 2.7437977643314845e-10, -6.9957960920705679e-11, + 2.5899863874868481e-17, 8.8566890996696381e-12, -4.403168815871311e-12, + 1.0865561947091654e-12}, + {-6.5262391859530942e-4, 8.3949872067208728e-4, -4.3829709854172101e-4, + -6.969091458420552e-7, 1.6644846642067548e-4, -1.2783517679769219e-4, + 4.6299532636913043e-5, 4.5579098679227077e-9, -1.0595271125805195e-5, + 6.7833429048651666e-6, -2.1075476666258804e-6, -1.7213731432817145e-11, + 3.7735877416110979e-7, -2.1867506700122867e-7, 6.2202288040189269e-8, + 6.5977038267330006e-16, -9.5903864974256858e-9, 5.2132144922808078e-9, + -1.3991589583935709e-9, 5.382058999060575e-16, 1.9484714275467745e-10, + -1.0127287556389682e-10, 2.6077347197254926e-11, -5.0904186999932993e-18, + -3.3721464474854592e-12}, + {-5.9676129019274625e-4, -7.2048954160200106e-5, 6.7823088376673284e-4, + -6.4014752602627585e-4, 2.7750107634328704e-4, 1.8197008380465151e-7, + -8.4795071170685032e-5, 6.105192082501531e-5, -2.1073920183404862e-5, + -8.8585890141255994e-10, 4.5284535953805377e-6, -2.8427815022504408e-6, + 8.7082341778646412e-7, 3.6886101871706965e-12, -1.5344695190702061e-7, + 8.862466778790695e-8, -2.5184812301826817e-8, -1.0225912098215092e-14, + 3.8969470758154777e-9, -2.1267304792235635e-9, 5.7370135528051385e-10, + -1.887749850169741e-19, -8.0931538694657866e-11, 4.2382723283449199e-11, + -1.1002224534207726e-11}, + {1.3324454494800656e-3, -1.9144384985654775e-3, 1.1089369134596637e-3, + 9.932404122642299e-7, -5.0874501293093199e-4, 4.2735056665392884e-4, + -1.6858853767910799e-4, -8.1301893922784998e-9, 4.5284402370562147e-5, + -3.127053674781734e-5, 1.044986828530338e-5, 4.8435226265680926e-11, + -2.1482565873456258e-6, 1.329369701097492e-6, -4.0295693092101029e-7, + -1.7567877666323291e-13, 7.0145043163668257e-8, -4.040787734999483e-8, + 1.1474026743371963e-8, 3.9642746853563325e-18, -1.7804938269892714e-9, + 9.7480262548731646e-10, -2.6405338676507616e-10, 5.794875163403742e-18, + 3.7647749553543836e-11}, + {1.579727660730835e-3, 1.6251626278391582e-4, -2.0633421035543276e-3, + 2.1389686185689098e-3, -1.0108559391263003e-3, -3.9912705529919201e-7, + 3.6235025084764691e-4, -2.8143901463712154e-4, 1.0449513336495887e-4, + 2.1211418491830297e-9, -2.5779417251947842e-5, 1.7281818956040463e-5, + -5.6413773872904282e-6, -1.1024320105776174e-11, 1.1223224418895175e-6, + -6.8693396379526735e-7, 2.0653236975414887e-7, 4.6714772409838506e-14, + -3.5609886164949055e-8, 2.0470855345905963e-8, -5.8091738633283358e-9, + -1.332821287582869e-16, 9.0354604391335133e-10, -4.9598782517330834e-10, + 1.3481607129399749e-10}, + {-4.0725121195140166e-3, 6.4033628338080698e-3, -4.0410161081676618e-3, + -2.183732802866233e-6, 2.1740441801254639e-3, -1.9700440518418892e-3, + 8.3595469747962458e-4, 1.9445447567109655e-8, -2.5779387120421696e-4, + 1.9009987368139304e-4, -6.7696499937438965e-5, -1.4440629666426572e-10, + 1.5712512518742269e-5, -1.0304008744776893e-5, 3.304517767401387e-6, + 7.9829760242325709e-13, -6.4097794149313004e-7, 3.8894624761300056e-7, + -1.1618347644948869e-7, -2.816808630596451e-15, 1.9878012911297093e-8, + -1.1407719956357511e-8, 3.2355857064185555e-9, 4.1759468293455945e-20, + -5.0423112718105824e-10}, + {-5.9475779383993003e-3, -5.4016476789260452e-4, 8.7910413550767898e-3, + -9.8576315587856125e-3, 5.0134695031021538e-3, 1.2807521786221875e-6, + -2.0626019342754683e-3, 1.7109128573523058e-3, -6.7695312714133799e-4, + -6.9011545676562133e-9, 1.8855128143995902e-4, -1.3395215663491969e-4, + 4.6263183033528039e-5, 4.0034230613321351e-11, -1.0255652921494033e-5, + 6.612086372797651e-6, -2.0913022027253008e-6, -2.0951775649603837e-13, + 3.9756029041993247e-7, -2.3956211978815887e-7, 7.1182883382145864e-8, + 8.925574873053455e-16, -1.2101547235064676e-8, 6.9350618248334386e-9, + -1.9661464453856102e-9}, + {1.7402027787522711e-2, -2.9527880945699121e-2, 2.0045875571402799e-2, + 7.0289515966903407e-6, -1.2375421071343148e-2, 1.1976293444235254e-2, + -5.4156038466518525e-3, -6.3290893396418616e-8, 1.8855118129005065e-3, + -1.473473274825001e-3, 5.5515810097708387e-4, 5.2406834412550662e-10, + -1.4357913535784836e-4, 9.9181293224943297e-5, -3.3460834749478311e-5, + -3.5755837291098993e-12, 7.1560851960630076e-6, -4.5516802628155526e-6, + 1.4236576649271475e-6, 1.8803149082089664e-14, -2.6623403898929211e-7, + 1.5950642189595716e-7, -4.7187514673841102e-8, -6.5107872958755177e-17, + 7.9795091026746235e-9}, + {3.0249124160905891e-2, 2.4817436002649977e-3, -4.9939134373457022e-2, + 5.9915643009307869e-2, -3.2483207601623391e-2, -5.7212968652103441e-6, + 1.5085251778569354e-2, -1.3261324005088445e-2, 5.5515262632426148e-3, + 3.0263182257030016e-8, -1.7229548406756723e-3, 1.2893570099929637e-3, + -4.6845138348319876e-4, -1.830259937893045e-10, 1.1449739014822654e-4, + -7.7378565221244477e-5, 2.5625836246985201e-5, 1.0766165333192814e-12, + -5.3246809282422621e-6, 3.349634863064464e-6, -1.0381253128684018e-6, + -5.608909920621128e-15, 1.9150821930676591e-7, -1.1418365800203486e-7, + 3.3654425209171788e-8}, + {-9.9051020880159045e-2, 1.7954011706123486e-1, -1.2989606383463778e-1, + -3.1478872752284357e-5, 9.0510635276848131e-2, -9.2828824411184397e-2, + 4.4412112839877808e-2, 2.7779236316835888e-7, -1.7229543805449697e-2, + 1.4182925050891573e-2, -5.6214161633747336e-3, -2.39598509186381e-9, + 1.6029634366079908e-3, -1.1606784674435773e-3, 4.1001337768153873e-4, + 1.8365800754090661e-11, -9.5844256563655903e-5, 6.3643062337764708e-5, + -2.076250624489065e-5, -1.1806020912804483e-13, 4.2131808239120649e-6, + -2.6262241337012467e-6, 8.0770620494930662e-7, 6.0125912123632725e-16, + -1.4729737374018841e-7}, + {-1.9994542198219728e-1, -1.5056113040026424e-2, 3.6470239469348489e-1, + -4.6435192311733545e-1, 2.6640934719197893e-1, 3.4038266027147191e-5, + -1.3784338709329624e-1, 1.276467178337056e-1, -5.6213828755200985e-2, + -1.753150885483011e-7, 1.9235592956768113e-2, -1.5088821281095315e-2, + 5.7401854451350123e-3, 1.0622382710310225e-9, -1.5335082692563998e-3, + 1.0819320643228214e-3, -3.7372510193945659e-4, -6.6170909729031985e-12, + 8.4263617380909628e-5, -5.5150706827483479e-5, 1.7769536448348069e-5, + 3.8827923210205533e-14, -3.53513697488768e-6, 2.1865832130045269e-6, + -6.6812849447625594e-7}, + {7.2438608504029431e-1, -1.3918010932653375, 1.0654143352413968, + 1.876173868950258e-4, -8.2705501176152696e-1, 8.9352433347828414e-1, + -4.4971003995291339e-1, -1.6107401567546652e-6, 1.9235590165271091e-1, + -1.6597702160042609e-1, 6.8882222681814333e-2, 1.3910091724608687e-8, + -2.146911561508663e-2, 1.6228980898865892e-2, -5.9796016172584256e-3, + -1.1287469112826745e-10, 1.5167451119784857e-3, -1.0478634293553899e-3, + 3.5539072889126421e-4, 8.1704322111801517e-13, -7.7773013442452395e-5, + 5.0291413897007722e-5, -1.6035083867000518e-5, 1.2469354315487605e-14, + 3.1369106244517615e-6}, + {1.6668949727276811, 1.165462765994632e-1, -3.3288393225018906, + 4.4692325482864037, -2.6977693045875807, -2.600667859891061e-4, + 1.5389017615694539, -1.4937962361134612, 6.8881964633233148e-1, + 1.3077482004552385e-6, -2.5762963325596288e-1, 2.1097676102125449e-1, + -8.3714408359219882e-2, -7.7920428881354753e-9, 2.4267923064833599e-2, + -1.7813678334552311e-2, 6.3970330388900056e-3, 4.9430807090480523e-11, + -1.5554602758465635e-3, 1.0561196919903214e-3, -3.5277184460472902e-4, + 9.3002334645022459e-14, 7.5285855026557172e-5, -4.8186515569156351e-5, + 1.5227271505597605e-5}, + {-6.6188298861372935, 1.3397985455142589e+1, -1.0789350606845146e+1, + -1.4352254537875018e-3, 9.2333694596189809, -1.0456552819547769e+1, + 5.5105526029033471, 1.2024439690716742e-5, -2.5762961164755816, + 2.3207442745387179, -1.0045728797216284, -1.0207833290021914e-7, + 3.3975092171169466e-1, -2.6720517450757468e-1, 1.0235252851562706e-1, + 8.4329730484871625e-10, -2.7998284958442595e-2, 2.0066274144976813e-2, + -7.0554368915086242e-3, 1.9402238183698188e-12, 1.6562888105449611e-3, + -1.1082898580743683e-3, 3.654545161310169e-4, -5.1290032026971794e-11, + -7.6340103696869031e-5}, + {-1.7112706061976095e+1, -1.1208044642899116, 3.7131966511885444e+1, + -5.2298271025348962e+1, 3.3058589696624618e+1, 2.4791298976200222e-3, + -2.061089403411526e+1, 2.088672775145582e+1, -1.0045703956517752e+1, + -1.2238783449063012e-5, 4.0770134274221141, -3.473667358470195, + 1.4329352617312006, 7.1359914411879712e-8, -4.4797257159115612e-1, + 3.4112666080644461e-1, -1.2699786326594923e-1, -2.8953677269081528e-10, + 3.3125776278259863e-2, -2.3274087021036101e-2, 8.0399993503648882e-3, + -1.177805216235265e-9, -1.8321624891071668e-3, 1.2108282933588665e-3, + -3.9479941246822517e-4}, + {7.389033153567425e+1, -1.5680141270402273e+2, 1.322177542759164e+2, + 1.3692876877324546e-2, -1.2366496885920151e+2, 1.4620689391062729e+2, + -8.0365587724865346e+1, -1.1259851148881298e-4, 4.0770132196179938e+1, + -3.8210340013273034e+1, 1.719522294277362e+1, 9.3519707955168356e-7, + -6.2716159907747034, 5.1168999071852637, -2.0319658112299095, + -4.9507215582761543e-9, 5.9626397294332597e-1, -4.4220765337238094e-1, + 1.6079998700166273e-1, -2.4733786203223402e-8, -4.0307574759979762e-2, + 2.7849050747097869e-2, -9.4751858992054221e-3, 6.419922235909132e-6, + 2.1250180774699461e-3}, + {2.1216837098382522e+2, 1.3107863022633868e+1, -4.9698285932871748e+2, + 7.3121595266969204e+2, -4.8213821720890847e+2, -2.8817248692894889e-2, + 3.2616720302947102e+2, -3.4389340280087117e+2, 1.7195193870816232e+2, + 1.4038077378096158e-4, -7.52594195897599e+1, 6.651969984520934e+1, + -2.8447519748152462e+1, -7.613702615875391e-7, 9.5402237105304373, + -7.5175301113311376, 2.8943997568871961, -4.6612194999538201e-7, + -8.0615149598794088e-1, 5.8483006570631029e-1, -2.0845408972964956e-1, + 1.4765818959305817e-4, 5.1000433863753019e-2, -3.3066252141883665e-2, + 1.5109265210467774e-2}, + {-9.8959643098322368e+2, 2.1925555360905233e+3, -1.9283586782723356e+3, + -1.5925738122215253e-1, 1.9569985945919857e+3, -2.4072514765081556e+3, + 1.3756149959336496e+3, 1.2920735237496668e-3, -7.525941715948055e+2, + 7.3171668742208716e+2, -3.4137023466220065e+2, -9.9857390260608043e-6, + 1.3356313181291573e+2, -1.1276295161252794e+2, 4.6310396098204458e+1, + -7.9237387133614756e-6, -1.4510726927018646e+1, 1.1111771248100563e+1, + -4.1690817945270892, 3.1008219800117808e-3, 1.1220095449981468, + -7.6052379926149916e-1, 3.6262236505085254e-1, 2.216867741940747e-1, + 4.8683443692930507e-1}}; + + int k, n, sgn; + int maxpow = 0; + static scalar_t MACHEP = std::is_same_v ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + scalar_t lambda = x / a; + scalar_t sigma = (x - a) / a; + scalar_t eta, res, ck, ckterm, term, absterm; + scalar_t absoldterm = INFINITY; + scalar_t etapow[25] = {1}; + scalar_t sum = 0; + scalar_t afac = 1; + + if (igam) { + sgn = -1; + } + else { + sgn = 1; + } + + if (lambda > 1) { + eta = std::sqrt(-2 * (std::log1p(sigma) - sigma)); + } + else if (lambda < 1) { + eta = -std::sqrt(-2 * (std::log1p(sigma) - sigma)); + } + else { + eta = 0; + } + res = 0.5 * std::erfc(sgn * eta * std::sqrt(a / 2)); + + for (k = 0; k < 25; k++) { + ck = d[k][0]; + for (n = 1; n < 25; n++) { + if (n > maxpow) { + etapow[n] = eta * etapow[n-1]; + maxpow += 1; + } + ckterm = d[k][n]*etapow[n]; + ck += ckterm; + if (std::fabs(ckterm) < MACHEP * std::fabs(ck)) { + break; + } + } + term = ck * afac; + absterm = std::fabs(term); + if (absterm > absoldterm) { + break; + } + sum += term; + if (absterm < MACHEP * std::fabs(sum)) { + break; + } + absoldterm = absterm; + afac /= a; + } + res += sgn * std::exp(-0.5 * a * eta * eta) * sum / std::sqrt(2 * c10::pi * a); + + return res; +} + +template +static scalar_t _igamc_helper_continued_fraction(scalar_t a, scalar_t x) { + // Compute igamc using DLMF 8.9.2. [igam1] + int i; + scalar_t ans, ax, c, yc, r, t, y, z; + scalar_t pk, pkm1, pkm2, qk, qkm1, qkm2; + int MAXITER = 2000; + static scalar_t MACHEP = std::is_same_v ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + static scalar_t BIG = std::is_same_v ? + 4.503599627370496e15 : 16777216.; + static scalar_t BIGINV = std::is_same_v ? + 2.22044604925031308085e-16 : 5.9604644775390625E-8; + + ax = _igam_helper_fac(a, x); + if (ax == 0.0) { + return 0.0; + } + + /* continued fraction */ + y = 1.0 - a; + z = x + y + 1.0; + c = 0.0; + pkm2 = 1.0; + qkm2 = x; + pkm1 = x + 1.0; + qkm1 = z * x; + ans = pkm1 / qkm1; + + for (i = 0; i < MAXITER; i++) { + c += 1.0; + y += 1.0; + z += 2.0; + yc = y * c; + pk = pkm1 * z - pkm2 * yc; + qk = qkm1 * z - qkm2 * yc; + if (qk != 0) { + r = pk / qk; + t = std::fabs((ans - r) / r); + ans = r; + } + else { + t = 1.0; + } + pkm2 = pkm1; + pkm1 = pk; + qkm2 = qkm1; + qkm1 = qk; + if (std::fabs(pk) > BIG) { + pkm2 *= BIGINV; + pkm1 *= BIGINV; + qkm2 *= BIGINV; + qkm1 *= BIGINV; + } + if (t <= MACHEP) { + break; + } + } + return ans * ax; +} + +template +inline scalar_t calc_igammac(scalar_t a, scalar_t x) { + /* the calculation of the regularized upper incomplete gamma function + * is done differently based on the values of a and x: + * - if x and/or a is at the boundary of defined region, then assign the + * result at the boundary + * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for + * Large Parameter (see DLMF 8.12.4 [igam1]) + * - if x > 1.1 and x < a, using the substraction from the regularized lower + * incomplete gamma + * - otherwise, calculate the series from [igam2] eq (5) + */ + scalar_t absxma_a; + + static scalar_t SMALL = 20.0; + static scalar_t LARGE = 200.0; + static scalar_t SMALLRATIO = 0.3; + static scalar_t LARGERATIO = 4.5; + + // note that in SciPy, a and x are non-negative, with exclusive 0s (i.e., + // at most 1 of them can be 0), where igammac(0, x) = 0.0 iff x > 0. + if ((x < 0) || (a < 0)) { + // out of defined-region of the function + return std::numeric_limits::quiet_NaN(); + } + else if (a == 0) { + if (x > 0) { + return 0.0; + } + else { + return std::numeric_limits::quiet_NaN(); + } + } + else if (x == 0) { + return 1.0; + } + else if (std::isinf(a)) { + if (std::isinf(x)) { + return std::numeric_limits::quiet_NaN(); + } + return 1.0; + } + else if (std::isinf(x)) { + return 0.0; + } + + absxma_a = std::fabs(x - a) / a; + if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) { + return _igam_helper_asymptotic_series(a, x, 0); + } + else if ((a > LARGE) && (absxma_a < LARGERATIO / std::sqrt(a))) { + return _igam_helper_asymptotic_series(a, x, 0); + } + + if (x > 1.1) { + if (x < a) { + return 1.0 - _igam_helper_series(a, x); + } + else { + return _igamc_helper_continued_fraction(a, x); + } + } + else if (x <= 0.5) { + if (-0.4 / std::log(x) < a) { + return 1.0 - _igam_helper_series(a, x); + } + else { + return _igamc_helper_series(a, x); + } + } + else { + if (x * 1.1 < a) { + return 1.0 - _igam_helper_series(a, x); + } + else { + return _igamc_helper_series(a, x); + } + } +} + +template +scalar_t calc_igamma(scalar_t a, scalar_t x) { + /* the calculation of the regularized lower incomplete gamma function + * is done differently based on the values of a and x: + * - if x and/or a is at the boundary of defined region, then assign the + * result at the boundary + * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for + * Large Parameter (see DLMF 8.12.3 [igam1]) + * - if x > 1 and x > a, using the substraction from the regularized upper + * incomplete gamma + * - otherwise, calculate the series from [igam2] eq (4) + */ + scalar_t absxma_a; + static scalar_t SMALL = 20.0; + static scalar_t LARGE = 200.0; + static scalar_t SMALLRATIO = 0.3; + static scalar_t LARGERATIO = 4.5; + + // boundary values following SciPy + // note that in SciPy, a and x are non-negative, with exclusive 0s (i.e., + // at most 1 of them can be 0), where igamma(0, x) = 1.0 iff x > 0. + if ((x < 0) || (a < 0)) { + // out of defined-region of the function + return std::numeric_limits::quiet_NaN(); + } + else if (a == 0) { + if (x > 0) { + return 1.0; + } + else { + return std::numeric_limits::quiet_NaN(); + } + } + else if (x == 0) { + return 0.0; // zero integration limit + } + else if (std::isinf(a)) { + if (std::isinf(x)) { + return std::numeric_limits::quiet_NaN(); + } + return 0.0; + } + else if (std::isinf(x)) { + return 1.0; + } + + /* Asymptotic regime where a ~ x. See [igam2] */ + absxma_a = std::fabs(x - a) / a; + if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) { + return _igam_helper_asymptotic_series(a, x, 1); + } + else if ((a > LARGE) && (absxma_a < LARGERATIO / std::sqrt(a))) { + return _igam_helper_asymptotic_series(a, x, 1); + } + + if ((x > 1.0) && (x > a)) { + return 1.0 - calc_igammac(a, x); + } + + return _igam_helper_series(a, x); +} + +template <> +[[maybe_unused]] inline c10::BFloat16 calc_igamma( + c10::BFloat16 a, + c10::BFloat16 x) { + return calc_igamma(float(a), float(x)); +} + +template <> +[[maybe_unused]] inline c10::Half calc_igamma( + c10::Half a, + c10::Half x) { + return calc_igamma(float(a), float(x)); +} + +template <> +[[maybe_unused]] inline c10::BFloat16 calc_igammac( + c10::BFloat16 a, + c10::BFloat16 x) { + return calc_igammac(float(a), float(x)); +} + +template <> +[[maybe_unused]] inline c10::Half calc_igammac( + c10::Half a, + c10::Half x) { + return calc_igammac(float(a), float(x)); +} + +inline c10::BFloat16 calc_erfinv(c10::BFloat16 a) { return calc_erfinv(float(a)); } + +template +inline T abs_impl(T v) { + return std::abs(v); +} + +template <> +[[maybe_unused]] inline uint8_t abs_impl(uint8_t v) { + return v; +} + +template +inline typename std::enable_if_t, T> +calc_gcd(T a, T b) { + a = abs_impl(a); + b = abs_impl(b); + while (a != 0) { + T c = a; + a = b % a; + b = c; + } + return b; +} + +template +C10_HOST_DEVICE T exp2_impl(T x) { + return std::exp2(x); +} + +template +C10_HOST_DEVICE c10::complex exp2_impl(c10::complex x) { + // There is no std::exp2 overload for complex, so instead + // use the identity 2^x = e^(ln(2) * x) + constexpr auto ln2 = c10::ln_2; + return std::exp(ln2 * x); +} + +/* + * This function is derived from the implementation of the chbevl function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + * + * Evaluates the series + * + * len-1 + * - ' + * y = > array[i] T (x/2) + * - i + * i=0 + * + * of Chebyshev polynomials Ti at argument x/2. + * + * Coefficients are stored in reverse order, i.e. the zero order term is last in the array. Note len is the number of + * coefficients, not the order. + * + * If coefficients are for the interval a to b, x must have been transformed to x -> 2(2x - b - a)/(b-a) before + * entering the routine. This maps x from (a, b) to (-1, 1), over which the Chebyshev polynomials are defined. + * + * If the coefficients are for the inverted interval, in which (a, b) is mapped to (1/b, 1/a), the transformation + * required is x -> 2(2ab/x - b - a)/(b-a). If b is infinity, this becomes x -> 4a/x - 1. + */ +template +inline typename std::enable_if_t, T> +chbevl(const T x, const T array[], size_t len) { + T b0, b1, b2; + + b0 = array[0]; + b1 = static_cast(0.0); + + for (size_t i = 1; i < len; ++i) { + b2 = b1; + b1 = b0; + b0 = x * b1 - b2 + array[i]; + } + + return (static_cast(0.5) * (b0 - b2)); +} + +/* + * This function is derived from the implementation of the i0 function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + * + * Computes an approximation of the zeroth order modified Bessel function of the first kind. + * The approximation is actually two (sub)approximations, both using a Chebyshev polynomial expansion. + * One approximates the function over [0, 8], and the other over (8, infinity). This function takes the absolute value + * of all inputs to convert them into the domain of the approximation. + */ +template +inline std::tuple chebyshev_coefficients_i0e_A() { + /* Chebyshev coefficients for exp(-x) I0(x) + * in the interval [0,8]. + * + * lim(x->0){ exp(-x) I0(x) } = 1. + */ + static const T coeff[] = { + -4.41534164647933937950E-18, 3.33079451882223809783E-17, + -2.43127984654795469359E-16, 1.71539128555513303061E-15, + -1.16853328779934516808E-14, 7.67618549860493561688E-14, + -4.85644678311192946090E-13, 2.95505266312963983461E-12, + -1.72682629144155570723E-11, 9.67580903537323691224E-11, + -5.18979560163526290666E-10, 2.65982372468238665035E-9, + -1.30002500998624804212E-8, 6.04699502254191894932E-8, + -2.67079385394061173391E-7, 1.11738753912010371815E-6, + -4.41673835845875056359E-6, 1.64484480707288970893E-5, + -5.75419501008210370398E-5, 1.88502885095841655729E-4, + -5.76375574538582365885E-4, 1.63947561694133579842E-3, + -4.32430999505057594430E-3, 1.05464603945949983183E-2, + -2.37374148058994688156E-2, 4.93052842396707084878E-2, + -9.49010970480476444210E-2, 1.71620901522208775349E-1, + -3.04682672343198398683E-1, 6.76795274409476084995E-1}; + return std::make_tuple(coeff, 30); +} + +template +inline std::tuple chebyshev_coefficients_i0e_B() { + /* Chebyshev coefficients for exp(-x) sqrt(x) I0(x) + * in the inverted interval [8,infinity]. + * + * lim(x->inf){ exp(-x) sqrt(x) I0(x) } = 1/sqrt(2pi). + */ + static const T coeff[] = { + -7.23318048787475395456E-18, -4.83050448594418207126E-18, + 4.46562142029675999901E-17, 3.46122286769746109310E-17, + -2.82762398051658348494E-16, -3.42548561967721913462E-16, + 1.77256013305652638360E-15, 3.81168066935262242075E-15, + -9.55484669882830764870E-15, -4.15056934728722208663E-14, + 1.54008621752140982691E-14, 3.85277838274214270114E-13, + 7.18012445138366623367E-13, -1.79417853150680611778E-12, + -1.32158118404477131188E-11, -3.14991652796324136454E-11, + 1.18891471078464383424E-11, 4.94060238822496958910E-10, + 3.39623202570838634515E-9, 2.26666899049817806459E-8, + 2.04891858946906374183E-7, 2.89137052083475648297E-6, + 6.88975834691682398426E-5, 3.36911647825569408990E-3, + 8.04490411014108831608E-1}; + + return std::make_tuple(coeff, 25); +} + +template +inline typename std::enable_if_t, std::tuple> +chebyshev_coefficients_i1e_A() { + /* Chebyshev coefficients for exp(-x) I1(x) + * in the interval [0,8]. + * + * lim(x->0){ exp(-x) I1(x) / x } = 1/2. + */ + static const T coeff[] = { + 2.77791411276104639959E-18, -2.11142121435816608115E-17, + 1.55363195773620046921E-16, -1.10559694773538630805E-15, + 7.60068429473540693410E-15, -5.04218550472791168711E-14, + 3.22379336594557470981E-13, -1.98397439776494371520E-12, + 1.17361862988909016308E-11, -6.66348972350202774223E-11, + 3.62559028155211703701E-10, -1.88724975172282928790E-9, + 9.38153738649577178388E-9, -4.44505912879632808065E-8, + 2.00329475355213526229E-7, -8.56872026469545474066E-7, + 3.47025130813767847674E-6, -1.32731636560394358279E-5, + 4.78156510755005422638E-5, -1.61760815825896745588E-4, + 5.12285956168575772895E-4, -1.51357245063125314899E-3, + 4.15642294431288815669E-3, -1.05640848946261981558E-2, + 2.47264490306265168283E-2, -5.29459812080949914269E-2, + 1.02643658689847095384E-1, -1.76416518357834055153E-1, + 2.52587186443633654823E-1}; + return std::make_tuple(coeff, 29); +} + +template +inline typename std::enable_if_t, std::tuple> +chebyshev_coefficients_i1e_A() { + /* Chebyshev coefficients for exp(-x) I1(x) + * in the interval [0,8]. + * + * lim(x->0){ exp(-x) I1(x) / x } = 1/2. + */ + static const T coeff[] = { + 9.38153738649577178388E-9f, + -4.44505912879632808065E-8f, + 2.00329475355213526229E-7f, + -8.56872026469545474066E-7f, + 3.47025130813767847674E-6f, + -1.32731636560394358279E-5f, + 4.78156510755005422638E-5f, + -1.61760815825896745588E-4f, + 5.12285956168575772895E-4f, + -1.51357245063125314899E-3f, + 4.15642294431288815669E-3f, + -1.05640848946261981558E-2f, + 2.47264490306265168283E-2f, + -5.29459812080949914269E-2f, + 1.02643658689847095384E-1f, + -1.76416518357834055153E-1f, + 2.52587186443633654823E-1f}; + return std::make_tuple(coeff, 17); +} + +template +inline typename std::enable_if_t, std::tuple> +chebyshev_coefficients_i1e_B() { + /* Chebyshev coefficients for exp(-x) sqrt(x) I1(x) + * in the inverted interval [8,infinity]. + * + * lim(x->inf){ exp(-x) sqrt(x) I1(x) } = 1/sqrt(2pi). + */ + static const T coeff[] = { + 7.51729631084210481353E-18, 4.41434832307170791151E-18, + -4.65030536848935832153E-17, -3.20952592199342395980E-17, + 2.96262899764595013876E-16, 3.30820231092092828324E-16, + -1.88035477551078244854E-15, -3.81440307243700780478E-15, + 1.04202769841288027642E-14, 4.27244001671195135429E-14, + -2.10154184277266431302E-14, -4.08355111109219731823E-13, + -7.19855177624590851209E-13, 2.03562854414708950722E-12, + 1.41258074366137813316E-11, 3.25260358301548823856E-11, + -1.89749581235054123450E-11, -5.58974346219658380687E-10, + -3.83538038596423702205E-9, -2.63146884688951950684E-8, + -2.51223623787020892529E-7, -3.88256480887769039346E-6, + -1.10588938762623716291E-4, -9.76109749136146840777E-3, + 7.78576235018280120474E-1}; + + return std::make_tuple(coeff, 25); +} + +template +inline typename std::enable_if_t, std::tuple> +chebyshev_coefficients_i1e_B() { + /* Chebyshev coefficients for exp(-x) sqrt(x) I1(x) + * in the inverted interval [8,infinity]. + * + * lim(x->inf){ exp(-x) sqrt(x) I1(x) } = 1/sqrt(2pi). + */ + static const T coeff[] = { + -3.83538038596423702205E-9f, + -2.63146884688951950684E-8f, + -2.51223623787020892529E-7f, + -3.88256480887769039346E-6f, + -1.10588938762623716291E-4f, + -9.76109749136146840777E-3f, + 7.78576235018280120474E-1f}; + + return std::make_tuple(coeff, 7); +} + +template +inline typename std::enable_if_t, T> +calc_i0(T _x) { + T x = std::abs(_x); + + if (x <= T{8.0}) { + auto [A, len] = chebyshev_coefficients_i0e_A(); + T y = (x / T{2.0}) - T{2.0}; + return static_cast(std::exp(x) * chbevl(y, A, len)); + } + auto [B, len] = chebyshev_coefficients_i0e_B(); + return std::exp(x) * chbevl(T{32.0} / x - T{2.0}, B, len) / std::sqrt(x); +} + +// Upcast bfloat16/half input to float for numerical accuracy purposes +inline c10::BFloat16 calc_i0(c10::BFloat16 a) { return calc_i0(static_cast(a)); } +inline c10::Half calc_i0(c10::Half a) { return calc_i0(static_cast(a)); } + +/* + * This function is derived from the implementation of the i1 function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + * + * Computes an approximation of the first order modified Bessel function of the first kind. + * The approximation is actually two (sub)approximations, both using a Chebyshev polynomial expansion. + * One approximates the function over [0, 8], and the other over (8, infinity). This function takes the absolute value + * of all inputs to convert them into the domain of the approximation. + */ +template +inline typename std::enable_if_t, T> +calc_i1(T _x) { + T x = std::abs(_x); + + if (x <= T{8.0}) { + auto [A, len] = chebyshev_coefficients_i1e_A(); + T y = (x / T{2.0}) - T{2.0}; + const T out = std::exp(x) * x * chbevl(y, A, len); + return (_x < T{0.0}) ? -out : out; + } + auto [B, len] = chebyshev_coefficients_i1e_B(); + const T out = (std::exp(x) * chbevl(T{32.0} / x - T{2.0}, B, len)) / std::sqrt(x); + return (_x < T{0.0}) ? -out : out; +} + +// Upcast bfloat16/half input to float for numerical accuracy purposes +inline c10::BFloat16 calc_i1(c10::BFloat16 a) { return calc_i1(static_cast(a)); } +inline c10::Half calc_i1(c10::Half a) { return calc_i1(static_cast(a)); } + + +/* + * This function is derived from the implementation of the i1e function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + * + * Computes an approximation of the exponentially scaled first order modified Bessel function of the first kind. + * The approximation is actually two (sub)approximations, both using a Chebyshev polynomial expansion. + * One approximates the function over [0, 8], and the other over (8, infinity). This function takes the absolute value + * of all inputs to convert them into the domain of the approximation. + */ +template +inline typename std::enable_if_t, T> +calc_i1e(T _x) { + T x = std::abs(_x); + + if (x <= T{8.0}) { + auto [A, len] = chebyshev_coefficients_i1e_A(); + T y = (x / T{2.0}) - T{2.0}; + const T out = chbevl(y, A, len) * x; + return (_x < T{0.0}) ? -out : out; + } + auto [B, len] = chebyshev_coefficients_i1e_B(); + const auto out = chbevl(T{32.0} / x - T{2.0}, B, len) / std::sqrt(x); + return (_x < T{0.0}) ? -out : out; +} + +// Upcast bfloat16/half input to float for numerical accuracy purposes +inline c10::BFloat16 calc_i1e(c10::BFloat16 a) { return calc_i1e(static_cast(a)); } +inline c10::Half calc_i1e(c10::Half a) { return calc_i1e(static_cast(a)); } + + +/* + * This function is derived from the implementation of the i1e function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + * + * Computes the argument, x, for which the area under the Gaussian probability density function + * (integrated from minus infinity to x) is equal to y. + */ +template +inline C10_HOST_DEVICE T calc_ndtri(T y0) { + + /* sqrt(2pi) */ + constexpr T s2pi = 2.50662827463100050242E0; + constexpr T one = 1; + constexpr T zero = 0; + + /* approximation for 0 <= |y - 0.5| <= 3/8 */ + static const T P0[5] = { + -5.99633501014107895267E1, + 9.80010754185999661536E1, + -5.66762857469070293439E1, + 1.39312609387279679503E1, + -1.23916583867381258016E0, + }; + + static const T Q0[9] = { + 1.00000000000000000000E0, + 1.95448858338141759834E0, + 4.67627912898881538453E0, + 8.63602421390890590575E1, + -2.25462687854119370527E2, + 2.00260212380060660359E2, + -8.20372256168333339912E1, + 1.59056225126211695515E1, + -1.18331621121330003142E0, + }; + + /* Approximation for interval z = sqrt(-2 log y ) between 2 and 8 + * i.e., y between exp(-2) = .135 and exp(-32) = 1.27e-14. + */ + static const T P1[9] = { + 4.05544892305962419923E0, + 3.15251094599893866154E1, + 5.71628192246421288162E1, + 4.40805073893200834700E1, + 1.46849561928858024014E1, + 2.18663306850790267539E0, + -1.40256079171354495875E-1, + -3.50424626827848203418E-2, + -8.57456785154685413611E-4, + }; + + static const T Q1[9] = { + 1.00000000000000000000E0, + 1.57799883256466749731E1, + 4.53907635128879210584E1, + 4.13172038254672030440E1, + 1.50425385692907503408E1, + 2.50464946208309415979E0, + -1.42182922854787788574E-1, + -3.80806407691578277194E-2, + -9.33259480895457427372E-4, + }; + + /* Approximation for interval z = sqrt(-2 log y ) between 8 and 64 + * i.e., y between exp(-32) = 1.27e-14 and exp(-2048) = 3.67e-890. + */ + + static const T P2[9] = { + 3.23774891776946035970E0, + 6.91522889068984211695E0, + 3.93881025292474443415E0, + 1.33303460815807542389E0, + 2.01485389549179081538E-1, + 1.23716634817820021358E-2, + 3.01581553508235416007E-4, + 2.65806974686737550832E-6, + 6.23974539184983293730E-9, + }; + + static const T Q2[9] = { + 1.00000000000000000000E0, + 6.02427039364742014255E0, + 3.67983563856160859403E0, + 1.37702099489081330271E0, + 2.16236993594496635890E-1, + 1.34204006088543189037E-2, + 3.28014464682127739104E-4, + 2.89247864745380683936E-6, + 6.79019408009981274425E-9, + }; + + if (y0 == zero) { + return -std::numeric_limits::infinity(); + } + if (y0 == one) { + return std::numeric_limits::infinity(); + } + if (y0 < zero || y0 > one) { + return std::numeric_limits::quiet_NaN(); + } + bool code = true; + T y = y0; + if (y > one - T{0.13533528323661269189}) { /* 0.135... = exp(-2) */ + y = one - y; + code = false; + } + + if (y > T{0.13533528323661269189}) { + y = y - T{0.5}; + const T y2 = y * y; + T x = y + y * (y2 * polevl(y2, P0, 4) / polevl(y2, Q0, 8)); + return (x * s2pi); + } + + T x = ::sqrt(T{-2.0} * ::log(y)); + const T x0 = x - ::log(x) / x; + + const T z = one / x; + T x1; + if (x < T{8.0}) /* y > exp(-32) = 1.2664165549e-14 */ + { + x1 = z * polevl(z, P1, 8) / polevl(z, Q1, 8); + } else { + x1 = z * polevl(z, P2, 8) / polevl(z, Q2, 8); + } + x = x0 - x1; + if (code) { + x = -x; + } + return x; +} + +/* The next function is taken from http://ab-initio.mit.edu/faddeeva */ + +/* Copyright (c) 2012 Massachusetts Institute of Technology + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files (the + * "Software"), to deal in the Software without restriction, including + * without limitation the rights to use, copy, modify, merge, publish, + * distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to + * the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE + * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION + * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION + * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +/* erfcx(x) = exp(x^2) erfc(x) function, for real x, written by + Steven G. Johnson, October 2012. + + This function combines a few different ideas. + + First, for x > 50, it uses a continued-fraction expansion (same as + for the Faddeeva function, but with algebraic simplifications for z=i*x). + + Second, for 0 <= x <= 50, it uses Chebyshev polynomial approximations, + but with two twists: + + a) It maps x to y = 4 / (4+x) in [0,1]. This simple transformation, + inspired by a similar transformation in the octave-forge/specfun + erfcx by Soren Hauberg, results in much faster Chebyshev convergence + than other simple transformations I have examined. + + b) Instead of using a single Chebyshev polynomial for the entire + [0,1] y interval, we break the interval up into 100 equal + subintervals, with a switch/lookup table, and use much lower + degree Chebyshev polynomials in each subinterval. This greatly + improves performance in my tests. + + For x < 0, we use the relationship erfcx(-x) = 2 exp(x^2) - erfc(x), + with the usual checks for overflow etcetera. + + Performance-wise, it seems to be substantially faster than either + the SLATEC DERFC function [or an erfcx function derived therefrom] + or Cody's CALERF function (from netlib.org/specfun), while + retaining near machine precision in accuracy. */ + +/* Given y100=100*y, where y = 4/(4+x) for x >= 0, compute erfc(x). + + Uses a look-up table of 100 different Chebyshev polynomials + for y intervals [0,0.01], [0.01,0.02], ...., [0.99,1], generated + with the help of Maple and a little shell script. This allows + the Chebyshev polynomials to be of significantly lower degree (about 1/4) + compared to fitting the whole [0,1] interval with a single polynomial. */ + + +template +C10_HOST_DEVICE inline typename std::enable_if_t, T> +erfcx_y100(T y100) +{ + switch (static_cast(y100)) { +case 0: { +T t = 2*y100 - 1; +return 0.70878032454106438663e-3 + (0.71234091047026302958e-3 + (0.35779077297597742384e-5 + (0.17403143962587937815e-7 + (0.81710660047307788845e-10 + (0.36885022360434957634e-12 + 0.15917038551111111111e-14 * t) * t) * t) * t) * t) * t; +} +case 1: { +T t = 2*y100 - 3; +return 0.21479143208285144230e-2 + (0.72686402367379996033e-3 + (0.36843175430938995552e-5 + (0.18071841272149201685e-7 + (0.85496449296040325555e-10 + (0.38852037518534291510e-12 + 0.16868473576888888889e-14 * t) * t) * t) * t) * t) * t; +} +case 2: { +T t = 2*y100 - 5; +return 0.36165255935630175090e-2 + (0.74182092323555510862e-3 + (0.37948319957528242260e-5 + (0.18771627021793087350e-7 + (0.89484715122415089123e-10 + (0.40935858517772440862e-12 + 0.17872061464888888889e-14 * t) * t) * t) * t) * t) * t; +} +case 3: { +T t = 2*y100 - 7; +return 0.51154983860031979264e-2 + (0.75722840734791660540e-3 + (0.39096425726735703941e-5 + (0.19504168704300468210e-7 + (0.93687503063178993915e-10 + (0.43143925959079664747e-12 + 0.18939926435555555556e-14 * t) * t) * t) * t) * t) * t; +} +case 4: { +T t = 2*y100 - 9; +return 0.66457513172673049824e-2 + (0.77310406054447454920e-3 + (0.40289510589399439385e-5 + (0.20271233238288381092e-7 + (0.98117631321709100264e-10 + (0.45484207406017752971e-12 + 0.20076352213333333333e-14 * t) * t) * t) * t) * t) * t; +} +case 5: { +T t = 2*y100 - 11; +return 0.82082389970241207883e-2 + (0.78946629611881710721e-3 + (0.41529701552622656574e-5 + (0.21074693344544655714e-7 + (0.10278874108587317989e-9 + (0.47965201390613339638e-12 + 0.21285907413333333333e-14 * t) * t) * t) * t) * t) * t; +} +case 6: { +T t = 2*y100 - 13; +return 0.98039537275352193165e-2 + (0.80633440108342840956e-3 + (0.42819241329736982942e-5 + (0.21916534346907168612e-7 + (0.10771535136565470914e-9 + (0.50595972623692822410e-12 + 0.22573462684444444444e-14 * t) * t) * t) * t) * t) * t; +} +case 7: { +T t = 2*y100 - 15; +return 0.11433927298290302370e-1 + (0.82372858383196561209e-3 + (0.44160495311765438816e-5 + (0.22798861426211986056e-7 + (0.11291291745879239736e-9 + (0.53386189365816880454e-12 + 0.23944209546666666667e-14 * t) * t) * t) * t) * t) * t; +} +case 8: { +T t = 2*y100 - 17; +return 0.13099232878814653979e-1 + (0.84167002467906968214e-3 + (0.45555958988457506002e-5 + (0.23723907357214175198e-7 + (0.11839789326602695603e-9 + (0.56346163067550237877e-12 + 0.25403679644444444444e-14 * t) * t) * t) * t) * t) * t; +} +case 9: { +T t = 2*y100 - 19; +return 0.14800987015587535621e-1 + (0.86018092946345943214e-3 + (0.47008265848816866105e-5 + (0.24694040760197315333e-7 + (0.12418779768752299093e-9 + (0.59486890370320261949e-12 + 0.26957764568888888889e-14 * t) * t) * t) * t) * t) * t; +} +case 10: { +T t = 2*y100 - 21; +return 0.16540351739394069380e-1 + (0.87928458641241463952e-3 + (0.48520195793001753903e-5 + (0.25711774900881709176e-7 + (0.13030128534230822419e-9 + (0.62820097586874779402e-12 + 0.28612737351111111111e-14 * t) * t) * t) * t) * t) * t; +} +case 11: { +T t = 2*y100 - 23; +return 0.18318536789842392647e-1 + (0.89900542647891721692e-3 + (0.50094684089553365810e-5 + (0.26779777074218070482e-7 + (0.13675822186304615566e-9 + (0.66358287745352705725e-12 + 0.30375273884444444444e-14 * t) * t) * t) * t) * t) * t; +} +case 12: { +T t = 2*y100 - 25; +return 0.20136801964214276775e-1 + (0.91936908737673676012e-3 + (0.51734830914104276820e-5 + (0.27900878609710432673e-7 + (0.14357976402809042257e-9 + (0.70114790311043728387e-12 + 0.32252476000000000000e-14 * t) * t) * t) * t) * t) * t; +} +case 13: { +T t = 2*y100 - 27; +return 0.21996459598282740954e-1 + (0.94040248155366777784e-3 + (0.53443911508041164739e-5 + (0.29078085538049374673e-7 + (0.15078844500329731137e-9 + (0.74103813647499204269e-12 + 0.34251892320000000000e-14 * t) * t) * t) * t) * t) * t; +} +case 14: { +T t = 2*y100 - 29; +return 0.23898877187226319502e-1 + (0.96213386835900177540e-3 + (0.55225386998049012752e-5 + (0.30314589961047687059e-7 + (0.15840826497296335264e-9 + (0.78340500472414454395e-12 + 0.36381553564444444445e-14 * t) * t) * t) * t) * t) * t; +} +case 15: { +T t = 2*y100 - 31; +return 0.25845480155298518485e-1 + (0.98459293067820123389e-3 + (0.57082915920051843672e-5 + (0.31613782169164830118e-7 + (0.16646478745529630813e-9 + (0.82840985928785407942e-12 + 0.38649975768888888890e-14 * t) * t) * t) * t) * t) * t; +} +case 16: { +T t = 2*y100 - 33; +return 0.27837754783474696598e-1 + (0.10078108563256892757e-2 + (0.59020366493792212221e-5 + (0.32979263553246520417e-7 + (0.17498524159268458073e-9 + (0.87622459124842525110e-12 + 0.41066206488888888890e-14 * t) * t) * t) * t) * t) * t; +} +case 17: { +T t = 2*y100 - 35; +return 0.29877251304899307550e-1 + (0.10318204245057349310e-2 + (0.61041829697162055093e-5 + (0.34414860359542720579e-7 + (0.18399863072934089607e-9 + (0.92703227366365046533e-12 + 0.43639844053333333334e-14 * t) * t) * t) * t) * t) * t; +} +case 18: { +T t = 2*y100 - 37; +return 0.31965587178596443475e-1 + (0.10566560976716574401e-2 + (0.63151633192414586770e-5 + (0.35924638339521924242e-7 + (0.19353584758781174038e-9 + (0.98102783859889264382e-12 + 0.46381060817777777779e-14 * t) * t) * t) * t) * t) * t; +} +case 19: { +T t = 2*y100 - 39; +return 0.34104450552588334840e-1 + (0.10823541191350532574e-2 + (0.65354356159553934436e-5 + (0.37512918348533521149e-7 + (0.20362979635817883229e-9 + (0.10384187833037282363e-11 + 0.49300625262222222221e-14 * t) * t) * t) * t) * t) * t; +} +case 20: { +T t = 2*y100 - 41; +return 0.36295603928292425716e-1 + (0.11089526167995268200e-2 + (0.67654845095518363577e-5 + (0.39184292949913591646e-7 + (0.21431552202133775150e-9 + (0.10994259106646731797e-11 + 0.52409949102222222221e-14 * t) * t) * t) * t) * t) * t; +} +case 21: { +T t = 2*y100 - 43; +return 0.38540888038840509795e-1 + (0.11364917134175420009e-2 + (0.70058230641246312003e-5 + (0.40943644083718586939e-7 + (0.22563034723692881631e-9 + (0.11642841011361992885e-11 + 0.55721092871111111110e-14 * t) * t) * t) * t) * t) * t; +} +case 22: { +T t = 2*y100 - 45; +return 0.40842225954785960651e-1 + (0.11650136437945673891e-2 + (0.72569945502343006619e-5 + (0.42796161861855042273e-7 + (0.23761401711005024162e-9 + (0.12332431172381557035e-11 + 0.59246802364444444445e-14 * t) * t) * t) * t) * t) * t; +} +case 23: { +T t = 2*y100 - 47; +return 0.43201627431540222422e-1 + (0.11945628793917272199e-2 + (0.75195743532849206263e-5 + (0.44747364553960993492e-7 + (0.25030885216472953674e-9 + (0.13065684400300476484e-11 + 0.63000532853333333334e-14 * t) * t) * t) * t) * t) * t; +} +case 24: { +T t = 2*y100 - 49; +return 0.45621193513810471438e-1 + (0.12251862608067529503e-2 + (0.77941720055551920319e-5 + (0.46803119830954460212e-7 + (0.26375990983978426273e-9 + (0.13845421370977119765e-11 + 0.66996477404444444445e-14 * t) * t) * t) * t) * t) * t; +} +case 25: { +T t = 2*y100 - 51; +return 0.48103121413299865517e-1 + (0.12569331386432195113e-2 + (0.80814333496367673980e-5 + (0.48969667335682018324e-7 + (0.27801515481905748484e-9 + (0.14674637611609884208e-11 + 0.71249589351111111110e-14 * t) * t) * t) * t) * t) * t; +} +case 26: { +T t = 2*y100 - 53; +return 0.50649709676983338501e-1 + (0.12898555233099055810e-2 + (0.83820428414568799654e-5 + (0.51253642652551838659e-7 + (0.29312563849675507232e-9 + (0.15556512782814827846e-11 + 0.75775607822222222221e-14 * t) * t) * t) * t) * t) * t; +} +case 27: { +T t = 2*y100 - 55; +return 0.53263363664388864181e-1 + (0.13240082443256975769e-2 + (0.86967260015007658418e-5 + (0.53662102750396795566e-7 + (0.30914568786634796807e-9 + (0.16494420240828493176e-11 + 0.80591079644444444445e-14 * t) * t) * t) * t) * t) * t; +} +case 28: { +T t = 2*y100 - 57; +return 0.55946601353500013794e-1 + (0.13594491197408190706e-2 + (0.90262520233016380987e-5 + (0.56202552975056695376e-7 + (0.32613310410503135996e-9 + (0.17491936862246367398e-11 + 0.85713381688888888890e-14 * t) * t) * t) * t) * t) * t; +} +case 29: { +T t = 2*y100 - 59; +return 0.58702059496154081813e-1 + (0.13962391363223647892e-2 + (0.93714365487312784270e-5 + (0.58882975670265286526e-7 + (0.34414937110591753387e-9 + (0.18552853109751857859e-11 + 0.91160736711111111110e-14 * t) * t) * t) * t) * t) * t; +} +case 30: { +T t = 2*y100 - 61; +return 0.61532500145144778048e-1 + (0.14344426411912015247e-2 + (0.97331446201016809696e-5 + (0.61711860507347175097e-7 + (0.36325987418295300221e-9 + (0.19681183310134518232e-11 + 0.96952238400000000000e-14 * t) * t) * t) * t) * t) * t; +} +case 31: { +T t = 2*y100 - 63; +return 0.64440817576653297993e-1 + (0.14741275456383131151e-2 + (0.10112293819576437838e-4 + (0.64698236605933246196e-7 + (0.38353412915303665586e-9 + (0.20881176114385120186e-11 + 0.10310784480000000000e-13 * t) * t) * t) * t) * t) * t; +} +case 32: { +T t = 2*y100 - 65; +return 0.67430045633130393282e-1 + (0.15153655418916540370e-2 + (0.10509857606888328667e-4 + (0.67851706529363332855e-7 + (0.40504602194811140006e-9 + (0.22157325110542534469e-11 + 0.10964842115555555556e-13 * t) * t) * t) * t) * t) * t; +} +case 33: { +T t = 2*y100 - 67; +return 0.70503365513338850709e-1 + (0.15582323336495709827e-2 + (0.10926868866865231089e-4 + (0.71182482239613507542e-7 + (0.42787405890153386710e-9 + (0.23514379522274416437e-11 + 0.11659571751111111111e-13 * t) * t) * t) * t) * t) * t; +} +case 34: { +T t = 2*y100 - 69; +return 0.73664114037944596353e-1 + (0.16028078812438820413e-2 + (0.11364423678778207991e-4 + (0.74701423097423182009e-7 + (0.45210162777476488324e-9 + (0.24957355004088569134e-11 + 0.12397238257777777778e-13 * t) * t) * t) * t) * t) * t; +} +case 35: { +T t = 2*y100 - 71; +return 0.76915792420819562379e-1 + (0.16491766623447889354e-2 + (0.11823685320041302169e-4 + (0.78420075993781544386e-7 + (0.47781726956916478925e-9 + (0.26491544403815724749e-11 + 0.13180196462222222222e-13 * t) * t) * t) * t) * t) * t; +} +case 36: { +T t = 2*y100 - 73; +return 0.80262075578094612819e-1 + (0.16974279491709504117e-2 + (0.12305888517309891674e-4 + (0.82350717698979042290e-7 + (0.50511496109857113929e-9 + (0.28122528497626897696e-11 + 0.14010889635555555556e-13 * t) * t) * t) * t) * t) * t; +} +case 37: { +T t = 2*y100 - 75; +return 0.83706822008980357446e-1 + (0.17476561032212656962e-2 + (0.12812343958540763368e-4 + (0.86506399515036435592e-7 + (0.53409440823869467453e-9 + (0.29856186620887555043e-11 + 0.14891851591111111111e-13 * t) * t) * t) * t) * t) * t; +} +case 38: { +T t = 2*y100 - 77; +return 0.87254084284461718231e-1 + (0.17999608886001962327e-2 + (0.13344443080089492218e-4 + (0.90900994316429008631e-7 + (0.56486134972616465316e-9 + (0.31698707080033956934e-11 + 0.15825697795555555556e-13 * t) * t) * t) * t) * t) * t; +} +case 39: { +T t = 2*y100 - 79; +return 0.90908120182172748487e-1 + (0.18544478050657699758e-2 + (0.13903663143426120077e-4 + (0.95549246062549906177e-7 + (0.59752787125242054315e-9 + (0.33656597366099099413e-11 + 0.16815130613333333333e-13 * t) * t) * t) * t) * t) * t; +} +case 40: { +T t = 2*y100 - 81; +return 0.94673404508075481121e-1 + (0.19112284419887303347e-2 + (0.14491572616545004930e-4 + (0.10046682186333613697e-6 + (0.63221272959791000515e-9 + (0.35736693975589130818e-11 + 0.17862931591111111111e-13 * t) * t) * t) * t) * t) * t; +} +case 41: { +T t = 2*y100 - 83; +return 0.98554641648004456555e-1 + (0.19704208544725622126e-2 + (0.15109836875625443935e-4 + (0.10567036667675984067e-6 + (0.66904168640019354565e-9 + (0.37946171850824333014e-11 + 0.18971959040000000000e-13 * t) * t) * t) * t) * t) * t; +} +case 42: { +T t = 2*y100 - 85; +return 0.10255677889470089531e0 + (0.20321499629472857418e-2 + (0.15760224242962179564e-4 + (0.11117756071353507391e-6 + (0.70814785110097658502e-9 + (0.40292553276632563925e-11 + 0.20145143075555555556e-13 * t) * t) * t) * t) * t) * t; +} +case 43: { +T t = 2*y100 - 87; +return 0.10668502059865093318e0 + (0.20965479776148731610e-2 + (0.16444612377624983565e-4 + (0.11700717962026152749e-6 + (0.74967203250938418991e-9 + (0.42783716186085922176e-11 + 0.21385479360000000000e-13 * t) * t) * t) * t) * t) * t; +} +case 44: { +T t = 2*y100 - 89; +return 0.11094484319386444474e0 + (0.21637548491908170841e-2 + (0.17164995035719657111e-4 + (0.12317915750735938089e-6 + (0.79376309831499633734e-9 + (0.45427901763106353914e-11 + 0.22696025653333333333e-13 * t) * t) * t) * t) * t) * t; +} +case 45: { +T t = 2*y100 - 91; +return 0.11534201115268804714e0 + (0.22339187474546420375e-2 + (0.17923489217504226813e-4 + (0.12971465288245997681e-6 + (0.84057834180389073587e-9 + (0.48233721206418027227e-11 + 0.24079890062222222222e-13 * t) * t) * t) * t) * t) * t; +} +case 46: { +T t = 2*y100 - 93; +return 0.11988259392684094740e0 + (0.23071965691918689601e-2 + (0.18722342718958935446e-4 + (0.13663611754337957520e-6 + (0.89028385488493287005e-9 + (0.51210161569225846701e-11 + 0.25540227111111111111e-13 * t) * t) * t) * t) * t) * t; +} +case 47: { +T t = 2*y100 - 95; +return 0.12457298393509812907e0 + (0.23837544771809575380e-2 + (0.19563942105711612475e-4 + (0.14396736847739470782e-6 + (0.94305490646459247016e-9 + (0.54366590583134218096e-11 + 0.27080225920000000000e-13 * t) * t) * t) * t) * t) * t; +} +case 48: { +T t = 2*y100 - 97; +return 0.12941991566142438816e0 + (0.24637684719508859484e-2 + (0.20450821127475879816e-4 + (0.15173366280523906622e-6 + (0.99907632506389027739e-9 + (0.57712760311351625221e-11 + 0.28703099555555555556e-13 * t) * t) * t) * t) * t) * t; +} +case 49: { +T t = 2*y100 - 99; +return 0.13443048593088696613e0 + (0.25474249981080823877e-2 + (0.21385669591362915223e-4 + (0.15996177579900443030e-6 + (0.10585428844575134013e-8 + (0.61258809536787882989e-11 + 0.30412080142222222222e-13 * t) * t) * t) * t) * t) * t; +} +case 50: { +T t = 2*y100 - 101; +return 0.13961217543434561353e0 + (0.26349215871051761416e-2 + (0.22371342712572567744e-4 + (0.16868008199296822247e-6 + (0.11216596910444996246e-8 + (0.65015264753090890662e-11 + 0.32210394506666666666e-13 * t) * t) * t) * t) * t) * t; +} +case 51: { +T t = 2*y100 - 103; +return 0.14497287157673800690e0 + (0.27264675383982439814e-2 + (0.23410870961050950197e-4 + (0.17791863939526376477e-6 + (0.11886425714330958106e-8 + (0.68993039665054288034e-11 + 0.34101266222222222221e-13 * t) * t) * t) * t) * t) * t; +} +case 52: { +T t = 2*y100 - 105; +return 0.15052089272774618151e0 + (0.28222846410136238008e-2 + (0.24507470422713397006e-4 + (0.18770927679626136909e-6 + (0.12597184587583370712e-8 + (0.73203433049229821618e-11 + 0.36087889048888888890e-13 * t) * t) * t) * t) * t) * t; +} +case 53: { +T t = 2*y100 - 107; +return 0.15626501395774612325e0 + (0.29226079376196624949e-2 + (0.25664553693768450545e-4 + (0.19808568415654461964e-6 + (0.13351257759815557897e-8 + (0.77658124891046760667e-11 + 0.38173420035555555555e-13 * t) * t) * t) * t) * t) * t; +} +case 54: { +T t = 2*y100 - 109; +return 0.16221449434620737567e0 + (0.30276865332726475672e-2 + (0.26885741326534564336e-4 + (0.20908350604346384143e-6 + (0.14151148144240728728e-8 + (0.82369170665974313027e-11 + 0.40360957457777777779e-13 * t) * t) * t) * t) * t) * t; +} +case 55: { +T t = 2*y100 - 111; +return 0.16837910595412130659e0 + (0.31377844510793082301e-2 + (0.28174873844911175026e-4 + (0.22074043807045782387e-6 + (0.14999481055996090039e-8 + (0.87348993661930809254e-11 + 0.42653528977777777779e-13 * t) * t) * t) * t) * t) * t; +} +case 56: { +T t = 2*y100 - 113; +return 0.17476916455659369953e0 + (0.32531815370903068316e-2 + (0.29536024347344364074e-4 + (0.23309632627767074202e-6 + (0.15899007843582444846e-8 + (0.92610375235427359475e-11 + 0.45054073102222222221e-13 * t) * t) * t) * t) * t) * t; +} +case 57: { +T t = 2*y100 - 115; +return 0.18139556223643701364e0 + (0.33741744168096996041e-2 + (0.30973511714709500836e-4 + (0.24619326937592290996e-6 + (0.16852609412267750744e-8 + (0.98166442942854895573e-11 + 0.47565418097777777779e-13 * t) * t) * t) * t) * t) * t; +} +case 58: { +T t = 2*y100 - 117; +return 0.18826980194443664549e0 + (0.35010775057740317997e-2 + (0.32491914440014267480e-4 + (0.26007572375886319028e-6 + (0.17863299617388376116e-8 + (0.10403065638343878679e-10 + 0.50190265831111111110e-13 * t) * t) * t) * t) * t) * t; +} +case 59: { +T t = 2*y100 - 119; +return 0.19540403413693967350e0 + (0.36342240767211326315e-2 + (0.34096085096200907289e-4 + (0.27479061117017637474e-6 + (0.18934228504790032826e-8 + (0.11021679075323598664e-10 + 0.52931171733333333334e-13 * t) * t) * t) * t) * t) * t; +} +case 60: { +T t = 2*y100 - 121; +return 0.20281109560651886959e0 + (0.37739673859323597060e-2 + (0.35791165457592409054e-4 + (0.29038742889416172404e-6 + (0.20068685374849001770e-8 + (0.11673891799578381999e-10 + 0.55790523093333333334e-13 * t) * t) * t) * t) * t) * t; +} +case 61: { +T t = 2*y100 - 123; +return 0.21050455062669334978e0 + (0.39206818613925652425e-2 + (0.37582602289680101704e-4 + (0.30691836231886877385e-6 + (0.21270101645763677824e-8 + (0.12361138551062899455e-10 + 0.58770520160000000000e-13 * t) * t) * t) * t) * t) * t; +} +case 62: { +T t = 2*y100 - 125; +return 0.21849873453703332479e0 + (0.40747643554689586041e-2 + (0.39476163820986711501e-4 + (0.32443839970139918836e-6 + (0.22542053491518680200e-8 + (0.13084879235290858490e-10 + 0.61873153262222222221e-13 * t) * t) * t) * t) * t) * t; +} +case 63: { +T t = 2*y100 - 127; +return 0.22680879990043229327e0 + (0.42366354648628516935e-2 + (0.41477956909656896779e-4 + (0.34300544894502810002e-6 + (0.23888264229264067658e-8 + (0.13846596292818514601e-10 + 0.65100183751111111110e-13 * t) * t) * t) * t) * t) * t; +} +case 64: { +T t = 2*y100 - 129; +return 0.23545076536988703937e0 + (0.44067409206365170888e-2 + (0.43594444916224700881e-4 + (0.36268045617760415178e-6 + (0.25312606430853202748e-8 + (0.14647791812837903061e-10 + 0.68453122631111111110e-13 * t) * t) * t) * t) * t) * t; +} +case 65: { +T t = 2*y100 - 131; +return 0.24444156740777432838e0 + (0.45855530511605787178e-2 + (0.45832466292683085475e-4 + (0.38352752590033030472e-6 + (0.26819103733055603460e-8 + (0.15489984390884756993e-10 + 0.71933206364444444445e-13 * t) * t) * t) * t) * t) * t; +} +case 66: { +T t = 2*y100 - 133; +return 0.25379911500634264643e0 + (0.47735723208650032167e-2 + (0.48199253896534185372e-4 + (0.40561404245564732314e-6 + (0.28411932320871165585e-8 + (0.16374705736458320149e-10 + 0.75541379822222222221e-13 * t) * t) * t) * t) * t) * t; +} +case 67: { +T t = 2*y100 - 135; +return 0.26354234756393613032e0 + (0.49713289477083781266e-2 + (0.50702455036930367504e-4 + (0.42901079254268185722e-6 + (0.30095422058900481753e-8 + (0.17303497025347342498e-10 + 0.79278273368888888890e-13 * t) * t) * t) * t) * t) * t; +} +case 68: { +T t = 2*y100 - 137; +return 0.27369129607732343398e0 + (0.51793846023052643767e-2 + (0.53350152258326602629e-4 + (0.45379208848865015485e-6 + (0.31874057245814381257e-8 + (0.18277905010245111046e-10 + 0.83144182364444444445e-13 * t) * t) * t) * t) * t) * t; +} +case 69: { +T t = 2*y100 - 139; +return 0.28426714781640316172e0 + (0.53983341916695141966e-2 + (0.56150884865255810638e-4 + (0.48003589196494734238e-6 + (0.33752476967570796349e-8 + (0.19299477888083469086e-10 + 0.87139049137777777779e-13 * t) * t) * t) * t) * t) * t; +} +case 70: { +T t = 2*y100 - 141; +return 0.29529231465348519920e0 + (0.56288077305420795663e-2 + (0.59113671189913307427e-4 + (0.50782393781744840482e-6 + (0.35735475025851713168e-8 + (0.20369760937017070382e-10 + 0.91262442613333333334e-13 * t) * t) * t) * t) * t) * t; +} +case 71: { +T t = 2*y100 - 143; +return 0.30679050522528838613e0 + (0.58714723032745403331e-2 + (0.62248031602197686791e-4 + (0.53724185766200945789e-6 + (0.37827999418960232678e-8 + (0.21490291930444538307e-10 + 0.95513539182222222221e-13 * t) * t) * t) * t) * t) * t; +} +case 72: { +T t = 2*y100 - 145; +return 0.31878680111173319425e0 + (0.61270341192339103514e-2 + (0.65564012259707640976e-4 + (0.56837930287837738996e-6 + (0.40035151353392378882e-8 + (0.22662596341239294792e-10 + 0.99891109760000000000e-13 * t) * t) * t) * t) * t) * t; +} +case 73: { +T t = 2*y100 - 147; +return 0.33130773722152622027e0 + (0.63962406646798080903e-2 + (0.69072209592942396666e-4 + (0.60133006661885941812e-6 + (0.42362183765883466691e-8 + (0.23888182347073698382e-10 + 0.10439349811555555556e-12 * t) * t) * t) * t) * t) * t; +} +case 74: { +T t = 2*y100 - 149; +return 0.34438138658041336523e0 + (0.66798829540414007258e-2 + (0.72783795518603561144e-4 + (0.63619220443228800680e-6 + (0.44814499336514453364e-8 + (0.25168535651285475274e-10 + 0.10901861383111111111e-12 * t) * t) * t) * t) * t) * t; +} +case 75: { +T t = 2*y100 - 151; +return 0.35803744972380175583e0 + (0.69787978834882685031e-2 + (0.76710543371454822497e-4 + (0.67306815308917386747e-6 + (0.47397647975845228205e-8 + (0.26505114141143050509e-10 + 0.11376390933333333333e-12 * t) * t) * t) * t) * t) * t; +} +case 76: { +T t = 2*y100 - 153; +return 0.37230734890119724188e0 + (0.72938706896461381003e-2 + (0.80864854542670714092e-4 + (0.71206484718062688779e-6 + (0.50117323769745883805e-8 + (0.27899342394100074165e-10 + 0.11862637614222222222e-12 * t) * t) * t) * t) * t) * t; +} +case 77: { +T t = 2*y100 - 155; +return 0.38722432730555448223e0 + (0.76260375162549802745e-2 + (0.85259785810004603848e-4 + (0.75329383305171327677e-6 + (0.52979361368388119355e-8 + (0.29352606054164086709e-10 + 0.12360253370666666667e-12 * t) * t) * t) * t) * t) * t; +} +case 78: { +T t = 2*y100 - 157; +return 0.40282355354616940667e0 + (0.79762880915029728079e-2 + (0.89909077342438246452e-4 + (0.79687137961956194579e-6 + (0.55989731807360403195e-8 + (0.30866246101464869050e-10 + 0.12868841946666666667e-12 * t) * t) * t) * t) * t) * t; +} +case 79: { +T t = 2*y100 - 159; +return 0.41914223158913787649e0 + (0.83456685186950463538e-2 + (0.94827181359250161335e-4 + (0.84291858561783141014e-6 + (0.59154537751083485684e-8 + (0.32441553034347469291e-10 + 0.13387957943111111111e-12 * t) * t) * t) * t) * t) * t; +} +case 80: { +T t = 2*y100 - 161; +return 0.43621971639463786896e0 + (0.87352841828289495773e-2 + (0.10002929142066799966e-3 + (0.89156148280219880024e-6 + (0.62480008150788597147e-8 + (0.34079760983458878910e-10 + 0.13917107176888888889e-12 * t) * t) * t) * t) * t) * t; +} +case 81: { +T t = 2*y100 - 163; +return 0.45409763548534330981e0 + (0.91463027755548240654e-2 + (0.10553137232446167258e-3 + (0.94293113464638623798e-6 + (0.65972492312219959885e-8 + (0.35782041795476563662e-10 + 0.14455745872000000000e-12 * t) * t) * t) * t) * t) * t; +} +case 82: { +T t = 2*y100 - 165; +return 0.47282001668512331468e0 + (0.95799574408860463394e-2 + (0.11135019058000067469e-3 + (0.99716373005509038080e-6 + (0.69638453369956970347e-8 + (0.37549499088161345850e-10 + 0.15003280712888888889e-12 * t) * t) * t) * t) * t) * t; +} +case 83: { +T t = 2*y100 - 167; +return 0.49243342227179841649e0 + (0.10037550043909497071e-1 + (0.11750334542845234952e-3 + (0.10544006716188967172e-5 + (0.73484461168242224872e-8 + (0.39383162326435752965e-10 + 0.15559069118222222222e-12 * t) * t) * t) * t) * t) * t; +} +case 84: { +T t = 2*y100 - 169; +return 0.51298708979209258326e0 + (0.10520454564612427224e-1 + (0.12400930037494996655e-3 + (0.11147886579371265246e-5 + (0.77517184550568711454e-8 + (0.41283980931872622611e-10 + 0.16122419680000000000e-12 * t) * t) * t) * t) * t) * t; +} +case 85: { +T t = 2*y100 - 171; +return 0.53453307979101369843e0 + (0.11030120618800726938e-1 + (0.13088741519572269581e-3 + (0.11784797595374515432e-5 + (0.81743383063044825400e-8 + (0.43252818449517081051e-10 + 0.16692592640000000000e-12 * t) * t) * t) * t) * t) * t; +} +case 86: { +T t = 2*y100 - 173; +return 0.55712643071169299478e0 + (0.11568077107929735233e-1 + (0.13815797838036651289e-3 + (0.12456314879260904558e-5 + (0.86169898078969313597e-8 + (0.45290446811539652525e-10 + 0.17268801084444444444e-12 * t) * t) * t) * t) * t) * t; +} +case 87: { +T t = 2*y100 - 175; +return 0.58082532122519320968e0 + (0.12135935999503877077e-1 + (0.14584223996665838559e-3 + (0.13164068573095710742e-5 + (0.90803643355106020163e-8 + (0.47397540713124619155e-10 + 0.17850211608888888889e-12 * t) * t) * t) * t) * t) * t; +} +case 88: { +T t = 2*y100 - 177; +return 0.60569124025293375554e0 + (0.12735396239525550361e-1 + (0.15396244472258863344e-3 + (0.13909744385382818253e-5 + (0.95651595032306228245e-8 + (0.49574672127669041550e-10 + 0.18435945564444444444e-12 * t) * t) * t) * t) * t) * t; +} +case 89: { +T t = 2*y100 - 179; +return 0.63178916494715716894e0 + (0.13368247798287030927e-1 + (0.16254186562762076141e-3 + (0.14695084048334056083e-5 + (0.10072078109604152350e-7 + (0.51822304995680707483e-10 + 0.19025081422222222222e-12 * t) * t) * t) * t) * t) * t; +} +case 90: { +T t = 2*y100 - 181; +return 0.65918774689725319200e0 + (0.14036375850601992063e-1 + (0.17160483760259706354e-3 + (0.15521885688723188371e-5 + (0.10601827031535280590e-7 + (0.54140790105837520499e-10 + 0.19616655146666666667e-12 * t) * t) * t) * t) * t) * t; +} +case 91: { +T t = 2*y100 - 183; +return 0.68795950683174433822e0 + (0.14741765091365869084e-1 + (0.18117679143520433835e-3 + (0.16392004108230585213e-5 + (0.11155116068018043001e-7 + (0.56530360194925690374e-10 + 0.20209663662222222222e-12 * t) * t) * t) * t) * t) * t; +} +case 92: { +T t = 2*y100 - 185; +return 0.71818103808729967036e0 + (0.15486504187117112279e-1 + (0.19128428784550923217e-3 + (0.17307350969359975848e-5 + (0.11732656736113607751e-7 + (0.58991125287563833603e-10 + 0.20803065333333333333e-12 * t) * t) * t) * t) * t) * t; +} +case 93: { +T t = 2*y100 - 187; +return 0.74993321911726254661e0 + (0.16272790364044783382e-1 + (0.20195505163377912645e-3 + (0.18269894883203346953e-5 + (0.12335161021630225535e-7 + (0.61523068312169087227e-10 + 0.21395783431111111111e-12 * t) * t) * t) * t) * t) * t; +} +case 94: { +T t = 2*y100 - 189; +return 0.78330143531283492729e0 + (0.17102934132652429240e-1 + (0.21321800585063327041e-3 + (0.19281661395543913713e-5 + (0.12963340087354341574e-7 + (0.64126040998066348872e-10 + 0.21986708942222222222e-12 * t) * t) * t) * t) * t) * t; +} +case 95: { +T t = 2*y100 - 191; +return 0.81837581041023811832e0 + (0.17979364149044223802e-1 + (0.22510330592753129006e-3 + (0.20344732868018175389e-5 + (0.13617902941839949718e-7 + (0.66799760083972474642e-10 + 0.22574701262222222222e-12 * t) * t) * t) * t) * t) * t; +} +case 96: { +T t = 2*y100 - 193; +return 0.85525144775685126237e0 + (0.18904632212547561026e-1 + (0.23764237370371255638e-3 + (0.21461248251306387979e-5 + (0.14299555071870523786e-7 + (0.69543803864694171934e-10 + 0.23158593688888888889e-12 * t) * t) * t) * t) * t) * t; +} +case 97: { +T t = 2*y100 - 195; +return 0.89402868170849933734e0 + (0.19881418399127202569e-1 + (0.25086793128395995798e-3 + (0.22633402747585233180e-5 + (0.15008997042116532283e-7 + (0.72357609075043941261e-10 + 0.23737194737777777778e-12 * t) * t) * t) * t) * t) * t; +} +case 98: { +T t = 2*y100 - 197; +return 0.93481333942870796363e0 + (0.20912536329780368893e-1 + (0.26481403465998477969e-3 + (0.23863447359754921676e-5 + (0.15746923065472184451e-7 + (0.75240468141720143653e-10 + 0.24309291271111111111e-12 * t) * t) * t) * t) * t) * t; +} +case 99: { +T t = 2*y100 - 199; +return 0.97771701335885035464e0 + (0.22000938572830479551e-1 + (0.27951610702682383001e-3 + (0.25153688325245314530e-5 + (0.16514019547822821453e-7 + (0.78191526829368231251e-10 + 0.24873652355555555556e-12 * t) * t) * t) * t) * t) * t; +} + } + // we only get here if y = 1, i.e. |x| < 4*eps, in which case + // erfcx is within 1e-15 of 1.. + return 1.0; +} + +template +C10_HOST_DEVICE inline typename std::enable_if_t, T> +calc_erfcx(T x) +{ + if (at::_isnan(x)) { + return x; + } + + if (x >= 0) { + if (x > 50) { // continued-fraction expansion is faster + const T ispi = 0.56418958354775628694807945156; // 1 / sqrt(pi) + if (x > 5e7) { // 1-term expansion, important to avoid overflow + return ispi / x; + } + /* 5-term expansion (rely on compiler for CSE), simplified from: + ispi / (x+0.5/(x+1/(x+1.5/(x+2/x)))) */ + return ispi*((x*x) * (x*x+4.5) + 2) / (x * ((x*x) * (x*x+5) + 3.75)); + } + return erfcx_y100(400/(4+x)); + } + else { + if (x < -26.7) { + return std::numeric_limits::infinity(); + } + else if (x < -6.1) { + return 2*exp(x*x); + } + else { + return 2*exp(x*x) - erfcx_y100(400/(4-x)); + } + } +} + +/* + * Logarithm of Gaussian cumulative distribution function. + + * This implementation of log_ndtr and its helper functions + * follow SciPy's implementation + * See NOTICE for the licenses. + */ +template +inline C10_HOST_DEVICE T calc_log_ndtr(T x) { + T t = x * c10::frac_sqrt_2; + if (x < T{-1.0}) { + return std::log(calc_erfcx(-t) / 2) - t * t; + } else { + return std::log1p(-std::erfc(t) / 2); + } +} + +template +inline C10_HOST_DEVICE T airy_ai_forward(T x) { + static const T AN[] = { + +3.46538101525629032477e-01, + +1.20075952739645805542e+01, + +7.62796053615234516538e+01, + +1.68089224934630576269e+02, + +1.59756391350164413639e+02, + +7.05360906840444183113e+01, + +1.40264691163389668864e+01, + +9.99999999999999995305e-01, + }; + + static const T AD[] = { + +5.67594532638770212846e-01, + +1.47562562584847203173e+01, + +8.45138970141474626562e+01, + +1.77318088145400459522e+02, + +1.64234692871529701831e+02, + +7.14778400825575695274e+01, + +1.40959135607834029598e+01, + +1.00000000000000000470e+00, + }; + + static const T AFN[] = { + -1.31696323418331795333e-01, + -6.26456544431912369773e-01, + -6.93158036036933542233e-01, + -2.79779981545119124951e-01, + -4.91900132609500318020e-02, + -4.06265923594885404393e-03, + -1.59276496239262096340e-04, + -2.77649108155232920844e-06, + -1.67787698489114633780e-08, + }; + + static const T AFD[] = { + +1.33560420706553243746e+01, + +3.26825032795224613948e+01, + +2.67367040941499554804e+01, + +9.18707402907259625840e+00, + +1.47529146771666414581e+00, + +1.15687173795188044134e-01, + +4.40291641615211203805e-03, + +7.54720348287414296618e-05, + +4.51850092970580378464e-07, + }; + + static const T AGN[] = { + +1.97339932091685679179e-02, + +3.91103029615688277255e-01, + +1.06579897599595591108e+00, + +9.39169229816650230044e-01, + +3.51465656105547619242e-01, + +6.33888919628925490927e-02, + +5.85804113048388458567e-03, + +2.82851600836737019778e-04, + +6.98793669997260967291e-06, + +8.11789239554389293311e-08, + +3.41551784765923618484e-10, + }; + + static const T AGD[] = { + +9.30892908077441974853e+00, + +1.98352928718312140417e+01, + +1.55646628932864612953e+01, + +5.47686069422975497931e+00, + +9.54293611618961883998e-01, + +8.64580826352392193095e-02, + +4.12656523824222607191e-03, + +1.01259085116509135510e-04, + +1.17166733214413521882e-06, + +4.91834570062930015649e-09, + }; + + int domain_flag = 0; + + T ai; + + if (std::isinf(x)) { + return std::numeric_limits::quiet_NaN(); + } + + if (x > T(103.892)) { + return T(0.0); + } + + T f; + T g; + T k; + + if (x < T(-2.09)) { + T z = T(1.0) / (T(-2.0) * x * std::sqrt(-x) / T(3.0)); + + T afn = 0.0; + + for (uint8_t index = 0; index <= 8; index++) { + afn = afn * (z * z) + AFN[index]; + } + + T afd = 0.0; + + for (uint8_t index = 0; index <= 8; index++) { + afd = afd * (z * z) + AFD[index]; + } + + T agn = 0.0; + + for (uint8_t index = 0; index <= 10 + 0; index++) { + agn = agn * (z * z) + AGN[index]; + } + + T agd = 0.0; + + for (uint8_t index = 0; index <= 10 - 1; index++) { + agd = agd * (z * z) + AGD[index]; + } + + T t = T(-2.0) * x * std::sqrt(-x) / T(3.0) + T(0.25) * c10::pi; + + return T(5.64189583547756286948e-01) / std::sqrt(std::sqrt(-x)) * (std::sin(t) * (T(1.0) + z * z * afn / afd) - std::cos(t) * (z * agn / agd)); + } + + if (x >= T(2.09)) { + domain_flag = 5; + + T zeta = T(2.0) * x * std::sqrt(x) / T(3.0); + + T an = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + an = an * (T(1.0) / zeta) + AN[index]; + } + + T ad = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + ad = ad * (T(1.0) / zeta) + AD[index]; + } + + ai = T(5.64189583547756286948e-01) * (an / ad) / (T(2.0) * std::sqrt(std::sqrt(x)) * std::exp(zeta)); + + if (x > T(8.3203353)) { + return ai; + } + } + + f = 1.0; + g = x; + k = 1.0; + + T m = 1.0; + T n = x; + T t = 1.0; + T z = x * x * x; + + while (t > std::numeric_limits::epsilon()) { + m *= z; + k += T(1.0); + m /= k; + n *= z; + k += T(1.0); + n /= k; + m /= k; + f += m; + k += T(1.0); + n /= k; + g += n; + + t = std::abs(m / f); + } + + if ((domain_flag & 1) == 0) { + return T(0.355028053887817239260) * f - T(0.258819403792806798405) * g; + } + + return ai; +} // T airy_ai(T x) + +template +inline C10_HOST_DEVICE T bessel_j0_forward(T x) { + static const T PP[] = { + +7.96936729297347051624e-04, + +8.28352392107440799803e-02, + +1.23953371646414299388e+00, + +5.44725003058768775090e+00, + +8.74716500199817011941e+00, + +5.30324038235394892183e+00, + +9.99999999999999997821e-01, + }; + + static const T PQ[] = { + +9.24408810558863637013e-04, + +8.56288474354474431428e-02, + +1.25352743901058953537e+00, + +5.47097740330417105182e+00, + +8.76190883237069594232e+00, + +5.30605288235394617618e+00, + +1.00000000000000000218e+00, + }; + + static const T QP[] = { + -1.13663838898469149931e-02, + -1.28252718670509318512e+00, + -1.95539544257735972385e+01, + -9.32060152123768231369e+01, + -1.77681167980488050595e+02, + -1.47077505154951170175e+02, + -5.14105326766599330220e+01, + -6.05014350600728481186e+00, + }; + + static const T QQ[] = { + +6.43178256118178023184e+01, + +8.56430025976980587198e+02, + +3.88240183605401609683e+03, + +7.24046774195652478189e+03, + +5.93072701187316984827e+03, + +2.06209331660327847417e+03, + +2.42005740240291393179e+02, + }; + + static const T RP[] = { + -4.79443220978201773821e+09, + +1.95617491946556577543e+12, + -2.49248344360967716204e+14, + +9.70862251047306323952e+15, + }; + + static const T RQ[] = { + +4.99563147152651017219e+02, + +1.73785401676374683123e+05, + +4.84409658339962045305e+07, + +1.11855537045356834862e+10, + +2.11277520115489217587e+12, + +3.10518229857422583814e+14, + +3.18121955943204943306e+16, + +1.71086294081043136091e+18, + }; + + if (x < T(0)) { + x = -x; + } + + if (x <= T(5.0)) { + if (x < T(0.00001)) { + return T(1.0) - x * x / T(4.0); + } + + T rp = 0.0; + + for (uint8_t index = 0; index <= 3; index++) { + rp = rp * (x * x) + RP[index]; + } + + T rq = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + rq = rq * (x * x) + RQ[index]; + } + + return (x * x - T(5.78318596294678452118e+00)) * (x * x - T(3.04712623436620863991e+01)) * rp / rq; + } + + T pp = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pp = pp * (T(25.0) / (x * x)) + PP[index]; + } + + T pq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pq = pq * (T(25.0) / (x * x)) + PQ[index]; + } + + T qp = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + qp = qp * (T(25.0) / (x * x)) + QP[index]; + } + + T qq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + qq = qq * (T(25.0) / (x * x)) + QQ[index]; + } + + return (pp / pq * std::cos(x - T(0.785398163397448309615660845819875721)) - T(5.0) / x * (qp / qq) * std::sin(x - T(0.785398163397448309615660845819875721))) * T(0.797884560802865355879892119868763737) / std::sqrt(x); +} // bessel_j0_forward(T x) + +template +inline C10_HOST_DEVICE T bessel_j1_forward(T x) { + static const T PP[] = { + +7.62125616208173112003e-04, + +7.31397056940917570436e-02, + +1.12719608129684925192e+00, + +5.11207951146807644818e+00, + +8.42404590141772420927e+00, + +5.21451598682361504063e+00, + +1.00000000000000000254e+00, + }; + + static const T PQ[] = { + +5.71323128072548699714e-04, + +6.88455908754495404082e-02, + +1.10514232634061696926e+00, + +5.07386386128601488557e+00, + +8.39985554327604159757e+00, + +5.20982848682361821619e+00, + +9.99999999999999997461e-01, + }; + + static const T QP[] = { + +5.10862594750176621635e-02, + +4.98213872951233449420e+00, + +7.58238284132545283818e+01, + +3.66779609360150777800e+02, + +7.10856304998926107277e+02, + +5.97489612400613639965e+02, + +2.11688757100572135698e+02, + +2.52070205858023719784e+01, + }; + + static const T QQ[] = { + +7.42373277035675149943e+01, + +1.05644886038262816351e+03, + +4.98641058337653607651e+03, + +9.56231892404756170795e+03, + +7.99704160447350683650e+03, + +2.82619278517639096600e+03, + +3.36093607810698293419e+02, + }; + + static const T RP[] = { + -8.99971225705559398224e+08, + +4.52228297998194034323e+11, + -7.27494245221818276015e+13, + +3.68295732863852883286e+15, + }; + + static const T RQ[] = { + +6.20836478118054335476e+02, + +2.56987256757748830383e+05, + +8.35146791431949253037e+07, + +2.21511595479792499675e+10, + +4.74914122079991414898e+12, + +7.84369607876235854894e+14, + +8.95222336184627338078e+16, + +5.32278620332680085395e+18, + }; + + if (x < T(0.0)) { + return -bessel_j1_forward(-x); + } + + if (x <= T(5.0)) { + T rp = 0.0; + + for (uint8_t index = 0; index <= 3; index++) { + rp = rp * (x * x) + RP[index]; + } + + T rq = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + rq = rq * (x * x) + RQ[index]; + } + + return rp / rq * x * (x * x - T(1.46819706421238932572e+01)) * (x * x - T(4.92184563216946036703e+01)); + } + + T pp = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pp = pp * (T(5.0) / x * (T(5.0) / x)) + PP[index]; + } + + T pq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pq = pq * (T(5.0) / x * (T(5.0) / x)) + PQ[index]; + } + + T qp = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + qp = qp * (T(5.0) / x * (T(5.0) / x)) + QP[index]; + } + + T qq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + qq = qq * (T(5.0) / x * (T(5.0) / x)) + QQ[index]; + } + + return (pp / pq * std::cos(x - T(2.356194490192344928846982537459627163)) - T(5.0) / x * (qp / qq) * std::sin(x - T(2.356194490192344928846982537459627163))) * T(0.797884560802865355879892119868763737) / std::sqrt(x); +} // bessel_j1_forward(T x) + +template +inline C10_HOST_DEVICE T bessel_y0_forward(T x) { + static const T PP[] = { + +7.96936729297347051624e-04, + +8.28352392107440799803e-02, + +1.23953371646414299388e+00, + +5.44725003058768775090e+00, + +8.74716500199817011941e+00, + +5.30324038235394892183e+00, + +9.99999999999999997821e-01, + }; + + static const T PQ[] = { + +9.24408810558863637013e-04, + +8.56288474354474431428e-02, + +1.25352743901058953537e+00, + +5.47097740330417105182e+00, + +8.76190883237069594232e+00, + +5.30605288235394617618e+00, + +1.00000000000000000218e+00, + }; + + static const T QP[] = { + -1.13663838898469149931e-02, + -1.28252718670509318512e+00, + -1.95539544257735972385e+01, + -9.32060152123768231369e+01, + -1.77681167980488050595e+02, + -1.47077505154951170175e+02, + -5.14105326766599330220e+01, + -6.05014350600728481186e+00, + }; + + static const T QQ[] = { + +6.43178256118178023184e+01, + +8.56430025976980587198e+02, + +3.88240183605401609683e+03, + +7.24046774195652478189e+03, + +5.93072701187316984827e+03, + +2.06209331660327847417e+03, + +2.42005740240291393179e+02, + }; + + static const T YP[] = { + +1.55924367855235737965e+04, + -1.46639295903971606143e+07, + +5.43526477051876500413e+09, + -9.82136065717911466409e+11, + +8.75906394395366999549e+13, + -3.46628303384729719441e+15, + +4.42733268572569800351e+16, + -1.84950800436986690637e+16, + }; + + static const T YQ[] = { + +1.04128353664259848412e+03, + +6.26107330137134956842e+05, + +2.68919633393814121987e+08, + +8.64002487103935000337e+10, + +2.02979612750105546709e+13, + +3.17157752842975028269e+15, + +2.50596256172653059228e+17, + }; + + if (x <= T(5.0)) { + if (x == T(0.0)) { + return -std::numeric_limits::infinity(); + } + + if (x < T(0.0)) { + return std::numeric_limits::quiet_NaN(); + } + + T yp = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + yp = yp * (x * x) + YP[index]; + } + + T yq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + yq = yq * (x * x) + YQ[index]; + } + + return yp / yq + (T(0.636619772367581343075535053490057448) * std::log(x) * bessel_j0_forward(x)); + } + + T pp = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pp = pp * (T(25.0) / (x * x)) + PP[index]; + } + + T pq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pq = pq * (T(25.0) / (x * x)) + PQ[index]; + } + + T qp = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + qp = qp * (T(25.0) / (x * x)) + QP[index]; + } + + T qq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + qq = qq * (T(25.0) / (x * x)) + QQ[index]; + } + + return (pp / pq * std::sin(x - T(0.785398163397448309615660845819875721)) + T(5.0) / x * (qp / qq) * std::cos(x - T(0.785398163397448309615660845819875721))) * T(0.797884560802865355879892119868763737) / std::sqrt(x); +} // bessel_y0_forward(T x) + +template +inline C10_HOST_DEVICE T bessel_y1_forward(T x) { + static const T PP[] = { + +7.62125616208173112003e-04, + +7.31397056940917570436e-02, + +1.12719608129684925192e+00, + +5.11207951146807644818e+00, + +8.42404590141772420927e+00, + +5.21451598682361504063e+00, + +1.00000000000000000254e+00, + }; + + static const T PQ[] = { + +5.71323128072548699714e-04, + +6.88455908754495404082e-02, + +1.10514232634061696926e+00, + +5.07386386128601488557e+00, + +8.39985554327604159757e+00, + +5.20982848682361821619e+00, + +9.99999999999999997461e-01, + }; + + static const T QP[] = { + +5.10862594750176621635e-02, + +4.98213872951233449420e+00, + +7.58238284132545283818e+01, + +3.66779609360150777800e+02, + +7.10856304998926107277e+02, + +5.97489612400613639965e+02, + +2.11688757100572135698e+02, + +2.52070205858023719784e+01, + }; + + static const T QQ[] = { + +7.42373277035675149943e+01, + +1.05644886038262816351e+03, + +4.98641058337653607651e+03, + +9.56231892404756170795e+03, + +7.99704160447350683650e+03, + +2.82619278517639096600e+03, + +3.36093607810698293419e+02, + }; + + static const T YP[] = { + +1.26320474790178026440e+09, + -6.47355876379160291031e+11, + +1.14509511541823727583e+14, + -8.12770255501325109621e+15, + +2.02439475713594898196e+17, + -7.78877196265950026825e+17, + }; + + static const T YQ[] = { + +5.94301592346128195359e+02, + +2.35564092943068577943e+05, + +7.34811944459721705660e+07, + +1.87601316108706159478e+10, + +3.88231277496238566008e+12, + +6.20557727146953693363e+14, + +6.87141087355300489866e+16, + +3.97270608116560655612e+18, + }; + + if (x <= T(5.0)) { + if (x == T(0.0)) { + return -std::numeric_limits::infinity(); + } + + if (x <= T(0.0)) { + return std::numeric_limits::quiet_NaN(); + } + + T yp = 0.0; + + for (uint8_t index = 0; index <= 5; index++) { + yp = yp * (x * x) + YP[index]; + } + + T yq = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + yq = yq * (x * x) + YQ[index]; + } + + return x * (yp / yq) + (T(0.636619772367581343075535053490057448) * (bessel_j1_forward(x) * std::log(x) - T(1.0) / x)); + } + + T pp = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pp = pp * (T(5.0) / x * (T(5.0) / x)) + PP[index]; + } + + T pq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pq = pq * (T(5.0) / x * (T(5.0) / x)) + PQ[index]; + } + + T qp = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + qp = qp * (T(5.0) / x * (T(5.0) / x)) + QP[index]; + } + + T qq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + qq = qq * (T(5.0) / x * (T(5.0) / x)) + QQ[index]; + } + + return (pp / pq * std::sin(x - T(2.356194490192344928846982537459627163)) + T(5.0) / x * (qp / qq) * std::cos(x - T(2.356194490192344928846982537459627163))) * T(0.797884560802865355879892119868763737) / std::sqrt(x); +} // bessel_y1_forward(T x) + +template +inline C10_HOST_DEVICE T chebyshev_polynomial_t_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (std::abs(x) == T(1.0)) { + if (x > T(0.0) || n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if ((n > 6) && (std::abs(x) < T(1.0))) { + return std::cos(n * std::acos(x)); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x; + } + + T p = T(1.0); + T q = x; + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x) * q - p; + p = q; + q = r; + } + + return r; +} // chebyshev_polynomial_t_forward(T x, int64_t n) + +template +inline C10_HOST_DEVICE T chebyshev_polynomial_t_forward(T x, T n) { + return chebyshev_polynomial_t_forward(x, static_cast(n)); +} // chebyshev_polynomial_t_forward(T x, T n) + +template +inline C10_HOST_DEVICE T chebyshev_polynomial_u_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (std::abs(x) == T(1.0)) { + if (x > T(0.0) || n % 2 == 0) { + return n + 1; + } + + return -(n + 1); + } + + if ((n > 8) && (std::abs(x) < T(1.0))) { + if (std::sin(std::acos(x)) != T(0.0)) { + return std::sin((n + 1) * std::acos(x)) / std::sin(std::acos(x)); + } + + return (n + 1) * std::cos((n + 1) * std::acos(x)) / x; + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x; + } + + T p = T(1.0); + T q = x + x; + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x) * q - p; + p = q; + q = r; + } + + return r; +} // chebyshev_polynomial_u_forward(T x, int64_t n) + +template +inline C10_HOST_DEVICE T chebyshev_polynomial_u_forward(T x, T n) { + return chebyshev_polynomial_u_forward(x, static_cast(n)); +} // chebyshev_polynomial_u_forward(T x, T n) + +template +inline C10_HOST_DEVICE T chebyshev_polynomial_v_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (std::abs(x) == T(1.0)) { + if (x > T(0.0)) { + return T(1.0); + } + + if (n % 2 == 0) { + return n + n + 1; + } + + return -(n + n + 1); + } + + if ((n > 8) && (std::abs(x) < T(1.0))) { + if (std::sin(std::acos(x) / T(2.0)) != T(1.0)) { + return std::cos((n + T(0.5)) * std::acos(x)) / std::cos(std::acos(x) / T(2.0)); + } + + if (n % 2 == 0) { + return n + n + 1; + } + + return -(n + n + 1); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x - T(1.0); + } + + T p = T(1.0); + T q = x + x - T(1.0); + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x) * q - p; + p = q; + q = r; + } + + return r; +} // chebyshev_polynomial_v_forward(T x, int64_t n) + +template +inline C10_HOST_DEVICE T chebyshev_polynomial_v_forward(T x, T n) { + return chebyshev_polynomial_v_forward(x, static_cast(n)); +} // chebyshev_polynomial_v_forward(T x, T n) + +template +inline C10_HOST_DEVICE T chebyshev_polynomial_w_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (std::abs(x) == T(1.0)) { + if (x > T(0.0)) { + return n + n + 1; + } + + if (n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if ((n > 8) && (std::abs(x) < T(1.0))) { + if (std::cos(std::acos(x) / T(2.0)) != T(1.0)) { + return std::sin((n + T(0.5)) * std::acos(x)) / std::sin(std::acos(x) / T(2.0)); + } + + if (x > T(0.0)) { + return n + n + 1; + } + + if (n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x + T(1.0); + } + + T p = T(1.0); + T q = x + x + T(1.0); + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x) * q - p; + p = q; + q = r; + } + + return r; +} // chebyshev_polynomial_w_forward(T x, int64_t n) + +template +inline C10_HOST_DEVICE T chebyshev_polynomial_w_forward(T x, T n) { + return chebyshev_polynomial_w_forward(x, static_cast(n)); +} // chebyshev_polynomial_w_forward(T x, T n) + +template +constexpr auto getHermitianLimit() { + if constexpr (std::is_same_v) { + return 128; + } else if constexpr (std::is_same_v) { + return 512; + } else { + return 1024; + } +} + +template +inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x; + } + + if (n > getHermitianLimit()) { + return std::numeric_limits::quiet_NaN(); + } + + T p = T(1.0); + T q = x + x; + T r = T(0.0); + + for (int64_t k = 2; k < n + n; k += 2) { + r = (x + x) * q - k * p; + p = q; + q = r; + } + + return r; +} // hermite_polynomial_h_forward(T x, int64_t n) + +template, int> = 0> +inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, T n) { + return hermite_polynomial_h_forward(x, static_cast(n)); +} // hermite_polynomial_h_forward(T x, T n) + +template, int> = 0> +__ubsan_ignore_float_cast_overflow__ inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, T n) { + return hermite_polynomial_h_forward(x, (!std::isinf(n) && !std::isnan(n)) ? static_cast(n) : static_cast(-1)); +} // hermite_polynomial_h_forward(T x, T n) + +template +inline C10_HOST_DEVICE T hermite_polynomial_he_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x; + } + + if (n > getHermitianLimit()) { + return std::numeric_limits::quiet_NaN(); + } + + T p = T(1.0); + T q = x; + T r; + + for (int64_t k = 1; k < n; k++) { + r = x * q - k * p; + p = q; + q = r; + } + + return r; +} // hermite_polynomial_he_forward(T x, int64_t n) + +template +inline C10_HOST_DEVICE T hermite_polynomial_he_forward(T x, T n) { + return hermite_polynomial_he_forward(x, static_cast(n)); +} // hermite_polynomial_he_forward(T x, T n) + +template +inline C10_HOST_DEVICE T laguerre_polynomial_l_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (std::abs(x) == T(0.0)) { + return T(1.0); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return T(1.0) - x; + } + + T p = T(1.0); + T q = T(1.0) - x; + T r; + + for (int64_t k = 1; k < n; k++) { + r = (((k + k) + (T(1.0) - x)) * q - k * p) / (k + 1); + p = q; + q = r; + } + + return r; +} // laguerre_polynomial_l_forward(T x, int64_t n) + +template +inline C10_HOST_DEVICE T laguerre_polynomial_l_forward(T x, T n) { + return laguerre_polynomial_l_forward(x, static_cast(n)); +} // laguerre_polynomial_l_forward(T x, T n) + +template +inline C10_HOST_DEVICE T legendre_polynomial_p_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (std::abs(x) == T(1.0)) { + if (x > T(0.0) || n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x; + } + + T p = T(1.0); + T q = x; + T r; + + for (int64_t k = 1; k < n; k++) { + r = ((k + k + 1) * x * q - k * p) / (k + 1); + p = q; + q = r; + } + + return r; +} // legendre_polynomial_p_forward(T x, int64_t n) + +template +inline C10_HOST_DEVICE T legendre_polynomial_p_forward(T x, T n) { + return legendre_polynomial_p_forward(x, static_cast(n)); +} // legendre_polynomial_p_forward(T x, T n) + +template +inline C10_HOST_DEVICE T modified_bessel_i0_forward(T x) { + static const T A[] = { + -4.41534164647933937950e-18, + +3.33079451882223809783e-17, + -2.43127984654795469359e-16, + +1.71539128555513303061e-15, + -1.16853328779934516808e-14, + +7.67618549860493561688e-14, + -4.85644678311192946090e-13, + +2.95505266312963983461e-12, + -1.72682629144155570723e-11, + +9.67580903537323691224e-11, + -5.18979560163526290666e-10, + +2.65982372468238665035e-09, + -1.30002500998624804212e-08, + +6.04699502254191894932e-08, + -2.67079385394061173391e-07, + +1.11738753912010371815e-06, + -4.41673835845875056359e-06, + +1.64484480707288970893e-05, + -5.75419501008210370398e-05, + +1.88502885095841655729e-04, + -5.76375574538582365885e-04, + +1.63947561694133579842e-03, + -4.32430999505057594430e-03, + +1.05464603945949983183e-02, + -2.37374148058994688156e-02, + +4.93052842396707084878e-02, + -9.49010970480476444210e-02, + +1.71620901522208775349e-01, + -3.04682672343198398683e-01, + +6.76795274409476084995e-01, + }; + + static const T B[] = { + -7.23318048787475395456e-18, + -4.83050448594418207126e-18, + +4.46562142029675999901e-17, + +3.46122286769746109310e-17, + -2.82762398051658348494e-16, + -3.42548561967721913462e-16, + +1.77256013305652638360e-15, + +3.81168066935262242075e-15, + -9.55484669882830764870e-15, + -4.15056934728722208663e-14, + +1.54008621752140982691e-14, + +3.85277838274214270114e-13, + +7.18012445138366623367e-13, + -1.79417853150680611778e-12, + -1.32158118404477131188e-11, + -3.14991652796324136454e-11, + +1.18891471078464383424e-11, + +4.94060238822496958910e-10, + +3.39623202570838634515e-09, + +2.26666899049817806459e-08, + +2.04891858946906374183e-07, + +2.89137052083475648297e-06, + +6.88975834691682398426e-05, + +3.36911647825569408990e-03, + +8.04490411014108831608e-01, + }; + + T p; + T q = 0.0; + + if (std::abs(x) <= T(8.0)) { + T a = A[0]; + + for (uint8_t index = 1; index < 30; index++) { + p = q; + q = a; + a = ((std::abs(x) / T(2.0)) - T(2.0)) * q - p + A[index]; + } + + return std::exp(std::abs(x)) * (T(0.5) * (a - p)); + } + + T b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(32.0) / std::abs(x) - T(2.0)) * q - p + B[index]; + } + + return std::exp(std::abs(x)) * (T(0.5) * (b - p)) / std::sqrt(std::abs(x)); +} // modified_bessel_i0_forward(T x) + +template +inline C10_HOST_DEVICE T modified_bessel_i1_forward(T x) { + static const T A[] = { + +2.77791411276104639959e-18, + -2.11142121435816608115e-17, + +1.55363195773620046921e-16, + -1.10559694773538630805e-15, + +7.60068429473540693410e-15, + -5.04218550472791168711e-14, + +3.22379336594557470981e-13, + -1.98397439776494371520e-12, + +1.17361862988909016308e-11, + -6.66348972350202774223e-11, + +3.62559028155211703701e-10, + -1.88724975172282928790e-09, + +9.38153738649577178388e-09, + -4.44505912879632808065e-08, + +2.00329475355213526229e-07, + -8.56872026469545474066e-07, + +3.47025130813767847674e-06, + -1.32731636560394358279e-05, + +4.78156510755005422638e-05, + -1.61760815825896745588e-04, + +5.12285956168575772895e-04, + -1.51357245063125314899e-03, + +4.15642294431288815669e-03, + -1.05640848946261981558e-02, + +2.47264490306265168283e-02, + -5.29459812080949914269e-02, + +1.02643658689847095384e-01, + -1.76416518357834055153e-01, + +2.52587186443633654823e-01, + }; + + static const T B[] = { + +7.51729631084210481353e-18, + +4.41434832307170791151e-18, + -4.65030536848935832153e-17, + -3.20952592199342395980e-17, + +2.96262899764595013876e-16, + +3.30820231092092828324e-16, + -1.88035477551078244854e-15, + -3.81440307243700780478e-15, + +1.04202769841288027642e-14, + +4.27244001671195135429e-14, + -2.10154184277266431302e-14, + -4.08355111109219731823e-13, + -7.19855177624590851209e-13, + +2.03562854414708950722e-12, + +1.41258074366137813316e-11, + +3.25260358301548823856e-11, + -1.89749581235054123450e-11, + -5.58974346219658380687e-10, + -3.83538038596423702205e-09, + -2.63146884688951950684e-08, + -2.51223623787020892529e-07, + -3.88256480887769039346e-06, + -1.10588938762623716291e-04, + -9.76109749136146840777e-03, + +7.78576235018280120474e-01, + }; + + T p; + T q = 0.0; + + if (std::abs(x) <= T(8.0)) { + T a = A[0]; + + for (uint8_t index = 1; index < 29; index++) { + p = q; + q = a; + a = ((std::abs(x) / T(2.0)) - T(2.0)) * q - p + A[index]; + } + + if (x < T(0.0)) { + return -(T(0.5) * (a - p) * std::abs(x) * std::exp(std::abs(x))); + } + + return T(0.5) * (a - p) * std::abs(x) * std::exp(std::abs(x)); + } + + T b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(32.0) / std::abs(x) - T(2.0)) * q - p + B[index]; + } + + if (x < T(0.0)) { + return -(std::exp(std::abs(x)) * (T(0.5) * (b - p)) / std::sqrt(std::abs(x))); + } + + return std::exp(std::abs(x)) * (T(0.5) * (b - p)) / std::sqrt(std::abs(x)); +} // modified_bessel_i1_forward(T x) + +template +inline C10_HOST_DEVICE T modified_bessel_k0_forward(T x) { + static const T A[] = { + +1.37446543561352307156e-16, + +4.25981614279661018399e-14, + +1.03496952576338420167e-11, + +1.90451637722020886025e-09, + +2.53479107902614945675e-07, + +2.28621210311945178607e-05, + +1.26461541144692592338e-03, + +3.59799365153615016266e-02, + +3.44289899924628486886e-01, + -5.35327393233902768720e-01, + }; + + static const T B[] = { + +5.30043377268626276149e-18, + -1.64758043015242134646e-17, + +5.21039150503902756861e-17, + -1.67823109680541210385e-16, + +5.51205597852431940784e-16, + -1.84859337734377901440e-15, + +6.34007647740507060557e-15, + -2.22751332699166985548e-14, + +8.03289077536357521100e-14, + -2.98009692317273043925e-13, + +1.14034058820847496303e-12, + -4.51459788337394416547e-12, + +1.85594911495471785253e-11, + -7.95748924447710747776e-11, + +3.57739728140030116597e-10, + -1.69753450938905987466e-09, + +8.57403401741422608519e-09, + -4.66048989768794782956e-08, + +2.76681363944501510342e-07, + -1.83175552271911948767e-06, + +1.39498137188764993662e-05, + -1.28495495816278026384e-04, + +1.56988388573005337491e-03, + -3.14481013119645005427e-02, + +2.44030308206595545468e+00, + }; + + if (x == T(0.0)) { + return std::numeric_limits::infinity(); + } + + if (x < T(0.0)) { + return std::numeric_limits::quiet_NaN(); + } + + T p; + T q = 0.0; + + if (x <= T(2.0)) { + T a = A[0]; + + for (uint8_t index = 1; index < 10; index++) { + p = q; + q = a; + a = (x * x - T(2.0)) * q - p + A[index]; + } + + return T(0.5) * (a - p) - std::log(0.5 * x) * modified_bessel_i0_forward(x); + } + + T b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(8.0) / x - T(2.0)) * q - p + B[index]; + } + + return std::exp(-x) * (T(0.5) * (b - p)) / std::sqrt(x); +} // modified_bessel_k0_forward(T x) + +template +inline C10_HOST_DEVICE T modified_bessel_k1_forward(T x) { + static const T A[] = { + -7.02386347938628759343e-18, + -2.42744985051936593393e-15, + -6.66690169419932900609e-13, + -1.41148839263352776110e-10, + -2.21338763073472585583e-08, + -2.43340614156596823496e-06, + -1.73028895751305206302e-04, + -6.97572385963986435018e-03, + -1.22611180822657148235e-01, + -3.53155960776544875667e-01, + +1.52530022733894777053e+00, + }; + + static const T B[] = { + -5.75674448366501715755e-18, + +1.79405087314755922667e-17, + -5.68946255844285935196e-17, + +1.83809354436663880070e-16, + -6.05704724837331885336e-16, + +2.03870316562433424052e-15, + -7.01983709041831346144e-15, + +2.47715442448130437068e-14, + -8.97670518232499435011e-14, + +3.34841966607842919884e-13, + -1.28917396095102890680e-12, + +5.13963967348173025100e-12, + -2.12996783842756842877e-11, + +9.21831518760500529508e-11, + -4.19035475934189648750e-10, + +2.01504975519703286596e-09, + -1.03457624656780970260e-08, + +5.74108412545004946722e-08, + -3.50196060308781257119e-07, + +2.40648494783721712015e-06, + -1.93619797416608296024e-05, + +1.95215518471351631108e-04, + -2.85781685962277938680e-03, + +1.03923736576817238437e-01, + +2.72062619048444266945e+00, + }; + + if (x == T(0.0)) { + return std::numeric_limits::infinity(); + } + + if (x < T(0.0)) { + return std::numeric_limits::quiet_NaN(); + } + + T p; + T q = 0.0; + + if (x <= T(2.0)) { + T a = A[0]; + + for (uint8_t index = 1; index < 11; index++) { + p = q; + q = a; + a = (x * x - T(2.0)) * q - p + A[index]; + } + + return std::log(T(0.5) * x) * modified_bessel_i1_forward(x) + T(0.5) * (a - p) / x; + } + + T b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(8.0) / x - T(2.0)) * q - p + B[index]; + } + + return std::exp(-x) * (T(0.5) * (b - p)) / std::sqrt(x); +} // modified_bessel_k1_forward(T x) + +template +inline C10_HOST_DEVICE T scaled_modified_bessel_k0_forward(T x) { + static const T A[] = { + +1.37446543561352307156e-16, + +4.25981614279661018399e-14, + +1.03496952576338420167e-11, + +1.90451637722020886025e-09, + +2.53479107902614945675e-07, + +2.28621210311945178607e-05, + +1.26461541144692592338e-03, + +3.59799365153615016266e-02, + +3.44289899924628486886e-01, + -5.35327393233902768720e-01, + }; + + static const T B[] = { + +5.30043377268626276149e-18, + -1.64758043015242134646e-17, + +5.21039150503902756861e-17, + -1.67823109680541210385e-16, + +5.51205597852431940784e-16, + -1.84859337734377901440e-15, + +6.34007647740507060557e-15, + -2.22751332699166985548e-14, + +8.03289077536357521100e-14, + -2.98009692317273043925e-13, + +1.14034058820847496303e-12, + -4.51459788337394416547e-12, + +1.85594911495471785253e-11, + -7.95748924447710747776e-11, + +3.57739728140030116597e-10, + -1.69753450938905987466e-09, + +8.57403401741422608519e-09, + -4.66048989768794782956e-08, + +2.76681363944501510342e-07, + -1.83175552271911948767e-06, + +1.39498137188764993662e-05, + -1.28495495816278026384e-04, + +1.56988388573005337491e-03, + -3.14481013119645005427e-02, + +2.44030308206595545468e+00, + }; + + if (x == T(0.0)) { + return std::numeric_limits::infinity(); + } + + if (x < T(0.0)) { + return std::numeric_limits::quiet_NaN(); + } + + T p; + T q = 0.0; + + if (x <= T(2.0)) { + T a = A[0]; + + for (uint64_t index = 1; index < 10; index++) { + p = q; + q = a; + a = (x * x - T(2.0)) * q - p + A[index]; + } + + return (T(0.5) * (a - p) - std::log(T(0.5) * x) * modified_bessel_i0_forward(x)) * std::exp(x); + } + + T b = B[0]; + + for (uint64_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(8.0) / x - T(2.0)) * q - p + B[index]; + } + + return T(0.5) * (b - p) / std::sqrt(x); +} // T scaled_modified_bessel_k0_forward(T x) + +template +inline C10_HOST_DEVICE T scaled_modified_bessel_k1_forward(T x) { + static const T A[] = { + -7.02386347938628759343e-18, + -2.42744985051936593393e-15, + -6.66690169419932900609e-13, + -1.41148839263352776110e-10, + -2.21338763073472585583e-08, + -2.43340614156596823496e-06, + -1.73028895751305206302e-04, + -6.97572385963986435018e-03, + -1.22611180822657148235e-01, + -3.53155960776544875667e-01, + +1.52530022733894777053e+00, + }; + + static const T B[] = { + -5.75674448366501715755e-18, + +1.79405087314755922667e-17, + -5.68946255844285935196e-17, + +1.83809354436663880070e-16, + -6.05704724837331885336e-16, + +2.03870316562433424052e-15, + -7.01983709041831346144e-15, + +2.47715442448130437068e-14, + -8.97670518232499435011e-14, + +3.34841966607842919884e-13, + -1.28917396095102890680e-12, + +5.13963967348173025100e-12, + -2.12996783842756842877e-11, + +9.21831518760500529508e-11, + -4.19035475934189648750e-10, + +2.01504975519703286596e-09, + -1.03457624656780970260e-08, + +5.74108412545004946722e-08, + -3.50196060308781257119e-07, + +2.40648494783721712015e-06, + -1.93619797416608296024e-05, + +1.95215518471351631108e-04, + -2.85781685962277938680e-03, + +1.03923736576817238437e-01, + +2.72062619048444266945e+00, + }; + + if (x == T(0.0)) { + return std::numeric_limits::infinity(); + } + + if (x < T(0.0)) { + return std::numeric_limits::quiet_NaN(); + } + + T p; + T q = 0.0; + + if (x <= T(2.0)) { + T a = A[0]; + + for (uint64_t index = 1; index < 11; index++) { + p = q; + q = a; + a = (x * x - T(2.0)) * q - p + A[index]; + } + + return (std::log(T(0.5) * x) * modified_bessel_i1_forward(x) + T(0.5) * (a - p) / x) * std::exp(x); + } + + T b = B[0]; + + for (uint64_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(8.0) / x - T(2.0)) * q - p + B[index]; + } + + return (T(0.5) * (b - p) / std::sqrt(x)); +} // T scaled_modified_bessel_k1_forward(T x) + +template +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_t_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (x == T(1.0)) { + return T(1.0); + } + + if (x == T(0.0)) { + if (n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if ((n > 6) && (std::abs(x + x - T(1.0)) < T(1.0))) { + return std::cos(n * std::acos(x + x - T(1.0))); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x - T(1.0); + } + + T p = T(1.0); + T q = x + x - T(1.0); + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; + p = q; + q = r; + } + + return r; +} // shifted_chebyshev_polynomial_t_forward(T x, int64_t n) + +template +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_t_forward(T x, T n) { + return shifted_chebyshev_polynomial_t_forward(x, static_cast(n)); +} // shifted_chebyshev_polynomial_t_forward(T x, T n) + +template +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_u_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (x == T(1.0)) { + return n + 1; + } + + if (x == T(0.0)) { + if (n % 2 == 0) { + return n + 1; + } + + return -(n + 1); + } + + if ((n > 6) && (std::abs(x + x - T(1.0)) < T(1.0))) { + if (std::sin(std::acos(x + x - T(1.0))) != T(0.0)) { + return std::sin((n + 1) * std::acos(x + x - T(1.0))) / std::sin(std::acos(x + x - T(1.0))); + } + + return (n + 1) * std::cos((n + 1) * std::acos(x + x - T(1.0))) / (x + x - T(1.0)); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x - T(1.0) + (x + x - T(1.0)); + } + + T p = T(1.0); + T q = x + x - T(1.0) + (x + x - T(1.0)); + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; + p = q; + q = r; + } + + return r; +} // shifted_chebyshev_polynomial_u_forward(T x, int64_t n) + +template +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_u_forward(T x, T n) { + return shifted_chebyshev_polynomial_u_forward(x, static_cast(n)); +} // shifted_chebyshev_polynomial_u_forward(T x, T n) + +template +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_v_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (x == T(1.0)) { + return T(1.0); + } + + if (x == T(0.0)) { + if (n % 2 == 0) { + return (n + n + 1); + } + + return -(n + n + 1); + } + + if ((n > 6) && (std::abs(x + x - T(1.0)) < T(1.0))) { + if (std::sin(std::acos(x + x - T(1.0)) / T(2.0)) != T(1.0)) { + return std::cos(((n) + T(0.5)) * std::acos(x + x - T(1.0))) / std::cos(std::acos(x + x - T(1.0)) / T(2.0)); + } + + if (n % 2 == 0) { + return n + n + 1; + } + + return -(n + n + 1); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x - T(1.0) + (x + x - T(1.0)) - T(1.0); + } + + T p = T(1.0); + T q = x + x - T(1.0) + (x + x - T(1.0)) - T(1.0); + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; + p = q; + q = r; + } + + return r; +} // shifted_chebyshev_polynomial_v_forward(T x, int64_t n) + +template +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_v_forward(T x, T n) { + return shifted_chebyshev_polynomial_v_forward(x, static_cast(n)); +} // shifted_chebyshev_polynomial_v_forward(T x, T n) + +template +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_w_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (x == T(1.0)) { + return n + n + 1; + } + + if (x == T(0.0)) { + if (n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if ((n > 4) && (std::abs(x + x - T(1.0)) < T(1.0))) { + if (std::cos(std::acos(x + x - T(1.0)) / T(2.0)) != T(1.0)) { + return std::sin((n + T(0.5)) * std::acos(x + x - T(1.0))) / std::sin(std::acos(x + x - T(1.0)) / T(2.0)); + } + + if (n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x - T(1.0) + (x + x - T(1.0)) + T(1.0); + } + + T p = T(1.0); + T q = x + x - T(1.0) + (x + x - T(1.0)) + T(1.0); + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; + p = q; + q = r; + } + + return r; +} // shifted_chebyshev_polynomial_w_forward(T x, int64_t n) + +template +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_w_forward(T x, T n) { + return shifted_chebyshev_polynomial_w_forward(x, static_cast(n)); +} // shifted_chebyshev_polynomial_w_forward(T x, T n) + +template +inline C10_HOST_DEVICE T spherical_bessel_j0_forward(T x) { + if (std::isinf(x)) { + return T(0.0); + } + + if (std::abs(x) < T(0.5)) { + return T(1.0) + x * x * (T(-1.0) / T(6.0) + x * x * (T(1.0) / T(120.0) + x * x * (T(-1.0) / T(5040.0) + x * x * (T(1.0) / T(362880.0) + x * x * (T(-1.0) / T(39916800.0) + x * x * (T(1.0) / T(6227020800.0))))))); + } + + return std::sin(x) / x; +} // T spherical_bessel_j0_forward(T x) + +C10_CLANG_DIAGNOSTIC_POP() diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/MathBitFallThroughLists.h b/phivenv/Lib/site-packages/torch/include/ATen/native/MathBitFallThroughLists.h new file mode 100644 index 0000000000000000000000000000000000000000..a1e84f029202bdb27e825a062a63adbcb5151d76 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/MathBitFallThroughLists.h @@ -0,0 +1,71 @@ +#pragma once + +namespace at { +// views and their in-place version ops +#define TORCH_VIEW_FNS(m) \ + m.impl("as_strided_", torch::CppFunction::makeFallthrough()); \ + m.impl("detach", torch::CppFunction::makeFallthrough()); \ + m.impl("detach_", torch::CppFunction::makeFallthrough()); \ + m.impl("diagonal", torch::CppFunction::makeFallthrough()); \ + m.impl("expand", torch::CppFunction::makeFallthrough()); \ + m.impl("expand_as", torch::CppFunction::makeFallthrough()); \ + m.impl("movedim.int", torch::CppFunction::makeFallthrough()); \ + m.impl("movedim.intlist", torch::CppFunction::makeFallthrough()); \ + m.impl("narrow", torch::CppFunction::makeFallthrough()); \ + m.impl("permute", torch::CppFunction::makeFallthrough()); \ + m.impl("select.Dimname", torch::CppFunction::makeFallthrough()); \ + m.impl("select.int", torch::CppFunction::makeFallthrough()); \ + m.impl("squeeze", torch::CppFunction::makeFallthrough()); \ + m.impl("squeeze_", torch::CppFunction::makeFallthrough()); \ + m.impl("transpose.int", torch::CppFunction::makeFallthrough()); \ + m.impl("transpose.Dimname", torch::CppFunction::makeFallthrough()); \ + m.impl("transpose_", torch::CppFunction::makeFallthrough()); \ + m.impl("t", torch::CppFunction::makeFallthrough()); \ + m.impl("t_", torch::CppFunction::makeFallthrough()); \ + m.impl("real", torch::CppFunction::makeFallthrough()); \ + m.impl("imag", torch::CppFunction::makeFallthrough()); \ + m.impl("view_as_real", torch::CppFunction::makeFallthrough()); \ + m.impl("unflatten.int", torch::CppFunction::makeFallthrough()); \ + m.impl("unflatten.Dimname", torch::CppFunction::makeFallthrough()); \ + m.impl("unfold", torch::CppFunction::makeFallthrough()); \ + m.impl("unsqueeze", torch::CppFunction::makeFallthrough()); \ + m.impl("unsqueeze_", torch::CppFunction::makeFallthrough()); \ + m.impl("view_as", torch::CppFunction::makeFallthrough()); \ + m.impl("unbind.int", torch::CppFunction::makeFallthrough()); \ + m.impl("unbind.Dimname", torch::CppFunction::makeFallthrough()); \ + m.impl("split.Tensor", torch::CppFunction::makeFallthrough()); \ + m.impl("split_with_sizes", torch::CppFunction::makeFallthrough()); \ + m.impl("swapaxes", torch::CppFunction::makeFallthrough()); \ + m.impl("swapdims", torch::CppFunction::makeFallthrough()); \ + m.impl("chunk", torch::CppFunction::makeFallthrough()); \ + m.impl("reshape", torch::CppFunction::makeFallthrough()); \ + m.impl("alias", torch::CppFunction::makeFallthrough()); \ + m.impl("hsplit.int", torch::CppFunction::makeFallthrough()); \ + m.impl("hsplit.array", torch::CppFunction::makeFallthrough()); \ + m.impl("dsplit.int", torch::CppFunction::makeFallthrough()); \ + m.impl("dsplit.array", torch::CppFunction::makeFallthrough()); \ + m.impl("vsplit.int", torch::CppFunction::makeFallthrough()); \ + m.impl("vsplit.array", torch::CppFunction::makeFallthrough()); \ + m.impl("conj", torch::CppFunction::makeFallthrough()); \ + m.impl("_conj", torch::CppFunction::makeFallthrough()); \ + m.impl("_unsafe_view", torch::CppFunction::makeFallthrough()); \ + m.impl("resize_", torch::CppFunction::makeFallthrough()); + +#define TENSOR_UTILITIES_AND_CONSTRUCTORS(m) \ + m.impl("empty_like", torch::CppFunction::makeFallthrough()); \ + m.impl("empty.memory_format", torch::CppFunction::makeFallthrough()); \ + m.impl("empty.out", torch::CppFunction::makeFallthrough()); \ + m.impl("empty_strided", torch::CppFunction::makeFallthrough()); \ + m.impl("full_like", torch::CppFunction::makeFallthrough()); \ + m.impl("stride.int", torch::CppFunction::makeFallthrough()); \ + m.impl("stride.Dimname", torch::CppFunction::makeFallthrough()); \ + m.impl("size.int", torch::CppFunction::makeFallthrough()); \ + m.impl("size.Dimname", torch::CppFunction::makeFallthrough()); \ + m.impl("is_complex", torch::CppFunction::makeFallthrough()); \ + m.impl("is_floating_point", torch::CppFunction::makeFallthrough()); \ + m.impl("requires_grad_", torch::CppFunction::makeFallthrough()); +} + +#define TORCH_VIEW_FNS_NATIVE_FN_REGISTRATION(m) \ + m.impl("as_strided", torch::CppFunction::makeFallthrough()); \ + m.impl("view", torch::CppFunction::makeFallthrough()); diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/MathBitsFallback.h b/phivenv/Lib/site-packages/torch/include/ATen/native/MathBitsFallback.h new file mode 100644 index 0000000000000000000000000000000000000000..af8ca30bcb470514966205211a5da25c3775d309 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/MathBitsFallback.h @@ -0,0 +1,157 @@ +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include + +#include +#endif + +namespace at::native { +// This fallback should only be used for operations that are self inverse and have a corresponding tensor +// bit (internally implemented using DispatchKey) to maintain the state on tensor using tensor bit. +// Currently there are two tensor bits that trigger this fallback: conjugate bit and negative bit. +// Conjugate bit is set on a tensor when `.conj()` is called and neg bit is set on a tensor when `.conj().imag` is called. + +// NOTE: To use this fallback, `clone` and `copy_` should fully understand and be able to correctly handle the semantic of your math bit. +struct MathOpFallback { + MathOpFallback(DispatchKey key_, std::string op_name_) : key(key_), op_name(std::move(op_name_)) {} + virtual bool is_bit_set(const Tensor&) = 0; + void fallback_impl(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) { + /* + Situations to handle: + 1. Out-of-place operation. Easy: materialize all inputs and + call it a day. + 2. Inplace operation. Desugar x.add_(2) into x.conj_().add_(2).conj_(). + Materialize other inputs as in (1). + 3. out= operation. Desugar add(x, 2, out=y) into y.copy_(add(x, 2)) + Materialize other inputs as in (1). + + It is important to be able to tell if we READ from an argument and if we + WRITE to an argument. Conservative approach is to assume that we always + READ from an argument, but in out= operations you can skip + conjugating inputs on entry that never get used. In the current schema we + can't easily tell if the operation is in in-place or out= operation. + + Note: + 1. Mutable tensorlists containing tensors whose math bit set to true are disallowed. + 2. Mutable tensors with math bit set to true are unconditionally cloned to ensure + correct behavior in the case when the mutable tensor shares memory with non mutable arguments. + + If we were to in-place resolve the math bit for mutable inputs, then the non-mutable inputs sharing partial or full memory + with these mutable inputs would read into wrong values in the following cases: + 1. Non mutable inputs have their math bit set to false. + 2. Math bit for mutable input(s) is resolved before the non mutable inputs (with bit set to true and sharing memory + with one or more mutable arg(s)) are cloned. + At the end, the final value of the mutable arguments from the stack are copied into the original input mutable tensor inputs. + */ + const auto& arguments = op.schema().arguments(); + const auto num_arguments = arguments.size(); + const auto stack_start = stack->size() - num_arguments; + + std::optional is_write; + for (const auto i : c10::irange(num_arguments)) { + // Three possible states: + // 1. alias_info has no value --> out-of-place operation + // 2. alias_info does have a value, alias_info->is_write=True --> in-place or out= operation + // 3. alias_info does have a value, alias_info->is_write=False --> view operation + const AliasInfo* alias_info = arguments[i].alias_info(); + if (alias_info != nullptr) { + if (is_write.has_value()) { + TORCH_CHECK(*is_write == alias_info->isWrite(), + "Unsupported operator for ", op_name, " fallback: ", op.schema().name(), + op_name, " fallback doesn't work for operators with a mix " + "mutable and non-mutable inputs that alias with outputs, " + "this must be implemented manually. " + "If you got this error on a core op, please report a bug to PyTorch."); + } else { + is_write = alias_info->isWrite(); + } + } + } + + if (is_write.has_value() && !*is_write) { + // We assume that view operators automatically handle the math bit + // correctly by propagating the dispatch key in key_set. + // This is not necessarily always right, so you should test these cases. + op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, key), stack); + return; + } + + // Mutable inputs with math bit set to True and their clones + std::vector> mutable_inputs_with_their_clones; + for (const auto i : c10::irange(num_arguments)) { + auto& ivalue = (*stack)[stack_start + i]; + if (!(ivalue.isTensor() || ivalue.isTensorList())) { + continue; + } + const auto& argument = arguments[i]; + bool mut_arg = false; + if (argument.alias_info()) { + // Was already tested by is_write loop above + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(argument.alias_info()->isWrite()); + mut_arg = true; + } + if (ivalue.isTensor()) { + if (!is_bit_set(ivalue.toTensor())) { + continue; + } + auto tensor = std::move(ivalue).toTensor(); + auto resolved_tensor = at::clone(tensor); + if (mut_arg) { + TORCH_CHECK(mutable_inputs_with_their_clones.empty(), op_name, " fallback does not support operators with more than one mutable tensors with ", + op_name, "bit set to true."); + mutable_inputs_with_their_clones.emplace_back(std::move(tensor), resolved_tensor); + } + (*stack)[stack_start + i] = std::move(resolved_tensor); + } else if (ivalue.isTensorList()) { + auto tensors = std::move(ivalue).toTensorList(); + for(const auto j : c10::irange(tensors.size())) { + const auto& tensor = tensors[j]; + if (!is_bit_set(tensor)) { + continue; + } + TORCH_CHECK(!mut_arg, " fallback doesn't currently support mutable TensorLists with ", + op_name, " inputs. Please materialize all the ", op_name, " input tensor(s) in the mutable TensorList inputs before calling ", + op.schema().name()); + tensors[j] = at::clone(tensor); + } + (*stack)[stack_start + i] = std::move(tensors); + } + } + + op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, key), stack); + + TORCH_INTERNAL_ASSERT(mutable_inputs_with_their_clones.size() <= 1); + + for (std::pair mut_tensors: mutable_inputs_with_their_clones) { + auto& mutable_input = mut_tensors.first; + auto& cloned_mutable_input = mut_tensors.second; + auto& ivalue = (*stack)[stack_start]; + auto returned_output = std::move(ivalue).toTensor(); + + // sanity check to ensure that the tensor in stack aliases the cloned_mutable_input + TORCH_INTERNAL_ASSERT(cloned_mutable_input.is_same(returned_output)); + + // necessary for out= arg + at::native::resize_output(mutable_input, returned_output.sizes()); + + mutable_input.copy_(returned_output); + (*stack)[stack_start] = std::move(mutable_input); + } + } + + virtual ~MathOpFallback() = default; + + DispatchKey key; + std::string op_name; +}; + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/MaxPooling.h b/phivenv/Lib/site-packages/torch/include/ATen/native/MaxPooling.h new file mode 100644 index 0000000000000000000000000000000000000000..1426c513c83fb4b5c929a0fc3345d1f9ebef9039 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/MaxPooling.h @@ -0,0 +1,97 @@ +#pragma once + +#include +#include +#include +#include + +namespace at::native { + +inline void check_max_pool1d( + const Tensor& self, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool ceil_mode) { + + TORCH_CHECK( + self.dim() == 2 || self.dim() == 3, + "max_pool1d() Expected 2D or 3D input tensor, but got ", self.sym_sizes()); + TORCH_CHECK( + kernel_size.size() == 1, + "max_pool1d() kernel_size must be an int, list of ints or tuple of ints of size 1 but got size ", + kernel_size.size()); + TORCH_CHECK( + stride.empty() || stride.size() == 1, + "max_pool1d() stride must be None, an int, list of ints, or tuple of ints of size 1 but got size ", + stride.size()); + TORCH_CHECK( + padding.size() == 1, + "max_pool1d() padding must be an int, list of ints, or tuple of ints of size 1 but got size ", + padding.size()); + TORCH_CHECK( + dilation.size() == 1, + "max_pool1d() dilation must be an int, list of ints or tuple of ints of size 1 but got size ", + dilation.size()); + + // If stride=None then set it to kernel_size + if (stride.empty()) { + stride = kernel_size; + } + + TORCH_CHECK( + kernel_size[0] > 0, + "max_pool1d() kernel_size must be greater than zero, but got ", + kernel_size[0]); + TORCH_CHECK( + stride[0] > 0, "max_pool1d() stride must be greater than zero, but got ", stride[0]); + TORCH_CHECK( + padding[0] >= 0, "max_pool1d() padding must be non-negative, but got ", padding[0]); + TORCH_CHECK( + padding[0] <= kernel_size[0] / 2, + "max_pool1d() padding should be at most half of kernel size, but got padding=", + padding[0], + " and kernel_size=", + kernel_size[0]); + TORCH_CHECK( + dilation[0] > 0, "max_pool1d() dilation must be greater than zero, but got ", dilation[0]); + + const int64_t OW = pooling_output_shape(self.sym_size(-1).guard_int(__FILE__, __LINE__), kernel_size[0], padding[0], stride[0], dilation[0], ceil_mode); + TORCH_CHECK(OW > 0, "max_pool1d() Invalid computed output size: ", OW); +} + +// TODO(Heitor) Template by dimension +struct PoolingParams1D { + int64_t NB; // Number of batches + int64_t NC; // Number of channels + int64_t IW; // Input width + int64_t OW; // Output width + int64_t KW; // Kernel width + int64_t SJ; // Column stride + int64_t PJ; // Column padding + int64_t DJ; // Column dilation + + // Return index of input element for the given kernel and output index + inline int64_t index(int64_t kj, int64_t oj) const { + return oj * SJ + kj * DJ - PJ; + } + + // Return index of first output within bounds for this kernel index + inline int64_t valid_output_start(int64_t kj) const { + int64_t ij = index(kj, 0);; + return ij < 0 ? at::divup(-ij, SJ) : 0; + } + + // Return index one past last output within bounds for this kernel index + inline int64_t valid_output_end(int64_t kj) const { + int64_t ij = index(kj, OW - 1); + return ij >= IW ? OW - at::divup(ij - (IW - 1), SJ) : OW; + } +}; + +using pooling_fn = void (*)(Tensor&, const Tensor&, const PoolingParams1D&); + +DECLARE_DISPATCH(pooling_fn, max_pool1d_stub) + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/NonEmptyUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/NonEmptyUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..26cb65d844b4f0e1d88a45712159d18c0312ab73 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/NonEmptyUtils.h @@ -0,0 +1,27 @@ +#include +#include +#include + +namespace at::native { + +inline int64_t ensure_nonempty_dim(int64_t dim) { + return std::max(dim, 1); +} + +inline int64_t ensure_nonempty_size(const TensorBase &t, int64_t dim) { + return t.dim() == 0 ? 1 : t.size(dim); +} + +inline int64_t ensure_nonempty_stride(const TensorBase &t, int64_t dim) { + return t.dim() == 0 ? 1 : t.stride(dim); +} + +using IdxVec = std::vector; +inline IdxVec ensure_nonempty_vec(IdxVec vec) { + if (vec.empty()) { + vec.push_back(1); + } + return vec; +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/NonSymbolicBC.h b/phivenv/Lib/site-packages/torch/include/ATen/native/NonSymbolicBC.h new file mode 100644 index 0000000000000000000000000000000000000000..83827425118ba3ab041904d5066c7f4bf44c077f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/NonSymbolicBC.h @@ -0,0 +1,26 @@ +#pragma once +#include +#include +#include + +namespace at::native { +// This file contains non-symbolic signatures for ops that we have sym-intified the signature of. +// However, in certain cases (such as static runtime), we call the native versions of the ops directly. +// In those cases, we will duplicate the signature here with non-symbolic ints, and also duplicate the C++ implementation. +TORCH_API at::Tensor reshape(const at::Tensor& self, at::IntArrayRef proposed_shape); +TORCH_API at::Tensor narrow(const at::Tensor& self, int64_t dim, int64_t start, int64_t length); +TORCH_API at::Tensor _sparse_coo_tensor_unsafe(const at::Tensor & indices, const at::Tensor & values, at::IntArrayRef size, std::optional dtype=std::nullopt, std::optional layout=std::nullopt, std::optional device=std::nullopt, std::optional pin_memory=std::nullopt, std::optional is_coalesced=std::nullopt); +TORCH_API at::Tensor nll_loss(const at::Tensor & self, const at::Tensor & target, const std::optional& weight_opt, int64_t reduction, int64_t ignore_index); +TORCH_API at::Tensor nll_loss2d(const at::Tensor & self, const at::Tensor & target, const std::optional& weight_opt, int64_t reduction, int64_t ignore_index); +// The below ops don't get a duplicated C++ implementation. +// They are backward ops, which make them very unlikely to be called directly +// by external code (at::native::trace_backward). +// They get their own declaration for BC purposes however. +TORCH_API at::Tensor _embedding_bag_backward(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse, const std::optional & per_sample_weights, int64_t padding_idx=-1); +TORCH_API at::Tensor _embedding_bag_sparse_backward(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, const std::optional & per_sample_weights, int64_t padding_idx=-1); +TORCH_API at::Tensor value_selecting_reduction_backward(const at::Tensor & grad, int64_t dim, const at::Tensor & indices, at::IntArrayRef sizes, bool keepdim); +TORCH_API at::Tensor trace_backward(const at::Tensor & grad, at::IntArrayRef sizes); +TORCH_API at::Tensor index_select_backward(const at::Tensor & grad, at::IntArrayRef self_sizes, int64_t dim, const at::Tensor & index); +TORCH_API at::Tensor select(const at::Tensor& self, int64_t dim, int64_t index); +TORCH_API std::vector tensor_split(const Tensor& self, IntArrayRef indices, int64_t dim); +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/Normalization.h b/phivenv/Lib/site-packages/torch/include/ATen/native/Normalization.h new file mode 100644 index 0000000000000000000000000000000000000000..644beb731be36beb44373fb7c466bb023dd73cb6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/Normalization.h @@ -0,0 +1,19 @@ +#pragma once + +#include +#include + +namespace at::native { + +using renorm_scale_factor_fn = void (*) (TensorIteratorBase& iter, double maxnorm); +DECLARE_DISPATCH(renorm_scale_factor_fn, renorm_scale_factor_stub) + +enum class BatchNormBackend { + Native, + Cudnn, + Miopen, +}; + +TORCH_API BatchNormBackend _select_batch_norm_backend(const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, bool training, double eps); + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/Padding.h b/phivenv/Lib/site-packages/torch/include/ATen/native/Padding.h new file mode 100644 index 0000000000000000000000000000000000000000..00c6c6a5b65858ea5a1fcb877f097e9e5ebddaca --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/Padding.h @@ -0,0 +1,63 @@ +#pragma once + +#include +#include + +namespace at::native { + +using padding_fn = void (*)(const Tensor&, const Tensor&, IntArrayRef); + +// reflection padding +DECLARE_DISPATCH(padding_fn, reflection_pad1d_kernel) +DECLARE_DISPATCH(padding_fn, reflection_pad1d_backward_kernel) +DECLARE_DISPATCH(padding_fn, reflection_pad2d_kernel) +DECLARE_DISPATCH(padding_fn, reflection_pad2d_backward_kernel) +DECLARE_DISPATCH(padding_fn, reflection_pad3d_kernel) +DECLARE_DISPATCH(padding_fn, reflection_pad3d_backward_kernel) + +// replication padding +DECLARE_DISPATCH(padding_fn, replication_pad1d_kernel) +DECLARE_DISPATCH(padding_fn, replication_pad1d_backward_kernel) +DECLARE_DISPATCH(padding_fn, replication_pad2d_kernel) +DECLARE_DISPATCH(padding_fn, replication_pad2d_backward_kernel) +DECLARE_DISPATCH(padding_fn, replication_pad3d_kernel) +DECLARE_DISPATCH(padding_fn, replication_pad3d_backward_kernel) + +namespace padding { + +template +inline void check_valid_input(const Tensor& input, IntArrayRef padding) { + + TORCH_CHECK(padding.size() == 2 * dim, + "padding size is expected to be ", 2 * dim, + ", but got: ", padding.size()); + + int input_dim = input.dim(); + + bool is_batch_mode = input_dim == (dim + 2); + bool is_non_batch_mode = input_dim == (dim + 1); + + bool valid_batch_mode = is_batch_mode; + bool valid_non_batch_mode = is_non_batch_mode; + + if (is_batch_mode) { + // allow batch size of 0-dim. + for (const auto d : c10::irange(1, input_dim)) { + valid_batch_mode = valid_batch_mode && input.size(d) != 0; + } + } else { + for (const auto d : c10::irange(0, input_dim)) { + valid_non_batch_mode = valid_non_batch_mode && input.size(d) != 0; + } + } + + // allow empty batch size but not other dimensions. + TORCH_CHECK(valid_batch_mode || valid_non_batch_mode, + "Expected ", dim + 1, "D or ", dim + 2, + "D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got: ", + input.sizes()); +} + +} // namespace padding + +} // at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/PixelShuffle.h b/phivenv/Lib/site-packages/torch/include/ATen/native/PixelShuffle.h new file mode 100644 index 0000000000000000000000000000000000000000..0422760865548a345637e1fd8ffd3d7e0c76e72e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/PixelShuffle.h @@ -0,0 +1,46 @@ +#include +#include + +namespace at::native { + +inline void check_pixel_shuffle_shapes(const Tensor& self, int64_t upscale_factor) { + TORCH_CHECK(self.dim() >= 3, + "pixel_shuffle expects input to have at least 3 dimensions, but got input with ", + self.dim(), " dimension(s)"); + TORCH_CHECK(upscale_factor > 0, + "pixel_shuffle expects a positive upscale_factor, but got ", + upscale_factor); + int64_t c = self.size(-3); + int64_t upscale_factor_squared = upscale_factor * upscale_factor; + TORCH_CHECK(c % upscale_factor_squared == 0, + "pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of " + "upscale_factor, but input.size(-3)=", c, " is not divisible by ", upscale_factor_squared); +} + +inline void check_pixel_unshuffle_shapes(const Tensor& self, int64_t downscale_factor) { + TORCH_CHECK( + self.dim() >= 3, + "pixel_unshuffle expects input to have at least 3 dimensions, but got input with ", + self.dim(), + " dimension(s)"); + TORCH_CHECK( + downscale_factor > 0, + "pixel_unshuffle expects a positive downscale_factor, but got ", + downscale_factor); + int64_t h = self.size(-2); + int64_t w = self.size(-1); + TORCH_CHECK( + h % downscale_factor == 0, + "pixel_unshuffle expects height to be divisible by downscale_factor, but input.size(-2)=", + h, + " is not divisible by ", + downscale_factor); + TORCH_CHECK( + w % downscale_factor == 0, + "pixel_unshuffle expects width to be divisible by downscale_factor, but input.size(-1)=", + w, + " is not divisible by ", + downscale_factor); +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/PointwiseOps.h b/phivenv/Lib/site-packages/torch/include/ATen/native/PointwiseOps.h new file mode 100644 index 0000000000000000000000000000000000000000..fec01a370318269ca33dadada7b73a72f3743ba5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/PointwiseOps.h @@ -0,0 +1,28 @@ +// Ternary and higher-order pointwise operations +#pragma once + +#include + +namespace c10 { +class Scalar; +} + +namespace at { + +struct TensorIterator; +struct TensorIteratorBase; + +namespace native { + +using pointwise_fn = void (*)(TensorIterator&, const Scalar& scalar); +using structured_pointwise_fn = void (*)(TensorIteratorBase&, const Scalar& scalar); +using pointwise_fn_double = void (*)(TensorIterator&, const Scalar&, double); + +DECLARE_DISPATCH(structured_pointwise_fn, addcmul_stub) +DECLARE_DISPATCH(structured_pointwise_fn, addcdiv_stub) +DECLARE_DISPATCH(pointwise_fn_double, smooth_l1_backward_stub) +DECLARE_DISPATCH(pointwise_fn_double, huber_backward_stub) +DECLARE_DISPATCH(pointwise_fn, mse_backward_stub) + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/Pool.h b/phivenv/Lib/site-packages/torch/include/ATen/native/Pool.h new file mode 100644 index 0000000000000000000000000000000000000000..63421d1a36ec194594145d3a0c74ed8d66e10b23 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/Pool.h @@ -0,0 +1,361 @@ +#include +#include +#include +#include +#include + +#include + +#pragma once + +namespace at::native { + +using max_pool2d_fn = void(*)(const Tensor& output, const Tensor& indices, const Tensor& input, + int kW, int kH, int dW, int dH, int padW, int padH, int dilationW, int dilationH); +using max_pool2d_backward_fn = void(*)(const Tensor& grad_input, const Tensor& grad_output, const Tensor& indices); + +DECLARE_DISPATCH(max_pool2d_fn, max_pool2d_kernel) +DECLARE_DISPATCH(max_pool2d_backward_fn, max_pool2d_backward_kernel) + +// averge pooling has same signature for forward and backward +using avg_pool2d_fn = void(*)(const Tensor& output, const Tensor& input, int64_t kW, int64_t kH, + int64_t dW, int64_t dH, int64_t padW, int64_t padH, bool count_include_pad, std::optional divisor_override); +using avg_pool2d_backward_fn = void(*)(const Tensor& output, const Tensor& input, int kW, int kH, + int dW, int dH, int padW, int padH, bool count_include_pad, std::optional divisor_override); + +DECLARE_DISPATCH(avg_pool2d_fn, avg_pool2d_kernel) +DECLARE_DISPATCH(avg_pool2d_backward_fn, avg_pool2d_backward_kernel) + +// averge pooling has same signature for forward and backward +using avg_pool3d_fn = void(*)(const Tensor& output, const Tensor& input, + int64_t kW, int64_t kH, int64_t kD, int64_t dW, int64_t dH, int64_t dD, + int64_t padW, int64_t padH, int64_t padD, bool count_include_pad, + std::optional divisor_override); +using avg_pool3d_backward_fn = void(*)(const Tensor& output, const Tensor& input, + int kW, int kH, int kD, int dW, int dH, int dD, + int padW, int padH, int padD, bool count_include_pad, + std::optional divisor_override); + +DECLARE_DISPATCH(avg_pool3d_fn, avg_pool3d_kernel) +DECLARE_DISPATCH(avg_pool3d_backward_fn, avg_pool3d_backward_kernel) + +using max_pool3d_fn = void(*)(Tensor& output, Tensor& indices, const Tensor& input, + int kW, int kH, int kD, int dW, int dH, int dD, int pW, int pH, int pD, int dilationW, int dilationH, int dilationD); +using max_pool3d_backward_fn = void(*)(Tensor& grad_input, const Tensor& grad_output, const Tensor& indices); + +DECLARE_DISPATCH(max_pool3d_fn, max_pool3d_kernel) +DECLARE_DISPATCH(max_pool3d_backward_fn, max_pool3d_backward_kernel) +namespace { + +template +inline dest_t +safe_downcast(src_t v) +{ + TORCH_CHECK(std::numeric_limits::min() <= v && v <= std::numeric_limits::max(), + "integer out of range"); + + return static_cast(v); +} + +template +inline T pooling_output_shape_pad_lr( + T inputSize, T kernelSize, T pad_l, T pad_r, T stride, T dilation, + bool ceil_mode) { + T outputSize = div_rtn( + inputSize + pad_l + pad_r - dilation * (kernelSize - 1) - 1 + + (ceil_mode ? stride - 1 : 0), stride) + 1; + if (ceil_mode) { + // ensure that the last pooling starts inside the image + // needed to avoid problems in ceil mode + if ((outputSize - 1) * stride >= inputSize + pad_l) { + --outputSize; + } + } + return outputSize; +} + +template +inline T pooling_output_shape( + T inputSize, T kernelSize, T pad, T stride, T dilation, bool ceil_mode) { + TORCH_CHECK(stride != 0, "stride should not be zero"); + TORCH_CHECK(pad >= 0, + "pad must be non-negative, but got pad: ", pad); + TORCH_CHECK(pad <= ((kernelSize - 1) * dilation + 1) / 2, + "pad should be at most half of effective kernel size, but got pad=", + pad, ", kernel_size=", kernelSize, " and dilation=", dilation) + return pooling_output_shape_pad_lr( + inputSize, kernelSize, pad, pad, stride, dilation, ceil_mode); +} + +template +std::pair _pooling_same_mode_padding_lr( + T inputSize, T kernelSize, T stride, T dilation) { + // NOTE: with strides, the output shape is ceil(inputSize/stride) + auto total_padding = T(dilation) * (kernelSize - 1); + + // Prefer symmetric padding if possible + if (stride > 2 && (total_padding % 2 == 1)) { + // The floor in the output size calculation gives us a little wiggle room + auto wiggle_room = inputSize % stride - 1; + if (wiggle_room > 0) { + total_padding = total_padding - 1; + } + } + + auto left = total_padding / 2; + return {left, total_padding - left}; +} + +inline std::pair pooling_same_mode_padding_lr( + int64_t inputSize, int64_t kernelSize, int64_t stride, int64_t dilation) { + return _pooling_same_mode_padding_lr(inputSize, kernelSize, stride, dilation); +} + +inline std::pair pooling_same_mode_padding_lr( + c10::SymInt inputSize, c10::SymInt kernelSize, c10::SymInt stride, c10::SymInt dilation) { + return _pooling_same_mode_padding_lr(std::move(inputSize), std::move(kernelSize), std::move(stride), std::move(dilation)); +} + +// AveragePool2d/DilatedMaxPool2d (forward) +inline void +pool2d_shape_check( + const Tensor& input, + int64_t kH, int64_t kW, int64_t dH, int64_t dW, int64_t padH, int64_t padW, int64_t dilationH, int64_t dilationW, + int64_t nInputPlane, + int64_t inputHeight, int64_t inputWidth, + int64_t outputHeight, int64_t outputWidth, MemoryFormat memory_format) +{ + const int64_t ndim = input.ndimension(); +#ifndef STRIP_ERROR_MESSAGES + const int64_t nOutputPlane = nInputPlane; +#endif + + TORCH_CHECK(kW > 0 && kH > 0, + "kernel size should be greater than zero, but got ", + "kH: ", kH, " kW: ", kW); + TORCH_CHECK(dW > 0 && dH > 0, + "stride should be greater than zero, but got " + "dH: ", dH, " dW: ", dW); + TORCH_CHECK(dilationH > 0 && dilationW > 0, + "dilation should be greater than zero, but got ", + "dilationH: ", dilationH, " dilationW: ", dilationW); + + bool valid_dims = input.size(1) != 0 && input.size(2) != 0; + if (memory_format == at::MemoryFormat::ChannelsLast){ + // Expect tensor in NHWC format and allow 0-dim only for N. + TORCH_CHECK((ndim == 4 && valid_dims && input.size(3) != 0), + "Expected 4D (batch mode) tensor expected for input with channels_last layout" + " with optional 0 dim batch size for input, but got: ", input.sizes()); + } else { + TORCH_CHECK((ndim == 3 && input.size(0) != 0 && valid_dims) || + (ndim == 4 && valid_dims && input.size(3) != 0), + "Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got:", + input.sizes()); + } + + TORCH_CHECK(kW/2 >= padW && kH/2 >= padH, + "pad should be smaller than or equal to half of kernel size, but got ", + "padW = ", padW, ", padH = ", padH, ", kW = ", kW, ", kH = ", kH); + + TORCH_CHECK(outputWidth >= 1 && outputHeight >= 1, + "Given input size: (", + nInputPlane, "x", inputHeight, "x", inputWidth, "). ", + "Calculated output size: (", + nOutputPlane, "x", outputHeight, "x", outputWidth, "). ", + "Output size is too small"); +} + +// DilatedMaxPool2d (backward) +inline void +max_pool2d_backward_shape_check( + const Tensor& input, + const Tensor& gradOutput, + const Tensor& indices, + int kH, int kW, int dH, int dW, int padH, int padW, int dilationH, int dilationW, + int64_t nInputPlane, + int64_t inputHeight, int64_t inputWidth, + int64_t outputHeight, int64_t outputWidth, MemoryFormat memory_format) +{ + pool2d_shape_check( + input, + kH, kW, dH, dW, padH, padW, dilationH, dilationW, + nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, memory_format); + + const int64_t ndim = input.ndimension(); + const int64_t nOutputPlane = nInputPlane; + + check_dim_size(gradOutput, ndim, ndim-3, nOutputPlane); + check_dim_size(gradOutput, ndim, ndim-2, outputHeight); + check_dim_size(gradOutput, ndim, ndim-1, outputWidth); + + check_dim_size(indices, ndim, ndim-3, nOutputPlane); + check_dim_size(indices, ndim, ndim-2, outputHeight); + check_dim_size(indices, ndim, ndim-1, outputWidth); + + if (ndim == 4) { + const int64_t batchSize = input.size(0); + check_dim_size(gradOutput, ndim, 0, batchSize); + check_dim_size(indices, ndim, 0, batchSize); + } +} + +// AveragePool2d (backward) +inline void +avg_pool2d_backward_shape_check( + const Tensor& input, + const Tensor& gradOutput, + int64_t /*nbatch*/, + int kH, int kW, int dH, int dW, int padH, int padW, + int64_t nInputPlane, + int64_t inputHeight, int64_t inputWidth, + int64_t outputHeight, int64_t outputWidth, + MemoryFormat memory_format) +{ + pool2d_shape_check( + input, + kH, kW, dH, dW, padH, padW, 1, 1, + nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, + memory_format); + + const int64_t ndim = input.ndimension(); + const int64_t nOutputPlane = nInputPlane; + + check_dim_size(gradOutput, ndim, ndim-3, nOutputPlane); + check_dim_size(gradOutput, ndim, ndim-2, outputHeight); + check_dim_size(gradOutput, ndim, ndim-1, outputWidth); +} + +// AveragePool3d/DilatedMaxPool3d (forward) +inline void +pool3d_shape_check( + const Tensor& input, + int64_t nslices, + int kT, int kH, int kW, + int dT, int dH, int dW, + int pT, int pH, int pW, + int dilationT, int dilationH, int dilationW, + int64_t itime, int64_t iheight, int64_t iwidth, + int64_t otime, int64_t oheight, int64_t owidth, + const char *fn_name, + bool check_input_size=false) +{ + const int64_t ndim = input.ndimension(); + + TORCH_CHECK(kT > 0 && kW > 0 && kH > 0, + "kernel size should be greater than zero, but got ", + "kT: ", kT, " kH: ", kH, " kW: ", kW); + TORCH_CHECK(dT > 0 && dW > 0 && dH > 0, + "stride should be greater than zero, but got ", + "dT: ", dT, " dH: ", dH, " dW: ", dW); + TORCH_CHECK(dilationT > 0 && dilationW > 0 && dilationH > 0, + "dilation should be greater than zero, but got ", + "dilationT: ", dilationT, " dilationH: ", dilationH, " dilationW: ", dilationW); + + TORCH_CHECK(ndim == 4 || ndim == 5, + fn_name, ": Expected 4D or 5D tensor for input, but got: ", input.sizes()); + + for (const auto i : c10::irange(ndim)) { + if (ndim == 5 && i == 0) { + // size of batch-dim can be 0. + continue; + } + TORCH_CHECK( + input.size(i) > 0, + fn_name, + ": Expected input's non-batch dimensions to have positive length," + " but input has a shape of ", + input.sizes(), + " and non-batch dimension ", + input.size(i), + " has length zero!") + } + + if (check_input_size) { // AveragePool3d + TORCH_CHECK(itime >= kT && iheight >= kH && iwidth >= kW, + "input image ", "(T: ", itime, " H: ", iheight, " W: ", iwidth, ") smaller than ", + "kernel size ", "(kT: ", kT, " kH: ", kH, " kW: ", kW, ")"); + } + + TORCH_CHECK(kT/2 >= pT && kW/2 >= pW && kH/2 >= pH, + "pad should be smaller than or equal to half of kernel size, but got " + "kT: ", kT, " kW: ", kW, " kH: ", kH, " padT: ", pT, " padW: ", pW, " padH: ", pH); + + TORCH_CHECK(otime >= 1 && owidth >= 1 && oheight >= 1, + "Given input size: (", + nslices,"x", itime, "x", iheight, "x", iwidth, "). ", + "Calculated output size: (", + nslices, "x", otime, "x", oheight, "x", owidth, "). ", + "Output size is too small"); +} + +inline void +max_pool3d_backward_shape_check( + const Tensor& input, + const Tensor& gradOutput, + const Tensor& indices, + int64_t nslices, + int kT, int kH, int kW, + int dT, int dH, int dW, + int pT, int pH, int pW, + int dilationT, int dilationH, int dilationW, + int64_t itime, int64_t iheight, int64_t iwidth, + int64_t otime, int64_t oheight, int64_t owidth, + const char* fn_name) +{ + const int64_t ndim = input.ndimension(); + + pool3d_shape_check( + input, + nslices, + kT, kH, kW, + dT, dH, dW, + pT, pH, pW, + dilationT, dilationH, dilationW, + itime, iheight, iwidth, + otime, oheight, owidth, fn_name); + + check_dim_size(gradOutput, ndim, ndim-4, nslices); + check_dim_size(gradOutput, ndim, ndim-3, otime); + check_dim_size(gradOutput, ndim, ndim-2, oheight); + check_dim_size(gradOutput, ndim, ndim-1, owidth); + + check_dim_size(indices, ndim, ndim-4, nslices); + check_dim_size(indices, ndim, ndim-3, otime); + check_dim_size(indices, ndim, ndim-2, oheight); + check_dim_size(indices, ndim, ndim-1, owidth); +} + +inline void +avg_pool3d_backward_shape_check( + const Tensor& input, + const Tensor& gradOutput, + int64_t nslices, + int kT, int kH, int kW, + int dT, int dH, int dW, + int pT, int pH, int pW, + int64_t itime, int64_t iheight, int64_t iwidth, + int64_t otime, int64_t oheight, int64_t owidth, + const char *fn_name) +{ + const int64_t ndim = input.ndimension(); + + pool3d_shape_check( + input, + nslices, + kT, kH, kW, + dT, dH, dW, + pT, pH, pW, + 1, 1, 1, + itime, iheight, iwidth, + otime, oheight, owidth, + fn_name, true); + + check_dim_size(gradOutput, ndim, ndim-4, nslices); + check_dim_size(gradOutput, ndim, ndim-3, otime); + check_dim_size(gradOutput, ndim, ndim-2, oheight); + check_dim_size(gradOutput, ndim, ndim-1, owidth); +} + +} // anonymous namespace + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/Pow.h b/phivenv/Lib/site-packages/torch/include/ATen/native/Pow.h new file mode 100644 index 0000000000000000000000000000000000000000..64fd7a2f9702252db8cd5a0a0badb435d59f8329 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/Pow.h @@ -0,0 +1,69 @@ +#pragma once + +#include + +namespace c10 { +class Scalar; +} + +namespace at { + +struct TensorIterator; +struct TensorIteratorBase; + +namespace native { + +#if defined(__CUDACC__) || defined(__HIPCC__) +#define HOST_DEVICE __host__ __device__ +#else +#define HOST_DEVICE +#endif + +// integral power in pytorch allows for negative exponents, giving truncated integral results. +// e.g. since 2**-1==0.5, the truncated integral result is zero. 1**negative_exponent is the +// only non-zero result. +template , T>* = nullptr> +inline HOST_DEVICE __ubsan_ignore_signed_int_overflow__ T powi_impl(T a, T b) { + T result = 1; + while (b) { + if (b & 1) { + result *= a; + } + b /= 2; + a *= a; + } + return result; +} + +template && !std::is_signed_v, T>* = nullptr> +inline HOST_DEVICE T powi(T a, T b) { + return powi_impl(a, b); +} + +template && std::is_signed_v, T>* = nullptr> +inline HOST_DEVICE T powi(T a, T b) { + if ( b < 0 ) { + if ( a == 1 ) { + return 1; + } else if ( a == -1 ) { + auto negative = (-b) % static_cast(2); + return negative ? -1 : 1; + } else { + return 0; + } + } + return powi_impl(a, b); +} + +using pow_tensor_tensor_fn = void (*)(TensorIteratorBase&); +using pow_tensor_scalar_fn = void (*)(TensorIteratorBase&, const c10::Scalar&); + +DECLARE_DISPATCH(pow_tensor_tensor_fn, pow_tensor_tensor_stub) +DECLARE_DISPATCH(pow_tensor_scalar_fn, pow_tensor_scalar_stub) + +} // namespace native + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/RNN.h b/phivenv/Lib/site-packages/torch/include/ATen/native/RNN.h new file mode 100644 index 0000000000000000000000000000000000000000..e09132554289c4b6e33e3887f8d0a32aa2bf4c3c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/RNN.h @@ -0,0 +1,53 @@ +#pragma once + +#include +#include + +namespace at::native { + +using lstm_fn = void(*)(Tensor&, Tensor&, Tensor&, const Tensor&, TensorList, TensorList, bool, int64_t, double, bool, bool, bool); +using rnn_fn = void(*)(Tensor&, Tensor&, const Tensor&, const Tensor&, TensorList, bool, int64_t, double, bool, bool, bool); +using lstm_packed_fn = void(*)(Tensor&, Tensor&, Tensor&, const Tensor&, const Tensor&, TensorList, TensorList, bool, int64_t, double, bool, bool); +using rnn_packed_fn = void(*)(Tensor&, Tensor&, const Tensor&, const Tensor&, const Tensor&, TensorList, bool, int64_t, double, bool, bool); + +DECLARE_DISPATCH(lstm_fn, lstm_cudnn_stub) +DECLARE_DISPATCH(lstm_fn, lstm_miopen_stub) +DECLARE_DISPATCH(lstm_fn, lstm_mkldnn_stub) +DECLARE_DISPATCH(rnn_fn, gru_cudnn_stub) +DECLARE_DISPATCH(rnn_fn, gru_miopen_stub) +DECLARE_DISPATCH(rnn_fn, rnn_tanh_cudnn_stub) +DECLARE_DISPATCH(rnn_fn, rnn_tanh_miopen_stub) +DECLARE_DISPATCH(rnn_fn, rnn_relu_cudnn_stub) +DECLARE_DISPATCH(rnn_fn, rnn_relu_miopen_stub) +DECLARE_DISPATCH(lstm_packed_fn, lstm_packed_cudnn_stub) +DECLARE_DISPATCH(lstm_packed_fn, lstm_packed_miopen_stub) +DECLARE_DISPATCH(rnn_packed_fn, gru_packed_cudnn_stub) +DECLARE_DISPATCH(rnn_packed_fn, gru_packed_miopen_stub) +DECLARE_DISPATCH(rnn_packed_fn, rnn_tanh_packed_cudnn_stub) +DECLARE_DISPATCH(rnn_packed_fn, rnn_tanh_packed_miopen_stub) +DECLARE_DISPATCH(rnn_packed_fn, rnn_relu_packed_cudnn_stub) +DECLARE_DISPATCH(rnn_packed_fn, rnn_relu_packed_miopen_stub) + +inline void check_attributes(const Tensor& input, const TensorList& params, const TensorList& hiddens, bool check_dtype=false) { + auto input_device = input.device(); + auto input_dtype = input.scalar_type(); + + auto check_tensors = [&](const std::string& name, const Tensor& t) { + if (!t.defined()) return; + auto t_device = t.device(); + TORCH_CHECK(input_device == t_device, + "Input and ", name, " tensors are not at the same device, found input tensor at ", + input_device, " and ", name, " tensor at ", t_device); + if (check_dtype) { + auto t_dtype = t.scalar_type(); + TORCH_CHECK(input_dtype == t_dtype, + "Input and ", name, " tensors are not the same dtype, found input tensor with ", + input_dtype, " and ", name, " tensor with ", t_dtype); + } + }; + + for (const auto& h : hiddens) check_tensors("hidden", h); + for (const auto& p : params) check_tensors("parameter", p); +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/RangeFactories.h b/phivenv/Lib/site-packages/torch/include/ATen/native/RangeFactories.h new file mode 100644 index 0000000000000000000000000000000000000000..8b6c1c918dcf5ccbd73d0162b0f0e6836d39ca78 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/RangeFactories.h @@ -0,0 +1,12 @@ +#include +#include + +namespace at { +struct TensorIterator; + +namespace native { + +DECLARE_DISPATCH(void(*)(TensorIterator&, const Scalar&, const Scalar&, const Scalar&), arange_stub) +DECLARE_DISPATCH(void(*)(TensorIterator&, const Scalar&, const Scalar&, int64_t), linspace_stub) + +}} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/RangeUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/RangeUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..8c4414543dd4c965c3519b4ad9fa854b79b5c1c6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/RangeUtils.h @@ -0,0 +1,60 @@ +#include +#include +#include + + + +namespace at::native { + +inline void arange_check_bounds( + const c10::Scalar& start, + const c10::Scalar& end, + const c10::Scalar& step) { + // use double precision for validation to avoid precision issues + double dstart = start.to(); + double dend = end.to(); + double dstep = step.to(); + + TORCH_CHECK(dstep > 0 || dstep < 0, "step must be nonzero"); + TORCH_CHECK( + std::isfinite(dstart) && std::isfinite(dend), + "unsupported range: ", + dstart, + " -> ", + dend); + TORCH_CHECK( + ((dstep > 0) && (dend >= dstart)) || ((dstep < 0) && (dend <= dstart)), + "upper bound and lower bound inconsistent with step sign"); +} + +template +int64_t compute_arange_size(const Scalar& start, const Scalar& end, const Scalar& step) { + arange_check_bounds(start, end, step); + + // we use double precision for (start - end) / step + // to compute size_d for consistency across devices. + // The problem with using accscalar_t is that accscalar_t might be float32 on gpu for a float32 scalar_t, + // but double on cpu for the same, + // and the effective output size starts differing on CPU vs GPU because of precision issues, which + // we dont want. + // the corner-case we do want to take into account is int64_t, which has higher precision than double + double size_d; + if constexpr (std::is_same_v) { + using accscalar_t = at::acc_type; + auto xstart = start.to(); + auto xend = end.to(); + auto xstep = step.to(); + int64_t sgn = (xstep > 0) - (xstep < 0); + size_d = std::ceil((xend - xstart + xstep - sgn) / xstep); + } else { + size_d = std::ceil(static_cast(end.to() - start.to()) + / step.to()); + } + + TORCH_CHECK(size_d >= 0 && size_d <= static_cast(std::numeric_limits::max()), + "invalid size, possible overflow?"); + + return static_cast(size_d); +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/ReduceAllOps.h b/phivenv/Lib/site-packages/torch/include/ATen/native/ReduceAllOps.h new file mode 100644 index 0000000000000000000000000000000000000000..c9e75d796c43f59bad222f0d867f0828e595cd34 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/ReduceAllOps.h @@ -0,0 +1,16 @@ +#pragma once + +#include + +namespace at { +class Tensor; +} + +namespace at::native { + +using reduce_all_fn = void (*)(Tensor & result, const Tensor & self); +using reduce_min_max_fn = void (*)(Tensor & max_result, Tensor & min_result, const Tensor & self); +DECLARE_DISPATCH(reduce_all_fn, min_all_stub) +DECLARE_DISPATCH(reduce_all_fn, max_all_stub) + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/ReduceOps.h b/phivenv/Lib/site-packages/torch/include/ATen/native/ReduceOps.h new file mode 100644 index 0000000000000000000000000000000000000000..d3528334ae86c8ceb7b3775fccd9969a485288ec --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/ReduceOps.h @@ -0,0 +1,56 @@ +#pragma once + +#include +#include +#include + +namespace c10 { +class Scalar; +} + +namespace at { +struct TensorIterator; +class Tensor; +} + +namespace at::native { + +using reduce_fn = void(*)(TensorIterator &); + +DECLARE_DISPATCH(reduce_fn, sum_stub) +DECLARE_DISPATCH(reduce_fn, nansum_stub) +DECLARE_DISPATCH(reduce_fn, prod_stub) +DECLARE_DISPATCH(reduce_fn, mean_stub) +DECLARE_DISPATCH(reduce_fn, and_stub) +DECLARE_DISPATCH(reduce_fn, or_stub) +DECLARE_DISPATCH(reduce_fn, min_values_stub) +DECLARE_DISPATCH(reduce_fn, max_values_stub) +DECLARE_DISPATCH(reduce_fn, argmax_stub) +DECLARE_DISPATCH(reduce_fn, argmin_stub) + +using reduce_std_var_function = + void (*)(TensorIterator&, double correction, bool take_sqrt); +DECLARE_DISPATCH(reduce_std_var_function, std_var_stub) + +using reduce_norm_fn = + void (*)(Tensor&, const Tensor&, const c10::Scalar&, std::optional); +DECLARE_DISPATCH(reduce_norm_fn, norm_kernel) + +using reduce_fn_flag = void(*)(TensorIterator &, const c10::Scalar&); +DECLARE_DISPATCH(reduce_fn_flag, norm_stub) + +using structured_cum_fn = void (*)(const Tensor&, const Tensor&, int64_t); +using cum_fn = void (*)(Tensor&, const Tensor&, int64_t); +DECLARE_DISPATCH(structured_cum_fn, cumsum_stub) +DECLARE_DISPATCH(structured_cum_fn, cumprod_stub) +DECLARE_DISPATCH(cum_fn, logcumsumexp_stub) + +DECLARE_DISPATCH(void (*)(const Tensor&, int64_t, bool, Tensor&, Tensor&), aminmax_stub) +DECLARE_DISPATCH(void (*)(const Tensor&, Tensor&, Tensor&), aminmax_allreduce_stub) + +// Used in cuda/Normalization.cu +TORCH_API std::tuple var_mean_out( + Tensor &result1, Tensor &result2, const Tensor &self, IntArrayRef dim, + int64_t correction, bool keepdim); + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/ReduceOpsUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/ReduceOpsUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..816dc12128d0114ba4773d23b9edcf10eddf3593 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/ReduceOpsUtils.h @@ -0,0 +1,468 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#endif + +namespace at::native { + +// Maximum and minimum possible scalar values, including infinities +template +constexpr scalar_t upper_bound() { + using lim = std::numeric_limits; + return lim::has_infinity ? lim::infinity() : lim::max(); +} + +template +constexpr scalar_t lower_bound() { + using lim = std::numeric_limits; + return lim::has_infinity ? -lim::infinity() : lim::lowest(); +} + +inline Tensor restride_dim( + const Tensor& src, int64_t dim, + IntArrayRef replacement_shape +) { + auto strides = ensure_nonempty_vec(src.strides().vec()); + strides[dim] = 0; + return src.as_strided(replacement_shape, strides); +} + +inline void _dimreduce_setup(const Tensor &result, const Tensor &self, + int64_t dim) { + IntArrayRef self_sizes = self.sizes(); + std::vector result_sizes; + result_sizes.insert(result_sizes.end(), self_sizes.begin(), self_sizes.end()); + result_sizes[dim] = 1; + result.resize_(result_sizes); +} + +inline bool _dimreduce_return_trivial(const Tensor &result, const Tensor &self, + const Scalar& ident, int64_t dim, bool keepdim) { + if (self.numel() == 1 && self.ndimension() == 0) { + result.resize_({}); + result.fill_(self); + return true; + } + // Return identity + if (self.numel() == 0) { + _dimreduce_setup(result, self, dim); + result.fill_(ident); + if (!keepdim) result.squeeze_(dim); + return true; + } + return false; +} + +inline bool _dimreduce_return_trivial_no_ident(Tensor &result, const Tensor &self, + int64_t /*dim*/, bool /*keepdim*/, const char* /*fn_name*/) { + if (self.numel() == 1 && self.ndimension() == 0) { + result.resize_({}); + result.fill_(self); + return true; + } + + return false; +} + +inline std::optional _allreduce_return_trivial( + const Tensor& self, + const Scalar& ident) { + // Return identity + if (self.numel() == 0) { + return at::scalar_tensor(ident, self.options()); + } + return std::nullopt; +} + +#define OPTION_TYPE_EQUALITY_CHECK(option, out, self) \ +{ \ + TORCH_CHECK(\ + out.option() == self.option(),\ + "expected ", #option, " ",\ + self.option(),\ + " but found ", out.option())\ +} + +inline void check_scalar_type_device_layout_equal(const Tensor& out, const Tensor& self) { + OPTION_TYPE_EQUALITY_CHECK(scalar_type, out, self); + OPTION_TYPE_EQUALITY_CHECK(device, out.options(), self.options()); + OPTION_TYPE_EQUALITY_CHECK(layout, out.options(), self.options()); +} + +inline Tensor integer_upcast(const Tensor& self, std::optional dtype) { + ScalarType scalarType = self.scalar_type(); + TORCH_CHECK(!isBarebonesUnsignedType(scalarType), "integer upcasting for uint16, uint32 and uint64 is not currently implemented"); + ScalarType upcast_scalarType = dtype.value_or(at::isIntegralType(scalarType, /*includeBool=*/true) ? ScalarType::Long : scalarType); + return self.toType(upcast_scalarType); +} + +using DimMask = TensorIterator::DimMask; + +inline DimVector make_dim_vector(OptionalIntArrayRef opt_dims, int64_t ndim) { + if (opt_dims.has_value()) { + return DimVector(opt_dims.value()); + } else { + std::vector all_dims(ndim); + std::iota(all_dims.begin(), all_dims.end(), 0); + return DimVector(all_dims); + } +} + +inline DimMask make_dim_mask(OptionalIntArrayRef opt_dims, int64_t ndim, bool allow_empty_dims=false) { + DimMask mask; + if (opt_dims.has_value()) { + auto dims = opt_dims.value(); + if (dims.empty() && !allow_empty_dims) { + mask = DimMask().flip(); + } else { + mask = at::dim_list_to_bitset(dims, ndim); + } + } else { + mask = DimMask().flip(); + } + return mask; +} + +inline DimVector shape_from_dim_mask(const Tensor& self, DimMask mask, bool keepdim) { + auto shape = DimVector(self.sizes()); + for (int dim = shape.size() - 1; dim >= 0; dim--) { + if (mask[dim]) { + if (keepdim) { + shape[dim] = 1; + } else { + shape.erase(shape.begin() + dim); + } + } + } + return shape; +} + +inline void resize_reduction_result( + Tensor& result, const Tensor& self, DimMask mask, bool keepdim, + ScalarType /*dtype*/) +{ + auto shape = shape_from_dim_mask(self, mask, keepdim); + TORCH_CHECK(result.defined(), "Cannot create a new tensor inside a reduction op. You likely tried to call an operator with an out argument but the out argument was an undefined tensor."); + at::native::resize_output(result, shape); +} + +inline Tensor create_reduction_result( + const Tensor& self, at::OptionalIntArrayRef dim, bool keepdim, ScalarType dtype +) { + DimMask mask = make_dim_mask(dim, self.dim()); + auto shape = shape_from_dim_mask(self, mask, keepdim); + return at::empty(shape, self.options().dtype(dtype)); +} + +inline Tensor review_reduce_result(const Tensor& result, int ndim, DimMask mask, bool keepdim) { + if (keepdim) { + return result; + } + auto shape = DimVector(result.sizes()); + auto stride = DimVector(result.strides()); + for (const auto dim : c10::irange(ndim)) { + if (mask[dim]) { + shape.insert(shape.begin() + dim, 1); + stride.insert(stride.begin() + dim, 0); + } + } + return result.as_strided(shape, stride); +} + +inline TensorIterator make_reduction( + const char* name, Tensor& result, const Tensor& self, + at::OptionalIntArrayRef dim_opt, + bool keepdim, ScalarType in_dtype, ScalarType out_dtype) { + // check that result type and dtype match if provided + TORCH_CHECK( + !result.defined() || result.scalar_type() == out_dtype, + name, ": provided dtype must match dtype of result. Got ", + toString(result.scalar_type()), + " and ", + toString(out_dtype), + "."); + // dim={} performs an all-reduce, same as dim=None + IntArrayRef dim = dim_opt.value_or(IntArrayRef{}); + int64_t ndim = self.dim(); + auto mask = make_dim_mask(dim, ndim); + resize_reduction_result(result, self, mask, keepdim, out_dtype); + auto viewed_result = review_reduce_result(result, ndim, mask, keepdim); + namedinference::propagate_names_for_reduction(result, self, dim, keepdim); + if (self.scalar_type() == in_dtype) { + return TensorIterator::reduce_op(viewed_result, self); + } + return TensorIterator::reduce_op(viewed_result, self.to(in_dtype)); +} + +[[maybe_unused]] inline TensorIterator make_reduction( + const char* name, + Tensor& result, + const Tensor& self, + at::OptionalIntArrayRef dim, + bool keepdim, + ScalarType out_dtype) { + // special case for type promotion in mixed precision, improves computational + // efficiency. + // not generalize this to common mismatched input/output types to avoid cross + // product of templated kernel launches. + const bool gpu_lowp_to_f32 = ( + (self.is_cuda() || self.is_xpu()) && (self.scalar_type() == kHalf || self.scalar_type() == kBFloat16) && out_dtype == kFloat); + auto in_dtype = gpu_lowp_to_f32 ? self.scalar_type() + : self.is_complex() ? c10::toComplexType(out_dtype) + : out_dtype; + return make_reduction(name, result, self, dim, keepdim, in_dtype, out_dtype); +} + +inline TensorIterator make_reduction( + const char* name, Tensor& result1, Tensor& result2, const Tensor& self, + at::OptionalIntArrayRef dim_opt, bool keepdim, ScalarType dtype1, + ScalarType dtype2) { + // check that result type and dtype match if provided + TORCH_CHECK( + (!result1.defined() || result1.scalar_type() == dtype1) && (!result2.defined() || result2.scalar_type() == dtype2), + name, ": provided dtype must match dtype of result. Got ", + toString(result1.scalar_type()), toString(result2.scalar_type()), + " and ", + toString(dtype1), toString(dtype2), + "."); + + // dim={} performs an all-reduce, same as dim=None + auto dim = dim_opt.value_or(IntArrayRef{}); + int64_t ndim = self.dim(); + DimMask mask = make_dim_mask(dim, ndim); + resize_reduction_result(result1, self, mask, keepdim, dtype1); + auto viewed_result1 = review_reduce_result(result1, ndim, mask, keepdim); + + resize_reduction_result(result2, self, mask, keepdim, dtype2); + auto viewed_result2 = review_reduce_result(result2, ndim, mask, keepdim); + + namedinference::propagate_names_for_reduction(result1, self, dim, keepdim); + namedinference::propagate_names_for_reduction(result2, self, dim, keepdim); + + // special case for type promotion in mixed precision, improves computational + // efficiency. + // We don't generalize this to common mismatched input/output types to avoid cross + // product of templated kernel launches. + if (self.scalar_type() == dtype1 || + (self.is_cuda() && self.scalar_type() == kHalf && dtype1 == kFloat)) { + return TensorIterator::reduce_op(viewed_result1, viewed_result2, self); + } + return TensorIterator::reduce_op(viewed_result1, viewed_result2, self.to(dtype1)); +} + +[[maybe_unused]] inline TensorIterator make_reduction( + const char* name, + Tensor& result1, + Tensor& result2, + const Tensor& self, + at::OptionalIntArrayRef dim, + bool keepdim, + ScalarType dtype) { + return make_reduction(name, result1, result2, self, dim, keepdim, dtype, dtype); +} + +inline void zero_numel_check_dims(const Tensor& self, const int64_t dim, const char *fn_name) { + if (self.ndimension() == 0) { + TORCH_CHECK_INDEX(dim == 0 || dim == -1, fn_name, + ": Expected reduction dim -1 or 0 for scalar but got ", dim); + } + else { + TORCH_CHECK_INDEX(self.size(dim) != 0, fn_name, + ": Expected reduction dim ", dim, " to have non-zero size."); + } +} + +inline void zero_numel_check_dims(const Tensor& self, const IntArrayRef dim, const char *fn_name) { + TORCH_CHECK( + !dim.empty(), + fn_name, ": Expected reduction dim to be specified for input.numel() == 0. ", + "Specify the reduction dim with the 'dim' argument."); + for (const int64_t d : dim) { + zero_numel_check_dims(self, d, fn_name); + } +} + +inline std::vector get_zero_numel_tensor_size( + const Tensor& self, + const int64_t dim, + const bool keepdim, + const char* fn_name) { + TORCH_INTERNAL_ASSERT(self.numel() == 0, fn_name, ": Expected self.numel() == 0."); + zero_numel_check_dims(self, dim, fn_name); + std::vector sizes; + if (keepdim) { + sizes = self.sizes().vec(); + sizes[dim] = 1; + } + else { + for (const auto d : c10::irange(self.dim())) { + if (d != dim) { + sizes.push_back(self.sizes()[d]); + } + } + } + return sizes; +} + +// Resize the result tensor and indices when result.numel() == 0 depending on values of +// dim and keepdim for returning tensors containing reduction results. +// This function should be called when you are reducing a zero-numel tensor and want to +// resize the output and return it. This function exists for resizing zero-numel +// tensors when the size of the reduction dimension is non-zero. +[[maybe_unused]] inline void zero_numel_tensor_resize( + Tensor& result, + Tensor& result_indices, + const Tensor& self, + const int64_t dim, + const bool keepdim, + const char* fn_name) { + auto sizes = get_zero_numel_tensor_size(self, dim, keepdim, fn_name); + at::native::resize_output(result, sizes); + at::native::resize_output(result_indices, sizes); +} + +inline ScalarType get_dtype_from_self( + const Tensor& self, + const std::optional& dtype, + bool promote_integers) { + if (dtype.has_value()) { + return dtype.value(); + } + ScalarType src_type = self.scalar_type(); + if (promote_integers && at::isIntegralType(src_type, /*includeBool=*/true)) { + return kLong; + } + return src_type; +} + +inline ScalarType get_dtype_from_result(Tensor& result, std::optional dtype) { + TORCH_CHECK(result.defined(), "Cannot create a new tensor inside a reduction op. You likely tried to call an operator with an out argument but the out argument was an undefined tensor."); + if (dtype.has_value()) { + return dtype.value(); + } else { + return result.scalar_type(); + } +} + + +} // namespace at::native + +namespace at::meta { + +[[maybe_unused]] inline DimVector get_reduction_shape( + const Tensor& self, + IntArrayRef dims, + bool keepdim, + bool allow_empty_dims = false) { + auto mask = native::make_dim_mask(dims, self.dim(), allow_empty_dims); + return native::shape_from_dim_mask(self, mask, keepdim); +} + +inline void resize_reduction( + impl::MetaBase& meta, + const Tensor& self, + OptionalIntArrayRef opt_dims, + bool keepdim, + ScalarType out_dtype, + bool allow_empty_dims=false) { + DimVector dims_ = at::native::make_dim_vector(opt_dims, self.dim()); + maybe_wrap_dims(dims_, self.dim()); + auto shape = get_reduction_shape(self, dims_, keepdim, allow_empty_dims); + if (self.layout() == kStrided) { + meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype)); + } else if (shape.empty()) { + meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype).layout(kStrided)); + } else { + TORCH_CHECK(false, "resize_reduction: support for output with ", self.layout(), " layout is not implemented yet"); + } + namedinference::propagate_names_for_reduction( + meta.maybe_get_output(), self, dims_, keepdim); +} + +inline void resize_reduction_with_indices( + impl::MetaBase& meta, + const Tensor& self, + IntArrayRef dims, + bool keepdim, + ScalarType out_dtype) { + DimVector dims_(dims); + maybe_wrap_dims(dims_, self.dim()); + auto shape = get_reduction_shape(self, dims_, keepdim); + meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype)); + meta.set_output_raw_strided(1, shape, {}, self.options().dtype(kLong)); + namedinference::propagate_names_for_reduction( + meta.maybe_get_output(0), self, dims_, keepdim); + namedinference::propagate_names_for_reduction( + meta.maybe_get_output(1), self, dims_, keepdim); +} + +inline TensorIterator make_reduction( + const Tensor& self, + const Tensor& result, + OptionalIntArrayRef opt_dims, + bool keepdim, + ScalarType in_dtype) { + int64_t ndim = self.dim(); + auto mask = at::native::make_dim_mask(opt_dims, ndim); + auto viewed_result = + at::native::review_reduce_result(result, ndim, mask, keepdim); + if (self.scalar_type() == in_dtype) { + return TensorIterator::reduce_op(viewed_result, self); + } + return TensorIterator::reduce_op(viewed_result, self.to(in_dtype)); +} + +inline TensorIterator make_reduction( + const Tensor& self, + const Tensor& result1, + const Tensor& result2, + IntArrayRef dims, + bool keepdim, + ScalarType dtype1, + ScalarType /*dtype2*/) { + int64_t ndim = self.dim(); + auto mask = at::native::make_dim_mask(dims, ndim); + auto viewed_result1 = at::native::review_reduce_result(result1, ndim, mask, keepdim); + auto viewed_result2 = at::native::review_reduce_result(result2, ndim, mask, keepdim); + // special case for type promotion in mixed precision, improves computational efficiency. + // We don't generalize this to common mismatched input/output types to avoid cross product + // of templated kernel launches. + if (self.scalar_type() == dtype1 || + (self.is_cuda() && self.scalar_type() == kHalf && dtype1 == kFloat)) { + return TensorIterator::reduce_op(viewed_result1, viewed_result2, self); + } + return TensorIterator::reduce_op(viewed_result1, viewed_result2, self.to(dtype1)); +} + +[[maybe_unused]] inline TensorIterator make_reduction_from_out_ty( + const Tensor& self, + const Tensor& result, + OptionalIntArrayRef opt_dims, + bool keepdim, + ScalarType out_dtype) { + // special case for type promotion in mixed precision, improves computational + // efficiency. + // not generalize this to common mismatched input/output types to avoid cross + // product of templated kernel launches. + const bool gpu_lowp_to_f32 = + (self.is_cuda() && + (self.scalar_type() == kHalf || self.scalar_type() == kBFloat16) && + out_dtype == kFloat); + auto in_dtype = gpu_lowp_to_f32 ? self.scalar_type() : out_dtype; + return make_reduction(self, result, opt_dims, keepdim, in_dtype); +} + +} // namespace at::meta diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/ReductionType.h b/phivenv/Lib/site-packages/torch/include/ATen/native/ReductionType.h new file mode 100644 index 0000000000000000000000000000000000000000..8c6d751d46b5acc1fd7b7694a916ca3e9b6ac43f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/ReductionType.h @@ -0,0 +1,40 @@ +#pragma once + +#include + +namespace at::native { + +enum class ReductionType {MAX, MEAN, MIN, SUM, PROD}; + +inline ReductionType get_reduction_enum(const std::string_view& reduce) { + if (reduce == "max" || reduce == "amax") { + return ReductionType::MAX; + } else if (reduce == "mean") { + return ReductionType::MEAN; + } else if (reduce == "min" || reduce == "amin") { + return ReductionType::MIN; + } else if (reduce == "sum") { + return ReductionType::SUM; + } else if (reduce == "prod") { + return ReductionType::PROD; + } else { + TORCH_CHECK(false, "reduce argument must be either sum, prod, mean, amax or amin, got ", reduce); + } +} + +// used for `scatter_reduce`, old options for BC. +inline ReductionType get_operator_enum(const std::string_view reduce, bool use_new_options) { + if (use_new_options) { + return get_reduction_enum(reduce); + } else { + if (reduce == "add") { + return ReductionType::SUM; + } else if (reduce == "multiply") { + return ReductionType::PROD; + } else { + TORCH_CHECK(false, "reduce argument must be either add or multiply.") + } + } +} + +} // at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/Repeat.h b/phivenv/Lib/site-packages/torch/include/ATen/native/Repeat.h new file mode 100644 index 0000000000000000000000000000000000000000..4440d7ccea833e0417039dc1662f71b24c64129b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/Repeat.h @@ -0,0 +1,48 @@ +#pragma once + +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#endif + +namespace at::native { + +template < + typename index_t, + void compute(const index_t*, const int64_t*, index_t*, int64_t, int64_t)> +static inline Tensor repeat_interleave_common( + const Tensor& repeats, + std::optional output_size) { + TORCH_CHECK( + repeats.dim() == 1, "repeat_interleave only accept 1D vector as repeat"); + TORCH_CHECK( + repeats.scalar_type() == at::kLong || repeats.scalar_type() == at::kInt, + "repeats has to be Long or Int tensor"); + if (repeats.size(0) == 0) { + return at::empty_like(repeats, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + Tensor repeats_ = repeats.contiguous(); + Tensor cumsum = repeats.cumsum(0); + int64_t total = 0; + if (output_size.has_value()) { + total = output_size.value(); + } else { + total = cumsum[-1].item(); + TORCH_CHECK( + (repeats >= 0).all().item(), "repeats can not be negative"); + } + + Tensor result = at::empty({total}, repeats.options()); + const index_t* repeat_ptr = repeats_.const_data_ptr(); + const int64_t* cumsum_ptr = cumsum.const_data_ptr(); + index_t* result_ptr = result.data_ptr(); + compute(repeat_ptr, cumsum_ptr, result_ptr, repeats.size(0), total); + return result; +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/Resize.h b/phivenv/Lib/site-packages/torch/include/ATen/native/Resize.h new file mode 100644 index 0000000000000000000000000000000000000000..bbae3f19213cc4c95a407b705a629272a129ac83 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/Resize.h @@ -0,0 +1,205 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +#include + + +namespace at::native { + +// TODO: make all operations that resize given outputs use this function +// for consistency and maintainability. +// Some operations like `cat` might not be able to make the use of +// resize_output directly. For more details to understand how it works in `cat`, +// see https://github.com/pytorch/pytorch/pull/62560#discussion_r687363362 +// Resizes outputs +// Functions accepting output tensors, like with the "out" kwarg, should +// call this function to handle resizing their output tensor. +// Issues a warning if the output tensor has one or more elements and +// needs resizing +// NOTE: In the future the warning will become an error +// Returns a bool saying whether or not the resize actually happened or not +TORCH_API bool resize_output(const Tensor& output, IntArrayRef shape); +// WARNING: Do NOT call this directly. If you are resizing an output and want +// to support dynamic shapes call at::resize__symint and resize_output_check_symint. +// For more details, see: https://github.com/pytorch/pytorch/pull/111530/files#r1365845272 +TORCH_API bool resize_output_symint(const Tensor& output, SymIntArrayRef shape); + +// Utility for resize_output +// Returns a bool saying resize should happen or not and +// raises a warning if resizing for one or more elements +TORCH_API bool resize_output_check(const Tensor& output, IntArrayRef shape); +TORCH_API bool resize_output_check_symint(const Tensor& output, SymIntArrayRef shape); + +TORCH_API void resize_bytes_cpu(StorageImpl* storage, size_t size_bytes); +TORCH_API void resize_bytes_meta(StorageImpl* storage, c10::SymInt size_bytes); +TORCH_API void resize_bytes_nocuda(const Storage& storage, const c10::SymInt& size_bytes); + +inline void maybe_resize_storage_cpu(TensorImpl* self, size_t new_size_bytes) { + // It does not make sense to try to resize a storage + // to hold 0 elements, and this can break + // if storage_offset is positive but + // new_size is 0, so just bail in that case + // (same comment is in cuda/Resize.h) + if (self->numel() == 0) { + return; + } + + const Storage& storage = self->unsafe_storage(); + if (!storage) { + auto new_storage = c10::make_intrusive( + StorageImpl::use_byte_size_t(), + new_size_bytes, + c10::GetCPUAllocator(), + true); + self->set_storage_keep_dtype(std::move(new_storage)); + } else if (new_size_bytes > storage.nbytes()) { + resize_bytes_cpu(storage.unsafeGetStorageImpl(), new_size_bytes); + } +} + +TORCH_API TensorImpl* resize_impl_cpu_( + TensorImpl* self, + IntArrayRef size, + at::OptionalIntArrayRef stride, + bool resize_storage = true); + +template +T maybe_convert_symint(c10::SymInt) = delete; + +template <> +inline c10::SymInt maybe_convert_symint(c10::SymInt x) { return x; } + +template <> +inline int64_t maybe_convert_symint(c10::SymInt x) { return x.guard_int(__FILE__, __LINE__); } + +template +inline void checkInBoundsForStorage( + ArrayRef size, + ArrayRef stride, + T storage_offset, + const caffe2::TypeMeta& data_type, + const Storage& new_storage) { + T storage_size_bytes, storage_size_plus_offset_bytes; + if (stride.data()) { + storage_size_bytes = + at::detail::computeStorageNbytes(size, stride, data_type.itemsize()); + storage_size_plus_offset_bytes = at::detail::computeStorageNbytes( + size, stride, data_type.itemsize(), storage_offset); + } else { + storage_size_bytes = + at::detail::computeStorageNbytesContiguous(size, data_type.itemsize()); + storage_size_plus_offset_bytes = at::detail::computeStorageNbytesContiguous( + size, data_type.itemsize(), storage_offset); + } + // It's ok to always evaluate to False for this early return for SymInts because + // (1) maybe_convert_symint below only installs guard for int64_t case + // (2) we check for this condition in the TORCH_MAYBE_SYM_CHECK below + if (TORCH_GUARD_OR_FALSE(sym_eq(storage_size_bytes, 0))) { + // NB: (a tensor with arbitrary 0 dims)'s storage can have any numel. + return; + } + T new_storage_size_bytes = maybe_convert_symint(new_storage.sym_nbytes()); + TORCH_MAYBE_SYM_CHECK( + sym_eq(storage_size_bytes, 0) || sym_le(storage_size_plus_offset_bytes, new_storage_size_bytes), + "setStorage: sizes ", + size, + ", strides ", + stride, + "," + " storage offset ", + storage_offset, + ", and itemsize ", + data_type.itemsize(), + " requiring a storage size of ", + storage_size_plus_offset_bytes, + " are out of bounds for storage of size ", + new_storage_size_bytes); +} + +template +inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset, + ArrayRef size, ArrayRef stride, bool check_offset_in_bounds = true) { + // FIXME: stride should be optional + if (stride.data()) { + TORCH_CHECK(size.size() == stride.size(), "unequal size length (", size.size(), + ") and stride length (", stride.size(), ")"); + } + +#ifdef DEBUG + TORCH_CHECK(size.size() <= INT_MAX, "size length (", size.size(), ") greater than INT_MAX"); +#endif + + // storageOffset + TORCH_CHECK( + TORCH_GUARD_OR_TRUE(sym_ge(storage_offset, 0)), "Tensor: invalid storage offset ", storage_offset); + + // set_storage_{device} (except set_storage_meta__symint) + // will (unsafely) set the storage offset and then call resize_impl that + // handles resizing the storage However, resize_impl will only resize the + // storage if the sizes/strides changed. For the case that the sizes/strides + // remain unchanged, the storage offset is not properly validated, so we do + // that here. + if (check_offset_in_bounds) { + auto result_tensor_impl = result.unsafeGetTensorImpl(); + bool size_unchanged = result_tensor_impl->generic_sizes() == size; + bool stride_unchanged = stride.data() + ? result_tensor_impl->generic_strides() == stride + : true; + if (size_unchanged && stride_unchanged) { + checkInBoundsForStorage( + size, stride, storage_offset, result.dtype(), storage); + } + } + + // storage: note this can't be replaced with result.set_(storage) as the semantics of that + // function is to set the tensor size to be equal to the size of the storage. + if (!result.storage().is_alias_of(storage)) { + // Caffe2 might have tensors whose storages are null, but we + // don't allow it in PyTorch. + TORCH_INTERNAL_ASSERT(storage); + TORCH_INTERNAL_ASSERT(result.storage()); + + // We used to allow this, but this breaks device caching. + // Let's put an actual error message for this one. + TORCH_CHECK(result.storage().device() == storage.device(), + "Attempted to set the storage of a tensor on device \"", result.storage().device(), + "\" to a storage on different device \"", storage.device(), + "\". This is no longer allowed; the devices must match."); + result.unsafeGetTensorImpl()->set_storage_keep_dtype(std::move(storage)); + } +} + +/** + * Set self's sizes, strides, and storage_offset. + * (size, stride, storage_offset) must be in bounds for self's storage. + */ +template +inline void setStrided( + const Tensor& self, + ArrayRef size, + ArrayRef stride, + T storage_offset) { + TORCH_CHECK(size.size() == stride.size(), "mismatch in length of strides and shape"); + for (const auto& val : stride) { + TORCH_CHECK(val >= 0, + "as_strided: Negative strides are not supported at the moment, " + "got strides: ", stride); + } + + auto* self_ = self.unsafeGetTensorImpl(); + checkInBoundsForStorage( + size, stride, storage_offset, self_->dtype(), self_->storage()); + + /* storage offset */ + TORCH_CHECK(storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset); + self_->set_sizes_and_strides(size, stride, storage_offset); +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/ResizeCommon.h b/phivenv/Lib/site-packages/torch/include/ATen/native/ResizeCommon.h new file mode 100644 index 0000000000000000000000000000000000000000..30eeb5274f0eeb033aeccd4215174d0b0a790bb3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/ResizeCommon.h @@ -0,0 +1,75 @@ +#pragma once + +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +namespace at::native { + +template +inline T storage_size_for(ArrayRef size, ArrayRef stride) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(size.size() == stride.size(), + "storage_size_for(size, stride) requires that size and stride ", + "have the same size as a precondition."); + T storage_size = 1; + for (const auto dim : c10::irange(size.size())) { + if (size[dim] == 0) { + storage_size = 0; + break; + } + storage_size += (size[dim] - 1) * stride[dim]; + } + return storage_size; +} + +inline const Tensor& resize_named_tensor_( + const Tensor& self, + IntArrayRef size, + std::optional optional_memory_format) { + TORCH_INTERNAL_ASSERT(self.has_names()); + TORCH_CHECK( + self.sizes() == size, + "Cannot resize named tensor with resize_ or resize_as_ (tried to resize " + "Tensor", + self.names(), + " with size ", + self.sizes(), + " to ", + size, + "). This may be caused by passing a named tensor ", + "as an `out=` argument; please ensure that the sizes are the same. "); + TORCH_CHECK( + !optional_memory_format.has_value(), + "Unsupported memory format for named tensor resize ", + optional_memory_format.value()); + return self; +} + +// For deterministic output, fill new elements that were added after a storage +// resize with NaN or MAX_INT. `old_storage_nbytes` is the size of the storage +// before the resize happened. +inline const Tensor& fill_resize_deterministic_(const Tensor& tensor, int64_t old_storage_nbytes) { + const at::Storage& storage = tensor.unsafeGetTensorImpl()->unsafe_storage(); + int64_t new_storage_nbytes = storage.nbytes(); + int64_t old_storage_numel = old_storage_nbytes / tensor.itemsize(); + int64_t new_storage_numel = new_storage_nbytes / tensor.itemsize(); + if (new_storage_numel > old_storage_numel) { + at::Tensor tensor_view = at::empty({}, at::TensorOptions().dtype(tensor.scalar_type()).device(tensor.device())); + tensor_view.set_( + storage, + /*storage_offset=*/old_storage_numel, + /*size=*/{new_storage_numel - old_storage_numel}, + /*stride=*/{1}); + at::native::fill_empty_deterministic_(tensor_view); + } + return tensor; +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/ScatterGatherChecks.h b/phivenv/Lib/site-packages/torch/include/ATen/native/ScatterGatherChecks.h new file mode 100644 index 0000000000000000000000000000000000000000..1a978db322591b5a7e03f42906d84df9326c7422 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/ScatterGatherChecks.h @@ -0,0 +1,128 @@ +#pragma once + +#include +#include +#include +#include + +namespace at::native { + +namespace { + +// checks whether index.dtype == int64 +// and self.dtype == src.dtype if src is a Tensor +inline void scatter_gather_dtype_check( + const std::string& method_name, + const Tensor& self, + const Tensor& index, + const std::optional& src_opt = std::nullopt +) { + if (index.numel() != 0) { + TORCH_CHECK( + index.scalar_type() == at::ScalarType::Long || index.scalar_type() == at::ScalarType::Int, + method_name, "(): Expected dtype int32/int64 for index" + ); + } + + if (src_opt.has_value()) { + const auto& src = src_opt.value(); + TORCH_CHECK( + self.scalar_type() == src.scalar_type(), + method_name, "(): Expected self.dtype to be equal to src.dtype" + ); + } +} + +// Used for `gather`-like methods +// Note: self means the input tensor here +// Test: +// 1. index.size(d) <= self.size(d) for all d != dim +// 2. index.dim() == self.dim() +inline void gather_shape_check(const Tensor& self, int64_t dim, + const Tensor& index +) { + auto self_dims = ensure_nonempty_dim(self.dim()); + TORCH_CHECK(self_dims == ensure_nonempty_dim(index.dim()), + "Index tensor must have the same number of dimensions as input tensor" + ); + + for (const auto i : c10::irange(self_dims)) { + if (i != dim) { + TORCH_CHECK( + ensure_nonempty_size(index, i) <= ensure_nonempty_size(self, i), + "Size does not match at dimension ", i, + " expected index ", index.sizes(), + " to be no larger than self ", self.sizes(), + " apart from dimension ", dim + ); + } + } +} + +// Used for `scatter` and `scatter_add` +// Tests: +// 1. index.size(d) <= self.size(d) for all d != dim +// 2. index.size(d) <= src.size(d) for all d if src is a Tensor +// 3. index.dim() == self.dim() == src.dim() +inline void scatter_shape_check( + const Tensor& self, int64_t dim, const Tensor& index, + const std::optional& src_opt = std::nullopt +) { + if (index.numel() == 0) return; + TORCH_CHECK( + ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()), + "Index tensor must have the same number of dimensions as self tensor" + ); + + bool is_wrong_shape = false; + int64_t self_dims = ensure_nonempty_dim(self.dim()); + + // Check: index.size(d) <= self.size(d) for all d != dim + for (const auto d : c10::irange(self_dims)) { + int64_t index_d_size = ensure_nonempty_size(index, d); + if (d == dim) continue; + if (index_d_size > ensure_nonempty_size(self, d)) { + is_wrong_shape = true; + break; + } + } + + // Check: index.size(d) <= src.size(d) for all d if src is Tensor + if (!is_wrong_shape && src_opt.has_value()) { + const auto& src = src_opt.value(); + for (const auto d : c10::irange(self_dims)) { + int64_t index_d_size = ensure_nonempty_size(index, d); + if (index_d_size > ensure_nonempty_size(src, d)) { + is_wrong_shape = true; + break; + } + } + } + + if (src_opt.has_value()) { + const auto& src = src_opt.value(); + + TORCH_CHECK( + ensure_nonempty_dim(src.dim()) == ensure_nonempty_dim(index.dim()), + "Index tensor must have the same number of dimensions as src tensor" + ); + + TORCH_CHECK(!is_wrong_shape, + "Expected index ", index.sizes(), + " to be no larger than self ", self.sizes(), + " apart from dimension ", dim, + " and to be no larger size than src ", src.sizes() + ); + } + else { + TORCH_CHECK(!is_wrong_shape, + "Expected index ", index.sizes(), + " to be no larger than self ", self.sizes(), + " apart from dimension ", dim + ); + } +} + +} // anonymous namespace + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/SegmentReduce.h b/phivenv/Lib/site-packages/torch/include/ATen/native/SegmentReduce.h new file mode 100644 index 0000000000000000000000000000000000000000..421fe085d290d853cf8ccbdb5e59e97f5033381a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/SegmentReduce.h @@ -0,0 +1,50 @@ +#pragma once + +#include +#include +#include +#include + +namespace at { +class Tensor; + +namespace native { + +using segment_reduce_lengths_fn = Tensor (*)( + ReductionType, + const Tensor&, + const Tensor&, + int64_t, + const std::optional&); +DECLARE_DISPATCH(segment_reduce_lengths_fn, _segment_reduce_lengths_stub) + +using segment_reduce_offsets_fn = Tensor (*)( + ReductionType, + const Tensor&, + const Tensor&, + int64_t, + const std::optional&); +DECLARE_DISPATCH(segment_reduce_offsets_fn, _segment_reduce_offsets_stub) + +using segment_reduce_lengths_backward_fn = Tensor (*)( + const Tensor&, + const Tensor&, + const Tensor&, + ReductionType, + const Tensor&, + int64_t, + const std::optional&); +DECLARE_DISPATCH(segment_reduce_lengths_backward_fn, _segment_reduce_lengths_backward_stub) + +using segment_reduce_offsets_backward_fn = Tensor (*)( + const Tensor&, + const Tensor&, + const Tensor&, + ReductionType, + const Tensor&, + int64_t, + const std::optional&); +DECLARE_DISPATCH(segment_reduce_offsets_backward_fn, _segment_reduce_offsets_backward_stub) + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/SharedReduceOps.h b/phivenv/Lib/site-packages/torch/include/ATen/native/SharedReduceOps.h new file mode 100644 index 0000000000000000000000000000000000000000..1815ccf8cb20f11248cb1f97feed215e84ce2c04 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/SharedReduceOps.h @@ -0,0 +1,545 @@ +#pragma once +// Please note that this file is +// used across both CPU and GPU. + +#include +#include +#include +#include +#include +#include +#if defined(__CUDACC__) +#include +#include +#elif defined(__HIPCC__) +#include +#include +#endif +#if defined(__CUDACC__) || defined(__HIPCC__) +#include +#else +#include +#define device_sqrt std::sqrt +#endif +#if defined(__CUDACC__) || defined(__HIPCC__) +template +inline C10_DEVICE scalar_t max_propagate_nan(scalar_t a, scalar_t b) { +#if defined(__HIPCC__) + // TODO: remove this special case for HIP when issue is fixed: + // https://github.com/ROCm/hip/issues/2209 + scalar_t max = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::max(a, b)); +#else + scalar_t max = at::_isnan(b) ? b : std::max(a, b); +#endif + return max; +} +template +inline C10_DEVICE scalar_t min_propagate_nan(scalar_t a, scalar_t b) { +#if defined(__HIPCC__) + // TODO: remove this special case for HIP when issue is fixed: + // https://github.com/ROCm/hip/issues/2209 + scalar_t min = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::min(a, b)); +#else + scalar_t min = at::_isnan(b) ? b : std::min(a, b); +#endif + return min; +} +#define MAX(X, Y) max_propagate_nan(X,Y) +#define MIN(X, Y) min_propagate_nan(X,Y) +#else +#include +#define MAX(X, Y) max_impl(X,Y) +#define MIN(X, Y) min_impl(X,Y) +#endif + +// ROCM hcc doesn't work well with using std:: in kernel functions +#if defined(__CUDA_ARCH__) +#include +#define compat_pow c10::cuda::compat::pow +#elif defined(__HIPCC__) +#include +#define compat_pow c10::hip::compat::pow +#else +#define compat_pow std::pow +#endif + +namespace at::native { + +namespace detail { + +#if defined(__CUDACC__) || defined(__HIPCC__) +template using pair = thrust::pair; +#else +template using pair = std::pair; +#endif + +} // namespace detail + +template +struct WelfordData { + scalar_t mean; + scalar_t m2; + index_t n; + scalar_t nf; + + C10_HOST_DEVICE WelfordData() : mean(0), m2(0), n(0), nf(0) {} + + C10_HOST_DEVICE WelfordData( + scalar_t mean, + scalar_t m2, + index_t n, + scalar_t nf) + : mean(mean), m2(m2), n(n), nf(nf) {} +}; + + +template +struct WelfordOps { + acc_scalar_t correction; + bool take_sqrt; + public: + using acc_t = WelfordData; + inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, index_t /*idx*/) const { + // We accumulate n in index_t to avoid cumulative rounding error, but still + // need nf for use in combine where int32 may overflow. + index_t new_n = acc.n + 1; + acc_scalar_t new_nf = static_cast(new_n); + acc_scalar_t delta = data - acc.mean; + acc_scalar_t new_mean = acc.mean + delta / new_nf; + acc_scalar_t new_delta = data - new_mean; + return { + new_mean, + acc.m2 + delta * new_delta, + new_n, + new_nf, + }; + } + inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { + if (a.nf == 0) { + return b; + } + if (b.nf == 0) { + return a; + } + acc_scalar_t delta = b.mean - a.mean; + acc_scalar_t new_count = a.nf + b.nf; + acc_scalar_t nb_over_n = b.nf / new_count; + return { + a.mean + delta * nb_over_n, + a.m2 + b.m2 + delta * delta * a.nf * nb_over_n, + // setting acc.n as -1 since acc.n might not be able to represent the count + // correctly within its range, setting it to -1 to avoid confusion + -1, + new_count + }; + } + inline C10_DEVICE res_t project(acc_t acc) const __ubsan_ignore_float_divide_by_zero__ { + const auto mean = static_cast(acc.mean); + const auto divisor = acc.nf > correction ? acc.nf - correction : 0; + const auto var = acc.m2 / divisor; + res_t results(take_sqrt ? device_sqrt(var) : var, mean); + return results; + } + + static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { + return acc; + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + inline __device__ acc_t warp_shfl_down(acc_t acc, int offset) const { + return { + WARP_SHFL_DOWN(acc.mean, offset) + , WARP_SHFL_DOWN(acc.m2, offset) + , WARP_SHFL_DOWN(acc.n, offset) + , WARP_SHFL_DOWN(acc.nf, offset) + }; + } +#endif + C10_HOST_DEVICE WelfordOps(acc_scalar_t correction, bool take_sqrt) + : correction(correction), take_sqrt(take_sqrt) {} +}; + +template +struct MeanOps { + factor_t factor; + + inline C10_DEVICE acc_t reduce(acc_t a, scalar_t b, int64_t /*idx*/) const { + return combine(a, static_cast(b)); + } + + inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { + return a + b; + } + + inline C10_DEVICE out_t project(acc_t a) const { + return a * factor; + } + + static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { + return acc; + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const { + return WARP_SHFL_DOWN(data, offset); + } +#endif + + MeanOps(factor_t factor): factor(factor) { + } +}; + +// This accumulator template is used to calculate the minimum absolute value of +// a set of numbers. +// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated +// value. These types differ for complex number input support. +template +struct AbsMinOps { + + inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const { + return MIN(acc, static_cast(std::abs(at::opmath_type(data)))); + } + + inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { + return MIN(a, b); + } + + inline C10_DEVICE out_t project(acc_t a) const { + return a; + } + + static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { + return acc; + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const { + return WARP_SHFL_DOWN(acc, offset); + } +#endif +}; + +// This accumulator template is used to calculate the maximum absolute value of +// a set of numbers. +// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated +// value. These types differ for complex number input support. +template +struct AbsMaxOps { + inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const { + return MAX(acc, static_cast(std::abs(at::opmath_type(data)))); + } + + inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { + return MAX(a, b); + } + + inline C10_DEVICE out_t project(acc_t a) const { + return a; + } + + static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { + return acc; + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const { + return WARP_SHFL_DOWN(acc, offset); + } +#endif +}; + +// This accumulator template is used to calculate the norm of the absolute value +// of a set of numbers. +// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated +// value. These types differ for complex number input support. +template +struct NormOps { + acc_t norm_; + + inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const { + return acc + compat_pow(static_cast(std::abs(at::opmath_type(data))), norm_); + } + + inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { + return a + b; + } + + inline C10_DEVICE out_t project(acc_t a) const { + return compat_pow(a, static_cast(1.0) / norm_); + } + + static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { + return acc; + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const { + return WARP_SHFL_DOWN(acc, offset); + } +#endif + + NormOps(acc_t norm_): norm_(norm_) { + } +}; + +// This accumulator template is used to calculate the order zero norm of the +// absolute value of a set of numbers. +// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated +// value. These types differ for complex number input support. +template +struct NormZeroOps { + inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const { + return acc + (data == static_cast(0) ? static_cast(0) : static_cast(1)); + } + + inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { + return a + b; + } + + inline C10_DEVICE out_t project(acc_t a) const { + return a; + } + + static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { + return acc; + } + + +#if defined(__CUDACC__) || defined(__HIPCC__) + inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const { + return WARP_SHFL_DOWN(acc, offset); + } +#endif +}; + +// This accumulator template is used to calculate the order one norm of the +// absolute value of a set of numbers. +// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated +// value. These types differ for complex number input support. +template +struct NormOneOps { + inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const { + return acc + static_cast(std::abs(at::opmath_type(data))); + } + + inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { + return a + b; + } + + inline C10_DEVICE out_t project(acc_t a) const { + return a; + } + + static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { + return acc; + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const { + return WARP_SHFL_DOWN(acc, offset); + } +#endif +}; + + +template +struct AbsSwitch {}; + +template +inline C10_DEVICE acc_t abs_if_complex(scalar_t data, AbsSwitch) { + return static_cast(data); +} + +template +inline C10_DEVICE acc_t abs_if_complex(std::complex data, AbsSwitch) { + return static_cast(std::abs(data)); +} + +template +inline C10_DEVICE acc_t abs_if_complex(c10::complex data, AbsSwitch) { + return static_cast(std::abs(at::opmath_type>(data))); +} + +// This accumulator template is used to calculate the order two norm of the +// absolute value of a set of numbers. +// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated +// value. These types differ for complex number input support. +template +struct NormTwoOps { + inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const { + acc_t data_ = abs_if_complex(data, AbsSwitch()); + return acc + data_ * data_; + } + + inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { + return a + b; + } + + inline C10_DEVICE out_t project(acc_t a) const { + return device_sqrt(a); + } + + static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { + return acc; + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const { + return WARP_SHFL_DOWN(acc, offset); + } +#endif +}; + +template +struct NanSumOps { + inline C10_DEVICE acc_t reduce(acc_t a, data_t b, int64_t /*idx*/) const { + return a + (at::_isnan(b) ? acc_t{0.} : acc_t{b}); + } + + inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { + return a + b; + } + + inline C10_DEVICE data_t project(acc_t a) const { + return data_t{a}; + } + + static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { + return acc; + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const { + return WARP_SHFL_DOWN(data, offset); + } +#endif +}; + +namespace detail { + +template +struct LessOrNan { + C10_DEVICE bool operator () (scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) const { + // If (a == b), then choose the one with lower idx, else min(a, b) + if (at::_isnan(a)) { + if (at::_isnan(b)) { + return idx_a < idx_b; + } + return true; + } + return (a == b) ? idx_a < idx_b : (a < b); + } +}; + +template +struct GreaterOrNan { + C10_DEVICE bool operator () (scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) const { + // If (a == b), then choose the one with lower idx, else max(a, b) + if (at::_isnan(a)) { + if (at::_isnan(b)) { + return idx_a < idx_b; + } + return true; + } + return (a == b) ? idx_a < idx_b : (a > b); + } +}; + +template +struct MinMaxReductionOps { + using scalar_t = typename binary_function_traits::arg1_t; + using index_t = int64_t; + using arg_t = detail::pair; + + static C10_DEVICE arg_t project(arg_t arg) { + return arg; + } + + static C10_DEVICE arg_t reduce(arg_t arg, scalar_t val, int64_t idx) { + return comp_t{}(arg.first, val, arg.second, idx) ? arg : arg_t(val, idx); + } + + static C10_DEVICE arg_t combine(arg_t a, arg_t b) { + return comp_t{}(a.first, b.first, a.second, b.second) ? a : b; + } + + static C10_DEVICE arg_t translate_idx(arg_t a, int64_t base_idx) { + return {a.first, a.second + base_idx}; + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + static C10_DEVICE arg_t warp_shfl_down(arg_t arg, int offset) { + return arg_t(WARP_SHFL_DOWN(arg.first, offset), + WARP_SHFL_DOWN(arg.second, offset)); + } +#endif +}; + +template +struct ArgReductionOps : public MinMaxReductionOps { + using typename MinMaxReductionOps::scalar_t; + using typename MinMaxReductionOps::index_t; + using typename MinMaxReductionOps::arg_t; + + static C10_DEVICE index_t project(arg_t arg) { + return arg.second; + } +}; + +} // namespace detail + +template +struct ArgMaxOps : + public detail::ArgReductionOps> { +}; + +template +struct ArgMinOps : + public detail::ArgReductionOps> { +}; + +template +struct MinOps : + public detail::MinMaxReductionOps> { +}; + +template +struct MaxOps : + public detail::MinMaxReductionOps> { +}; + +template +struct MinMaxOps { + using acc_t = detail::pair; + inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, index_t /*idx*/) const { + return combine(acc, {data, data}); + } + + inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { + auto min_val = (at::_isnan(a.first) || a.first < b.first) ? a.first : b.first; + auto max_val = (at::_isnan(a.second) || a.second > b.second) ? a.second : b.second; + + return {min_val, max_val}; + } + + inline C10_DEVICE acc_t project(acc_t acc) const { + return acc; + } + + static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { + return acc; + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const { + return { + WARP_SHFL_DOWN(acc.first, offset), WARP_SHFL_DOWN(acc.second, offset) + }; + } +#endif +}; + +} // namespace at::native + +#undef MAX +#undef MIN diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/SobolEngineOpsUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/SobolEngineOpsUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..3fbc4e311dea531b7d0a2501dad0685671f8a1b2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/SobolEngineOpsUtils.h @@ -0,0 +1,55 @@ +/// This file contains some tensor-agnostic operations to be used in the +/// core functions of the `SobolEngine` +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#include +#endif + +namespace at::native::sobol_utils { + +/// Function to return the minimum of number of bits to represent the integer `n` +inline int64_t bit_length(const int64_t n) { + int64_t nbits, nloc; + for (nloc = n, nbits = 0; nloc > 0; nloc /= 2, nbits++); + return nbits; +} + +/// Function to get the position of the rightmost zero in the bit representation of an integer +/// This value is the zero-indexed position +inline int64_t rightmost_zero(const int64_t n) { + int64_t z, i; + for (z = n, i = 0; z % 2 == 1; z /= 2, i++); + return i; +} + +/// Function to get a subsequence of bits in the representation of an integer starting from +/// `pos` and of length `length` +inline int64_t bitsubseq(const int64_t n, const int64_t pos, const int64_t length) { + return (n >> pos) & ((1 << length) - 1); +} + +/// Function to perform the inner product between a batched square matrix and a power of 2 vector +inline at::Tensor cdot_pow2(const at::Tensor& bmat) { + at::Tensor inter = at::arange(bmat.size(-1) - 1, -1, -1, bmat.options()); + inter = at::pow(2, inter).expand_as(bmat); + return at::mul(inter, bmat).sum(-1); +} + +/// All definitions below this point are data. These are constant, and should not be modified +/// without notice + +constexpr int64_t MAXDIM = 21201; +constexpr int64_t MAXDEG = 18; +constexpr int64_t MAXBIT = 30; +constexpr int64_t LARGEST_NUMBER = 1 << MAXBIT; +constexpr float RECIPD = 1.0 / LARGEST_NUMBER; + +extern const int64_t poly[MAXDIM]; +extern const int64_t initsobolstate[MAXDIM][MAXDEG]; + +} // namespace at::native::sobol_utils diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/Sorting.h b/phivenv/Lib/site-packages/torch/include/ATen/native/Sorting.h new file mode 100644 index 0000000000000000000000000000000000000000..938f6603b9c03e8fdc144aa4fb11152e562ca8f6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/Sorting.h @@ -0,0 +1,28 @@ +#pragma once + +#include +#include + +namespace at { +class TensorBase; +} + +namespace at::native { + +enum class QUANTILE_INTERPOLATION_MODE : uint8_t { + LINEAR, + LOWER, + HIGHER, + MIDPOINT, + NEAREST +}; + +using sort_fn = void(*)(const TensorBase&, const TensorBase&, const TensorBase&, int64_t, bool, bool); +using topk_fn = void(*)(const TensorBase&, const TensorBase&, const TensorBase&, int64_t, int64_t, bool, bool); + +DECLARE_DISPATCH(sort_fn, sort_stub) +DECLARE_DISPATCH(topk_fn, topk_stub) + +void _fill_indices(const TensorBase &indices, int64_t dim); + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/SortingUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/SortingUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..cb9e3d37c6768e08cc091d2ce8c7efed04a8a2cc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/SortingUtils.h @@ -0,0 +1,88 @@ +#pragma once + +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +namespace at::native { + +// ensure we get good values and indices for kthvalue, mode +// this will always be with the reducing dim as 1-d +inline void _reduction_with_indices_allocate_or_resize_output( + Tensor& values, + Tensor& indices, + const Tensor& self, + int64_t dim_, + bool keepdim) { + int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true); + auto result_sizes = self.sizes().vec(); + if (!result_sizes.empty()) { + result_sizes[dim] = 1; + } + if (values.defined()) { + TORCH_CHECK( + self.options().type_equal(values.options()), + "output values must be of same type as input"); + if (!keepdim && values.dim() == self.dim() - 1) { + // unsqueeze to preserve passed in noncontiguous tensor in resize + values.unsqueeze_(dim); + } + resize_output(values, result_sizes); + } else { + values = at::empty(result_sizes, self.options()); + } + if (indices.defined()) { + TORCH_CHECK( + indices.dtype() == kLong, "output indices must be of scalar type Long"); + TORCH_CHECK( + indices.device() == self.device(), + "output indices must be on same device as input"); + if (!keepdim && indices.dim() == self.dim() - 1) { + // unsqueeze to preserve passed in noncontiguous tensor in resize + indices.unsqueeze_(dim); + } + resize_output(indices, result_sizes); + } else { + indices = at::empty(result_sizes, self.options().dtype(kLong)); + } +} + +// ensure we get good values and indices for topk +inline void _allocate_or_resize_output_with_indices( + Tensor& values, + Tensor& indices, + const Tensor& self, + int64_t dim_, + int64_t k) { + int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true); + auto result_sizes = self.sizes().vec(); + if (!result_sizes.empty()) { + result_sizes[dim] = k; + } + if (values.defined()) { + TORCH_CHECK( + self.options().type_equal(values.options()), + "output values must be of same type as input"); + values.resize_(result_sizes); + } else { + values = at::empty(result_sizes, self.options()); + } + if (indices.defined()) { + TORCH_CHECK( + indices.dtype() == kLong, "output indices must be of scalar type Long"); + TORCH_CHECK( + indices.device() == self.device(), + "output indices must be on same device as input"); + indices.resize_(result_sizes); + } else { + indices = at::empty(result_sizes, self.options().dtype(kLong)); + } +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/SparseTensorUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/SparseTensorUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..f44d51b352eebed86c0743eca4446842a1b65ca3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/SparseTensorUtils.h @@ -0,0 +1,190 @@ +#pragma once + +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#endif + +namespace at::sparse { + +// Just for documentary purposes +using SparseTensor = Tensor; +using SparseType = Type; + +// This is an internal utility function for getting at the SparseTensorImpl, +// so that we can write sparse tensor specific accessors for special fields +// in SparseTensor. You should only use this for writing low level +// setters/getters for SparseTensorImpl fields; otherwise, you should use +// the low level setters/getters that were implemented using this. +// +// This may be called repeatedly, so make sure it's pretty cheap. +inline SparseTensorImpl* get_sparse_impl(const SparseTensor& self) { + TORCH_INTERNAL_ASSERT( + self.is_sparse(), "_internal_get_SparseTensorImpl: not a sparse tensor"); + return static_cast(self.unsafeGetTensorImpl()); +} + +// Takes indices and values and directly puts them into the sparse tensor, no +// copy. This used to be called THSTensor_(_move) +inline void alias_into_sparse( + const SparseTensor& self, + const Tensor& indices, + const Tensor& values) { + get_sparse_impl(self)->set_indices_and_values_unsafe(indices, values); +} + +// Take indices and values and makes a (data) copy of them to put into the +// sparse indices/values. This used to be called THSTensor_(_set) +inline void copy_into_sparse( + const SparseTensor& self, + const Tensor& indices, + const Tensor& values, + bool non_blocking) { + alias_into_sparse( + self, + indices.to(self._indices().options(), non_blocking, /*copy=*/true), + values.to(self._values().options(), non_blocking, /*copy=*/true)); +} + +// TODO: put this into the public API +inline bool is_same_tensor(const Tensor& lhs, const Tensor& rhs) { + return lhs.unsafeGetTensorImpl() == rhs.unsafeGetTensorImpl(); +} + +inline bool is_same_density(const SparseTensor& self, const SparseTensor& src) { + return self.sparse_dim() == src.sparse_dim() && + self.dense_dim() == src.dense_dim(); +} + +// Give us a new values tensor, with the same dimensionality +// as 'values' but with a new number of non-zero elements. +// TODO: Expose this for real in ATen, some day? +// NB: Doesn't preserve data. +inline Tensor new_values_with_size_of(const Tensor& values, int64_t nnz) { + std::vector size = values.sizes().vec(); + size[0] = nnz; + return at::empty(size, values.options()); +} + +// NOTE [ Flatten Sparse Indices ] +// This helper function flattens a sparse indices tensor (a Tensor) into a 1D +// indices tensor. E.g., +// input = [[2, 4, 0], +// [3, 1, 10]] +// full_size = [2, 12] +// output = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 10 ] = [27, 49, 10] +// +// In other words, assuming that each `indices[i, :]` is a valid index to a +// tensor `t` of shape `full_size`. This returns the corresponding indices to +// the flattened tensor `t.reshape( prod(full_size[:indices.size(0)]), -1 )`. +// if forceClone is true, the result will forced to be a clone of self. +// if force_clone is true, the result will forced to be a clone of self. +TORCH_API Tensor flatten_indices( + const Tensor& indices, + IntArrayRef full_size, + bool force_clone = false); + +// Flatten sparse tensor's indices from nD to 1D, similar to NOTE [ Flatten +// Sparse Indices ], except this one allows partial flatten: only flatten on +// specified dims. Note that the flatten indices might be uncoalesced if +// dims_to_flatten.size() < sparse_dim. Also if input indices is already +// coalesced, the flattened indices will also be sorted. +// +// args: +// indices: sparse tensor indices +// sizes: sparse tensor sizes +// dims_to_flatten: a list of dim index to flatten +// +// Ex1: +// indices = [[2, 4, 0], +// [3, 1, 3]] +// sizes = [2, 12] +// dims_to_flatten = [0, 1] +// new_indices = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 3 ] = [27, 49, 3] +// +// Ex2: +// dims_to_flatten = [1] +// new_indices = [ 3, 1, 3 ] # uncoalesced +TORCH_API Tensor flatten_indices_by_dims( + const Tensor& indices, + const IntArrayRef& sizes, + const IntArrayRef& dims_to_flatten); + +// Find the CSR representation for a row `indices` from the COO format +TORCH_API Tensor coo_to_csr(const int64_t* indices, int64_t dim, int64_t nnz); + +TORCH_API Tensor zeros_like_with_indices(const Tensor& t); + +template +class TensorGeometryHolder { + using geometry_holder_t = std::array; + + public: + explicit TensorGeometryHolder( + IntArrayRef sizes, + IntArrayRef strides, + TensorOptions options = {}) { + std::copy(sizes.begin(), sizes.end(), t_sizes.begin()); + std::copy(strides.begin(), strides.end(), t_strides.begin()); + } + + explicit TensorGeometryHolder(const Tensor& t) + : TensorGeometryHolder(t.sizes(), t.strides()) {} + + auto operator*() const { + return std::make_tuple(t_sizes, t_strides); + } + + private: + geometry_holder_t t_sizes; + geometry_holder_t t_strides; +}; + +template <> +class TensorGeometryHolder<0> { + using geometry_holder_t = Tensor; + + public: + explicit TensorGeometryHolder( + IntArrayRef sizes, + IntArrayRef strides, + TensorOptions options) { + const int64_t t_ndims = sizes.size(); + const auto cpu_options = TensorOptions(options).dtype(kLong).device(kCPU); + Tensor t_sizes_and_strides_cpu = at::empty({2, t_ndims}, cpu_options); + t_sizes_and_strides_cpu.select(0, 0).copy_(at::tensor(sizes, cpu_options)); + t_sizes_and_strides_cpu.select(0, 1).copy_( + at::tensor(strides, cpu_options)); + const Tensor t_sizes_and_strides = + t_sizes_and_strides_cpu.to(options.device()); + t_sizes = t_sizes_and_strides.select(0, 0); + t_strides = t_sizes_and_strides.select(0, 1); + } + + explicit TensorGeometryHolder(const Tensor& t) + : TensorGeometryHolder(t.sizes(), t.strides(), t.options()) {} + + auto operator*() const { + return std::make_tuple( + t_sizes.template data_ptr(), + t_strides.template data_ptr()); + } + + private: + geometry_holder_t t_sizes; + geometry_holder_t t_strides; +}; + +// Return all indices of a tensor with the given shape. +// +// full_coo_indices(shape) is equivalent to +// torch.ones(shape).nonzero().transpose(-2, -1) but much faster. +TORCH_API Tensor full_coo_indices(IntArrayRef sizes, TensorOptions options); + +} // namespace at::sparse diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/SpectralOpsUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/SpectralOpsUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..3273cbe5a34be9a71fb0c8b082c85847df0ca427 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/SpectralOpsUtils.h @@ -0,0 +1,84 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { + +// Normalization types used in _fft_with_size +enum class fft_norm_mode { + none, // No normalization + by_root_n, // Divide by sqrt(signal_size) + by_n, // Divide by signal_size +}; + +// NOTE [ Fourier Transform Conjugate Symmetry ] +// +// Real-to-complex Fourier transform satisfies the conjugate symmetry. That is, +// assuming X is the transformed K-dimensionsal signal, we have +// +// X[i_1, ..., i_K] = X[j_i, ..., j_K]*, +// +// where j_k = (N_k - i_k) mod N_k, N_k being the signal size at dim k, +// * is the conjugate operator. +// +// Therefore, in such cases, FFT libraries return only roughly half of the +// values to avoid redundancy: +// +// X[:, :, ..., :floor(N / 2) + 1] +// +// This is also the assumption in cuFFT and MKL. In ATen SpectralOps, such +// halved signal will also be returned by default (flag onesided=True). +// The following infer_ft_real_to_complex_onesided_size function calculates the +// onesided size from the twosided size. +// +// Note that this loses some information about the size of signal at last +// dimension. E.g., both 11 and 10 maps to 6. Hence, the following +// infer_ft_complex_to_real_onesided_size function takes in optional parameter +// to infer the twosided size from given onesided size. +// +// cuFFT doc: http://docs.nvidia.com/cuda/cufft/index.html#multi-dimensional +// MKL doc: https://software.intel.com/en-us/mkl-developer-reference-c-dfti-complex-storage-dfti-real-storage-dfti-conjugate-even-storage#CONJUGATE_EVEN_STORAGE + +inline int64_t infer_ft_real_to_complex_onesided_size(int64_t real_size) { + return (real_size / 2) + 1; +} + +inline int64_t infer_ft_complex_to_real_onesided_size(int64_t complex_size, + int64_t expected_size=-1) { + int64_t base = (complex_size - 1) * 2; + if (expected_size < 0) { + return base + 1; + } else if (base == expected_size) { + return base; + } else if (base + 1 == expected_size) { + return base + 1; + } else { + std::ostringstream ss; + ss << "expected real signal size " << expected_size << " is incompatible " + << "with onesided complex frequency size " << complex_size; + TORCH_CHECK(false, ss.str()); + } +} + +using fft_fill_with_conjugate_symmetry_fn = + void (*)(ScalarType dtype, IntArrayRef mirror_dims, IntArrayRef half_sizes, + IntArrayRef in_strides, const void* in_data, + IntArrayRef out_strides, void* out_data); +DECLARE_DISPATCH(fft_fill_with_conjugate_symmetry_fn, fft_fill_with_conjugate_symmetry_stub) + +// In real-to-complex transform, cuFFT and MKL only fill half of the values +// due to conjugate symmetry. This function fills in the other half of the full +// fft by using the Hermitian symmetry in the signal. +// self should be the shape of the full signal and dims.back() should be the +// one-sided dimension. +// See NOTE [ Fourier Transform Conjugate Symmetry ] +TORCH_API void _fft_fill_with_conjugate_symmetry_(const Tensor& self, IntArrayRef dims); + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/StridedRandomAccessor.h b/phivenv/Lib/site-packages/torch/include/ATen/native/StridedRandomAccessor.h new file mode 100644 index 0000000000000000000000000000000000000000..5ee7da926bae6bdf5c8c9e1149152f6c75f263aa --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/StridedRandomAccessor.h @@ -0,0 +1,301 @@ +#pragma once + +namespace at::native { + +// (Const)StridedRandomAccessor is a +// (const) random access iterator defined over +// a strided array. + +// The traits below are to introduce __restrict__ +// modifier on different platforms. + +template +struct DefaultPtrTraits { + using PtrType = T*; +}; + +#if (defined(_WIN32) || defined(_WIN64)) +#define RESTRICT __restrict +#else +#define RESTRICT __restrict__ +#endif + +template +struct RestrictPtrTraits { + using PtrType = T* RESTRICT; +}; + +template < + typename T, + typename index_t = int64_t, + template class PtrTraits = DefaultPtrTraits +> +class ConstStridedRandomAccessor { +public: + using difference_type = index_t; + using value_type = const T; + using pointer = const typename PtrTraits::PtrType; + using reference = const value_type&; + using iterator_category = std::random_access_iterator_tag; + + using PtrType = typename PtrTraits::PtrType; + using index_type = index_t; + + // Constructors { + C10_HOST_DEVICE + ConstStridedRandomAccessor(PtrType ptr, index_t stride) + : ptr{ptr}, stride{stride} + {} + + C10_HOST_DEVICE + explicit ConstStridedRandomAccessor(PtrType ptr) + : ptr{ptr}, stride{static_cast(1)} + {} + + C10_HOST_DEVICE + ConstStridedRandomAccessor() + : ptr{nullptr}, stride{static_cast(1)} + {} + // } + + // Pointer-like operations { + C10_HOST_DEVICE + reference operator*() const { + return *ptr; + } + + C10_HOST_DEVICE + const value_type* operator->() const { + return reinterpret_cast(ptr); + } + + C10_HOST_DEVICE + reference operator[](index_t idx) const { + return ptr[idx * stride]; + } + // } + + // Prefix/postfix increment/decrement { + C10_HOST_DEVICE + ConstStridedRandomAccessor& operator++() { + ptr += stride; + return *this; + } + + C10_HOST_DEVICE + ConstStridedRandomAccessor operator++(int) { + ConstStridedRandomAccessor copy(*this); + ++*this; + return copy; + } + + C10_HOST_DEVICE + ConstStridedRandomAccessor& operator--() { + ptr -= stride; + return *this; + } + + C10_HOST_DEVICE + ConstStridedRandomAccessor operator--(int) { + ConstStridedRandomAccessor copy(*this); + --*this; + return copy; + } + // } + + // Arithmetic operations { + C10_HOST_DEVICE + ConstStridedRandomAccessor& operator+=(index_t offset) { + ptr += offset * stride; + return *this; + } + + C10_HOST_DEVICE + ConstStridedRandomAccessor operator+(index_t offset) const { + return ConstStridedRandomAccessor(ptr + offset * stride, stride); + } + + C10_HOST_DEVICE + friend ConstStridedRandomAccessor operator+( + index_t offset, + const ConstStridedRandomAccessor& accessor + ) { + return accessor + offset; + } + + C10_HOST_DEVICE + ConstStridedRandomAccessor& operator-=(index_t offset) { + ptr -= offset * stride; + return *this; + } + + C10_HOST_DEVICE + ConstStridedRandomAccessor operator-(index_t offset) const { + return ConstStridedRandomAccessor(ptr - offset * stride, stride); + } + + // Note that this operator is well-defined when `this` and `other` + // represent the same sequences, i.e. when + // 1. this.stride == other.stride, + // 2. |other - this| / this.stride is an Integer. + C10_HOST_DEVICE + difference_type operator-(const ConstStridedRandomAccessor& other) const { + return (ptr - other.ptr) / stride; + } + // } + + // Comparison operators { + C10_HOST_DEVICE + bool operator==(const ConstStridedRandomAccessor& other) const { + return (ptr == other.ptr) && (stride == other.stride); + } + + C10_HOST_DEVICE + bool operator!=(const ConstStridedRandomAccessor& other) const { + return !(*this == other); + } + + C10_HOST_DEVICE + bool operator<(const ConstStridedRandomAccessor& other) const { + return ptr < other.ptr; + } + + C10_HOST_DEVICE + bool operator<=(const ConstStridedRandomAccessor& other) const { + return (*this < other) || (*this == other); + } + + C10_HOST_DEVICE + bool operator>(const ConstStridedRandomAccessor& other) const { + return !(*this <= other); + } + + C10_HOST_DEVICE + bool operator>=(const ConstStridedRandomAccessor& other) const { + return !(*this < other); + } + // } + +protected: + PtrType ptr; + index_t stride; +}; + +template < + typename T, + typename index_t = int64_t, + template class PtrTraits = DefaultPtrTraits +> +class StridedRandomAccessor + : public ConstStridedRandomAccessor { +public: + using difference_type = index_t; + using value_type = T; + using pointer = typename PtrTraits::PtrType; + using reference = value_type&; + + using BaseType = ConstStridedRandomAccessor; + using PtrType = typename PtrTraits::PtrType; + + // Constructors { + C10_HOST_DEVICE + StridedRandomAccessor(PtrType ptr, index_t stride) + : BaseType(ptr, stride) + {} + + C10_HOST_DEVICE + explicit StridedRandomAccessor(PtrType ptr) + : BaseType(ptr) + {} + + C10_HOST_DEVICE + StridedRandomAccessor() + : BaseType() + {} + // } + + // Pointer-like operations { + C10_HOST_DEVICE + reference operator*() const { + return *this->ptr; + } + + C10_HOST_DEVICE + value_type* operator->() const { + return reinterpret_cast(this->ptr); + } + + C10_HOST_DEVICE + reference operator[](index_t idx) const { + return this->ptr[idx * this->stride]; + } + // } + + // Prefix/postfix increment/decrement { + C10_HOST_DEVICE + StridedRandomAccessor& operator++() { + this->ptr += this->stride; + return *this; + } + + C10_HOST_DEVICE + StridedRandomAccessor operator++(int) { + StridedRandomAccessor copy(*this); + ++*this; + return copy; + } + + C10_HOST_DEVICE + StridedRandomAccessor& operator--() { + this->ptr -= this->stride; + return *this; + } + + C10_HOST_DEVICE + StridedRandomAccessor operator--(int) { + StridedRandomAccessor copy(*this); + --*this; + return copy; + } + // } + + // Arithmetic operations { + C10_HOST_DEVICE + StridedRandomAccessor& operator+=(index_t offset) { + this->ptr += offset * this->stride; + return *this; + } + + C10_HOST_DEVICE + StridedRandomAccessor operator+(index_t offset) const { + return StridedRandomAccessor(this->ptr + offset * this->stride, this->stride); + } + + C10_HOST_DEVICE + friend StridedRandomAccessor operator+( + index_t offset, + const StridedRandomAccessor& accessor + ) { + return accessor + offset; + } + + C10_HOST_DEVICE + StridedRandomAccessor& operator-=(index_t offset) { + this->ptr -= offset * this->stride; + return *this; + } + + C10_HOST_DEVICE + StridedRandomAccessor operator-(index_t offset) const { + return StridedRandomAccessor(this->ptr - offset * this->stride, this->stride); + } + + // Note that here we call BaseType::operator- version + C10_HOST_DEVICE + difference_type operator-(const BaseType& other) const { + return (static_cast(*this) - other); + } + // } +}; + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/TensorAdvancedIndexing.h b/phivenv/Lib/site-packages/torch/include/ATen/native/TensorAdvancedIndexing.h new file mode 100644 index 0000000000000000000000000000000000000000..611823a6ac1d16f611102cdd1cb661af7141cf0c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/TensorAdvancedIndexing.h @@ -0,0 +1,102 @@ +#pragma once + +// Indexing tensors by tensors + +#include +#include +#include +#include + +namespace at { +struct TensorIterator; +} + +namespace at::native { + +using index_put_with_sort_fn = void (*)( + Tensor&, + const c10::List>&, + const Tensor&, + bool accumulate, + bool unsafe); +using index_put_with_sort_quantized_fn = void (*)( + Tensor& self, + const c10::List>& indices, + const Tensor& value, + double scale, + int zero_point, + bool unsafe); +using gather_fn = void (*)( + const Tensor& result, + const Tensor& self, + int64_t dim, + const Tensor& index); +using scatter_fn = void (*)( + const Tensor& self, + int64_t dim, + const Tensor& index, + const Tensor& src); +using scatter_fill_fn = void (*)( + const Tensor& self, + int64_t dim, + const Tensor& index, + const Scalar& src); +using scatter_add_fn = void (*)( + const Tensor& self, + int64_t dim, + const Tensor& index, + const Tensor& src); +using scatter_reduce_fn = void (*)( + const Tensor& self, + const int64_t dim, + const Tensor& index, + const Tensor& src, + const ReductionType& reduce); +using scatter_scalar_reduce_fn = void (*)( + const Tensor& self, + const int64_t dim, + const Tensor& index, + const Scalar& value, + const ReductionType& reduce); +using scatter_reduce_two_fn = void (*)( + const Tensor& self, + const int64_t dim, + const Tensor& index, + const Tensor& src, + const ReductionType& reduce); + +DECLARE_DISPATCH(index_put_with_sort_fn, index_put_with_sort_stub) +DECLARE_DISPATCH( + index_put_with_sort_quantized_fn, + index_put_with_sort_quantized_stub) +DECLARE_DISPATCH(gather_fn, gather_stub) +DECLARE_DISPATCH(scatter_fn, scatter_stub) +DECLARE_DISPATCH(scatter_fill_fn, scatter_fill_stub) +DECLARE_DISPATCH(scatter_add_fn, scatter_add_stub) +DECLARE_DISPATCH(scatter_reduce_fn, scatter_reduce_stub) +DECLARE_DISPATCH(scatter_scalar_reduce_fn, scatter_scalar_reduce_stub) +DECLARE_DISPATCH(scatter_reduce_two_fn, scatter_reduce_two_stub) + +TORCH_API Tensor& index_out( + Tensor& result, + const Tensor& self, + const c10::List>& indices); + +using scatter_add_expanded_index_fn = + void (*)(const Tensor&, const Tensor&, const Tensor&); +using scatter_reduce_expanded_index_fn = void (*)( + const Tensor&, + const Tensor&, + const Tensor&, + const ReductionType& reduce, + bool); +using gather_expanded_index_fn = + void (*)(const Tensor&, const Tensor&, const Tensor&); + +DECLARE_DISPATCH(scatter_add_expanded_index_fn, scatter_add_expanded_index_stub) +DECLARE_DISPATCH( + scatter_reduce_expanded_index_fn, + scatter_reduce_expanded_index_stub) +DECLARE_DISPATCH(gather_expanded_index_fn, gather_expanded_index_stub) + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/TensorAdvancedIndexingUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/TensorAdvancedIndexingUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..cef66b7409e9c180568db01c7bed5b4e658d09db --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/TensorAdvancedIndexingUtils.h @@ -0,0 +1,109 @@ +#pragma once +#include +#include +#include + +namespace at::native { +namespace { +#ifndef STRIP_ERROR_MESSAGES +inline std::string shapes_as_str(TensorList tensors) { + std::ostringstream os; + bool first = true; + for (auto& tensor : tensors) { + if (tensor.defined()) { + if (!first) { + os << ", "; + } + os << tensor.sizes(); + first = false; + } + } + return os.str(); +} +#endif +} // anonymous namespace + +inline std::tuple canDispatchToMaskedFill( + const Tensor& self, + const torch::List>& indices, + const Tensor& value) { + if (!(value.numel() == 1 && value.device().is_cpu())) { + return std::make_tuple(false, Tensor()); + } + int64_t num_ind = 0; + Tensor mask; + auto self_device = self.device(); + for (const std::optional& i : indices) { + if (!i.has_value() || !(*i).defined()) { + num_ind++; + } else { + const Tensor& index = *i; + if ((index.scalar_type() != kByte && index.scalar_type() != kBool) || + index.device() != self_device || mask.defined()) { + return std::make_tuple(false, Tensor()); + } else { + mask = index; + for (const auto j : c10::irange(index.dim())) { + int64_t srcIdx = num_ind + j; + TORCH_CHECK_INDEX( + index.size(j) == self.size(srcIdx), + "The shape of the mask ", + index.sizes(), + " at index ", + j, + " does not match the shape of the indexed tensor ", + self.sizes(), + " at index ", + srcIdx); + } + num_ind += mask.ndimension(); + } + } + } + for ([[maybe_unused]] const auto i : + c10::irange(num_ind, self.ndimension())) { + mask = mask.unsqueeze(-1); + } + return std::make_tuple(true, mask); +} + +inline AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) { + checkIndexTensorTypes(orig, /*allow_int*/ true); + // first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more + // LongTensors + auto indices = expandTensors(self, orig); + // next broadcast all index tensors together + try { + indices = expand_outplace(indices); + } catch (std::exception& e) { + TORCH_CHECK_INDEX( + false, + "shape mismatch: indexing tensors could not be broadcast together" + " with shapes ", + shapes_as_str(indices)); + } + // add missing null Tensors so that it matches self.dim() + while (indices.size() < (size_t)self.dim()) { + indices.emplace_back(); + } + // if the non-null indices are not all adjacent, transpose self and indices + // together so that they're adjacent at the front + if (!hasContiguousSubspace(indices)) { + std::tie(self, indices) = transposeToFront(self, indices); + } + // Ensure indices are on the same device as self + for (auto& indice : indices) { + if (indice.defined() && indice.device() != self.device()) { + indice = indice.to(self.device()); + } + } + for (auto& indice : indices) { + if (indice.defined() && indice.dtype() == at::kInt) { + indice = indice.to(at::kLong); + } + } + + return AdvancedIndex(self, indices); +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/TensorCompare.h b/phivenv/Lib/site-packages/torch/include/ATen/native/TensorCompare.h new file mode 100644 index 0000000000000000000000000000000000000000..5efb5ac54f6b1453170c6c363cff471539d26b1f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/TensorCompare.h @@ -0,0 +1,56 @@ +#pragma once + +#include + +namespace c10 { +class Scalar; +} + +namespace at { +class Tensor; +struct TensorIterator; +struct TensorIteratorBase; +} // namespace at + +namespace at::native { + +using reduce_minmax_fn = + void (*)(Tensor&, Tensor&, const Tensor&, int64_t, bool); +using structured_reduce_minmax_fn = + void (*)(const Tensor&, const Tensor&, const Tensor&, int64_t, bool); + +DECLARE_DISPATCH(structured_reduce_minmax_fn, max_stub) +DECLARE_DISPATCH(structured_reduce_minmax_fn, min_stub) + +using where_fn = void (*)(TensorIterator&); +DECLARE_DISPATCH(where_fn, where_kernel) + +using is_infinity_op_fn = void (*)(TensorIteratorBase&); +DECLARE_DISPATCH(is_infinity_op_fn, isposinf_stub) +DECLARE_DISPATCH(is_infinity_op_fn, isneginf_stub) + +using mode_fn = void (*)(Tensor&, Tensor&, const Tensor&, int64_t, bool); +DECLARE_DISPATCH(mode_fn, mode_stub) + +using clamp_tensor_fn = void (*)(TensorIteratorBase&); +DECLARE_DISPATCH(clamp_tensor_fn, clamp_stub) + +namespace detail { +enum class ClampLimits { Min, Max, MinMax }; +} + +DECLARE_DISPATCH( + void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&), + clamp_scalar_stub) +DECLARE_DISPATCH( + void (*)(TensorIteratorBase&, c10::Scalar), + clamp_min_scalar_stub) +DECLARE_DISPATCH( + void (*)(TensorIteratorBase&, c10::Scalar), + clamp_max_scalar_stub) + +using isin_default_fn = + void (*)(const Tensor&, const Tensor&, bool, const Tensor&); +DECLARE_DISPATCH(isin_default_fn, isin_default_stub) + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/TensorConversions.h b/phivenv/Lib/site-packages/torch/include/ATen/native/TensorConversions.h new file mode 100644 index 0000000000000000000000000000000000000000..dafdde502a6eba184b7c9b871855b7a33f4b400d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/TensorConversions.h @@ -0,0 +1,31 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace at { +class Tensor; +namespace native { +bool to_will_alias( + const Tensor& self, + std::optional dtype, + std::optional layout, + std::optional device, + bool copy, + std::optional optional_memory_format); + +Tensor to_meta(const Tensor& tensor); +std::optional to_meta(const std::optional& tensor); +std::vector to_meta(at::ITensorListRef t_list); +Tensor dense_to_sparse_with_mask( + const Tensor& self, + const Tensor& mask, + std::optional layout, + OptionalIntArrayRef blocksize, + std::optional dense_dim_opt); + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/TensorDimApply.h b/phivenv/Lib/site-packages/torch/include/ATen/native/TensorDimApply.h new file mode 100644 index 0000000000000000000000000000000000000000..65f4f2108db6ee1f5de7dc7a8596bba740ad5058 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/TensorDimApply.h @@ -0,0 +1,67 @@ +#pragma once +#include +#include + +namespace at::native { +// input tensors are non-zero dim and non-empty +template + +void tensor_dim_apply3( + const Tensor& self, + Tensor& values, + Tensor& indices, + int64_t dim, + Function func) { + int ndims = self.dim(); + int tensor_dim_apply_has_finished = 0; + std::vector counter(ndims, 0); + const T1* self_data = self.const_data_ptr(); + T1* values_data = values.data_ptr(); + T2* indices_data = indices.data_ptr(); + int64_t self_stride = self.stride(dim); + int64_t values_stride = values.stride(dim); + int64_t indices_stride = indices.stride(dim); + int self_dim_size = self.size(dim); + + while (!tensor_dim_apply_has_finished) { + func( + self_data, + values_data, + indices_data, + self_dim_size, + self_stride, + values_stride, + indices_stride); + if (ndims == 1) { + break; + } + for (const auto dim_i : c10::irange(ndims)) { + if (dim_i == dim) { + if (dim_i == (ndims - 1)) { + tensor_dim_apply_has_finished = 1; + break; + } + continue; + } + counter[dim_i]++; + self_data += self.stride(dim_i); + values_data += values.stride(dim_i); + indices_data += indices.stride(dim_i); + + if (counter[dim_i] == self.size(dim_i)) { + if (dim_i == ndims - 1) { + tensor_dim_apply_has_finished = 1; + break; + } else { + self_data -= counter[dim_i] * self.stride(dim_i); + values_data -= counter[dim_i] * values.stride(dim_i); + indices_data -= counter[dim_i] * indices.stride(dim_i); + counter[dim_i] = 0; + } + } else { + break; + } + } + } +} +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/TensorFactories.h b/phivenv/Lib/site-packages/torch/include/ATen/native/TensorFactories.h new file mode 100644 index 0000000000000000000000000000000000000000..094080c77412eee2578e4e5ed7dc37583f9a5782 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/TensorFactories.h @@ -0,0 +1,169 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +namespace at::native { +// Different combinations of row, col, and offset can lead to two cases: +// +// Case 1 - Trapezoid (Triangle as a special case): row + offset <= col +// Example A: offset > 0 +// 1 1 0 0 0 +// 1 1 1 0 0 +// 1 1 1 1 0 +// Example B: offset <= 0 +// 0 0 0 +// 1 0 0 +// 1 1 0 +// In this case, we calculate the number of elements in the first row and +// last row of the tril respectively, and then compute the tril size. +// +// Case 2 - Trapezoid + Rectangle: row + offset > col +// Example: +// 1 1 0 +// 1 1 1 +// 1 1 1 +// In this case, we first calculate the size of top trapezoid, and then +// calculate the size of the bottom rectangle. +inline int64_t get_tril_size(int64_t row, int64_t col, int64_t offset) { + // If either dimension is 0 then the there is no tril + if (row == 0 || col == 0) { + return 0; + } + // number of elements in the first row of the tril + auto m_first_row = offset > 0 ? std::min(col, 1 + offset) + : // upper bounded by col + row + offset > 0; // either 0 or 1 + // number of elements in the last row of the tril, bounded by [0, col] + auto m_last_row = std::max(0, std::min(col, row + offset)); + // number of rows, bounded by [0, row] + auto n_row_all = std::max(0, std::min(row, row + offset)); + auto n_row_trapezoid = (m_last_row - m_first_row + 1); + + // calculate # of elements in the top trapezoid + auto tril_size = (m_first_row + m_last_row) * n_row_trapezoid >> 1; + + // calculate # of elements in the bottom rectangle if there is any + auto diff_row = n_row_all - n_row_trapezoid; + if (diff_row > 0) { + tril_size += diff_row * col; + } + + return tril_size; +} + +inline void check_args( + int64_t row, + int64_t col, + std::optional layout_opt) { + TORCH_CHECK(row >= 0, "row must be non-negative, got", row); + TORCH_CHECK(col >= 0, "col must be non-negative, got", col); + if (layout_opt.has_value()) { + TORCH_CHECK( + *layout_opt == at::kStrided, + "only support layout=torch.strided, got", + *layout_opt) + } +} + +using at::check_size_nonnegative; + +// assumes maximum value in created tensor is n-1 (e.g., torch.randperm(n)) +inline void check_supported_max_int_with_precision( + int64_t n, + const Tensor& tensor) { + // match defined() to behavior of checks below + TORCH_CHECK( + at::scalar_tensor(n > 0 ? n - 1 : n, tensor.options()).defined(), + "n is too large for result tensor type: '", + tensor.toString(), + "'"); + + // Ensure sufficient precision for floating point representation. + switch (tensor.scalar_type()) { + case at::ScalarType::Half: + TORCH_CHECK( + n <= (int64_t(1) << 11) + 1, + "n cannot be greater than 2049 for Half type."); + break; + case at::ScalarType::Float: + TORCH_CHECK( + n <= (int64_t(1) << 24) + 1, + "n cannot be greater than 2^24+1 for Float type."); + break; + case at::ScalarType::Double: // Unlikely to happen, but doesn't hurt to + // check + TORCH_CHECK( + n <= (int64_t(1) << 53) + 1, + "n cannot be greater than 2^53+1 for Double type."); + break; + default: + break; + } +} + +// Called by `empty*` functions when deterministic algorithms are enabled to +// fill the tensor with NaN if it is floating point or complex type, or fill +// with max value if it is integer type +inline Tensor& fill_empty_deterministic_(Tensor& tensor) { + if (tensor.is_floating_point() || tensor.is_complex()) { + AT_DISPATCH_V2( + tensor.scalar_type(), + "fill_empty_deterministic_", + AT_WRAP([&]() { + tensor.fill_(std::numeric_limits::quiet_NaN()); + }), + AT_EXPAND(AT_FLOATING_TYPES), + AT_EXPAND(AT_COMPLEX_TYPES), + AT_EXPAND(AT_FLOAT8_TYPES), + kBFloat16, + kHalf, + kComplexHalf); + } else { + AT_DISPATCH_V2( + tensor.scalar_type(), + "fill_empty_deterministic_", + AT_WRAP([&]() { tensor.fill_(std::numeric_limits::max()); }), + kBool, + AT_EXPAND(AT_INTEGRAL_TYPES_V2)); + } + return tensor; +} + +// The ZeroTensor allocator ignores whatever allocation is requested and always +// gives you nullptr +struct ZeroTensorAllocator final : public at::Allocator { + ZeroTensorAllocator(at::Device device) : device_(device) {} + ~ZeroTensorAllocator() override = default; + static void deleter(void* const pointer) { + TORCH_INTERNAL_ASSERT(!pointer); + } + DataPtr allocate(const size_t /*nbytes*/) override { + return {nullptr, nullptr, &deleter, device_}; + } + DeleterFnPtr raw_deleter() const override { + return deleter; + } + void copy_data( + void* dest [[maybe_unused]], + const void* src [[maybe_unused]], + std::size_t count [[maybe_unused]]) const final {} + at::Device device_; +}; + +using binary_fn = void (*)(TensorIterator&); + +DECLARE_DISPATCH(binary_fn, complex_stub) +DECLARE_DISPATCH(binary_fn, polar_stub) + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/TensorIterator.h b/phivenv/Lib/site-packages/torch/include/ATen/native/TensorIterator.h new file mode 100644 index 0000000000000000000000000000000000000000..4fb52e967ad7da6e58fca440b588f20767c0bf15 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/TensorIterator.h @@ -0,0 +1,2 @@ +#pragma once +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/TensorIteratorDynamicCasting.h b/phivenv/Lib/site-packages/torch/include/ATen/native/TensorIteratorDynamicCasting.h new file mode 100644 index 0000000000000000000000000000000000000000..93331ee969b48280bb55cdc3666b0da0d2269bb6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/TensorIteratorDynamicCasting.h @@ -0,0 +1,52 @@ +#pragma once + +#include +#include +#include +#include +#include + +// This file includes utilities for dynamic_casting done by TensorIterator, see +// CUDALoops.cuh and Loops.h. + +// dynamic_casting handles when the types expected by the iterator do not match +// the types of the arguments to the function that is being called. On CUDA, the +// cast is currently pushed down into the kernel (for performance reasons). On +// CPU, there is currently an internal assert that a dynamic_cast is not needed. + +namespace at::native { + +// `needs_dynamic_casting` compares the types expected by iterator +// (i.e. dtypes of the operands) with the actual type of the arguments +// (and returns) of func_t +template ::arity> +struct needs_dynamic_casting { + static bool check(TensorIteratorBase& iter) { + using traits = function_traits; + using cpp_type = typename traits::template arg::type; + using cpp_map = c10::CppTypeToScalarType; + + if (iter.input_dtype(nargs - 1) != cpp_map::value) { + return true; + } + return needs_dynamic_casting::check(iter); + } +}; + +template +struct needs_dynamic_casting { + static bool check(TensorIteratorBase& iter) { + using traits = function_traits; + using cpp_type = typename traits::result_type; + + // we could assert output numbers are correct here, but checks + // (including arity) are currently pushed outside of this struct. + if constexpr (std::is_void_v) { + return false; + } else { + return iter.dtype(0) != c10::CppTypeToScalarType::value; + } + } +}; + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/TensorProperties.h b/phivenv/Lib/site-packages/torch/include/ATen/native/TensorProperties.h new file mode 100644 index 0000000000000000000000000000000000000000..8654b3dae577b192c75c9cb8f74ea417bcd3b961 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/TensorProperties.h @@ -0,0 +1,12 @@ +#pragma once + +// See NOTE: [Tensor vs. TensorBase] +namespace at { +class TensorBase; +} + +namespace at::native { + +TORCH_API bool cudnn_is_acceptable(const TensorBase& self); + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/TensorShape.h b/phivenv/Lib/site-packages/torch/include/ATen/native/TensorShape.h new file mode 100644 index 0000000000000000000000000000000000000000..7f9f9abfaf74195e5a3f2e9deeba1948ac5b76af --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/TensorShape.h @@ -0,0 +1,145 @@ +#pragma once +#include +#include +#include + +namespace at::native { + +TORCH_API at::Tensor clone_preserve_strides(const at::Tensor& self); + +inline bool cat_should_skip_tensor(const Tensor& t) { + return t.sym_numel() == 0 && t.dim() == 1; +} + +// Check to see if the shape of tensors is compatible +// for being concatenated along a given dimension. +inline void check_cat_shape_except_dim( + const Tensor& first, + const Tensor& second, + int64_t dimension, + int64_t index) { + int64_t first_dims = first.dim(); + int64_t second_dims = second.dim(); + TORCH_CHECK( + first_dims == second_dims, + "Tensors must have same number of dimensions: got ", + first_dims, + " and ", + second_dims); + for (const auto dim : c10::irange(first_dims)) { + if (dim == dimension) { + continue; + } + int64_t first_dim_size = first.sizes()[dim]; + int64_t second_dim_size = second.sizes()[dim]; + TORCH_CHECK( + first_dim_size == second_dim_size, + "Sizes of tensors must match except in dimension ", + dimension, + ". Expected size ", + static_cast(first_dim_size), + " but got size ", + static_cast(second_dim_size), + " for tensor number ", + index, + " in the list."); + } +} + +inline void check_cat_no_zero_dim(const MaterializedITensorListRef& tensors) { + [[maybe_unused]] int64_t i = 0; + for (const Tensor& t : tensors) { + TORCH_CHECK( + t.dim() > 0, + "zero-dimensional tensor (at position ", + i, + ") cannot be concatenated"); + i++; + } +} + +inline int64_t get_num_splits( + const Tensor& self, + int64_t split_size, + int64_t dim) { + TORCH_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor"); + TORCH_CHECK( + split_size >= 0, + "split expects split_size be non-negative, but got split_size=", + split_size); + int64_t dim_size = self.size(dim); + TORCH_CHECK( + split_size > 0 || dim_size == 0, + "split_size can only be 0 if dimension size is 0, " + "but got dimension size of ", + dim_size); + // if split_size is 0 and dimension size is 0, there is 1 split. + int64_t num_splits = 1; + if (split_size != 0) { + // ensuring num_splits is at least 1 makes consistent the case where + // split_size > dim_size (returns a single split). We might want to error + // here, but keep it for BC. + num_splits = std::max((dim_size + split_size - 1) / split_size, 1); + } + return num_splits; +} + +inline bool have_same_ndims(TensorList tensors) { + auto ndim = tensors[0].dim(); + for (const auto tensor_idx : c10::irange(tensors.size())) { + if (tensors[tensor_idx].dim() != ndim) { + return false; + } + } + return true; +} + +inline void leading_dimension_matches(TensorList tensors, int64_t dim) { + auto tensor_zero_size = tensors[0].sizes(); + std::vector leading_dim_sizes( + tensor_zero_size.begin(), tensor_zero_size.begin() + dim); + for (const auto i : c10::irange(tensors.size())) { + at::Tensor tensor = tensors[i]; + for (const auto j : c10::irange(dim)) { + TORCH_CHECK( + tensor.size(j) == leading_dim_sizes[j], + "_chunk_cat expects same sizes of 0,...,dim-1 dimensions for all tensors"); + } + } +} + +inline int64_t preprocess_chunk_cat_inputs( + TensorList tensors, + int64_t dim, + int64_t num_chunks) { + TORCH_CHECK(num_chunks >= 1, "_chunk_cat expects positive num_chunks"); + TORCH_CHECK( + !tensors.empty(), "_chunk_cat expects a non-empty input tensor list"); + auto expected_dtype = tensors[0].dtype(); + auto expected_device = tensors[0].device(); + for (const auto i : c10::irange(tensors.size())) { + TORCH_CHECK(tensors[i].numel() > 0, "_chunk_cat expects non-empty tensor"); + TORCH_CHECK( + tensors[i].dtype() == expected_dtype, + "_chunk_cat expects all input tensors with the same dtype"); + TORCH_CHECK( + tensors[i].device() == expected_device, + "_chunk_cat expects all inputs tensors on the same device"); + } + if (have_same_ndims(tensors)) { + dim = maybe_wrap_dim(dim, tensors[0].dim()); + } else { + TORCH_CHECK( + dim >= 0, + "_chunk_cat expects non-negative dim when input tensors have different ndims") + for (const auto i : c10::irange(tensors.size())) { + TORCH_CHECK( + dim < tensors[i].ndimension(), + "_chunk_cat expects dim < ndim for all input tensors"); + } + } + leading_dimension_matches(tensors, dim); + return dim; +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/TensorTransformations.h b/phivenv/Lib/site-packages/torch/include/ATen/native/TensorTransformations.h new file mode 100644 index 0000000000000000000000000000000000000000..1b662a2edbaa073e77a82643fe4be37a39b7674d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/TensorTransformations.h @@ -0,0 +1,35 @@ +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +#include + +namespace at::native { + +static inline Tensor roll_common( + const Tensor& self, + IntArrayRef shifts, + IntArrayRef dims) { + TORCH_CHECK(!shifts.empty(), "`shifts` required"); + if (dims.empty() && shifts.size() == 1) { + auto flattened = self.contiguous().view(self.numel()); + return roll(flattened, shifts[0], 0).view(self.sizes()); + } + TORCH_CHECK( + shifts.size() == dims.size(), + "shifts and dimensions must align. shifts: ", + shifts.size(), + ", dims:", + dims.size()); + AT_ASSERT(dims.size() > 1); + auto tail_shifts = shifts.slice(1); + auto tail_dims = dims.slice(1); + auto first_dim_rolled = roll(self, shifts[0], dims[0]); + return at::roll(first_dim_rolled, tail_shifts, tail_dims); +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/TopKImpl.h b/phivenv/Lib/site-packages/torch/include/ATen/native/TopKImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..a8ffaf61295398c9e7a28bdcbc77d4c81e9b3846 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/TopKImpl.h @@ -0,0 +1,98 @@ +#pragma once +#include +#include + +namespace at::native { + +#ifdef CPU_CAPABILITY +inline namespace CPU_CAPABILITY { +#else +inline namespace DEFAULT { +#endif + +// Core topk loop, shared between CPU and QuantizedCPU +template +void topk_impl_loop( + const int64_t mode_values_stride, + const int64_t mode_indices_stride, + const int64_t tmp_values_stride, + const int64_t k, + const int64_t dim_size, + const bool largest, + const bool sorted, + char** data, const int64_t* strides, const int64_t n) { + + // If k is zero, then output values and indices are empty tensors + // So iterating over other dims is pointless + if (k == 0) { + return; + } + using elem_t = std::pair; + std::vector queue(dim_size); + for (const auto i : c10::irange(n)) { + TensorAccessor mode_values( + reinterpret_cast(data[0] + i * strides[0]), + &k, &mode_values_stride); + TensorAccessor mode_indices( + reinterpret_cast(data[1] + i * strides[1]), + &k, &mode_indices_stride); + TensorAccessor tmp_values( + reinterpret_cast(data[2] + i * strides[2]), + &dim_size, &tmp_values_stride); + + auto n_2 = dim_size; + auto use_partial_sort = k * 64 <= n_2; + + for (const auto j : c10::irange(n_2)) { + queue[j].first = tmp_values[j]; + queue[j].second = j; + } + + // we want nan to be sorted as top for numpy compatibility + if (use_partial_sort) { + if (largest) { + std::partial_sort(queue.begin(), queue.begin() + k, queue.end(), + [](const elem_t& x, const elem_t& y) -> bool { + return ((_isnan(x.first) && !_isnan(y.first)) || (x.first > y.first)); + }); + } else { + std::partial_sort(queue.begin(), queue.begin() + k, queue.end(), + [](const elem_t& x, const elem_t& y) -> bool { + return ((!_isnan(x.first) && _isnan(y.first)) || (x.first < y.first)); + }); + } + } else { + if (largest) { + std::nth_element(queue.begin(), queue.begin() + k - 1, queue.end(), + [](const elem_t& x, const elem_t& y) -> bool { + return ((_isnan(x.first) && !_isnan(y.first)) || (x.first > y.first)); + }); + if (sorted) { + std::sort(queue.begin(), queue.begin() + k - 1, + [](const elem_t& x, const elem_t& y) -> bool { + return ((_isnan(x.first) && !_isnan(y.first)) || (x.first > y.first)); + }); + } + } else { + std::nth_element(queue.begin(), queue.begin() + k -1, queue.end(), + [](const elem_t& x, const elem_t& y) -> bool { + return ((!_isnan(x.first) && _isnan(y.first)) || (x.first < y.first)); + }); + if (sorted) { + std::sort(queue.begin(), queue.begin() + k -1, + [](const elem_t& x, const elem_t& y) -> bool { + return ((!_isnan(x.first) && _isnan(y.first)) || (x.first < y.first)); + }); + } + } + } + + for (const auto j : c10::irange(k)) { + mode_values[j] = queue[j].first; + mode_indices[j] = queue[j].second; + } + } +} + +} // namespace CPU_CAPABILITY +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/TransposeType.h b/phivenv/Lib/site-packages/torch/include/ATen/native/TransposeType.h new file mode 100644 index 0000000000000000000000000000000000000000..2ebdce31873a4ff7e6269551d374952a35f49fdc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/TransposeType.h @@ -0,0 +1,23 @@ +#pragma once +#include + +namespace at::native { + +// Used as an interface between the different BLAS-like libraries +enum class TransposeType { + NoTranspose, + Transpose, + ConjTranspose, +}; + +// Transforms TransposeType into the BLAS / LAPACK format +static inline char to_blas(TransposeType trans) { + switch (trans) { + case TransposeType::Transpose: return 'T'; + case TransposeType::NoTranspose: return 'N'; + case TransposeType::ConjTranspose: return 'C'; + } + TORCH_INTERNAL_ASSERT(false, "Invalid transpose type"); +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/TriangularOpsUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/TriangularOpsUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..27fe2e18cb685b5fe32214b0fe10466d2b5d0189 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/TriangularOpsUtils.h @@ -0,0 +1,57 @@ +#include +#include + +namespace at::native { + +/* + * Given batches of matrices with arbitrary batch dim, + * computes the number of batches for Triu and Tril. This ignores stride 0 dimension + */ +static inline int64_t batchCountTrilTriu(const Tensor& batched_matrices) { + int64_t result = 1; + for (int64_t i = 0; i < batched_matrices.ndimension() - 2; i++) { + if (batched_matrices.stride(i) != 0) { + result *= batched_matrices.size(i); + } + } + return result; +} + +/* Checks a necessary property for the triu and tril implementations, hence the name. + * Here batch contiguity is checked for tensors with greater than 4 dimensions. + * Contiguous tensors and tensors with less than 3 dimensions pass this check + */ +static inline std::tuple checkTrilTriuBatchContiguous(const Tensor& tensor, bool allow_zero_stride) { + // Complete contiguity is the most desired property, which is why + // we return true if the tensor is contiguous + if (tensor.is_contiguous()) { + auto default_strides_for_size = batched_matrix_contiguous_strides(tensor.sizes()); + if (tensor.strides() == default_strides_for_size) { + return std::make_tuple(true, tensor); + } else { + return std::make_tuple(false, tensor.as_strided(tensor.sizes(), default_strides_for_size)); + } + } + + int64_t dims = tensor.dim(); + + // Tensors with dimension less than 4 are handled by default + if (allow_zero_stride && dims <= 3) { + return std::make_tuple(true, tensor); + } + + int64_t expected_stride = tensor.size(-1) * tensor.size(-2); + for (int64_t i = dims - 3; i >= 0; i--) { + // Skip trivial dimension; + if (allow_zero_stride && i == 0 && (tensor.stride(i) == 0 || tensor.size(i) == 1)) { + continue; + } + if (expected_stride != tensor.stride(i)) { + return std::make_tuple(false, tensor.contiguous()); + } + expected_stride *= tensor.size(i); + } + return std::make_tuple(true, tensor); +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/TypeProperties.h b/phivenv/Lib/site-packages/torch/include/ATen/native/TypeProperties.h new file mode 100644 index 0000000000000000000000000000000000000000..07f0028655e58f6c1305251782ad6a5e51ad7a74 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/TypeProperties.h @@ -0,0 +1,20 @@ +#pragma once + +#include +#include + +namespace at::native { + +struct ResultTypeState { + c10::ScalarType dimResult = ScalarType::Undefined; + c10::ScalarType wrappedResult = ScalarType::Undefined; + c10::ScalarType zeroResult = ScalarType::Undefined; +}; + +TORCH_API ResultTypeState update_result_type_state(const Tensor& tensor, const ResultTypeState& in_state); +TORCH_API ResultTypeState update_result_type_state(const Scalar& scalar, const ResultTypeState& in_state); +TORCH_API ScalarType result_type(const ResultTypeState& state); + +TORCH_API ScalarType result_type(ITensorListRef tensors); + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/UnaryOps.h b/phivenv/Lib/site-packages/torch/include/ATen/native/UnaryOps.h new file mode 100644 index 0000000000000000000000000000000000000000..ab4d03c98bd44ed6f9cc76fc079b838ea522a7a6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/UnaryOps.h @@ -0,0 +1,128 @@ +#pragma once + +#include +#include +#include + +namespace at { +class Tensor; +class TensorBase; +struct TensorIteratorBase; +} + +namespace at::native { + +using unary_fn = void(*)(TensorIteratorBase&); +using unary_fn_with_scalar = void(*)(TensorIteratorBase&, const Scalar& a); + +inline namespace CPU_CAPABILITY { +void conj_kernel(TensorIteratorBase &iter); +void neg_kernel(TensorIteratorBase &iter); +void reciprocal_kernel(TensorIteratorBase &iter); +void rsqrt_kernel(TensorIteratorBase& iter); +void sqrt_kernel(TensorIteratorBase& iter); +} // namespace CPU_CAPABILITY + +DECLARE_DISPATCH(unary_fn, abs_stub) +DECLARE_DISPATCH(unary_fn, angle_stub) +DECLARE_DISPATCH(unary_fn, conj_physical_stub) +DECLARE_DISPATCH(unary_fn, acos_stub) +DECLARE_DISPATCH(unary_fn, acosh_stub) +DECLARE_DISPATCH(unary_fn, asinh_stub) +DECLARE_DISPATCH(unary_fn, atanh_stub) +DECLARE_DISPATCH(unary_fn, asin_stub) +DECLARE_DISPATCH(unary_fn, atan_stub) +DECLARE_DISPATCH(unary_fn, bitwise_not_stub) +DECLARE_DISPATCH(unary_fn, logical_not_stub) +DECLARE_DISPATCH(unary_fn, ceil_stub) +DECLARE_DISPATCH(unary_fn, cos_stub) +DECLARE_DISPATCH(unary_fn, cosh_stub) +DECLARE_DISPATCH(unary_fn, digamma_stub) +DECLARE_DISPATCH(unary_fn, special_entr_stub) +DECLARE_DISPATCH(unary_fn, special_erfcx_stub) +DECLARE_DISPATCH(unary_fn, erf_stub) +DECLARE_DISPATCH(unary_fn, erfc_stub) +DECLARE_DISPATCH(unary_fn, erfinv_stub) +DECLARE_DISPATCH(unary_fn, exp_stub) +DECLARE_DISPATCH(unary_fn, exp2_stub) +DECLARE_DISPATCH(unary_fn, expm1_stub) +DECLARE_DISPATCH(unary_fn, floor_stub) +DECLARE_DISPATCH(unary_fn, frac_stub) +DECLARE_DISPATCH(unary_fn, frexp_stub) +DECLARE_DISPATCH(unary_fn, i0_stub) +DECLARE_DISPATCH(unary_fn, special_i0e_stub) +DECLARE_DISPATCH(unary_fn, special_i1_stub) +DECLARE_DISPATCH(unary_fn, special_i1e_stub) +DECLARE_DISPATCH(unary_fn, log_stub) +DECLARE_DISPATCH(unary_fn, log10_stub) +DECLARE_DISPATCH(unary_fn, log1p_stub) +DECLARE_DISPATCH(unary_fn, log2_stub) +DECLARE_DISPATCH(unary_fn, special_ndtri_stub) +DECLARE_DISPATCH(unary_fn, special_log_ndtr_stub) +DECLARE_DISPATCH(unary_fn, neg_stub) + +DECLARE_DISPATCH(unary_fn, reciprocal_stub) +DECLARE_DISPATCH(unary_fn, round_stub) +DECLARE_DISPATCH(unary_fn, rsqrt_stub) +DECLARE_DISPATCH(unary_fn, sigmoid_stub) +DECLARE_DISPATCH(unary_fn_with_scalar, logit_stub) +DECLARE_DISPATCH(unary_fn, sign_stub) +DECLARE_DISPATCH(unary_fn, signbit_stub) +DECLARE_DISPATCH(unary_fn, sgn_stub) +DECLARE_DISPATCH(unary_fn, sin_stub) +DECLARE_DISPATCH(unary_fn, sinc_stub) +DECLARE_DISPATCH(unary_fn, sinh_stub) +DECLARE_DISPATCH(unary_fn, sqrt_stub) +DECLARE_DISPATCH(unary_fn, tan_stub) +DECLARE_DISPATCH(unary_fn, tanh_stub) +DECLARE_DISPATCH(unary_fn, trigamma_stub) +DECLARE_DISPATCH(unary_fn, trunc_stub) +DECLARE_DISPATCH(unary_fn, lgamma_stub) +DECLARE_DISPATCH(unary_fn, special_airy_ai_stub) +DECLARE_DISPATCH(unary_fn, special_bessel_j0_stub) +DECLARE_DISPATCH(unary_fn, special_bessel_j1_stub) +DECLARE_DISPATCH(unary_fn, special_bessel_y0_stub) +DECLARE_DISPATCH(unary_fn, special_bessel_y1_stub) +DECLARE_DISPATCH(unary_fn, special_modified_bessel_i0_stub) +DECLARE_DISPATCH(unary_fn, special_modified_bessel_i1_stub) +DECLARE_DISPATCH(unary_fn, special_modified_bessel_k0_stub) +DECLARE_DISPATCH(unary_fn, special_modified_bessel_k1_stub) +DECLARE_DISPATCH(unary_fn, special_scaled_modified_bessel_k0_stub) +DECLARE_DISPATCH(unary_fn, special_scaled_modified_bessel_k1_stub) +DECLARE_DISPATCH(unary_fn, special_spherical_bessel_j0_stub) + +// NB: these are actually defined in Distribution +DECLARE_DISPATCH(void(*)(const TensorBase&, const TensorBase&, std::optional), bernoulli_tensor_stub) +DECLARE_DISPATCH(void(*)(const TensorBase&, const double, std::optional), bernoulli_scalar_stub) +DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, std::optional), cauchy_stub) +DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, std::optional), exponential_stub) +DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, std::optional), geometric_stub) +DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, std::optional), log_normal_stub) +DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, std::optional), uniform_stub) +DECLARE_DISPATCH(void(*)(const TensorBase&, const double, const double, std::optional), normal_stub) +DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const uint64_t, const int64_t, std::optional), random_from_to_stub) +DECLARE_DISPATCH(void(*)(TensorIteratorBase&, std::optional), random_full_64_bits_range_stub) +DECLARE_DISPATCH(void(*)(TensorIteratorBase&, std::optional), random_stub) + +DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const int64_t, const double), kaiser_window_stub) +DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const int64_t), polygamma_stub) +DECLARE_DISPATCH( + void (*)(Tensor&, const Tensor&, int64_t, std::optional), + multinomial_with_replacement_stub) +DECLARE_DISPATCH( + void (*)( + TensorIteratorBase&, + std::optional, + std::optional, + std::optional), + nan_to_num_stub) +DECLARE_DISPATCH(void (*)(TensorIteratorBase&, int64_t), round_decimals_stub) + +// Missing unary functions +// digamma +// lgamma +// erfinv +// clone +// contiguous +// zero +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/Unfold2d.h b/phivenv/Lib/site-packages/torch/include/ATen/native/Unfold2d.h new file mode 100644 index 0000000000000000000000000000000000000000..249a5747912409cd23ab8514db26d1b431958e4c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/Unfold2d.h @@ -0,0 +1,48 @@ +#pragma once + +#include +#include +#include + +namespace at::native { + +using unfold2d_copy_fn = void (*)( + ScalarType dtype, + void *finput, + const void *input, + int64_t kH, + int64_t kW, + int64_t dH, + int64_t dW, + int64_t padH, + int64_t padW, + int64_t n_input_plane, + int64_t input_height, + int64_t input_width, + int64_t output_height, + int64_t output_width, + bool is_channels_last +); + +using unfold2d_acc_fn = void (*)( + ScalarType dtype, + void *finput, + void *input, + int64_t kH, + int64_t kW, + int64_t dH, + int64_t dW, + int64_t padH, + int64_t padW, + int64_t n_input_plane, + int64_t input_height, + int64_t input_width, + int64_t output_height, + int64_t output_width, + bool is_channels_last +); + +DECLARE_DISPATCH(unfold2d_copy_fn, unfolded2d_copy_stub) +DECLARE_DISPATCH(unfold2d_acc_fn, unfolded2d_acc_stub) + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/Unfold3d.h b/phivenv/Lib/site-packages/torch/include/ATen/native/Unfold3d.h new file mode 100644 index 0000000000000000000000000000000000000000..eae526b7ec33a2ec1b34aeee808f78fc47931c82 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/Unfold3d.h @@ -0,0 +1,49 @@ +#pragma once + +#include + +namespace at::native { + +void Unfold3dCopyCPU( + ScalarType dtype, + const void *src, + int64_t C, + int64_t X_D, + int64_t X_H, + int64_t X_W, + int64_t Y_D, + int64_t Y_H, + int64_t Y_W, + int64_t kernel_d, + int64_t kernel_h, + int64_t kernel_w, + int64_t stride_d, + int64_t stride_h, + int64_t stride_w, + int64_t pad_d, + int64_t pad_h, + int64_t pad_w, + void* dst); + +void Unfold3dAccCPU( + ScalarType dtype, + const void *src, + int64_t C, + int64_t X_D, + int64_t X_H, + int64_t X_W, + int64_t Y_D, + int64_t Y_H, + int64_t Y_W, + int64_t kernel_d, + int64_t kernel_h, + int64_t kernel_w, + int64_t stride_d, + int64_t stride_h, + int64_t stride_w, + int64_t pad_d, + int64_t pad_h, + int64_t pad_w, + void *dst); + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/UnfoldBackward.h b/phivenv/Lib/site-packages/torch/include/ATen/native/UnfoldBackward.h new file mode 100644 index 0000000000000000000000000000000000000000..f96f45d8b3ea1a3d145168ced09ea836b9b2c9f1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/UnfoldBackward.h @@ -0,0 +1,110 @@ +#pragma once + +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +namespace at::native { + +using unfold_backward_fn = void (*)( + Tensor& grad_in, + const Tensor& grad, + int64_t dim, + int64_t size, + int64_t step +); + +DECLARE_DISPATCH(unfold_backward_fn, unfold_backward_stub) + +namespace { + +// Note on naming: it is unconventional. +// grad_in does not mean that it is a gradient wrt to input, +// grad_in/grad_out is just an input/output of unfold_backward kernel. + +[[maybe_unused]] static TensorIterator _make_unfold_backward_iter_over_grad_out( + Tensor& grad_out, + const Tensor& grad_in, + int64_t dim, + int64_t size, + int64_t step) { + dim = maybe_wrap_dim(dim, grad_out.dim()); + // last dim stores the folds + + auto grad_out_dim_size = ensure_nonempty_size(grad_out, dim); + auto grad_in_dim_size = ensure_nonempty_size(grad_in, dim); + // dictates the number of elements to iterate over + // in dimension `dim` + auto iter_dim_size = std::min( + grad_out_dim_size, + (grad_in_dim_size - 1) * step + size + ); + + /* prepare grad_out for TensorIterator { */ + auto grad_out_strides = ensure_nonempty_vec(grad_out.strides().vec()); + auto grad_out_sizes = ensure_nonempty_vec(grad_out.sizes().vec()); + grad_out_sizes[dim] = iter_dim_size; + auto grad_out_restrided = grad_out.as_strided( + grad_out_sizes, grad_out_strides + ); + /* } */ + + /* prepare grad_in for TensorIterator { */ + auto grad_in_strides = ensure_nonempty_vec(grad_in.strides().vec()); + auto grad_in_sizes = ensure_nonempty_vec(grad_in.sizes().vec()); + + // set strides for dim to 0 + // and size to 1 because + // this dimension is indexed inside the kernel + grad_in_strides[dim] = 0; + grad_in_sizes[dim] = 1; + + grad_in_strides.pop_back(); + grad_in_sizes.pop_back(); + + auto grad_in_restrided = grad_in.squeeze(-1).as_strided( + grad_in_sizes, grad_in_strides + ); + /* } */ + + // During the TensorIterator iteration we have to know + // i_dim in grad_out[i_1,...,i_dim,...i_n], + // idx_dim stores this information + /* prepare idx_dim for TensorIterator { */ + auto idx_dim = at::arange( + 0, iter_dim_size, grad_in.options().dtype(at::kLong) + ); + + auto grad_out_dim = ensure_nonempty_dim(grad_out.dim()); + + auto idx_dim_strides = std::vector(grad_out_dim, 0); + auto idx_dim_sizes = std::vector(grad_out_dim, 1); + + idx_dim_strides[dim] = 1; + idx_dim_sizes[dim] = iter_dim_size; + + // idx_dim size will broadcast over determined by grad_out sizes in TensorIterator + auto idx_dim_restrided = idx_dim.as_strided(idx_dim_sizes, idx_dim_strides); + /* } */ + + auto iter = TensorIteratorConfig() + .set_check_mem_overlap(false) + .check_all_same_dtype(false) + .resize_outputs(false) + .add_owned_output(grad_out_restrided) + .add_owned_const_input(grad_in_restrided) + .add_owned_const_input(idx_dim_restrided) + .build(); + + return iter; +} +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/UpSample.h b/phivenv/Lib/site-packages/torch/include/ATen/native/UpSample.h new file mode 100644 index 0000000000000000000000000000000000000000..452d7de79aae4c6e8f92872d03b4a07f8255666a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/UpSample.h @@ -0,0 +1,510 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +/** + * Note [compute_scales_value] + * Note [area_pixel_compute_scale] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * Interpolate with scale_factor can have different behaviors + * depending on the value of recompute_scale_factor: + * + * - With recompute_scale_factor = True (current default behavior): + * the scale_factor, when provided by the user, are used to calculate + * the output size. The input size and the computed output_size + * are then used to infer new values for the scales which are + * used in the interpolation. Because floating-point math is not exact, + * this may be a different value from the user-supplied scales. + * + * - With recompute_scale_factor = False (which will be the default + * behavior starting 1.5.0): + * the behavior follows opencv logic, and the scales provided by + * the user are the ones used in the interpolation calculations. + * + * If the scales are not provided or if they are provided but + * recompute_scale_factor is set to True (default behavior), the scales + * are computed from the input and the output size; + * + * + * When the scales are inferred from the input and output sizes, + * we view each pixel as an area, idx + 0.5 as its center index. + * Here is an example formula in 1D case. + * if align_corners: center of two corner pixel areas are preserved, + * (0.5, 0.5) -> (0.5, 0.5), + * (input_size - 0.5, 0.5) -> (output_size - 0.5) + * scale = (input_size - 0.5 - 0.5) / (output_size - 0.5 - 0.5) + * src_index + 0.5 - 0.5 = scale * (dst_index + 0.5 - 0.5) + * if not align_corners: the whole range is scaled accordingly + * scale = input_size / output_size + * src_idx + 0.5 = scale * (dst_index + 0.5) + */ + +namespace at::native { + +namespace upsample { + +TORCH_API c10::SmallVector compute_output_size( + c10::IntArrayRef input_size, // Full input tensor size. + at::OptionalIntArrayRef output_size, + std::optional> scale_factors); + +inline std::optional get_scale_value(std::optional> scales, int idx) { + if (!scales) { + return std::nullopt; + } + return scales->at(idx); +} + +} // namespace upsample + +using scale_t = std::optional; +using upsampling_nearest1d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_w); +using _upsampling_nearest_exact1d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_w); +using upsampling_nearest2d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_h, scale_t scales_w); +using _upsampling_nearest_exact2d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_h, scale_t scales_w); +using upsampling_nearest3d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_d, scale_t scales_h, scale_t scales_w); +using _upsampling_nearest_exact3d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_d, scale_t scales_h, scale_t scales_w); +using upsampling_linear1d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_w); +using upsampling_bilinear2d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w); +using _upsampling_bilinear2d_aa = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w); +using upsampling_trilinear3d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_d, scale_t scales_h, scale_t scales_w); +using upsampling_bicubic2d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w); +using _upsampling_bicubic2d_aa = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w); +DECLARE_DISPATCH(upsampling_nearest1d, upsample_nearest1d_kernel) +DECLARE_DISPATCH(_upsampling_nearest_exact1d, _upsample_nearest_exact1d_kernel) +DECLARE_DISPATCH(upsampling_nearest2d, upsample_nearest2d_kernel) +DECLARE_DISPATCH(_upsampling_nearest_exact2d, _upsample_nearest_exact2d_kernel) +DECLARE_DISPATCH(upsampling_nearest3d, upsample_nearest3d_kernel) +DECLARE_DISPATCH(_upsampling_nearest_exact3d, _upsample_nearest_exact3d_kernel) +DECLARE_DISPATCH(upsampling_nearest1d, upsample_nearest1d_backward_kernel) +DECLARE_DISPATCH(_upsampling_nearest_exact1d, _upsample_nearest_exact1d_backward_kernel) +DECLARE_DISPATCH(upsampling_nearest2d, upsample_nearest2d_backward_kernel) +DECLARE_DISPATCH(_upsampling_nearest_exact2d, _upsample_nearest_exact2d_backward_kernel) +DECLARE_DISPATCH(upsampling_nearest3d, upsample_nearest3d_backward_kernel) +DECLARE_DISPATCH(_upsampling_nearest_exact3d, _upsample_nearest_exact3d_backward_kernel) +DECLARE_DISPATCH(upsampling_linear1d, upsample_linear1d_kernel) +DECLARE_DISPATCH(upsampling_bilinear2d, upsample_bilinear2d_kernel) +DECLARE_DISPATCH(_upsampling_bilinear2d_aa, _upsample_bilinear2d_aa_kernel) +DECLARE_DISPATCH(upsampling_trilinear3d, upsample_trilinear3d_kernel) +DECLARE_DISPATCH(upsampling_linear1d, upsample_linear1d_backward_kernel) +DECLARE_DISPATCH(upsampling_bilinear2d, upsample_bilinear2d_backward_kernel) +DECLARE_DISPATCH(_upsampling_bilinear2d_aa, _upsample_bilinear2d_aa_backward_kernel) +DECLARE_DISPATCH(upsampling_trilinear3d, upsample_trilinear3d_backward_kernel) +DECLARE_DISPATCH(upsampling_bicubic2d, upsample_bicubic2d_kernel) +DECLARE_DISPATCH(_upsampling_bicubic2d_aa, _upsample_bicubic2d_aa_kernel) +DECLARE_DISPATCH(_upsampling_bicubic2d_aa, _upsample_bicubic2d_aa_backward_kernel) + +[[maybe_unused]] inline std::array upsample_1d_common_check( + IntArrayRef input_size, + IntArrayRef output_size) { + TORCH_CHECK( + output_size.size() == 1, + "It is expected output_size equals to 1, but got size ", + output_size.size()); + + TORCH_CHECK( + input_size.size() == 3, + "It is expected input_size equals to 3, but got size ", + input_size.size()); + + int64_t output_width = output_size[0]; + + int64_t nbatch = input_size[0]; + int64_t channels = input_size[1]; + int64_t input_width = input_size[2]; + + TORCH_CHECK( + input_width > 0 && output_width > 0, + "Input and output sizes should be greater than 0, but got input (W: ", + input_width, + ") and output (W: ", + output_width, + ")"); + + return {nbatch, channels, output_width}; +} + +[[maybe_unused]] inline std::array upsample_2d_common_check( + IntArrayRef input_size, + IntArrayRef output_size) { + TORCH_CHECK( + output_size.size() == 2, + "It is expected output_size equals to 2, but got size ", + output_size.size()); + + TORCH_CHECK( + input_size.size() == 4, + "It is expected input_size equals to 4, but got size ", + input_size.size()); + + int64_t output_height = output_size[0]; + int64_t output_width = output_size[1]; + + int64_t nbatch = input_size[0]; + int64_t channels = input_size[1]; + int64_t input_height = input_size[2]; + int64_t input_width = input_size[3]; + + TORCH_CHECK( + input_height > 0 && input_width > 0 && output_height > 0 && + output_width > 0, + "Input and output sizes should be greater than 0," + " but got input (H: ", + input_height, + ", W: ", + input_width, + ") output (H: ", + output_height, + ", W: ", + output_width, + ")"); + + return {nbatch, channels, output_height, output_width}; +} + +[[maybe_unused]] inline std::array upsample_3d_common_check( + IntArrayRef input_size, + IntArrayRef output_size) { + TORCH_CHECK( + output_size.size() == 3, + "It is expected output_size equals to 3, but got size ", + output_size.size()); + + TORCH_CHECK( + input_size.size() == 5, + "It is expected input_size equals to 5, but got size ", + input_size.size()); + + int64_t output_depth = output_size[0]; + int64_t output_height = output_size[1]; + int64_t output_width = output_size[2]; + + int64_t nbatch = input_size[0]; + int64_t channels = input_size[1]; + int64_t input_depth = input_size[2]; + int64_t input_height = input_size[3]; + int64_t input_width = input_size[4]; + + TORCH_CHECK( + input_depth > 0 && input_height > 0 && input_width > 0 && + output_depth > 0 && output_height > 0 && output_width > 0, + "Input and output sizes should be greater than 0, but got input (D: ", + input_depth, + ", H: ", + input_height, + ", W: ", + input_width, + ") output (D: ", + output_depth, + ", H: ", + output_height, + ", W: ", + output_width, + ")"); + + + return {nbatch, channels, output_depth, output_height, output_width}; +} + +inline void upsample_2d_shape_check( + const Tensor& input, + const Tensor& grad_output, + int64_t nbatch, + int64_t nchannels, + int64_t input_height, + int64_t input_width, + int64_t output_height, + int64_t output_width) { + TORCH_CHECK( + input_height > 0 && input_width > 0 && output_height > 0 && + output_width > 0, + "Input and output sizes should be greater than 0," + " but got input (H: ", + input_height, + ", W: ", + input_width, + ") output (H: ", + output_height, + ", W: ", + output_width, + ")"); + + if (input.defined()) { + // Allow for empty batch size but not other dimensions + TORCH_CHECK( + (input.numel() != 0 || + (input.size(1) != 0 && input.size(2) != 0 && input.size(3) != 0) + ) && + input.dim() == 4, + "Non-empty 4D data tensor expected but got a tensor with sizes ", + input.sizes()); + } else if (grad_output.defined()) { + check_dim_size(grad_output, 4, 0, nbatch); + check_dim_size(grad_output, 4, 1, nchannels); + check_dim_size(grad_output, 4, 2, output_height); + check_dim_size(grad_output, 4, 3, output_width); + } +} + +template +inline scalar_t compute_scales_value( + const std::optional scale, + int64_t input_size, + int64_t output_size) { + // see Note [compute_scales_value] + // FIXME: remove magic > 0 after we ensure no models were serialized with -1 defaults. + return (scale.has_value() && scale.value() > 0.) + ? static_cast(1.0 / scale.value()) + : (static_cast(input_size) / output_size); +} + +template +inline scalar_t area_pixel_compute_scale( + int64_t input_size, + int64_t output_size, + bool align_corners, + const std::optional scale) { + // see Note [area_pixel_compute_scale] + if(align_corners) { + if(output_size > 1) { + return static_cast(input_size - 1) / (output_size - 1); + } else { + return static_cast(0); + } + } else { + return compute_scales_value(scale, input_size, output_size); + } +} + +template +inline scalar_t area_pixel_compute_source_index( + scalar_t scale, + int64_t dst_index, + bool align_corners, + bool cubic) { + if (align_corners) { + return scale * dst_index; + } else { + scalar_t src_idx = scale * (dst_index + static_cast(0.5)) - + static_cast(0.5); + // [Note] Follow Opencv resize logic: + // We allow negative src_idx here and later will use + // dx = src_idx - floorf(src_idx) + // to compute the "distance"(which affects weights). + // For linear modes, weight distribution doesn't matter + // for negative indices as they use 2 pixels to interpolate. + // For example, [-1, 0], they both use pixel 0 value so it + // doesn't affect if we bound the src_idx to 0 or not. + // TODO: Our current linear mode impls use unbound indices + // where we should and then remove this cubic flag. + // This matters in cubic mode, as we might need [-1, 0, 1, 2] + // to interpolate and the weights can be affected. + return (!cubic && src_idx < static_cast(0)) ? scalar_t(0) + : src_idx; + } +} + +inline int64_t nearest_neighbor_compute_source_index( + const float scale, + int64_t dst_index, + int64_t input_size) { + // Index computation matching OpenCV INTER_NEAREST + // which is buggy and kept for BC + const int64_t src_index = + std::min(static_cast(floorf(dst_index * scale)), input_size - 1); + return src_index; +} + +inline int64_t nearest_neighbor_exact_compute_source_index( + const float scale, + int64_t dst_index, + int64_t input_size) { + // index_f32 = (output_index + 0.5) * scale - 0.5 + // input_index = round(index_f32) + // Same as Pillow and Scikit-Image/Scipy ndi.zoom + const int64_t src_index = + std::min(static_cast(floorf((dst_index + 0.5) * scale)), input_size - 1); + return src_index; +} + +inline int64_t nearest_idx( + int64_t output_index, + int64_t input_size, + int64_t output_size, + std::optional scales) { + // This method specifically treats cases: output_size == input_size or + // output_size == 2 * input_size, that we would like to get rid of + // We keep this method for BC and consider as deprecated. + // See nearest_exact_idx as replacement + if (output_size == input_size) { + // scale_factor = 1, simply copy + return output_index; + } else if (output_size == 2 * input_size) { + // scale_factor = 2, shift input index + return output_index >> 1; + } else { + float scale = compute_scales_value(scales, input_size, output_size); + return nearest_neighbor_compute_source_index(scale, output_index, input_size); + } +} + +inline int64_t nearest_exact_idx( + int64_t output_index, + int64_t input_size, + int64_t output_size, + std::optional scales) { + float scale = compute_scales_value(scales, input_size, output_size); + return nearest_neighbor_exact_compute_source_index(scale, output_index, input_size); +} + +// Define a typedef to dispatch to nearest_idx or nearest_exact_idx +typedef int64_t (*nearest_idx_fn_t)(int64_t, int64_t, int64_t, std::optional); + +template +scalar_t upsample_get_value_bounded( + scalar_t* data, + int64_t width, + int64_t height, + int64_t x, + int64_t y) { + int64_t access_x = std::max(std::min(x, width - 1), static_cast(0)); + int64_t access_y = std::max(std::min(y, height - 1), static_cast(0)); + return data[access_y * width + access_x]; +} + +template +void upsample_increment_value_bounded( + scalar_t* data, + int64_t width, + int64_t height, + int64_t x, + int64_t y, + scalar_t value) { + int64_t access_x = std::max(std::min(x, width - 1), static_cast(0)); + int64_t access_y = std::max(std::min(y, height - 1), static_cast(0)); + data[access_y * width + access_x] += value; +} + +// Based on +// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm +template +scalar_t cubic_convolution1(scalar_t x, scalar_t A) { + return ((A + 2) * x - (A + 3)) * x * x + 1; +} + +template +scalar_t cubic_convolution2(scalar_t x, scalar_t A) { + return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; +} + +template +void get_cubic_upsample_coefficients( + scalar_t coeffs[4], + scalar_t t) { + scalar_t A = -0.75; + + scalar_t x1 = t; + coeffs[0] = cubic_convolution2(x1 + 1.0, A); + coeffs[1] = cubic_convolution1(x1, A); + + // opposite coefficients + scalar_t x2 = 1.0 - t; + coeffs[2] = cubic_convolution1(x2, A); + coeffs[3] = cubic_convolution2(x2 + 1.0, A); +} + +template +inline scalar_t cubic_interp1d( + scalar_t x0, + scalar_t x1, + scalar_t x2, + scalar_t x3, + scalar_t t) { + scalar_t coeffs[4]; + get_cubic_upsample_coefficients(coeffs, t); + + return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3]; +} + +// when `real_input_index` becomes larger than the range the floating point +// type can accurately represent, the type casting to `int64_t` might exceed +// `input_size`, causing overflow. So we guard it with `std::min` below. +template +inline void guard_index_and_lambda(const opmath_t& real_input_index, const int64_t& input_size, int64_t& input_index, scalar_t& lambda) { + input_index = std::min(static_cast(floorf(real_input_index)), input_size - 1); + lambda = std::min( + std::max(real_input_index - input_index, static_cast(0)), + static_cast(1) + ); +} + +template +inline void compute_source_index_and_lambda( + int64_t& input_index0, + int64_t& input_index1, + scalar_t& lambda0, + scalar_t& lambda1, + opmath_t ratio, + int64_t output_index, + int64_t input_size, + int64_t output_size, + bool align_corners) { + if (output_size == input_size) { + // scale_factor = 1, simply copy + input_index0 = output_index; + input_index1 = output_index; + lambda0 = static_cast(1); + lambda1 = static_cast(0); + } else { + const auto real_input_index = + area_pixel_compute_source_index( + ratio, output_index, align_corners, /*cubic=*/false); + guard_index_and_lambda(real_input_index, input_size, input_index0, lambda1); + int64_t offset = (input_index0 < input_size - 1) ? 1 : 0; + input_index1 = input_index0 + offset; + lambda0 = static_cast(1.) - lambda1; + } +} + +// It will not be used by data types other than BFloat16 and Half. +template || !std::is_same_v, int> = 0> +void inline apply_grad_input(scalar_in* buffer_ptr, scalar_out* gin, int64_t size) { + TORCH_CHECK((is_reduced_floating_point_v), + "Upsample backward only support BFloat16 and Half in the lower precision data types on CPU.") + TORCH_CHECK((std::is_same_v), + "Upsample backward should use float as acc buffer for BFloat16 and Half grad input on CPU.") + return; +} + +template && std::is_same_v, int> = 0> +void inline apply_grad_input(scalar_in* buffer_ptr, scalar_out* gin, int64_t size) { + using bVec = Vectorized; + using fVec = Vectorized; + int64_t d = 0; + for (; d < size - (size % bVec::size()); d += bVec::size()) { + bVec gin_bvec = bVec::loadu(gin + d); + auto [gin_fvec0, gin_fvec1] = convert_to_float(gin_bvec); + gin_fvec0 += fVec::loadu(buffer_ptr + d); + gin_fvec1 += fVec::loadu(buffer_ptr + d + fVec::size()); + fVec(0).store(buffer_ptr + d); + fVec(0).store(buffer_ptr + d + fVec::size()); + convert_from_float(gin_fvec0, gin_fvec1).store(gin + d); + } + for (; d < size; d++) { + gin[d] += buffer_ptr[d]; + buffer_ptr[d] = 0; + } +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..cfff8f8f038cc29bb7b2e882746d4a8c632a0aee --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h @@ -0,0 +1,95 @@ +#pragma once + +#include +#include + +#ifdef USE_FBGEMM +#include +#include +#include + + +namespace ao::sparse { + +struct TORCH_API PackedLinearWeight + : public LinearPackedParamsBase { + PackedLinearWeight(std::unique_ptr> w, + std::optional bias, + std::vector col_offsets, + std::vector w_scale, + std::vector w_zp, + c10::QScheme q_scheme, + const int64_t out_features_block_size /* block sparsity size across output_features */, + const int64_t in_features_block_size /* block sparsity size across input_features */) + : LinearPackedParamsBase( + out_features_block_size, + in_features_block_size), + w(std::move(w)), + bias_(std::move(bias)), + col_offsets(std::move(col_offsets)), + w_scale(std::move(w_scale)), + w_zp(std::move(w_zp)), + q_scheme(q_scheme) {} + std::unique_ptr> w; + std::optional bias_; + std::vector col_offsets; + std::vector w_scale; + std::vector w_zp; + c10::QScheme q_scheme; + + at::Tensor apply( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override; + at::Tensor apply_relu( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_dynamic(const at::Tensor& input) override { + TORCH_INTERNAL_ASSERT( + false, + "Sparse quantized dynamic linear with fused relu is not yet " + "supported on qnnpack backend."); + return at::Tensor(); + } + at::Tensor apply_dynamic_relu(const at::Tensor& input) override { + TORCH_INTERNAL_ASSERT( + false, + "Sparse quantized dynamic linear with fused relu is not yet " + "supported on qnnpack backend."); + return at::Tensor(); + } + + LinearPackedSerializationType unpack() override; + + BCSRSerializationType serialize() override; + + static c10::intrusive_ptr deserialize( + const BCSRSerializationType& serialized); + + std::optional bias() override { + return bias_; + } + + static c10::intrusive_ptr prepack( + const at::Tensor& weight, + const std::optional& bias, + const int64_t out_features_block_size, + const int64_t in_features_block_size); + + private: + template + at::Tensor apply_impl( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point); +}; + +} // namespace ao::sparse + +#endif // USE_FBGEMM + +namespace ao::sparse { +int register_linear_params(); +} // namespace ao::sparse diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/packed_params.h b/phivenv/Lib/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/packed_params.h new file mode 100644 index 0000000000000000000000000000000000000000..c75ae6b18aa0f5238662e44b4ae9a349223f613b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/packed_params.h @@ -0,0 +1,74 @@ +#pragma once + +#include + +#include + +namespace ao::sparse { + +// +using LinearPackedSerializationType = + std::tuple, std::vector>; + +#define SPARSE_LINEAR_PACKED_PARAM_SERIALIZATION_VERSION 2 + +using BCSRSerializationType = + std::tuple< + int64_t, // Serialization Version + std::optional, // Bias + int64_t, // Out Features (Row) Block Size + int64_t, // In Features (Column) Block Size + at::Tensor, // Weight Scales (single element vector if per-tensor) (float) + at::Tensor, // Wrapper for Weight Zero Points (single element vector if per-tensor) (int8_t) + bool, // Quantization Scheme (true: per tensor, false: per channel) + at::Tensor, // Wrapper for Row Block Indices (int8_t, int16_t, or int32_t) + at::Tensor, // Wrapper for Column Block Indices (int8_t, int16_t, or int32_t) + at::Tensor, // Wrapper for Non-Zero Weight Values, each +128 (uint8_t) + int64_t, // Number of Output Channels + int64_t // Number of Input Channels + >; + +using BCSR = + std::tuple< + std::vector, // Non-Zero Weight Values + std::vector, // Compressed Row Block Indices + std::vector // Column Block Indices + >; + +struct LinearPackedParamsBase : public torch::jit::CustomClassHolder { + public: + LinearPackedParamsBase( + const int64_t out_features_block_size, + const int64_t in_features_block_size) + : out_features_block_size_(out_features_block_size), + in_features_block_size_(in_features_block_size) {} + + virtual at::Tensor apply( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) = 0; + virtual at::Tensor apply_relu( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) = 0; + + virtual at::Tensor apply_dynamic(const at::Tensor& input) = 0; + virtual at::Tensor apply_dynamic_relu(const at::Tensor& input) = 0; + + virtual LinearPackedSerializationType unpack() = 0; + + virtual BCSRSerializationType serialize() = 0; + + virtual std::optional bias() = 0; + + virtual void set_bias(const std::optional& bias) { + throw std::runtime_error( + "set_bias is not implemented for this packed " + "parameter type"); + } + + protected: + const int64_t out_features_block_size_, in_features_block_size_; +}; + +} // namespace ao::sparse diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/qnnpack_utils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/qnnpack_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..563a2b26fb9593cd78e0fb7f7e15cc5c2b86a7e0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/qnnpack_utils.h @@ -0,0 +1,90 @@ +#pragma once + +#include +#include + +#ifdef USE_PYTORCH_QNNPACK +// TODO: Refacto QnnpackUtils.h so as to separate code +// needed for quantized op from the generic qnnpack specific +// quantization utilities. +#include +#include +#include + +namespace ao::sparse { + +struct TORCH_API PackedLinearWeightQnnp : public LinearPackedParamsBase { + PackedLinearWeightQnnp(const at::Tensor& weight, const std::optional& bias, const int64_t out_features_block_size /* block sparsity size across output_features */, const int64_t in_features_block_size /* block sparsity size across input_features */); + explicit PackedLinearWeightQnnp(const BCSRSerializationType& serialized); + std::optional orig_bias_; + // Separate copy of bias exist so that we can fill in zeros when + // optional bias does not exist. This is to compy with qnnpack operator that + // expects bias to be present. + // In case bias is present bias_ is just a reference to orig_bias_ + at::Tensor bias_; + c10::QScheme q_scheme_; + double input_scale_{}; + std::unique_ptr bcsr_matrix_; + at::Tensor w_scales_; + std::vector w_zero_points_; + std::vector requantization_scales_; + std::unique_ptr + sparse_linear_op_{nullptr}; + int64_t output_channels_; + int64_t input_channels_; + // Deserialized Tensors are stored to maintain the lifetime of underlying + // BCSR data. + // These are left empty if PackedLinearWeightQnnp is created via prepacking + // rather than deserializing. + at::Tensor deserialized_bcsr_row_block_indices_; + at::Tensor deserialized_bcsr_col_block_indices_; + at::Tensor deserialized_bcsr_weight_values_; + + at::Tensor apply( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override { + TORCH_CHECK( + false, "Static quantized sparse linear unimplemented on QNNPACK"); + } + at::Tensor apply_relu( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override { + TORCH_CHECK( + false, "Static quantized sparse linear unimplemented on QNNPACK"); + } + + at::Tensor apply_dynamic(const at::Tensor& input) override; + at::Tensor apply_dynamic_relu(const at::Tensor& input) override; + + LinearPackedSerializationType unpack() override; + + BCSRSerializationType serialize() override; + + static c10::intrusive_ptr deserialize( + const BCSRSerializationType& serialized); + + std::optional bias() override { + return orig_bias_; + } + + static c10::intrusive_ptr prepack( + const at::Tensor& weight, + const std::optional& bias, + const int64_t out_features_block_size, + const int64_t in_features_block_size); + + private: + template + at::Tensor apply_impl( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point); + template + at::Tensor apply_dynamic_impl(const at::Tensor& input); +}; + +} // namespace ao::sparse + +#endif // USE_PYTORCH_QNNPACK diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/batch_norm.h b/phivenv/Lib/site-packages/torch/include/ATen/native/batch_norm.h new file mode 100644 index 0000000000000000000000000000000000000000..6ae48a0089cf6e5a65e55f0bb0dcb07fdf202419 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/batch_norm.h @@ -0,0 +1,38 @@ +#pragma once + +#include +#include + +namespace at::native { + +using batch_norm_fn = void (*)(Tensor&, const Tensor&, const Tensor&, + const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, bool, double); +using batch_norm_collect_stats_fn = void (*)(Tensor&, Tensor&, const Tensor&); +using batch_norm_backward_fn = void(*)(Tensor&, Tensor&, Tensor&, const Tensor&, + const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, bool, double); + +DECLARE_DISPATCH(batch_norm_fn, batch_norm_cpu_stub) +DECLARE_DISPATCH(batch_norm_collect_stats_fn, batch_norm_cpu_collect_stats_stub) +DECLARE_DISPATCH(batch_norm_backward_fn, batch_norm_cpu_backward_stub) + +// TensorAccessor when it is defined to work around undefined... +template +static TensorAccessor conditional_accessor_1d(const Tensor& t) { + if (! t.defined()) { + return TensorAccessor(nullptr, nullptr, nullptr); + } + return t.accessor(); +} + +template +static scalar_t* conditional_data_ptr(const Tensor& t) { + if constexpr (std::is_const_v) { + return t.defined() ? t.contiguous().const_data_ptr() + : nullptr; + } else { + return t.defined() ? t.contiguous().data_ptr() + : nullptr; + } +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/AtomicAddFloat.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/AtomicAddFloat.h new file mode 100644 index 0000000000000000000000000000000000000000..5f2fe7f1a32f50cd35585d4f1060cd76386beec5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/AtomicAddFloat.h @@ -0,0 +1,37 @@ +#ifndef ATOMIC_ADD_FLOAT +#define ATOMIC_ADD_FLOAT + +#if (defined(__x86_64__) || defined(__i386__) || defined(__aarch64__)) +#include +#else +#define _mm_pause() +#endif + +#include + +static inline void cpu_atomic_add_float(float* dst, float fvalue) +{ + typedef union { + unsigned intV; + float floatV; + } uf32_t; + + uf32_t new_value, old_value; + std::atomic* dst_intV = (std::atomic*)(dst); + + old_value.floatV = *dst; + new_value.floatV = old_value.floatV + fvalue; + + unsigned* old_intV = (unsigned*)(&old_value.intV); + while (!std::atomic_compare_exchange_strong(dst_intV, old_intV, new_value.intV)) { +#ifdef __aarch64__ + __asm__ __volatile__("yield;" : : : "memory"); +#else + _mm_pause(); +#endif + old_value.floatV = *dst; + new_value.floatV = old_value.floatV + fvalue; + } +} + +#endif diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/CatKernel.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/CatKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..e4f43b4624c9eeb6a4de8039db8546e1c261ae25 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/CatKernel.h @@ -0,0 +1,12 @@ +#pragma once + +#include +#include +#include + +namespace at::native { + +using cat_serial_fn = void(*)(const Tensor &, const MaterializedITensorListRef&, int64_t); +DECLARE_DISPATCH(cat_serial_fn, cat_serial_stub) + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/ChannelShuffleKernel.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/ChannelShuffleKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..355d27bb306c490a84c4cf293a3f2c56e3829248 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/ChannelShuffleKernel.h @@ -0,0 +1,14 @@ +#pragma once +#include +#include + +namespace at { +class TensorBase; +} + +namespace at::native { + +using channel_shuffle_fn = void(*)(TensorBase&, const TensorBase&, int64_t); +DECLARE_DISPATCH(channel_shuffle_fn, channel_shuffle_kernel) + +} // at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/CopyKernel.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/CopyKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..469f399ca637ae886fa93cf584d6b8c3b44f681f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/CopyKernel.h @@ -0,0 +1,14 @@ +#pragma once + +#include + +namespace at { +struct TensorIteratorBase; + +namespace native { +inline namespace CPU_CAPABILITY { + +void direct_copy_kernel(TensorIteratorBase &iter); +void copy_kernel(TensorIterator& iter, bool /*non_blocking*/); + +}}} // namespace at::native::CPU_CAPABILITY diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/DepthwiseConvKernel.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/DepthwiseConvKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..09ee71a1310f5d8d2ca41520a7870dafa1f04eea --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/DepthwiseConvKernel.h @@ -0,0 +1,21 @@ +#pragma once + +#include +#include + +/* + Depthwise 3x3 Winograd convolution operator +*/ + +namespace at { +class Tensor; + +namespace native { + +using convolution_depthwise3x3_winograd_fn = + Tensor (*)(const Tensor &, const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, int64_t); + +DECLARE_DISPATCH(convolution_depthwise3x3_winograd_fn, convolution_depthwise3x3_winograd_stub) + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/DistributionTemplates.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/DistributionTemplates.h new file mode 100644 index 0000000000000000000000000000000000000000..300d4e6d6d977067057b61b0da2f4e51c5b3d6b4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/DistributionTemplates.h @@ -0,0 +1,425 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef CPU_CAPABILITY_AVX2 +#include +#include +#endif + + + + +namespace at::native::templates::cpu { +namespace { + +// ==================================================== Random ======================================================== + +template +void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG generator) { + AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cpu", AT_WRAP([&] { + std::lock_guard lock(generator->mutex_); + cpu_serial_kernel(iter, [range, base, generator]() -> scalar_t { + uniform_int_from_to_distribution random(range, base); + return random(generator); + }); + }), kBool, kHalf, kBFloat16, AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); +} + +// This is the special kernel to handle single specific case: +// from(inclusive) = std::numeric_limits::lowest() +// to(exclusive) = None (= std::numeric_limits::max() + 1) +template +void random_full_64_bits_range_kernel(TensorIteratorBase& iter, RNG generator) { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::BFloat16, iter.dtype(), "random_full_64_bits_range_kernel_cpu", [&] { + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) { + std::lock_guard lock(generator->mutex_); + cpu_serial_kernel(iter, [generator]() -> scalar_t { + uniform_int_full_range_distribution random; + return random(generator); + }); + } else { + TORCH_CHECK(false, "random_full_64_bits_range_kernel_cpu handles only int64, double, float and bfloat16"); + } + }); +} + +template +struct RandomFromToKernel { + void operator()(TensorIteratorBase& iter, uint64_t range, int64_t base, std::optional gen) { + random_from_to_kernel(iter, range, base, check_generator(gen)); + } + void operator()(TensorIteratorBase& iter, std::optional gen) { + random_full_64_bits_range_kernel(iter, check_generator(gen)); + } +}; + +template +void random_kernel(TensorIteratorBase& iter, RNG generator) { + std::lock_guard lock(generator->mutex_); + AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "random_kernel_cpu", [&] { + cpu_serial_kernel(iter, [generator]() -> scalar_t { + uniform_int_distribution random; + return random(generator); + }); + }); +} + +template +struct RandomKernel { + void operator()(TensorIteratorBase& iter, std::optional gen) { + random_kernel(iter, check_generator(gen)); + } +}; + +// ==================================================== Normal ======================================================== + +#ifdef CPU_CAPABILITY_AVX2 +static void normal_fill_16_AVX2(float *data, + const __m256* two_pi, + const __m256* one, + const __m256* minus_two, + const __m256* mean, + const __m256* std_v) { + const __m256 u1 = _mm256_sub_ps(*one, _mm256_loadu_ps(data)); + const __m256 u2 = _mm256_loadu_ps(data + 8); + // sincos256_ps and log256_ps are from avx_mathfun.h + const __m256 radius = _mm256_sqrt_ps(_mm256_mul_ps(*minus_two, log256_ps(u1))); + const __m256 theta = _mm256_mul_ps(*two_pi, u2); + __m256 sintheta, costheta; + sincos256_ps(theta, &sintheta, &costheta); + const __m256 n1 = _mm256_mul_ps(radius, costheta); + const __m256 n2 = _mm256_mul_ps(radius, sintheta); + _mm256_storeu_ps(data, _mm256_fmadd_ps(n1, *std_v, *mean)); + _mm256_storeu_ps(data + 8, _mm256_fmadd_ps(n2, *std_v, *mean)); +} + +template +void normal_fill_AVX2(const TensorBase &self, const float mean, const float std, RNG generator) { + float *data = self.data_ptr(); + auto size = self.numel(); + std::lock_guard lock(generator->mutex_); + for (const auto i : c10::irange(size)) { + at::uniform_real_distribution uniform(0, 1); + data[i] = uniform(generator); + } + const __m256 two_pi = _mm256_set1_ps(2.0f * c10::pi); + const __m256 one = _mm256_set1_ps(1.0f); + const __m256 minus_two = _mm256_set1_ps(-2.0f); + const __m256 mean_v = _mm256_set1_ps(mean); + const __m256 std_v = _mm256_set1_ps(std); + + for (int64_t i = 0; i < size - 15; i += 16) { + normal_fill_16_AVX2(data + i, &two_pi, &one, &minus_two, &mean_v, &std_v); + } + + if (size % 16 != 0) { + // Recompute the last 16 values. + data = data + size - 16; + for (const auto i : c10::irange(16)) { + at::uniform_real_distribution uniform(0, 1); + data[i] = uniform(generator); + } + normal_fill_16_AVX2(data, &two_pi, &one, &minus_two, &mean_v, &std_v); + } +} +#endif + +template +static void normal_fill_16(scalar_t *data, const scalar_t mean, const scalar_t std) { + for (const auto j : c10::irange(8)) { + const scalar_t u1 = 1 - data[j]; // [0, 1) -> (0, 1] for log. + const scalar_t u2 = data[j + 8]; + const scalar_t radius = std::sqrt(-2 * std::log(u1)); + const scalar_t theta = 2.0f * c10::pi * u2; + data[j] = radius * std::cos(theta) * std + mean; + data[j + 8] = radius * std::sin(theta) * std + mean; + } +} + +#if defined(__VSX__) || defined(CPU_CAPABILITY_VSX) +static void normal_fill_16_VSX(float *data,const Vectorized &two_pi,const Vectorized &one,const Vectorized &minus_two,const Vectorized &mean,const Vectorized &std) { + using Vec = Vectorized; + Vec u1=one-Vec::loadu(data); + Vec u2=Vec::loadu(data+8); + Vec radius=(minus_two * u1.log()); + radius=radius.sqrt(); + Vec theta=two_pi * u2; + Vec output_vec=radius * theta.cos() * std + mean; + Vec output_vec2=radius * theta.sin() * std + mean; + output_vec.store(data); + output_vec2.store(data+8); +} + +template +void normal_fill_VSX(const TensorBase &self, const scalar_t mean, const scalar_t std, RNG generator) { + float *data = self.data_ptr(); + auto size = self.numel(); + std::lock_guard lock(generator->mutex_); + for (const auto i : c10::irange(size)) { + at::uniform_real_distribution uniform(0, 1); + data[i] = uniform(generator); + } + + using Vec = Vectorized; + const Vec two_pi = Vec(2.0f * c10::pi); + const Vec one = Vec(1.0f); + const Vec minus_two = Vec(-2.0f); + const Vec var_vec = Vec(std); + const Vec mean_vec = Vec(mean); + + for (int64_t i = 0; i < size - 15; i += 16) { + if(Vec::size()==8) { + normal_fill_16_VSX(data + i, two_pi, one, minus_two, mean_vec, var_vec); + } + else{ + normal_fill_16(data + i, mean, std); + } + } + if (size % 16 != 0) { + // Recompute the last 16 values. + data = data + size - 16; + for (const auto i : c10::irange(16)) { + at::uniform_real_distribution uniform(0, 1); + data[i] = uniform(generator); + } + if(Vec::size()==8){ + normal_fill_16_VSX(data, two_pi, one, minus_two, mean_vec, var_vec); + } + else{ + normal_fill_16(data, mean, std); + } + } +} +#endif //VSX + +template +void normal_fill(const TensorBase &self, const scalar_t mean, const scalar_t std, RNG generator) { + scalar_t *data = self.data_ptr(); + auto size = self.numel(); + std::lock_guard lock(generator->mutex_); + for (const auto i : c10::irange(size)) { + at::uniform_real_distribution uniform(0, 1); + data[i] = uniform(generator); + } + + for (int64_t i = 0; i < size - 15; i += 16) { + normal_fill_16(data + i, mean, std); + } + if (size % 16 != 0) { + // Recompute the last 16 values. + data = data + size - 16; + for (const auto i : c10::irange(16)) { + at::uniform_real_distribution uniform(0, 1); + data[i] = uniform(generator); + } + normal_fill_16(data, mean, std); + } +} + +template +void normal_kernel(const TensorBase &self, double mean, double std, RNG generator) { + auto size = self.numel(); + if (self.scalar_type() == ScalarType::Float && size >= 16 && self.is_contiguous()) { +#ifdef CPU_CAPABILITY_AVX2 + normal_fill_AVX2(self, static_cast(mean), static_cast(std), generator); +#elif defined(__VSX__) || defined(CPU_CAPABILITY_VSX) + normal_fill_VSX(self, static_cast(mean), static_cast(std), generator); +#else + normal_fill(self, static_cast(mean), static_cast(std), generator); +#endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, self.scalar_type(), "normal_kernel_cpu", [&] { + if (size >= 16 && self.is_contiguous()) { + normal_fill(self, static_cast(mean), static_cast(std), generator); + } else { + auto iter = TensorIterator::borrowing_nullary_op(self); + std::lock_guard lock(generator->mutex_); + cpu_serial_kernel(iter, [mean, std, generator]() -> scalar_t { + at::normal_distribution normal(mean, std); + return static_cast(normal(generator)); + }); + } + }); + } +} + +template +struct NormalKernel { + void operator()(Tensor& self, double mean, double std, std::optional gen) { + normal_kernel(self, mean, std, check_generator(gen)); + } +}; + +// ==================================================== Uniform ======================================================= + +template +void uniform_kernel(TensorIteratorBase& iter, double from_, double to_, RNG generator) { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "uniform_kernel_cpu", [&]() { + std::lock_guard lock(generator->mutex_); + auto from = static_cast(from_); + auto to = static_cast(to_); + at::uniform_real_distribution uniform(from, to); + cpu_serial_kernel(iter, [&uniform, generator]() -> scalar_t { + return static_cast(uniform(generator)); + }); + }); +} + +template +struct UniformKernel { + void operator()(TensorIteratorBase& iter, double from, double to, std::optional gen) { + uniform_kernel(iter, from, to, check_generator(gen)); + } +}; + +// ==================================================== Cauchy ======================================================== + +template +void cauchy_kernel(TensorIteratorBase& iter, double median, double sigma, RNG generator) { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "cauchy_cpu", [&]() { + std::lock_guard lock(generator->mutex_); + at::cauchy_distribution cauchy(median, sigma); + cpu_serial_kernel(iter, [&cauchy, generator]() -> scalar_t { + return static_cast(cauchy(generator)); + }); + }); +} + +template +struct CauchyKernel { + void operator()(TensorIteratorBase& iter, double median, double sigma, std::optional gen) { + cauchy_kernel(iter, median, sigma, check_generator(gen)); + } +}; + +// ================================================== LogNormal ======================================================= + +template +void log_normal_kernel(TensorIteratorBase& iter, double mean, double std, RNG generator) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "log_normal_cpu", [&]() { + std::lock_guard lock(generator->mutex_); + at::lognormal_distribution logNormal(mean, std); + cpu_serial_kernel(iter, [&logNormal, generator]() -> scalar_t { + return static_cast(logNormal(generator)); + }); + }); +} + +template +struct LogNormalKernel { + void operator()(TensorIteratorBase& iter, double mean, double std, std::optional gen) { + log_normal_kernel(iter, mean, std, check_generator(gen)); + } +}; + +// =================================================== Geometric ====================================================== + +template +void geometric_kernel(TensorIteratorBase& iter, double p, RNG generator) { + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "geometric_cpu", [&]() { + std::lock_guard lock(generator->mutex_); + at::geometric_distribution geometric(p); + cpu_serial_kernel(iter, [&geometric, generator]() -> scalar_t { + return static_cast(geometric(generator)); + }); + }); +} + +template +struct GeometricKernel { + void operator()(TensorIteratorBase& iter, double p, std::optional gen) { + geometric_kernel(iter, p, check_generator(gen)); + } +}; + +// ================================================== Exponential ===================================================== + +template +void exponential_kernel(TensorIteratorBase& iter, double lambda, RNG generator) { + TORCH_CHECK(isFloatingType(iter.dtype()), "Exponential distribution is a continuous probability distribution. dtype must be a floating point but you specified ", iter.dtype()); + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exponential_cpu", [&]() { + std::lock_guard lock(generator->mutex_); + at::exponential_distribution exponential(lambda); + cpu_serial_kernel(iter, [&exponential, generator]() -> scalar_t { + return static_cast(exponential(generator)); + }); + }); +} + +template +struct ExponentialKernel { + void operator()(TensorIteratorBase& iter, double lambda, std::optional gen) { + exponential_kernel(iter, lambda, check_generator(gen)); + } +}; + +// ================================================== Bernoulli ======================================================= + +template +void bernoulli_kernel(const TensorBase &self, const TensorBase &p_, RNG generator) { + AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, + self.scalar_type(), "bernoulli_tensor_cpu_self_", [&] { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(generator->mutex_); + using self_t = scalar_t; + auto p_cpu = p_.to(kCPU); + auto p = expand_inplace(self, p_cpu); + auto iter = TensorIteratorConfig() + .add_output(self) + .add_const_input(*p) + .check_all_same_dtype(false) + .build(); + if (p->scalar_type() == kDouble) { + cpu_serial_kernel(iter, [&](const double p_val) -> self_t { + at::bernoulli_distribution bernoulli(p_val); + return static_cast(bernoulli(generator)); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half, + p->scalar_type(), "bernoulli_tensor_cpu_p_", [&] { + using p_t = scalar_t; + cpu_serial_kernel(iter, [&](const p_t p_val) -> self_t { + at::bernoulli_distribution bernoulli(p_val); + return static_cast(bernoulli(generator)); + }); + }); + } + }); +} + +template +void bernoulli_kernel(const TensorBase &self, double p, RNG generator) { + AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, + self.scalar_type(), "bernoulli_scalar_cpu_", [&] { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(generator->mutex_); + auto iter = TensorIterator::borrowing_nullary_op(self); + cpu_serial_kernel(iter, [p, generator]() -> scalar_t { + at::bernoulli_distribution bernoulli(p); + return static_cast(bernoulli(generator)); + }); + }); +} + +template +struct BernoulliKernel { + void operator()(const TensorBase &self, double p, std::optional gen) { + bernoulli_kernel(self, p, check_generator(gen)); + } + void operator()(const TensorBase &self, const TensorBase &p_, std::optional gen) { + bernoulli_kernel(self, p_, check_generator(gen)); + } +}; + +}} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/Elu.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/Elu.h new file mode 100644 index 0000000000000000000000000000000000000000..cf8f62782cc793ca3d885009c9bbb79a27692c34 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/Elu.h @@ -0,0 +1,74 @@ +#pragma once + +// On Windows, math.h needs to be included with _USE_MATH_DEFINES defined to +// access constants such as M_SQRT2 and M_2_SQRTPI. +#ifdef _WIN32 +#define _USE_MATH_DEFINES +#include +#endif // _WIN32 + +#include +#include // For c10::is_reduced_floating_point_v. + +namespace at::native { +inline namespace CPU_CAPABILITY { +/** + * Return a function object that calculates ELU with the given + * parameters on its input element. ParamT is the type of the input + * and output to the ELU, and MathT is the type (possibly + * higher-precision, e.g. float if ParamT is reduced-precision float) + * in which to do intermediate calculations. + */ +template +auto get_scalar_elu_elementwise_func(MathT alpha, MathT scale, MathT input_scale) { + const auto negcoef = alpha * scale; + const auto poscoef = scale; + const auto negiptcoef = input_scale; + return [negcoef, negiptcoef, poscoef](ParamT a) -> ParamT { + return MathT(a) < MathT(0) + ? std::expm1(MathT(a) * negiptcoef) * negcoef + : MathT(a) * poscoef; + }; +} + +/** + * Return a function object that calculates ELU with the given + * parameters on its input element. The function object takes and + * returns Vectorized. + */ +template , bool> = true> +auto get_vectorized_elu_elementwise_func(T alpha, T scale, T input_scale) { + const vec::Vectorized negcoef_vec(alpha * scale); + const vec::Vectorized poscoef_vec(scale); + const vec::Vectorized negiptcoef_vec(input_scale); + const vec::Vectorized zero_vec(static_cast(0)); + return [negcoef_vec, poscoef_vec, negiptcoef_vec, zero_vec](vec::Vectorized a) -> vec::Vectorized { + const auto cmp = a >= zero_vec; + if (!cmp.zero_mask()) { + return a * poscoef_vec; + } else { + return vec::Vectorized::blendv((a * negiptcoef_vec).expm1() * negcoef_vec, a * poscoef_vec, cmp); + } + }; +} + +/** + * Return a function object that calculates ELU with the given + * parameters on its input element. The function object takes and + * returns Vectorized, and Vectorized is the type + * (possibly higher-precision) in which to do intermediate + * calculations. + */ +template , bool> = true> +auto get_vectorized_elu_elementwise_func(float alpha, float scale, float input_scale) { + // Takes float->float. + const auto float_func = get_vectorized_elu_elementwise_func(alpha, scale, input_scale); + return [float_func](vec::Vectorized a) -> vec::Vectorized { + auto [a0, a1] = vec::convert_to_float(a); + auto res0 = float_func(a0); + auto res1 = float_func(a1); + return vec::convert_from_float(res0, res1); + }; +} +} // namespace CPU_CAPABILITY +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/Gelu.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/Gelu.h new file mode 100644 index 0000000000000000000000000000000000000000..1b9cc5248e6ce3c0228ee1552e86b7f942cfd3f7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/Gelu.h @@ -0,0 +1,83 @@ +#pragma once + +// On Windows, math.h needs to be included with _USE_MATH_DEFINES defined to +// access constants such as M_SQRT2 and M_2_SQRTPI. +#ifdef _WIN32 +#define _USE_MATH_DEFINES +#include +#include +#endif // _WIN32 + +#include +#include // For c10::is_reduced_floating_point_v. + +namespace at::native { +inline namespace CPU_CAPABILITY { +constexpr double kGeluBeta = M_SQRT2 * M_2_SQRTPI * 0.5; +constexpr double kGeluKappa = 0.044715; + +template +using reduced_fp_to_float_t = std::conditional_t, float, T>; + +template , bool> = true> +float reduced_fp_to_float(T x) { + return float(x); +} + +template , bool> = true> +T reduced_fp_to_float(T x) { + return x; +} + +template +T scalar_gelu_approximated_with_tanh(T x) { + using opmath_t = reduced_fp_to_float_t; + auto x_float = reduced_fp_to_float(x); + auto x_cube = x_float * x_float * x_float; + auto inner = opmath_t(kGeluBeta) * (x_float + opmath_t(kGeluKappa) * x_cube); + return opmath_t(0.5) * x_float * (opmath_t(1) + std::tanh(inner)); +} + +template , bool> = true> +vec::Vectorized vectorized_gelu_approximated_with_tanh(vec::Vectorized x) { + const vec::Vectorized kPointFiveVec(T(0.5)); + const vec::Vectorized kOneVec(T(1)); + const vec::Vectorized kGeluBetaVec((T(kGeluBeta))); + const vec::Vectorized kGeluKappaVec((T(kGeluKappa))); + auto x_cube = x * x * x; + vec::Vectorized inner_vec = kGeluBetaVec * (x + kGeluKappaVec * x_cube); + return kPointFiveVec * x * (kOneVec + inner_vec.tanh()); +} + +template , bool> = true> +vec::Vectorized vectorized_gelu_approximated_with_tanh(vec::Vectorized x) { + auto [x0, x1] = at::vec::convert_to_float(x); + return at::vec::convert_from_float( + vectorized_gelu_approximated_with_tanh(x0), + vectorized_gelu_approximated_with_tanh(x1)); +} + + +template +T scalar_gelu(T x) { + using opmath_t = reduced_fp_to_float_t; + const auto kAlpha = opmath_t(M_SQRT1_2); + return reduced_fp_to_float(x) * opmath_t(0.5) * (opmath_t(1) + std::erf(reduced_fp_to_float(x) * kAlpha)); +} + +template, bool> = true> +vec::Vectorized vectorized_gelu(vec::Vectorized x) { + const vec::Vectorized kAlphaVec(T(M_SQRT1_2)); + const vec::Vectorized kOneVec(T(1)); + const vec::Vectorized kPointFiveVec(T(0.5)); + return x * kPointFiveVec * (kOneVec + (x * kAlphaVec).erf()); +} + +template, bool> = true> +vec::Vectorized vectorized_gelu(vec::Vectorized x) { + auto [x0, x1] = at::vec::convert_to_float(x); + return at::vec::convert_from_float(vectorized_gelu(x0), vectorized_gelu(x1)); +} + +} // namespace CPU_CAPABILITY +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/GridSamplerKernel.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/GridSamplerKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..f0c8439f6f80e025c98d4a698464d5bd6244916e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/GridSamplerKernel.h @@ -0,0 +1,34 @@ +#pragma once + +#include + +#include +#include + +namespace at { +class TensorBase; +} + +namespace at::native { + +using forward_2d_fn = void (*) ( + const TensorBase &output, + const TensorBase &input, + const TensorBase &grid, + int64_t interpolation_mode, + int64_t padding_mode, + bool align_corners); +using backward_2d_fn = void (*) ( + const TensorBase &grad_input, + const TensorBase &grad_grid, + const TensorBase &grad_output, + const TensorBase &input, + const TensorBase &grid, + int64_t interpolation_mode, + int64_t padding_mode, + bool align_corners, + std::array output_mask); +DECLARE_DISPATCH(forward_2d_fn, grid_sampler_2d_cpu_kernel) +DECLARE_DISPATCH(backward_2d_fn, grid_sampler_2d_backward_cpu_kernel) + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/IndexKernelUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/IndexKernelUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..7b1284bb4026f07c55e7b2496ec8ad7003fa1bc9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/IndexKernelUtils.h @@ -0,0 +1,85 @@ +#pragma once +#include +#include + +namespace at::native { + +inline bool is_constant_index(int ntensor, const int64_t* strides) { + AT_ASSERT(ntensor >= 3); + for (const auto arg : c10::irange(2, ntensor)) { + if (strides[arg] != 0) { + return false; + } + } + return true; +} + + +struct Indexer { + Indexer(int64_t num_indexers, char** indexers, const int64_t* indexer_strides, + IntArrayRef original_sizes, IntArrayRef original_strides) + : num_indexers(num_indexers) + , indexers(indexers) + , indexer_strides(indexer_strides) + , original_strides(original_strides.data()) + , original_sizes(original_sizes.data()) { + AT_ASSERT(static_cast(original_strides.size()) == num_indexers); + AT_ASSERT(static_cast(original_sizes.size()) == num_indexers); + } + + int64_t num_indexers; + char** indexers; + const int64_t* indexer_strides; + const int64_t* original_strides; + const int64_t* original_sizes; + + int64_t get(int64_t idx) { + int64_t offset = 0; + for (const auto j : c10::irange(num_indexers)) { + int64_t value = *(int64_t*)&indexers[j][idx * indexer_strides[j]]; + int64_t size = original_sizes[j]; + TORCH_CHECK_INDEX(value >= -size && value < size, + "index ", value, " is out of bounds for dimension ", j, " with size ", size); + if (value < 0) { + value += size; + } + offset += value * original_strides[j]; + } + return offset; + } +}; + +template +void cpu_index_kernel(TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride, + const func_t& f, bool serial_execution=false) +{ + int ntensor = iter.ntensors(); + // When launch the index parallel version, set a relative small grain size less than the INTERNAL::GRAIN_SIZE + // to make the whole available thread numbers get more balanced work load and a better cache location. + // The grain size here is chosen by the op benchmark to overcome the thread launch overhead + const int index_parallel_grain_size = 3000; + auto loop = [&](char** data, const int64_t* strides, int64_t n) { + auto indexer = Indexer(ntensor - 2, &data[2], &strides[2], index_size, index_stride); + char* dst = data[0]; + char* src = data[1]; + if (is_constant_index(ntensor, strides)) { + // specialization for when every element uses the same index + int64_t offset = indexer.get(0); + for (const auto i : c10::irange(n)) { + f(dst + strides[0] * i, src + strides[1] * i, offset); + } + } else { + for (const auto i : c10::irange(n)) { + int64_t offset = indexer.get(i); + f(dst + strides[0] * i, src + strides[1] * i, offset); + } + } + }; + if (serial_execution) { + iter.serial_for_each(loop, {0, iter.numel()}); + } else { + iter.for_each(loop, index_parallel_grain_size); + } +} +} // at +// native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/Intrinsics.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/Intrinsics.h new file mode 100644 index 0000000000000000000000000000000000000000..c85239e5a7067907af8c7e903208f2d4338c8213 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/Intrinsics.h @@ -0,0 +1,33 @@ +#pragma once + +#if defined(__clang__) && (defined(__x86_64__) || defined(__i386__)) +/* Clang-compatible compiler, targeting x86/x86-64 */ +#include +#elif defined(_MSC_VER) +/* Microsoft C/C++-compatible compiler */ +#include +#if _MSC_VER <= 1900 +#define _mm256_extract_epi64(X, Y) (((uint64_t*)&X)[Y]) +#endif +#elif defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__)) +/* GCC-compatible compiler, targeting x86/x86-64 */ +#include +#elif defined(__GNUC__) && defined(__ARM_NEON__) +/* GCC-compatible compiler, targeting ARM with NEON */ +#include +#elif defined(__GNUC__) && defined(__IWMMXT__) +/* GCC-compatible compiler, targeting ARM with WMMX */ +#include +#elif (defined(__GNUC__) || defined(__xlC__)) && \ + (defined(__VEC__) || defined(__ALTIVEC__)) +/* XLC or GCC-compatible compiler, targeting PowerPC with VMX/VSX */ +#include +/* We need to undef those tokens defined by to avoid conflicts + with the C++ types. => Can still use __bool/__vector */ +#undef bool +#undef vector +#undef pixel +#elif defined(__GNUC__) && defined(__SPE__) +/* GCC-compatible compiler, targeting PowerPC with SPE */ +#include +#endif diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/IsContiguous.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/IsContiguous.h new file mode 100644 index 0000000000000000000000000000000000000000..6ff1afb8777f6af721d162b55399ffe958984086 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/IsContiguous.h @@ -0,0 +1,64 @@ +#pragma once + +namespace at::native { inline namespace CPU_CAPABILITY { + +// n: number of function arguments (arity) +// traits: function_traits (see FunctionTraits.h) +// s: index of scalar argument or -1 +template +struct IsContiguous { + static bool eval(const int64_t* strides) { + using type = typename traits::template arg::type; + return strides[stride_index] == (s == n ? 0 : sizeof(type)) && + IsContiguous::eval(strides); + } +}; + +// will be called when there is an output exists +template +struct IsContiguous<0, 0, traits, s> { + static bool eval(const int64_t* strides) { + return strides[0] == sizeof(typename traits::result_type); + } +}; + +// will be called when there is no output +template +struct IsContiguous<0, -1, traits, s> { + static bool eval(const int64_t* /*strides*/) { + return true; + } +}; + +// output and all inputs are contiguous +template < + typename traits, + std::enable_if_t>* = + nullptr> +static inline bool is_contiguous(const int64_t* strides) { + return IsContiguous::eval(strides); +} + +template >* = nullptr> +static inline bool is_contiguous(const int64_t* strides) { + return IsContiguous::eval(strides); +} + +// input at `s` is scalar (stride 0); output and other inputs are contiguous +// NB: output is typically at strides[0] so first input corresponds to s=1 +template >* = nullptr> +static inline bool is_contiguous_scalar(const int64_t* strides) { + static_assert(s > 0 && s <= traits::arity, "scalar argument index out of bounds"); + return IsContiguous::eval(strides); +} + +template >* = nullptr> +static inline bool is_contiguous_scalar(const int64_t* strides) { + static_assert(s > 0 && s <= traits::arity, "scalar argument index out of bounds"); + return IsContiguous::eval(strides); +} + +}} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/LogAddExp.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/LogAddExp.h new file mode 100644 index 0000000000000000000000000000000000000000..ebed2b70ea820d0761110ea438905842f0877385 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/LogAddExp.h @@ -0,0 +1,61 @@ +#pragma once + +#include +#include + +namespace at::native { +inline namespace CPU_CAPABILITY { + +// custom min and max to be used in logcumsumexp for complex arguments +template +std::pair, c10::complex> _logcumsumexp_minmax(c10::complex x, c10::complex y) { + if (at::_isnan(y)) { // either real is nan or imag is nan + return std::make_pair(y, y); + } else if (at::_isnan(x)) { // either real is nan or imag is nan + return std::make_pair(x, x); + } else { + return (x.real() < y.real()) ? std::make_pair(x, y) : std::make_pair(y, x); + } +} + +template +scalar_t _log_add_exp_helper(scalar_t x, scalar_t y) { + // Reference : https://www.tensorflow.org/api_docs/python/tf/math/cumulative_logsumexp + scalar_t min = at::_isnan(y) ? y : std::min(x, y); // std::min returns first arg if one of the args is nan + scalar_t max = at::_isnan(y) ? y : std::max(x, y); // std::max returns first arg if one of the args is nan + if (min != max || std::isfinite(min)) { + // nan will be propagated here + return std::log1p(std::exp(min - max)) + max; + } else { + // special case to correctly handle infinite cases + return x; + } +} + +template +c10::complex _log_add_exp_helper(const c10::complex& x, const c10::complex& y) { + auto [min, max] = _logcumsumexp_minmax(x, y); + auto min_real = std::real(min); + auto max_real = std::real(max); + + if (at::_isnan(min)) { // either real is nan or imag is nan + // handling the "infectious" NaNs + return {std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN()}; + } else if (!std::isfinite(min_real) && (min_real == max_real)) { + if (min_real < 0) { + // handle the -inf case, the imaginary part here does not really matter as the exp(value) + // will be around 0.0 and the angle (i.e. the imaginary part) cannot be determined. + // It does not matter if we're taking the exp of this value + return min; + } else { + // handle the +inf case, we don't need the special precision for log1p for small values + // and to avoid producing nan in case of real(max) == real(min) == +inf + return std::log(std::exp(min) + std::exp(max)); + } + } else { + return std::log1p(std::exp(min - max)) + max; + } +} + +} // end namespace +} //end at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/LogSoftmaxKernelImpl.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/LogSoftmaxKernelImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..0c979702030c5b41b1e0d09e64ba1ba091886ff5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/LogSoftmaxKernelImpl.h @@ -0,0 +1,337 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace at::native { +inline namespace CPU_CAPABILITY { +template +int64_t vec_log_softmax_lastdim_chunk_size(int64_t grain_size, int64_t outer_size, int64_t dim_size) { + // Coincidentally, at::internal::GRAIN_SIZE is 32768, which is equal to the + // size of L1D cache on many processors. Some processors have 48 KB L1D cache + // nowadays, so maybe in the future, we can leverage the knowledge of a + // machine's L1D cache size. + int64_t MAX_CHUNK_SIZE = std::max( + 1, + grain_size / (sizeof(scalar_t) * dim_size)); + return std::min(MAX_CHUNK_SIZE, outer_size); +} + +template +void serial_vec_log_softmax_lastdim_range( + const scalar_t* input_data_base, + scalar_t* output_data_base, + int64_t dim_size, + int64_t chunk_size, + int64_t begin, + int64_t end) { + if (end <= begin) { + return; + } + using Vec = vec::Vectorized>; + // MSVC requires such a declaration of dynamic arrays + // Source: https://stackoverflow.com/a/33423538 + auto tmp_sum_scalar = std::make_unique(chunk_size); + auto max_input_arr = std::make_unique(chunk_size); + for (int64_t ii = begin; ii < end; ii += chunk_size) { + int64_t loop_end = chunk_size; + if (ii + chunk_size > end) { + loop_end = end - ii; + } + for (const auto j : c10::irange(loop_end)) { + int64_t i = ii + j; + const scalar_t* input_data = input_data_base + i * dim_size; + max_input_arr[j] = vec::reduce_all( + [](Vec& x, Vec& y) { return vec::maximum(x, y); }, + input_data, + dim_size); + } + for (const auto j : c10::irange(loop_end)) { + int64_t i = ii + j; + const scalar_t* input_data = input_data_base + i * dim_size; + scalar_t max_input = max_input_arr[j]; + tmp_sum_scalar[j] = vec::map_reduce_all( + [max_input](Vec x) { return (x - Vec(max_input)).exp(); }, + [](Vec x, Vec y) { return x + y; }, + input_data, + dim_size); + } + // See [Note AVX-SSE transitions] for why this should call the + // vectorized version (aside from perf improvements). + vec::map( + [](Vec x) { return x.log(); }, + tmp_sum_scalar.get(), + tmp_sum_scalar.get(), + loop_end); + for (const auto j : c10::irange(loop_end)) { + int64_t i = ii + j; + const scalar_t* input_data = input_data_base + i * dim_size; + scalar_t* output_data = output_data_base + i * dim_size; + scalar_t tmp_sum = tmp_sum_scalar[j]; + scalar_t max_input = max_input_arr[j]; + + // It's necessary to keep the order of the operations below. + // In some cases that input is large digits and the difference + // is small, if we compute `max_input` plus `tmp_sum` before, + // there would be a numerical problem. See an example in + // https://github.com/pytorch/pytorch/issues/11752#issuecomment-422883379 + vec::map( + [tmp_sum, max_input](Vec x) { + return x - Vec(max_input) - Vec(tmp_sum); + }, + output_data, + input_data, + dim_size); + } + } +} + +// Can't include ATen/Parallel.h. +// TODO: find a way to have only one copy of divup. +inline int64_t divup(int64_t x, int64_t y) { + return (x + y - 1) / y; +} + +template +std::pair vec_logsoftmax_chunk_size_and_num_chunks(int64_t inner_size, int64_t dim_size) { + using Vec = vec::Vectorized; + int64_t MAX_CHUNK_SIZE = std::max(BLOCK_SIZE / dim_size / sizeof(scalar_t), Vec::size()); + MAX_CHUNK_SIZE = MAX_CHUNK_SIZE / Vec::size() * Vec::size(); + int64_t CHUNK_SIZE = std::min(MAX_CHUNK_SIZE, inner_size); + int64_t num_chunks = divup(inner_size, CHUNK_SIZE); + return {CHUNK_SIZE, num_chunks}; +} + +template +std::enable_if_t>, void> +serial_vec_logsoftmax_range( + const scalar_t* input_data_base, + scalar_t* output_data_base, + int64_t inner_size, + int64_t chunk_size, + int64_t num_chunks, + int64_t dim_size, + int64_t begin, + int64_t end) { + using Vec = vec::Vectorized; + // thread local temp buffer which holds vertical reduction result: max and sum. + auto buffer = std::make_unique(chunk_size * 2); + scalar_t* input_max_data = buffer.get(); + scalar_t* tmp_sum_data = buffer.get() + chunk_size; + + for (int64_t i = begin; i < end; i++) { + int64_t outer_idx = i / num_chunks; + int64_t k = i % num_chunks; + int64_t inner_idx_begin = k * chunk_size; + int64_t size = std::min(chunk_size, inner_size - inner_idx_begin); + + // init + Vec zero_vec = Vec(scalar_t(0)); + Vec min_vec = Vec(-std::numeric_limits::infinity()); + int64_t d0 = 0; + for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) { + min_vec.store(input_max_data + d0); + zero_vec.store(tmp_sum_data + d0); + } + for (; d0 < size; d0++) { + input_max_data[d0] = -std::numeric_limits::infinity(); + tmp_sum_data[d0] = scalar_t(0); + } + + // compute max + for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) { + const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size + + dim_idx * inner_size + inner_idx_begin; + + int64_t d1 = 0; + for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) { + Vec data_vec = Vec::loadu(input_ptr + d1); + Vec max_vec = Vec::loadu(input_max_data + d1); + max_vec = Vec::blendv(max_vec, data_vec, data_vec > max_vec); + max_vec.store(input_max_data + d1); + } + for (; d1 < size; d1++) { + scalar_t data_val = input_ptr[d1]; + scalar_t max_val = input_max_data[d1]; + input_max_data[d1] = data_val > max_val ? data_val : max_val; + } + } + + // compute sum of (x - max).exp() + for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) { + const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size + + dim_idx * inner_size + inner_idx_begin; + + int64_t d2 = 0; + for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) { + Vec data_vec = Vec::loadu(input_ptr + d2); + Vec sum_vec = Vec::loadu(tmp_sum_data + d2); + Vec max_vec = Vec::loadu(input_max_data + d2); + sum_vec += (data_vec - max_vec).exp(); + sum_vec.store(tmp_sum_data + d2); + } + for (; d2 < size; d2++) { + scalar_t data_val = input_ptr[d2]; + scalar_t max_val = input_max_data[d2]; + tmp_sum_data[d2] += std::exp(data_val - max_val); + } + } + + // apply log + vec::map([](Vec x) { return x.log(); }, tmp_sum_data, tmp_sum_data, size); + + // compute x - max - sum + for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) { + int64_t offset = outer_idx * dim_size * inner_size + dim_idx * inner_size + inner_idx_begin; + const scalar_t* input_ptr = input_data_base + offset; + scalar_t* output_ptr = output_data_base + offset; + + int64_t d3 = 0; + for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) { + Vec data_vec = Vec::loadu(input_ptr + d3); + Vec max_vec = Vec::loadu(input_max_data + d3); + Vec sum_vec = Vec::loadu(tmp_sum_data + d3); + Vec out_vec = data_vec - max_vec - sum_vec; + out_vec.store(output_ptr + d3); + } + for (; d3 < size; d3++) { + output_ptr[d3] = input_ptr[d3] - input_max_data[d3] - tmp_sum_data[d3]; + } + } + } +} + +template +std::enable_if_t>, void> +serial_vec_logsoftmax_range( + const scalar_t* input_data_base, + scalar_t* output_data_base, + int64_t inner_size, + int64_t chunk_size, + int64_t num_chunks, + int64_t dim_size, + int64_t begin, + int64_t end) { + using Vec = vec::Vectorized; + using fVec = vec::Vectorized; + auto buffer = std::make_unique(chunk_size * 2); + float* input_max_data = buffer.get(); + float* tmp_sum_data = buffer.get() + chunk_size; + + // thread local buffer that holds input data in float32 to save next 2 dtype conversion + auto input_buffer = std::make_unique(dim_size * chunk_size); + float* input_buffer_data = input_buffer.get(); + + // init + for (int64_t i = begin; i < end; i++) { + int64_t outer_idx = i / num_chunks; + int64_t k = i % num_chunks; + int64_t inner_idx_begin = k * chunk_size; + int64_t size = std::min(chunk_size, inner_size - inner_idx_begin); + + fVec zero_fvec = fVec(float(0)); + fVec min_fvec = fVec(-std::numeric_limits::infinity()); + int64_t d0 = 0; + for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) { + min_fvec.store(input_max_data + d0); + min_fvec.store(input_max_data + d0 + fVec::size()); + zero_fvec.store(tmp_sum_data + d0); + zero_fvec.store(tmp_sum_data + d0 + fVec::size()); + } + for (; d0 < size; d0++) { + input_max_data[d0] = -std::numeric_limits::infinity(); + tmp_sum_data[d0] = float(0); + } + + // compute max + for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) { + const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size + + dim_idx * inner_size + inner_idx_begin; + float* input_buffer_ptr = input_buffer_data + dim_idx * chunk_size; + + int64_t d1 = 0; + for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) { + Vec data_vec = Vec::loadu(input_ptr + d1); + auto [data_fvec0, data_fvec1] = vec::convert_to_float(data_vec); + fVec max_fvec0 = fVec::loadu(input_max_data + d1); + fVec max_fvec1 = fVec::loadu(input_max_data + d1 + fVec::size()); + max_fvec0 = fVec::blendv(max_fvec0, data_fvec0, data_fvec0 > max_fvec0); + max_fvec1 = fVec::blendv(max_fvec1, data_fvec1, data_fvec1 > max_fvec1); + max_fvec0.store(input_max_data + d1); + max_fvec1.store(input_max_data + d1 + fVec::size()); + + // cache the 'converted' float input + data_fvec0.store(input_buffer_ptr + d1); + data_fvec1.store(input_buffer_ptr + d1 + fVec::size()); + } + for (; d1 < size; d1++) { + float data_val = float(input_ptr[d1]); + float max_val = input_max_data[d1]; + input_max_data[d1] = data_val > max_val ? data_val : max_val; + input_buffer_ptr[d1] = data_val; + } + } + + // compute sum of (x - max).exp() + for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) { + float* input_buffer_ptr = input_buffer_data + dim_idx * chunk_size; + + int64_t d2 = 0; + for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) { + fVec data_fvec0 = fVec::loadu(input_buffer_ptr + d2); + fVec data_fvec1 = fVec::loadu(input_buffer_ptr + d2 + fVec::size()); + fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d2); + fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d2 + fVec::size()); + fVec max_fvec0 = fVec::loadu(input_max_data + d2); + fVec max_fvec1 = fVec::loadu(input_max_data + d2 + fVec::size()); + sum_fvec0 += (data_fvec0 - max_fvec0).exp(); + sum_fvec1 += (data_fvec1 - max_fvec1).exp(); + sum_fvec0.store(tmp_sum_data + d2); + sum_fvec1.store(tmp_sum_data + d2 + fVec::size()); + } + for (; d2 < size; d2++) { + float data_val = input_buffer_ptr[d2]; + float max_val = input_max_data[d2]; + tmp_sum_data[d2] += std::exp(data_val - max_val); + } + } + + // apply log + vec::map([](fVec x) { return x.log(); }, tmp_sum_data, tmp_sum_data, size); + + // compute x - max - sum + for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) { + float* input_buffer_ptr = input_buffer_data + dim_idx * chunk_size; + scalar_t* output_ptr = output_data_base + outer_idx * dim_size * inner_size + + dim_idx * inner_size + inner_idx_begin; + + int64_t d3 = 0; + for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) { + fVec data_fvec0 = fVec::loadu(input_buffer_ptr + d3); + fVec data_fvec1 = fVec::loadu(input_buffer_ptr + d3 + fVec::size()); + fVec max_fvec0 = fVec::loadu(input_max_data + d3); + fVec max_fvec1 = fVec::loadu(input_max_data + d3 + fVec::size()); + fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d3); + fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d3 + fVec::size()); + fVec out_fvec0 = data_fvec0 - max_fvec0 - sum_fvec0; + fVec out_fvec1 = data_fvec1 - max_fvec1 - sum_fvec1; + Vec out_vec = vec::convert_from_float(out_fvec0, out_fvec1); + out_vec.store(output_ptr + d3); + } + for (; d3 < size; d3++) { + output_ptr[d3] = scalar_t(input_buffer_ptr[d3] - input_max_data[d3] - tmp_sum_data[d3]); + } + } + } +} // namespace CPU_CAPABILITY +}} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/Loops.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/Loops.h new file mode 100644 index 0000000000000000000000000000000000000000..98ef9392cd4019422469a2ad061a76c290214af1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/Loops.h @@ -0,0 +1,395 @@ +#pragma once + +// This file provides two functions to help write elementwise kernels: +// +// cpu_kernel(TensorIterator iter, ) +// cpu_kernel_vec(TensorIterator iter, , ) +// +// Both functions may generate vectorized code. The cpu_kernel implementation +// relies on the compiler's auto-vectorization. The cpu_kernel_vec +// implementation uses x86 SIMD intrinsics when available. These functions +// are only intended to be used in the ATen/native/cpu subdirectory, since files +// in other directories are not compiled with AVX/AVX2 enabled. See README.md +// for more details. +// +// For example, to write a multiplication kernel for float: +// +// cpu_kernel(iter, [](float a, float b) { return a * b; }); +// +// Or you may write: +// +// cpu_kernel_vec(iter, +// [](float a, float b) { return a * b; }, +// [](Vectorized a, Vectorized b) { return a * b; }); +// +// See BinaryOpsKernel.cpp for the complete implementation +// +// + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace at::native { inline namespace CPU_CAPABILITY { + +using namespace vec; + +template +typename traits::ArgsTuple +dereference_impl(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, + std::index_sequence) { + return std::make_tuple( + c10::load::type>( + data[INDEX] + i * strides[INDEX])...); +} + +template +typename traits::ArgsTuple +dereference(char* C10_RESTRICT data[], const int64_t* strides, int64_t i) { + using Indices = std::make_index_sequence; + return dereference_impl(data, strides, i, Indices{}); +} + +template +typename traits::ArgsTuple +dereference_vec_impl(char* C10_RESTRICT data[], + const typename traits::result_type& opt_scalar, + size_t S, + int64_t i, + std::index_sequence) { + using Vec = typename traits::result_type; + using scalar_t = typename Vec::value_type; + return std::make_tuple( + S == INDEX + 1 ? + opt_scalar : + Vec::loadu(data[INDEX] + i * sizeof(scalar_t))...); +} + +template +typename traits::ArgsTuple +dereference_vec(char* C10_RESTRICT data[], const typename traits::result_type& opt_scalar, size_t S, int64_t i) { + using Indices = std::make_index_sequence; + return dereference_vec_impl(data, opt_scalar, S, i, Indices{}); +} + +template ::result_type>>* = nullptr> +inline void +execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t&& op) { + using traits = function_traits; + using result_type = typename traits::result_type; + for (; i < n; i++) { + result_type* out_ptr = (result_type*)(data[0] + i * strides[0]); + *out_ptr = c10::guts::apply(op, dereference( + &data[1], + &strides[1], + i)); + } +} + +template ::result_type>>* = nullptr> +inline void +execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t&& op) { + using traits = function_traits; + for (; i < n; i++) { + c10::guts::apply(op, dereference( + &data[0], + &strides[0], + i)); + } +} + +// Basic loop operation (one output, N inputs). May be auto-vectorized +// by the compiler. Supports inputs and outputs of different types. +template +inline void +basic_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_t n, func_t&& op) { + using traits = function_traits; + constexpr int ntensors = traits::arity + 1; + + // Copying strides to temporary array helps auto vectorization in older GCC + // versions. + int64_t strides[ntensors]; + for (const auto arg : c10::irange(ntensors)) { + strides[arg] = strides_[arg]; + } + + execute_op(data, strides, i, n, std::forward(op)); +} + +// the recursive variadic template for iterating over the returned tuple +template +struct TupleOutput { + static void handle(char *C10_RESTRICT data[], const int64_t *strides, int64_t i, + const T &tuple) { + TupleOutput::handle(data, strides, i, tuple); + + auto output = std::get(tuple); + using output_type = decltype(output); + output_type * out_ptr = (output_type *)(data[N - 1] + i * strides[N - 1]); + *out_ptr = output; + } +}; + +// Base case for the above recursive template +template +struct TupleOutput { + static void handle(char *C10_RESTRICT data[], const int64_t *strides, int64_t i, + const T &tuple) { + auto output = std::get<0>(tuple); + using output_type = decltype(output); + output_type* out_ptr = (output_type *)(data[0] + i * strides[0]); + *out_ptr = output; + } +}; + +template +void handle_tuple_outputs(char* C10_RESTRICT data[], + const int64_t* strides, + int64_t i, + const std::tuple &tuple) { + TupleOutput::handle(data, strides, i, tuple); +} + +// Loop operation for `cpu_kernel_multiple_outputs`. +// 1. Use `c10::guts::apply` to make dynamic method invocation +// for the lambda passed in `cpu_kernel_multiple_outputs`. +// 2. Iterate over the members of the returned tuple, set the corresponding +// output tensor by the tuple member in `handle_tuple_outputs` function. +template +inline void +multiple_outputs_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_t n, func_t&& op) { + using traits = function_traits; + + using result_type = typename traits::result_type; + constexpr int num_outputs = std::tuple_size_v; + constexpr int ntensors = traits::arity + num_outputs; + + // Copying strides to temporary array helps auto vectorization in older GCC + // versions. + int64_t strides[ntensors]; + for (const auto arg : c10::irange(ntensors)) { + strides[arg] = strides_[arg]; + } + + for (; i < n; i++) { + auto output = c10::guts::apply(op, dereference( + &data[num_outputs], + &strides[num_outputs], + i)); + handle_tuple_outputs(data, strides, i, output); + } +} + +// Explicitly vectorized loop implementation. All inputs and outputs must be +// the same type and contiguous with one exception: a single input may be +// a scalar (stride 0). It's position is indicated by the argument `S`. If `S` +// is 0, then there are no scalar inputs. +template +inline void +vectorized_loop(char** C10_RESTRICT data_, int64_t n, int64_t S, func_t&& op, vec_func_t&& vop) { + using traits = function_traits; + using scalar_t = typename function_traits::result_type; + using Vec = Vectorized; + constexpr int ntensors = traits::arity + 1; + + char* C10_RESTRICT data[ntensors]; + for (const auto arg : c10::irange(ntensors)) { + data[arg] = data_[arg]; + } + + Vec opt_scalar = Vec(S > 0 ? c10::load((scalar_t*)data[S]) : scalar_t(0)); + int64_t i = 0; + for (; i <= n - 2 * Vec::size(); i += 2 * Vec::size()) { + auto args1 = dereference_vec(&data[1], opt_scalar, S, i); + auto args2 = dereference_vec(&data[1], opt_scalar, S, i + Vec::size()); + auto out1 = c10::guts::apply(vop, std::move(args1)); + auto out2 = c10::guts::apply(vop, std::move(args2)); + out1.store(data[0] + i * sizeof(scalar_t)); + out2.store(data[0] + (i + Vec::size()) * sizeof(scalar_t)); + } + if (i < n) { + int64_t strides[ntensors]; + for (const auto arg : c10::irange(ntensors)) { + strides[arg] = (S > 0 && arg == S) ? 0 : sizeof(scalar_t); + } + basic_loop(data, strides, i, n, std::forward(op)); + } +} + + +template +inline void unroll_contiguous_scalar_checks( + const int64_t* /*strides*/, + std::index_sequence<>, + cb_t&& cb) { + cb(0); +} + +template +inline void unroll_contiguous_scalar_checks( + const int64_t* strides, + std::index_sequence, + cb_t&& cb) { + if (is_contiguous_scalar(strides)) { + cb(INDEX0 + 1); + } else { + unroll_contiguous_scalar_checks(strides, std::index_sequence{}, std::forward(cb)); + } +} + +template +struct VectorizedLoop2d { + op_t op; + vop_t vop; + + using traits = function_traits; + static constexpr int ntensors = traits::arity + 1; + using data_t = std::array; + + VectorizedLoop2d(op_t op, vop_t vop): + op(std::move(op)), vop(std::move(vop)) {} + + static void advance(data_t &data, const int64_t *outer_strides) { + for (const auto arg : c10::irange(data.size())) { + data[arg] += outer_strides[arg]; + } + } + + void operator()(char** base, const int64_t *strides, int64_t size0, int64_t size1) { + data_t data; + std::copy_n(base, ntensors, data.data()); + const int64_t *outer_strides = &strides[ntensors]; + + if (is_contiguous(strides)) { + for ([[maybe_unused]] const auto i : c10::irange(size1)) { + vectorized_loop(data.data(), size0, 0, op, vop); + advance(data, outer_strides); + } + } else { + using Indices = std::make_index_sequence; + unroll_contiguous_scalar_checks(strides, Indices{}, [&](size_t idx) { + if (idx) { + for ([[maybe_unused]] const auto i : c10::irange(size1)) { + vectorized_loop(data.data(), size0, idx, op, vop); + advance(data, outer_strides); + } + } else { + for ([[maybe_unused]] const auto i : c10::irange(size1)) { + basic_loop(data.data(), strides, 0, size0, op); + advance(data, outer_strides); + } + } + }); + } + } +}; + +template +VectorizedLoop2d make_vectorized_loop2d( + op_t &&op, vop_t &&vop) { + return VectorizedLoop2d(std::forward(op), std::forward(vop)); +} + +template +void cpu_kernel(TensorIteratorBase& iter, func_t&& op, int64_t grain_size = at::internal::GRAIN_SIZE) { + using traits = function_traits; + // this could be extended to work with void return types + TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity); + TORCH_INTERNAL_ASSERT(iter.noutputs() == 1); + // dynamic casting not currently supported on CPU + TORCH_INTERNAL_ASSERT(!needs_dynamic_casting::check(iter)); + + iter.for_each([&](char** data, const int64_t* strides, int64_t n) { + // basic loop can handle 1d slices with arbitrary strides, and 1d slices is all that + // iter.for_each is ever sending to the loop lambda + basic_loop(data, strides, 0, n, op); + }, grain_size); + iter.cast_outputs(); +} + +// This function helps write elementwise kernels that requires multiple outputs. +// It follows the similar structure of cpu_kernel. +// Instead of `basic_loop` function, a new `multiple_outputs_loop` function is +// manipulated to handle multiple return values. +// For now `needs_dynamic_casting` check is not added as the passed lambda (`func_t`) +// of `multiple_outputs_loop` returns `std::tuple` instead of `scalar_t`. +// The `gpu_kernel_multiple_outputs` is also implemented without this check, +// We could extend `needs_dynamic_casting` to support both `std::tuple` and +// `thrust::tuple` in the future. +template +void cpu_kernel_multiple_outputs(TensorIteratorBase& iter, func_t&& op, int64_t grain_size = at::internal::GRAIN_SIZE) { + using traits = function_traits; + TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity); + + iter.for_each([&](char** data, const int64_t* strides, int64_t n) { + multiple_outputs_loop(data, strides, 0, n, op); + }, grain_size); + iter.cast_outputs(); +} + +template +void cpu_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop, int64_t grain_size = at::internal::GRAIN_SIZE) { + using traits = function_traits; + // this could be extended to work with void return types + TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity); + TORCH_INTERNAL_ASSERT(iter.noutputs() == 1); + // dynamic casting not currently supported on CPU, but some kernels (like Fill) + // explicitly dynamic_cast, so we give the opt-out of checking. + if constexpr (check_dynamic_cast) { + TORCH_INTERNAL_ASSERT(!needs_dynamic_casting::check(iter)); + } + + iter.for_each(make_vectorized_loop2d(std::forward(op), std::forward(vop)), grain_size); + iter.cast_outputs(); +} + +template +void cpu_serial_kernel(TensorIteratorBase& iter, func_t&& op, const Range& range) { + using traits = function_traits; + constexpr bool result_void = std::is_void_v; + TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity && + ((result_void && iter.noutputs() == 0) || (!result_void && iter.noutputs() == 1))); + // dynamic casting not currently supported on CPU + TORCH_INTERNAL_ASSERT(!needs_dynamic_casting::check(iter)); + + iter.serial_for_each([&](char** data, const int64_t* strides, int64_t n) { + basic_loop(data, strides, 0, n, op); + }, range); + iter.cast_outputs(); +} + +template +void cpu_serial_kernel(TensorIteratorBase& iter, func_t&& op) { + cpu_serial_kernel(iter, std::forward(op), {0, iter.numel()}); +} + +template +void cpu_serial_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop, const Range& range) { + using traits = function_traits; + // this could be extended to work with void return types + TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity); + TORCH_INTERNAL_ASSERT(iter.noutputs() == 1); + // dynamic casting not currently supported on CPU + TORCH_INTERNAL_ASSERT(!needs_dynamic_casting::check(iter)); + + iter.serial_for_each(make_vectorized_loop2d(std::forward(op), std::forward(vop)), range); + iter.cast_outputs(); +} + +template +void cpu_serial_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop) { + cpu_serial_kernel_vec(iter, std::forward(op), std::forward(vop), {0, iter.numel()}); +} + +}} // namespace at::native:: diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/MaxUnpoolKernel.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/MaxUnpoolKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..ca3d3aa0c6d8f71bc2223f5492b39f3f5e269e74 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/MaxUnpoolKernel.h @@ -0,0 +1,14 @@ +#pragma once +#include + +namespace at { +class Tensor; + +namespace native { + +using max_unpooling_fn = void(*)(Tensor&, const Tensor&, const Tensor&); + +DECLARE_DISPATCH(max_unpooling_fn, max_unpool2d_kernel) +DECLARE_DISPATCH(max_unpooling_fn, max_unpool3d_kernel) + +}} // at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/PixelShuffleKernel.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/PixelShuffleKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..50be85c97c58e486b4307f616da30672f61b63a4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/PixelShuffleKernel.h @@ -0,0 +1,14 @@ +#pragma once +#include + +namespace at { +class TensorBase; +} + +namespace at::native { + +using pixel_shuffle_fn = void(*)(TensorBase&, const TensorBase&, int64_t); +DECLARE_DISPATCH(pixel_shuffle_fn, pixel_shuffle_kernel) +DECLARE_DISPATCH(pixel_shuffle_fn, pixel_unshuffle_kernel) + +} // at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/Reduce.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/Reduce.h new file mode 100644 index 0000000000000000000000000000000000000000..6ec2faf3360bbc29752f4c40357dc1cb192491dd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/Reduce.h @@ -0,0 +1,310 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace at::native { inline namespace CPU_CAPABILITY { + +using namespace vec; + +#define VEC_LOOP_HEADER(func_t, data) \ + using scalar_t = typename function_traits::result_type; \ + using Vec = Vectorized; \ + char* out_ptr = data[0]; \ + (void) out_ptr; + +// reduction that is contiguous over the input in dim 0 +template +inline bool is_contiguous_reduction(const int64_t* strides) { + return strides[0] == 0 && + strides[1] == sizeof(typename traits::arg2_t); +} + +// reduction that is contiguous over the input in dim 1 +template +inline bool is_outer_reduction(const int64_t* strides) { + return strides[0] == 0 && + strides[2] == sizeof(typename traits::result_type) && + strides[3] == sizeof(typename traits::arg2_t); +} + +template +inline void vectorized_reduction(char** data, int64_t n, int64_t stride, + func_t op, vec_func_t vop, bool reduce) { + VEC_LOOP_HEADER(func_t, data) + const char* in1_ptr = data[1]; + Vec acc[4]; + for (const auto j : c10::irange(4)) { + acc[j] = Vec::loadu(in1_ptr + j * Vec::size() * sizeof(scalar_t)); + } + for (const auto i : c10::irange(1, n)) { + const char* ptr = in1_ptr + stride * i; + acc[0] = vop(acc[0], Vec::loadu(ptr + (0 * Vec::size() * sizeof(scalar_t)))); + acc[1] = vop(acc[1], Vec::loadu(ptr + (1 * Vec::size() * sizeof(scalar_t)))); + acc[2] = vop(acc[2], Vec::loadu(ptr + (2 * Vec::size() * sizeof(scalar_t)))); + acc[3] = vop(acc[3], Vec::loadu(ptr + (3 * Vec::size() * sizeof(scalar_t)))); + } + if (reduce) { + scalar_t buffer[Vec::size()]; + acc[0] = vop(vop(acc[0], acc[1]), vop(acc[2], acc[3])); + acc[0].store(buffer); + for (const auto j : c10::irange(1, Vec::size())) { + buffer[0] = op(buffer[0], buffer[j]); + } + auto dst = (scalar_t*)out_ptr; + *dst = op(*dst, buffer[0]); + } else { + for (const auto j : c10::irange(4)) { + auto dst = out_ptr + j * Vec::size() * sizeof(scalar_t); + acc[j] = vop(acc[j], Vec::loadu(dst)); + acc[j].store(dst); + } + } +} + +template +inline void UNARY_OUTER_LOOP(char* data[2], const int64_t strides[2], int64_t n, F f) { + for ([[maybe_unused]] const auto j : c10::irange(n)) { + f(); + data[0] += strides[0]; + data[1] += strides[1]; + } +} + +// computes the reduction out = op(out, in) +template +inline void vectorized_inner_reduction(char** data, int64_t n, func_t op, vec_func_t vop) { + VEC_LOOP_HEADER(func_t, data) + constexpr int64_t vector_stride = 4 * Vec::size() * sizeof(scalar_t); + int64_t count = n / (4 * Vec::size()); + if (count > 0) { + vectorized_reduction(data, count, vector_stride, op, vop, /*reduce=*/true); + } + char* ptrs[3] = { data[0], data[0], data[1] }; + int64_t strides[] = { 0, 0, sizeof(scalar_t) }; + basic_loop(ptrs, strides, count * 4 * Vec::size(), n, op); +} + +// computes the reduction out = op(out, in) +template +inline void vectorized_outer_reduction(char** data, int64_t inner_stride, int64_t size0, int64_t size1, func_t op, vec_func_t vop) { + VEC_LOOP_HEADER(func_t, data) + + // reduce down each column of 4 * Vec::size() elements. + constexpr int64_t vector_stride = 4 * Vec::size() * sizeof(scalar_t); + int64_t outer_stride[2] = { vector_stride, vector_stride }; + UNARY_OUTER_LOOP(data, outer_stride, size1 / (4 * Vec::size()), [&] { + vectorized_reduction(data, size0, inner_stride, op, vop, /*reduce=*/false); + }); + + // reduce down the remaining columns + int64_t step[] = { sizeof(scalar_t), sizeof(scalar_t) }; + int64_t remaining = size1 % (4 * Vec::size()); + UNARY_OUTER_LOOP(data, step, remaining, [&] { + char* ptrs[3] = { data[0], data[0], data[1] }; + int64_t strides[] = { 0, 0, inner_stride }; + basic_loop(ptrs, strides, 0, size0, op); + }); +} + +template +static void set_result(const int index, const res_t result, const TensorIteratorBase &iter, const int num_outputs) { + // static_assert(std::is_same_v, "data types must match"); + if (index < num_outputs) { + char *out = (char *) iter.data_ptr(index); + *(res_t *) out = result; + } +} + +template +static void set_results(const res_t result, const TensorIteratorBase &iter, const int num_outputs) { + AT_ASSERT(num_outputs == 1); + set_result(0, result, iter, num_outputs); +} + +template +inline std::enable_if_t +for_each_in_tuple(const std::tuple& /*t*/, const TensorIteratorBase& /*iter*/, const int /*num_outputs*/) { + return i; +} + +template +inline std::enable_if_t +for_each_in_tuple(const std::tuple& t, const TensorIteratorBase &iter, const int num_outputs) { + if (i < (size_t)num_outputs) { + set_result(i, std::get(t), iter, num_outputs); + return for_each_in_tuple(t, iter, num_outputs); + } + return i; +} + +template +static void set_results(const std::tuple& result, const TensorIteratorBase &iter, const int num_outputs) { + AT_ASSERT(num_outputs >= 1); + std::size_t result_size = for_each_in_tuple(result, iter, num_outputs); + AT_ASSERT((size_t)num_outputs == result_size); +} + +template +struct all_same : std::conjunction< + std::is_same... +> {}; + +// data_t is the input/output data type. +// acc_t is a type that contains all the necessary data +// to continue reducing. +// index_t is a one-dimensional index +// +// ops_t is such that &ops_t::reduce, &ops_t::combine, and &ops_t::project exist and satisfy +// the following. +// reduce: (acc_t, data_t, index_t) -> acc_t adds one data point to the accumulated value. +// combine: (acc_t, acc_t) -> acc_t combines two accumulated values into one. +// project: acc_t -> out_t finishes the reduction, getting the required output. +// +// Additionally, acc_t must be default-constructible: +// acc_t {} is an identity for combine, +// and project(acc_t {}) is the value of the operation on zero elements. +// +// The point of `combine` is to support parallelization - +// the idea is to one sequence of `reduce` calls per thread of execution, +// and then to combine them at the end with `combine`. +// +// If there is more than one output element, +// our parallelization strategy is to use one thread for each of them, +// which means that `combine` will never be called. +// +// If, on the other hand, there is only one, then we split the input into +// into several pieces, reduce each separately, and then combine them. + +template +void binary_kernel_reduce(TensorIteratorBase& iter, ops_t ops, init_t init) { + using rf_t = decltype(&ops_t::reduce); + using cf_t = decltype(&ops_t::combine); + using pf_t = decltype(&ops_t::project); + using r_traits = binary_function_traits; + using c_traits = binary_function_traits; + using p_traits = unary_function_traits; + using acc_t = typename p_traits::arg1_t; + using data_t = typename r_traits::arg2_t; + static_assert( + all_same< + acc_t, + init_t, + typename r_traits::arg1_t, + typename r_traits::result_type, + typename c_traits::arg1_t, + typename c_traits::arg2_t, + typename c_traits::result_type>::value, + "all accumulate types must match"); + static_assert( + std::is_default_constructible_v, + "the accumulate type must be default-constructible" + ); + const int num_outputs = iter.noutputs(); + iter.foreach_reduced_elt([&ops, &init, num_outputs](TensorIteratorBase &sub_iter) { + auto reduction_body = [&ops, &sub_iter, num_outputs](acc_t acc, int64_t begin, int64_t end) -> acc_t { + int ntensors = sub_iter.ntensors(); + sub_iter.serial_for_each([&acc, &ops, num_outputs, ntensors, begin](char** data, const int64_t* strides, int64_t size) { + AT_ASSERT(ntensors - num_outputs == 1); + char *in = data[ntensors - 1]; + int64_t stride = strides[ntensors - 1]; + for (const auto i : c10::irange(size)) { + acc = ops.reduce(acc, c10::load(in), begin + i); + in += stride; + } + }, {begin, end}); + return ops.translate_idx(acc, sub_iter.view_offsets()[0]); + }; + acc_t total_acc = init; + auto numel = sub_iter.numel(); + if (numel < at::internal::GRAIN_SIZE || at::get_num_threads() == 1 || + at::in_parallel_region()) { + total_acc = reduction_body(total_acc, 0, numel); + } else { + int max_threads = at::get_num_threads(); + AT_ASSERT(max_threads > 0); + static_assert( + !std::is_same_v, + "Concurrently modifying different references into std::vector is UB." + ); + std::vector buffer((unsigned)max_threads, init); + at::parallel_for(0, numel, internal::GRAIN_SIZE, + [&](int64_t begin, int64_t end) { + auto& acc = buffer[at::get_thread_num()]; + acc = reduction_body(acc, begin, end); + } + ); + for (const auto i : c10::irange(max_threads)) { + total_acc = ops.combine(total_acc, buffer[i]); + } + } + set_results(ops.project(total_acc), sub_iter, num_outputs); + }); +} + +template +void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, double ident = 0) { + using traits = binary_function_traits; + static_assert( + all_same< + typename traits::result_type, + typename traits::arg1_t, + typename traits::arg2_t>::value, + "all types must match"); + + iter.output_base().fill_(ident); + iter.parallel_reduce([&](char** data, const int64_t* strides, int64_t size0, int64_t size1) { + int64_t outer_strides[] = { strides[2], strides[3] }; + if (is_contiguous_reduction(strides)) { + // input is contiguous in dim 0, output is reduced in dim 0 + UNARY_OUTER_LOOP(data, outer_strides, size1, [&] { + vectorized_inner_reduction(data, size0, op, vop); + }); + } else if (is_outer_reduction(strides)) { + // input and output are contiguous in dim 1 + int64_t inner_stride = strides[1]; // stride of input in dim 0 + vectorized_outer_reduction(data, inner_stride, size0, size1, op, vop); + } else { + UNARY_OUTER_LOOP(data, outer_strides, size1, [&] { + char* ptrs[3] = { data[0], data[0], data[1] }; + int64_t inner_strides[3] = { strides[0], strides[0], strides[1] }; + basic_loop(ptrs, inner_strides, 0, size0, op); + }); + } + }); +} + +// when reduction is on most inner dimension (dim 0 in TensorIterator) +// and input has contiguous most inner dimension, `binary_kernel_reduce_lastdim` +// can be used. +inline bool is_reduce_lastdim(TensorIteratorBase& iter) { + return iter.num_reduce_dims() == 1 && iter.is_dim_reduced(0) + && iter.ninputs() == 1 && iter.strides(1)[0] == iter.element_size(1); +} + +template +void binary_kernel_reduce_lastdim(TensorIteratorBase& iter, reduce_func_t reduce_op) { + auto shape = iter.shape(); + int64_t dim_size = shape[0]; + int64_t grain_size = std::max((int64_t) 1, at::internal::GRAIN_SIZE / dim_size); + TensorIterator sub_iter(iter); + // create sub iterator to parallel on all non-reduce-dims + sub_iter.narrow(0, 0, 1); + auto loop = [&](char** data, const int64_t* strides, int64_t size) { + char* out = data[0]; + char* in = data[1]; + for (int64_t i = 0; i < size; ++i) { + reduce_op(out, in, dim_size); + out += strides[0]; + in += strides[1]; + } + }; + sub_iter.for_each(loop, grain_size); +} + +}} // namespace at::native:: diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/ReduceUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/ReduceUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..5d1901c98b34fe8b87e2cc67fae31ce53db9abae --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/ReduceUtils.h @@ -0,0 +1,238 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { +inline namespace CPU_CAPABILITY { + +using namespace vec; + +#define AT_DISPATCH_REDUCTION_TYPES(op, ...) \ + [&] { \ + switch (op) { \ + case ReductionType::SUM: { \ + static constexpr auto reduce = ReductionType::SUM; \ + return __VA_ARGS__(); \ + } \ + case ReductionType::MEAN: { \ + static constexpr auto reduce = ReductionType::MEAN; \ + return __VA_ARGS__(); \ + } \ + case ReductionType::MIN: { \ + static constexpr auto reduce = ReductionType::MIN; \ + return __VA_ARGS__(); \ + } \ + case ReductionType::MAX: { \ + static constexpr auto reduce = ReductionType::MAX; \ + return __VA_ARGS__(); \ + } \ + case ReductionType::PROD: { \ + static constexpr auto reduce = ReductionType::PROD; \ + return __VA_ARGS__(); \ + } \ + } \ + }() + +template +inline vec_scalar_t init_value() { + using acc_t = vec_scalar_t; + acc_t val; + if (reduce == ReductionType::SUM || + reduce == ReductionType::MEAN) { + val = static_cast(0); + } else if (reduce == ReductionType::PROD) { + val = static_cast(1); + } else if (reduce == ReductionType::MAX) { + val = -std::numeric_limits::infinity(); + } else { + TORCH_INTERNAL_ASSERT(reduce == ReductionType::MIN); + val = std::numeric_limits::infinity(); + } + return val; +} + +template +inline vec_scalar_t init_value(const std::optional& initial) { + using acc_t = vec_scalar_t; + if (initial.has_value()) { + return initial.value().to(); + } else { + return init_value(); + } +} + +template +inline void init(scalar_t* out, int64_t size, const vec_scalar_t& val) { + using Vec = Vectorized>; + map( + [val](Vec x) { return Vec(val); }, + out, + out, + size); +} + +template +inline void init(scalar_t* out, int64_t size, const std::optional& initial) { + using acc_t = vec_scalar_t; + acc_t val = init_value(initial); + init(out, size, val); +} + +// overload with `include_self`, used by scatter_reduce +template +inline void init(scalar_t* out, int64_t size, bool include_self = false) { + using acc_t = vec_scalar_t; + if (!include_self) { + acc_t val = init_value(); + init(out, size, val); + } +} + +template +inline void _init(scalar_t* self_ptr, at::opmath_type* buffer_ptr, int64_t size, bool include_self) { + if (!include_self) { + init, reduce>(buffer_ptr, size, include_self); + } else { + vec::convert(self_ptr, buffer_ptr, size); + } +} + +template +inline std::enable_if_t, scalar_t> +_max(const scalar_t& x, const scalar_t& y) { + return at::_isnan(y) ? y : std::max(x, y); +} + +template +inline Vectorized _max(const Vectorized& x, const Vectorized& y) { + // vec::maximum propagates NaN + return vec::maximum(x, y); +} + +template +inline std::enable_if_t, Vec2> +_max(const vec_t& x, const vec_t& y) { + // vec::maximum propagates NaN + return maximum(x, y); +} + +template +inline std::enable_if_t, scalar_t> +_min(const scalar_t& x, const scalar_t& y) { + return at::_isnan(y) ? y : std::min(x, y); +} + +template +inline Vectorized _min(const Vectorized& x, const Vectorized& y) { + // vec::minimum propagates NaN + return vec::minimum(x, y); +} + +template +inline std::enable_if_t, Vec2> +_min(const vec_t& x, const vec_t& y) { + // vec::minimum propagates NaN + return minimum(x, y); +} + +template , int> = 0> +inline void map_acc( + const Op& vec_fun, + accumut* output_data, + const accumut* input_data, + const scalar_t* input_data2, + int64_t size) { + using Vec = vec::Vectorized; + using aVec = vec::Vectorized; + int64_t d = 0; + constexpr int64_t kVecSize = Vec::size(); + constexpr int64_t kaVecSize = aVec::size(); + for (d = 0; d < size - (size % kVecSize); d += kVecSize) { + Vec data2_vec = Vec::loadu(input_data2 + d); + auto [data2_avec0, data2_avec1] = convert_to_float(data2_vec); + aVec input_vec0 = aVec::loadu(input_data + d); + aVec input_vec1 = aVec::loadu(input_data + d + kaVecSize); + vec_fun(input_vec0, data2_avec0).store(output_data + d); + vec_fun(input_vec1, data2_avec1).store(output_data + d + kaVecSize); + } + if (size - d > 0) { + int64_t tail_size = size - d; + Vec data2_vec = Vec::loadu(input_data2 + d, tail_size); + auto [data2_avec0, data2_avec1] = convert_to_float(data2_vec); + if (tail_size > kaVecSize) { + aVec input_vec0 = aVec::loadu(input_data + d); + aVec input_vec1 = aVec::loadu(input_data + d + kaVecSize, tail_size - kaVecSize); + vec_fun(input_vec0, data2_avec0).store(output_data + d); + vec_fun(input_vec1, data2_avec1).store(output_data + d + kaVecSize, tail_size - kaVecSize); + } else { + aVec input_vec0 = aVec::loadu(input_data + d, tail_size); + vec_fun(input_vec0, data2_avec0).store(output_data + d, tail_size); + } + } +} + +// for Max and Min, propagate NaN: +template +inline T update(const T& x, const T& y) { + if (reduce == ReductionType::SUM || + reduce == ReductionType::MEAN) { + return x + y; + } else if (reduce == ReductionType::PROD) { + return x * y; + } else if (reduce == ReductionType::MAX) { + return _max(x, y); + } else { + TORCH_INTERNAL_ASSERT(reduce == ReductionType::MIN); + return _min(x, y); + } +} + +template +inline void update(scalar_t* out, const scalar_t* data, int64_t K) { + using Vec = vec::Vectorized>; + map2( + [](Vec x, Vec y) { return update(x, y); }, + out, + out, + data, + K); +} + +template , int> = 0> +inline void update(at::opmath_type* out, const scalar_t* data, int64_t K) { + using opmath_t = at::opmath_type; + using Vec = vec::Vectorized; + map_acc( + [](Vec x, Vec y) { return update(x, y); }, + out, + out, + data, + K); +} + +template +inline void write(scalar_t* out, int64_t count, int64_t K) { + using Vec = vec::Vectorized>; + if (reduce == ReductionType::MEAN) { + if (count > 0) { + vec::map( + [count](Vec x) { return x / Vec(count); }, + out, + out, + K); + } + } +} + +} // namespace CPU_CAPABILITY +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/ReducedPrecisionFloatGemvFastPathKernel.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/ReducedPrecisionFloatGemvFastPathKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..006ff75f72716d7f57865d036d3d7a02a2298b28 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/ReducedPrecisionFloatGemvFastPathKernel.h @@ -0,0 +1,27 @@ +#pragma once + +#include +#include +#include +#include + +namespace at::native { +#if !defined(C10_MOBILE) +using fp16_gemv_fn = void(*)(int, int, float, const Half*, int, const Half*, int, float, Half*, int); +DECLARE_DISPATCH(fp16_gemv_fn, fp16_gemv_trans_stub) + +using bf16_gemv_fn = void(*)(int, int, BFloat16, const BFloat16*, int, const BFloat16*, int, BFloat16, BFloat16*, int); +DECLARE_DISPATCH(bf16_gemv_fn, bf16_gemv_trans_stub) + +using fp16_dot_fn = float(*)(const int64_t, const Half*, const int64_t, const Half*, const int64_t); +DECLARE_DISPATCH(fp16_dot_fn, fp16_dot_stub) + +using bf16_dot_fn = float(*)(const int64_t, const BFloat16*, const int64_t, const BFloat16*, const int64_t); +DECLARE_DISPATCH(bf16_dot_fn, bf16_dot_stub) + +inline namespace CPU_CAPABILITY { +float fp16_dot_with_fp32_arith(const Half* vec1, const Half* vec2, int64_t len); +float bf16_dot_with_fp32_arith(const BFloat16* vec1, const BFloat16* vec2, int64_t len); +} // inline namespace CPU_CAPABILITY +#endif // !defined(C10_MOBILE) +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/SampledAddmmKernel.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/SampledAddmmKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..29479c8211d20d67c9c2f7b18989a9f1418695f2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/SampledAddmmKernel.h @@ -0,0 +1,12 @@ +#pragma once + +#include +#include + +namespace at::native { + +using sampled_addmm_sparse_csr_fn = void(*)(const Tensor&, const Tensor&, const Scalar&, const Scalar&, const Tensor&); + +DECLARE_DISPATCH(sampled_addmm_sparse_csr_fn, sampled_addmm_sparse_csr_stub) + +} // at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/SerialStackImpl.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/SerialStackImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..203d0b8cf525639984ab4f09d742efa6f52866c5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/SerialStackImpl.h @@ -0,0 +1,146 @@ +// Copyright 2004-present Facebook. All Rights Reserved. +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +namespace at::native::detail { + +struct InputMeta { + void* data_ptr; + int64_t inner_size; + + InputMeta(const Tensor& t, int64_t dim, int64_t inner) + : data_ptr(t.data_ptr()), inner_size(t.sizes()[dim] * inner) {} +}; + +// This kernel is used by two TensorList types: +// 1. stack_serial_kernel uses at::ArrayRef +// 2. Static runtime calls this kernel directly (csrc/jit/runtime/static/ops.cpp) with +// ProcessedNodeInputWrapper. +// When making changes, make sure that they are compatible with both types! +template +void stack_serial_kernel_impl(Tensor& result, TensorListType tensors, int64_t dim) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + dim >= 0 && dim <= result.dim(), + "dim out of range in stack_serial_kernel_impl"); + int64_t outer = + result.numel() / (result.sizes()[dim] * result.strides()[dim]); + scalar_t* result_data = result.data_ptr(); + int64_t ninputs = tensors.size(); + std::vector inputs; + inputs.reserve(ninputs); + for (const auto& tensor : tensors) { + inputs.emplace_back(tensor, dim, tensor.strides()[dim]); + } + + using Vec = vec::Vectorized; + scalar_t* result_ptr = result_data; + for (const auto i : c10::irange(outer)) { + for (const auto j : c10::irange(ninputs)) { + int64_t local_inner = inputs[j].inner_size; + scalar_t* input_ptr = (scalar_t*)(inputs[j].data_ptr) + i * local_inner; + + if (local_inner < Vec::size()) { + for (const auto k : c10::irange(local_inner)) { + result_ptr[k] = input_ptr[k]; + } + } else { + vec::map( + [](Vec x) { return x; }, result_ptr, input_ptr, local_inner); + } + result_ptr += local_inner; + } + } +} + +// Checks to see whether native stack can be invoked under these conditions: +// - result and input tensors are contiguous +// - only one thread is used +// - no type promotion has to occur +// - tensors dtype is Double or Float +template +bool can_use_native_serial_stack_impl(Tensor& result, TensorListType tensors, int64_t dim) { + TORCH_CHECK(!tensors.empty(), "expected a non-empty list of Tensors"); + const Tensor& first_tensor = tensors[0]; + // stack dimension should be in range [0,firstTensor.dim()) + // dim == firstTensor.dim() is a valid input, but it is handled by default code path + // that uses unsqueeze + if (dim >= first_tensor.dim()) return false; + // Native stack doesn't apply any tensor is skipped. + if (first_tensor.numel() == 0 && first_tensor.dim() == 1) return false; + // there should be no type promotion + if (result.dtype() != first_tensor.dtype()) return false; + + auto first_tensor_mem_format = first_tensor.suggest_memory_format(); + ScalarType dtype = first_tensor.scalar_type(); + + if (!result.is_contiguous(first_tensor_mem_format)) { + return false; + } + + // fast path only works for Double and Float + if (dtype != ScalarType::Double && dtype != ScalarType::Float) { + return false; + } + + // check remainder of inputs +#ifndef STRIP_ERROR_MESSAGES + auto const &first_tensor_shape = first_tensor.sizes(); +#endif + for (const auto i : c10::irange(1, tensors.size())) { + auto const &tensor = tensors[i]; + TORCH_CHECK(tensors[i].sizes() == first_tensor.sizes(), + "stack expects each tensor to be equal size, but got ", first_tensor_shape, + " at entry 0 and ", tensor.sizes(), " at entry ", i); + + // every tensor must be contiguous + // tensor sizes and strides must be the same + // there should be no type promotion + if (!tensor.is_contiguous(first_tensor_mem_format) || + tensor.strides() != first_tensor.strides() || + tensor.dtype() != dtype) { + return false; + } + } + + // fast native stack should only be used when it is not worth using multiple threads + // or there is only one thread. Note that we aren't checking result.numel() here because + // it may not have been resized and we want to defer that cost till later. + int64_t numel_in_stack = first_tensor.numel() * tensors.size(); + return numel_in_stack < at::internal::GRAIN_SIZE || at::get_num_threads() == 1; +} + +template +struct CanUseNativeSerialStack; + +template +struct CanUseNativeSerialStack { + static bool call(Tensor& result, TensorListType tensors, int64_t dim) { + // Inputs cannot alias the output tensor + for (const auto i : c10::irange(tensors.size())) { + auto lap = at::get_overlap_status(result, tensors[i]); + TORCH_CHECK(lap != at::MemOverlapStatus::Partial && + lap != at::MemOverlapStatus::Full, 0, + "unsupported operation: the input tensors cannot refer to any of the " + "output memory locations. Found overlap in input tensor ", i); + } + + return can_use_native_serial_stack_impl(result, tensors, dim); + } +}; + +template +struct CanUseNativeSerialStack { + static bool call(Tensor& result, TensorListType tensors, int64_t dim) { + return can_use_native_serial_stack_impl(result, tensors, dim); + } +}; + +} // namespace at::native::detail diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/SoftmaxKernel.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/SoftmaxKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..fac10bad489fe3358eb2cdb79d5e6edbc20f14d7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/SoftmaxKernel.h @@ -0,0 +1,28 @@ +#pragma once + +#include +#include + +namespace at { +class Tensor; + +namespace native { + +using forward_fn = void (*)(const Tensor&, const Tensor&); +using backward_fn = void(*)(const Tensor &, const Tensor &, const Tensor&); + +DECLARE_DISPATCH(forward_fn, softmax_lastdim_kernel) +DECLARE_DISPATCH(forward_fn, log_softmax_lastdim_kernel) +DECLARE_DISPATCH(backward_fn, softmax_backward_lastdim_kernel) +DECLARE_DISPATCH(backward_fn, log_softmax_backward_lastdim_kernel) + +using forward_fn_with_dim = void(*)(const Tensor &, const Tensor &, const int64_t); +using backward_fn_with_dim = + void (*)(const Tensor&, const Tensor&, const Tensor&, const int64_t); + +DECLARE_DISPATCH(forward_fn_with_dim, softmax_kernel) +DECLARE_DISPATCH(forward_fn_with_dim, log_softmax_kernel) +DECLARE_DISPATCH(backward_fn_with_dim, softmax_backward_kernel) +DECLARE_DISPATCH(backward_fn_with_dim, log_softmax_backward_kernel) +} +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/SpmmReduceKernel.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/SpmmReduceKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..dba26240e57fb552aa051278dbdc7d07914a9aea --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/SpmmReduceKernel.h @@ -0,0 +1,22 @@ +#pragma once + +#include +#include +#include + +namespace at::native { + +using spmm_reduce_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op); +using spmm_reduce_arg_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op); +using spmm_reduce_backward_input_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op); +using spmm_reduce_backward_input_arg_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op); +using spmm_reduce_backward_other_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op); + +DECLARE_DISPATCH(spmm_reduce_fn, spmm_reduce_stub) +DECLARE_DISPATCH(spmm_reduce_arg_fn, spmm_reduce_arg_stub) +DECLARE_DISPATCH(spmm_reduce_backward_input_fn, spmm_reduce_backward_input_stub) +DECLARE_DISPATCH(spmm_reduce_backward_input_arg_fn, spmm_reduce_backward_input_arg_stub) +DECLARE_DISPATCH(spmm_reduce_backward_other_fn, spmm_reduce_backward_other_stub) +DECLARE_DISPATCH(spmm_reduce_backward_input_arg_fn, spmm_reduce_backward_other_arg_stub) + +} // at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/StackKernel.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/StackKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..3bd5f585f9d23896bd26259b4370a486db288f37 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/StackKernel.h @@ -0,0 +1,12 @@ +// Copyright 2004-present Facebook. All Rights Reserved. +#pragma once + +#include +#include + +namespace at::native { + +using stack_serial_fn = void(*)(Tensor &, TensorList, int64_t); +DECLARE_DISPATCH(stack_serial_fn, stack_serial_stub) + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/UpSampleKernelAVXAntialias.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/UpSampleKernelAVXAntialias.h new file mode 100644 index 0000000000000000000000000000000000000000..fe9c323e4266962e41eac7c98c8509e26c59b305 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/UpSampleKernelAVXAntialias.h @@ -0,0 +1,1376 @@ +/* +The Python Imaging Library (PIL) is + + Copyright © 1997-2011 by Secret Labs AB + Copyright © 1995-2011 by Fredrik Lundh + +Pillow is the friendly PIL fork. It is + + Copyright © 2010-2022 by Alex Clark and contributors + +Like PIL, Pillow is licensed under the open source HPND License +*/ + +// This code is heavily inspired from PILLOW-SIMD's implementation: +// https://github.com/uploadcare/pillow-simd/blob/simd/master/src/libImaging/Resample.c + +#pragma once +#ifdef CPU_CAPABILITY_AVX2 +// TODO: This file only supports AVX2. We could split the AVX kernels into +// smaller logical blocks in order to port them into the Vec.h logic. This would +// allow to support other vectorization architectures and perhaps also support +// the non-vectorized fallback (we'd need to make sure it's not slower than the +// current fallback). + +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + + +namespace { + +static inline __m128i mm_cvtsi32_si128(const uint8_t* C10_RESTRICT ptr, bool i32_aligned) { + int32_t v; + if (i32_aligned) { + v = *(const int32_t*)ptr; + } else { + std::memcpy(&v, ptr, 4); + } + return _mm_cvtsi32_si128(v); +} + +static inline __m128i mm_cvtepu8_epi32(const uint8_t* C10_RESTRICT ptr, bool i32_aligned) { + return _mm_cvtepu8_epi32(mm_cvtsi32_si128(ptr, i32_aligned)); +} + +static inline void _write_endline_rgb_as_uint32( + uint8_t* C10_RESTRICT output, + uint32_t data +) { + // data is (R G B X), output is (X1 X2 X3 | R1 B1 G1 R2 ...) + // Here we explicitly set X as R1 + uint8_t* data_ptr = reinterpret_cast(&data); + data_ptr[3] = output[3]; + std::memcpy(output, data_ptr, 4); +} + +at::Tensor unpack_rgb(const at::Tensor& packed_tensor) { + // Convert a "packed" tensor (typically RGBRGBRGB if channels_last) into + // RGBARGBARGBA format where A is hard-coded to 0. Each pixel is encoded + // into as 32 bits. This generalizes to num_channels <= 4 and also works for + // non-channels_last tensors. + + const uint8_t* packed = (const uint8_t*)packed_tensor.const_data_ptr(); + auto num_pixels = packed_tensor.size(1) * packed_tensor.size(2); + auto num_channels = packed_tensor.size(0); + + constexpr int rgba_size = 4; + auto unpacked_tensor = at::empty({rgba_size, packed_tensor.size(1), packed_tensor.size(2)}, at::CPU(at::kByte)); + uint8_t* unpacked = (uint8_t*) unpacked_tensor.data_ptr(); + + auto stride_i = packed_tensor.stride(2); + auto stride_j = packed_tensor.stride(0); + + for (const auto i : c10::irange(num_pixels)) { + for (const auto j : c10::irange(rgba_size)) { + unpacked[rgba_size * i + j] = (j < num_channels) ? packed[stride_i * i + stride_j * j] : 0; + } + } + return unpacked_tensor; +} + +void pack_rgb( + const at::Tensor& unpacked_tensor, // IN + const at::Tensor& packed_tensor // OUT +) { + // Convert from unpacked channels last 3-channels or 4-channels tensor into original data layout. + + uint8_t* unpacked = (uint8_t*)unpacked_tensor.data_ptr(); + uint8_t* packed = (uint8_t*)packed_tensor.data_ptr(); + auto num_pixels = packed_tensor.size(1) * packed_tensor.size(2); + auto num_channels = packed_tensor.size(0); + + auto unpacked_increment = unpacked_tensor.size(0); + auto packed_increment = packed_tensor.stride(2); + auto packed_stride = packed_tensor.stride(0); + + TORCH_INTERNAL_ASSERT(unpacked_increment == 3 || unpacked_increment == 4); + + for ([[maybe_unused]] const auto i : c10::irange(num_pixels)) { + for (const auto j : c10::irange(num_channels)) { + packed[j * packed_stride] = unpacked[j]; + } + unpacked += unpacked_increment; + packed += packed_increment; + } +} + +void ImagingResampleHorizontalConvolution8u4x( + uint8_t* C10_RESTRICT lineOut0, + uint8_t* C10_RESTRICT lineOut1, + uint8_t* C10_RESTRICT lineOut2, + uint8_t* C10_RESTRICT lineOut3, + int64_t out_xsize, + const uint8_t* C10_RESTRICT lineIn0, + const uint8_t* C10_RESTRICT lineIn1, + const uint8_t* C10_RESTRICT lineIn2, + const uint8_t* C10_RESTRICT lineIn3, + int64_t in_xsize, + const int64_t* idx_ptr_xmin, + const int64_t* idx_ptr_size, + const int16_t* kk, + int kmax, + unsigned int coefs_precision, + int64_t num_channels, + bool is_last_line); + +void ImagingResampleHorizontalConvolution8u( + uint8_t* C10_RESTRICT lineOut, + int64_t out_xsize, + const uint8_t* C10_RESTRICT lineIn, + int64_t in_xsize, + const int64_t* idx_ptr_xmin, + const int64_t* idx_ptr_size, + const int16_t* kk, + int kmax, + unsigned int coefs_precision, + int64_t num_channels, + bool is_last_line); + +void ImagingResampleVerticalConvolution8u( + uint8_t* C10_RESTRICT lineOut, + const uint8_t* C10_RESTRICT lineIn, + int64_t xsize, + int64_t ids_min, + int64_t ids_size, + const int16_t* k, + unsigned int coefs_precision, + int64_t num_channels); + +template +void ImagingResampleHorizontal( + const at::Tensor & unpacked_output, + const at::Tensor & unpacked_input, + int ksize, + const std::vector& horiz_indices_weights, + unsigned int horiz_weights_precision) { + + // Interpolation horizontal pass: we compute x-axis (image width) interpolation outputs. + + // Input data is stored as + // input = [r[0], g[0], b[0], a[0], r[1], g[1], b[1], a[1], r[2], g[2], b[2], a[2], ...] + // Weights are float values computed for each output pixel and rescaled to uint16: + // weights[i] = [w[i, 0], w[i, 1], ..., w[i, K-1]] + // We want to compute the output as following: + // output = [oR[0], oG[0], oB[0], oA[0], oR[1], oG[1], oB[1], oA[1], ...] + // where + // oR[yoffset + i] = r[yoffset + xmin[i]] * w[i, 0] + ... + r[yoffset + xmin[i] + K-1] * w[i, K-1] + // oG[yoffset + i] = g[yoffset + xmin[i]] * w[i, 0] + ... + g[yoffset + xmin[i] + K-1] * w[i, K-1] + // oB[yoffset + i] = b[yoffset + xmin[i]] * w[i, 0] + ... + b[yoffset + xmin[i] + K-1] * w[i, K-1] + // + + // TODO: we may want to merge that into the fallback code (currently called + // basic_loop_aa_horizontal) + // Although this may not be needed if / when we port all this code to use + // Vec.h since this would potentially give us another fall-back implem + + const int16_t* kk = (int16_t*)(horiz_indices_weights[3].const_data_ptr()); + + auto xout = unpacked_output.size(2); + auto yout = unpacked_output.size(1); + auto xin = unpacked_input.size(2); + TORCH_INTERNAL_ASSERT(num_channels == unpacked_input.size(0)); + + const int64_t* idx_ptr_xmin = horiz_indices_weights[0].const_data_ptr(); + const int64_t* idx_ptr_size = horiz_indices_weights[1].const_data_ptr(); + + uint8_t* unpacked_output_p = unpacked_output.data_ptr(); + const uint8_t* unpacked_input_p = unpacked_input.const_data_ptr(); + + int64_t yy = 0; + auto xout_stride = xout * num_channels; + auto xin_stride = xin * num_channels; + for (; yy < yout - 3; yy += 4) { + ImagingResampleHorizontalConvolution8u4x( + unpacked_output_p + yy * xout_stride, + unpacked_output_p + (yy + 1) * xout_stride, + unpacked_output_p + (yy + 2) * xout_stride, + unpacked_output_p + (yy + 3) * xout_stride, + xout, + unpacked_input_p + yy * xin_stride, + unpacked_input_p + (yy + 1) * xin_stride, + unpacked_input_p + (yy + 2) * xin_stride, + unpacked_input_p + (yy + 3) * xin_stride, + xin, + idx_ptr_xmin, + idx_ptr_size, + kk, + ksize, + horiz_weights_precision, + num_channels, + yy + 3 == yout - 1); + } + for (; yy < yout; yy++) { + ImagingResampleHorizontalConvolution8u( + unpacked_output_p + yy * xout_stride, + xout, + unpacked_input_p + yy * xin_stride, + xin, + idx_ptr_xmin, + idx_ptr_size, + kk, + ksize, + horiz_weights_precision, + num_channels, + yy == yout - 1); + } +} + +void ImagingResampleVertical( + const at::Tensor & unpacked_output, + const at::Tensor & unpacked_input, + int ksize, + const std::vector& vert_indices_weights, + unsigned int vert_weights_precision) { + + // Interpolation vertical pass: we compute y-axis interpolation outputs. + // Input data is stored as + // input = [r[0], g[0], b[0], a[0], r[1], g[1], b[1], a[1], r[2], g[2], b[2], a[2], ...] + // Weights are float values computed for each output pixel and rescaled to uint16: + // weights[i] = [w[i, 0], w[i, 1], ..., w[i, K-1]] + // We want to compute the output as following: + // output = [oR[0], oG[0], oB[0], oA[0], oR[1], oG[1], oB[1], oA[1], ...] + // where + // oR[xoffset + i] = r[xoffset + ymin[i]] * w[i, 0] + ... + r[xoffset + ymin[i] + (K-1) * xsize] * w[i, K-1] + // oG[xoffset + i] = g[xoffset + ymin[i]] * w[i, 0] + ... + g[xoffset + ymin[i] + (K-1) * xsize] * w[i, K-1] + // oB[xoffset + i] = b[xoffset + ymin[i]] * w[i, 0] + ... + b[xoffset + ymin[i] + (K-1) * xsize] * w[i, K-1] + + // TODO: we may want to merge that into the fallback code (currently called + // basic_loop_aa_vertical) + // Although this may not be needed if / when we port all this code to use + // Vec.h since this would potentially give us another fall-back implem + const int16_t* kk = (int16_t*)(vert_indices_weights[3].const_data_ptr()); + + const int64_t* idx_ptr_xmin = vert_indices_weights[0].const_data_ptr(); + const int64_t* idx_ptr_size = vert_indices_weights[1].const_data_ptr(); + + uint8_t* unpacked_output_p = unpacked_output.data_ptr(); + const uint8_t* unpacked_input_p = unpacked_input.const_data_ptr(); + + auto xout = unpacked_output.size(2); + auto yout = unpacked_output.size(1); + const auto num_channels = unpacked_input.size(0); + TORCH_INTERNAL_ASSERT(num_channels == unpacked_output.size(0)); + + auto xout_stride = xout * num_channels; + for (const auto yy : c10::irange(yout)) { + const auto* k = &kk[yy * ksize]; + auto ids_min = idx_ptr_xmin[yy]; + auto ids_size = idx_ptr_size[yy]; + ImagingResampleVerticalConvolution8u( + unpacked_output_p + yy * xout_stride, + unpacked_input_p, + xout, + ids_min, + ids_size, + k, + vert_weights_precision, + num_channels); + } +} + +// This is the only public entry point in this file. It supports bilinear or bicubic +// mode for uint8 dtype when C <= 4, with or without antialias. The +// implem is based on PIL-SIMD. +// Its equivalent implementation (fallback) for when AVX isn't supported or when +// C > 4 is separable_upsample_generic_Nd_kernel_impl() There are a bunch of +// future improvement that can be done: look for the TODOs in this file. +// For details on how the weights are computed and how the multiplications are +// run on int (instead of float weights), see +// [ Weights computation for uint8_t and multiplication trick ] +// For details on how the AVX kernels are implemented, see +// https://gist.github.com/NicolasHug/47c97d731f05eaad5694c173849b86f5 +// See also [ Support for antialias=False as a subcase of antialias=True ] to +// learn more about how the antialias=False case is computed. The same holds +// here: all these kernels are general enough to handle an arbitrary number of +// weights, but when aa=False they could be optimized further. +template +void upsample_avx_bilinear_bicubic_uint8( + const at::Tensor& input_, + const at::Tensor& output, + bool align_corners, + const scale_type& scales, + bool antialias) { + auto batch_size = input_.size(0); + auto num_channels = input_.size(1); + auto xin = input_.size(3); + auto yin = input_.size(2); + auto xout = output.size(3); + auto yout = output.size(2); + + if (xin == xout && yin == yout) { + output.copy_(input_); + return; + } + + at::Tensor input = input_; + if (!(input.is_contiguous() || input.is_contiguous(at::MemoryFormat::ChannelsLast))) { + // If input is not contiguous with memory format channels first or channels last, + // we explicitly convert the input to contiguous channels last memory format. + // This simplifies the rest of the code and let us assume that the format is only contiguous channels first or channels last, + // Most tensors going through this `if` block won't need to go through unpacking, but those having C < 3 may + // have to (this means 2 copies are made). We could avoid the extra copy by handling non-contiguous input + // directly within unpack_rgb() and pack_rgb(), but initial attempts showed that this is fairly complex. + input = input.contiguous(at::MemoryFormat::ChannelsLast); + } + + auto need_horizontal = xout != xin; + auto need_vertical = yout != yin; + + int ksize_horiz, ksize_vert; + std::vector horiz_indices_weights, vert_indices_weights; + unsigned int horiz_weights_precision, vert_weights_precision; + + bool skip_unpacking = (num_channels == 3 || num_channels == 4) && input.is_contiguous(at::MemoryFormat::ChannelsLast); + bool skip_packing = (num_channels == 3 || num_channels == 4) && output.is_contiguous(at::MemoryFormat::ChannelsLast); + + if (need_horizontal) { + int interp_dim = 3; + auto stride = (skip_unpacking) ? num_channels : 4; + std::tie(horiz_indices_weights, ksize_horiz, horiz_weights_precision) = + F::compute_index_ranges_int16_weights( + /*input_size=*/xin, + /*output_size=*/xout, + /*stride=*/stride, + /*ndims=*/4, + /*reshape_dim=*/interp_dim, + /*align_corners=*/align_corners, + /*opt_scale=*/scales[interp_dim - 2], + /*antialias=*/antialias, + /*align_i32=*/true); + } + + if (need_vertical) { + int interp_dim = 2; + auto stride = (skip_unpacking) ? num_channels * xout : 4 * xout; + std::tie(vert_indices_weights, ksize_vert, vert_weights_precision) = + F::compute_index_ranges_int16_weights( + /*input_size=*/yin, + /*output_size=*/yout, + /*stride=*/stride, + /*ndims=*/4, + /*reshape_dim=*/interp_dim, + /*align_corners=*/align_corners, + /*opt_scale=*/scales[interp_dim - 2], + /*antialias=*/antialias, + /*align_i32=*/true); + } + + at::Tensor buffer_horiz, buffer_vert; + // Minor optimization: we can avoid allocating an extra buffer if we're performing + // horizontal-only or vertical-only interpolation, and if the tensor doesn't + // need repacking + if (need_horizontal && (need_vertical || !skip_packing)) { + auto c = (skip_unpacking) ? num_channels : 4; + buffer_horiz = at::empty({c, yin, xout}, input.options()); + } + if (need_vertical && !skip_packing) { + auto c = (skip_unpacking) ? num_channels : 4; + buffer_vert = at::empty({c, yout, xout}, input.options()); + } + + for (const auto i : c10::irange(batch_size)) { + + at::Tensor unpacked_input = (skip_unpacking) ? input[i] : unpack_rgb(input[i]); + at::Tensor unpacked_output; + + if (need_horizontal) { + at::Tensor unpacked_output_temp = (need_vertical || !skip_packing) ? buffer_horiz : output[i]; + + if (skip_unpacking && num_channels == 3) { + ImagingResampleHorizontal<3>( + unpacked_output_temp, + unpacked_input, + ksize_horiz, + horiz_indices_weights, + horiz_weights_precision); + } else { + ImagingResampleHorizontal<4>( + unpacked_output_temp, + unpacked_input, + ksize_horiz, + horiz_indices_weights, + horiz_weights_precision); + } + unpacked_output = unpacked_input = unpacked_output_temp; + } + if (need_vertical) { + unpacked_output = (skip_packing) ? output[i] : buffer_vert; + + ImagingResampleVertical( + unpacked_output, + unpacked_input, + ksize_vert, + vert_indices_weights, + vert_weights_precision + ); + } + + TORCH_INTERNAL_ASSERT(unpacked_output.defined()); + + if (!skip_packing) { + pack_rgb(unpacked_output, output[i]); + } + } +} + +void ImagingResampleHorizontalConvolution8u4x( + uint8_t* C10_RESTRICT lineOut0, + uint8_t* C10_RESTRICT lineOut1, + uint8_t* C10_RESTRICT lineOut2, + uint8_t* C10_RESTRICT lineOut3, + int64_t out_xsize, + const uint8_t* C10_RESTRICT lineIn0, + const uint8_t* C10_RESTRICT lineIn1, + const uint8_t* C10_RESTRICT lineIn2, + const uint8_t* C10_RESTRICT lineIn3, + int64_t in_xsize, + const int64_t* idx_ptr_xmin, + const int64_t* idx_ptr_size, + const int16_t* kk, + int kmax, + unsigned int coefs_precision, + int64_t num_channels, + bool is_last_line) { + + // Interpolation horizontal pass processing together 4 vertical lines. + // - Input data format is RGBA or RGB with R,G,B,A being uint8. In case of RGBA + // we can encode 4 values as a single uint32 value. + // - We split the size of weight vector for a given output index as a sum: + // ids_size = num_blocks_4 * 4 + num_blocks_2 * 2 + num_blocks_1. + // - We load and process 4 weights values in a loop ("block 4") then we process 2 weights values + // in another loop ("block 2") and finally we process 1 weights value in the final loop ("block 1"). + + // Define shuffling masks (low/high) for num_channels 4 and 3 + // Mask low casts lower half of each lane to epi16 and reorder RGBARGBA -> RRGGBBAA: + // [r1 g1 b1 a1 r2 g2 b2 a2 ... | R1 G1 B1 A1 R2 G2 B2 A2 ... ] -> + // [r1 0 r2 0 g1 0 g2 0 b1 0 b2 0 a1 0 a2 0 | R1 0 R2 0 G1 0 G2 0 B1 0 B2 0 A1 0 A2 0] + // Mask high casts upper half of each lane to epi16 and reorder RGBARGBA -> RRGGBBAA:: + // [ ... r3 g3 b3 a3 r4 g4 b4 a4 | ... R3 G3 B3 A3 R4 G4 B4 A4 ] -> + // [r3 0 r4 0 g3 0 g4 0 b3 0 b4 0 a3 0 a4 0 | R3 0 R4 0 G3 0 G4 0 B3 0 B4 0 A3 0 A4 0] + + const auto mask_low_c4 = _mm256_set_epi8( + -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0, + -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0); + const auto mask_high_c4 = _mm256_set_epi8( + -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8, + -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8); + const auto mask_low_c3 = _mm256_set_epi8( + -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0, + -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0); + const auto mask_high_c3 = _mm256_set_epi8( + -1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6, + -1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6); + + const auto mask_low = (num_channels == 3) ? mask_low_c3 : mask_low_c4; + const auto mask_high = (num_channels == 3) ? mask_high_c3 : mask_high_c4; + + const auto stride = num_channels * sizeof(uint8_t); + + TORCH_INTERNAL_ASSERT(stride == 3 || stride == 4); + + // out_xsize = output width, out_x = output x index + // ids_min is the input offset index corresponding to out_x + // ids_size is the interpolation size for out_x + + // Let's precompute ids_size limits for block 4 and block 2. + // + // In block 4 (4 means we process 4 weight values together), we read input data + // with _mm_loadu_si128, i.e. 16 bytes, per one line: + // lineIn0 + stride * (i + ids_min) + 16 <= lineIn0 + stride * (ids_size + ids_min) + // --> i <= ids_size - 16.0 / stride + // Strict boundary: + // --> i < ids_size + 1 - int(ceil(16.0 / stride)) = ids_size - b4_delta + // Soft boundary for reading inside the buffer except its boundaries: + // --> i < ids_size + 1 - int(16.0 / stride) = ids_size - b4_delta_soft + // RGBA: b4_delta = b4_delta_soft = 3 + // RGB : b4_delta = 5 + // RGB : b4_delta_soft = 4 + const auto b4_delta = (stride == 4) ? 3 : ((is_last_line) ? 5 : 4); + + // In block 2 (2 means we process 2 weights values together), we read input data + // with _mm_loadl_epi64, i.e. 8 bytes, per one line: + // lineIn0 + stride * (i + ids_min) + 8 <= lineIn0 + stride * (ids_size + ids_min) + // --> i <= ids_size - 8.0 / stride + // Strict boundary: + // --> i < ids_size + 1 - int(ceil(8.0 / stride)) = ids_size - b2_delta + // Soft boundary for reading inside the buffer except its boundaries: + // --> i < ids_size + 1 - int(8.0 / stride) = ids_size - b2_delta_soft + // RGBA: b2_delta = b2_delta_soft = 1 + // RGB : b2_delta = 2 + // RGB : b2_delta_soft = 1 + const auto b2_delta = (stride == 4) ? 1 : ((is_last_line) ? 2 : 1); + + const auto max_out_x_strided = out_xsize * stride; + const auto max_in_x_strided = in_xsize * stride; + + const auto zero = _mm256_setzero_si256(); + const auto initial = _mm256_set1_epi32(1 << (coefs_precision - 1)); + + for (const auto out_x : c10::irange(out_xsize)) { + const auto ids_min = idx_ptr_xmin[out_x]; + const auto ids_size = idx_ptr_size[out_x]; + const auto * k = &kk[out_x * kmax]; + int64_t i = 0; + + auto sss0 = initial; + auto sss1 = initial; + + const auto * lineIn0_min = lineIn0 + ids_min; + const auto * lineIn1_min = lineIn1 + ids_min; + const auto * lineIn2_min = lineIn2 + ids_min; + const auto * lineIn3_min = lineIn3 + ids_min; + + // block 4 + for (; i < ids_size - b4_delta; i += 4) { + // Load 4 values from weight vector + // mmk0 = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ...] + // mmk1 = [wl_2 wh_2 wl_3 wh_3 wl_2 wh_2 wl_3 wh_3 ...] + const auto mmk0 = _mm256_set1_epi32(*(int32_t*)&k[i]); + const auto mmk1 = _mm256_set1_epi32(*(int32_t*)&k[i + 2]); + + // RGBA: Load 8 pixels (4 per line) from input lines 0 and 1: + // source = [ + // r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3 + // R0 G0 B0 A0 R1 G1 B1 A1 R2 G2 B2 A2 R3 G3 B3 A3 + // ] + // RGB: Load 10 pixels (5 per line) + // source = [ + // r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5 + // R0 G0 B0 R1 G1 B1 R2 G2 B2 R3 G3 B3 R4 G4 B4 R5 + // ] + auto source = _mm256_inserti128_si256(_mm256_castsi128_si256( + _mm_loadu_si128((__m128i *) (lineIn0_min + stride * i))), + _mm_loadu_si128((__m128i *) (lineIn1_min + stride * i)), 1); + + // Apply mask_low: + // RGBA: + // [r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0 | R0 0 R1 0 G0 0 G1 0 B0 0 B1 0 A0 0 A1 0] + // RGB: + // [r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 0 0 0 0 | R0 0 R1 0 G0 0 G1 0 B0 0 B1 0 0 0 0 0] + auto pix1 = _mm256_shuffle_epi8(source, mask_low); + // Compute output value as C += w0 * C0 + w1 * C1 for each channel in 32-bit precision + sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk0)); + + // Apply mask_high: + // RGBA: + // [r2 0 r3 0 g2 0 g3 0 b2 0 b3 0 a2 0 a3 0 | R2 0 R3 0 G2 0 G3 0 B2 0 B3 0 A2 0 A3 0] + // RGB: + // [r2 0 r3 0 g2 0 g3 0 b2 0 b3 0 0 0 0 0 | R2 0 R3 0 G2 0 G3 0 B2 0 B3 0 0 0 0 0] + auto pix2 = _mm256_shuffle_epi8(source, mask_high); + // Compute output value as C += w2 * C2 + w3 * C3 for each channel in 32-bit precision + sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix2, mmk1)); + + // Same as above to next lines 2 and 3: + auto source2 = _mm256_inserti128_si256(_mm256_castsi128_si256( + _mm_loadu_si128((__m128i *) (lineIn2_min + stride * i))), + _mm_loadu_si128((__m128i *) (lineIn3_min + stride * i)), 1); + auto pix3 = _mm256_shuffle_epi8(source2, mask_low); + sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix3, mmk0)); + auto pix4 = _mm256_shuffle_epi8(source2, mask_high); + sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix4, mmk1)); + } + + // block 2 + for (; i < ids_size - b2_delta; i += 2) { + // Load 2 values from weight vector + // mmk = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ...] + const auto mmk = _mm256_set1_epi32(*(int32_t*)&k[i]); + + // Load 4 pixels (2 per line) from input lines 0 and 1: + // RGBA: source1 = [ + // r0 g0 b0 a0 r1 g1 b1 a1 0 0 0 0 0 0 0 0 + // R0 G0 B0 A0 R1 G1 B1 A1 0 0 0 0 0 0 0 0 + // ] + // RGB: source1 = [ + // r0 g0 b0 r1 g1 b1 r2 0 0 0 0 0 0 0 0 + // R0 G0 B0 R1 G1 B1 R2 0 0 0 0 0 0 0 0 + // ] + auto source1 = _mm256_inserti128_si256(_mm256_castsi128_si256( + _mm_loadl_epi64((__m128i *) (lineIn0_min + stride * i))), + _mm_loadl_epi64((__m128i *) (lineIn1_min + stride * i)), 1); + // Apply mask_low: + // RGBA: + // [r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0 | R0 0 R1 0 G0 0 G1 0 B0 0 B1 0 A0 0 A1 0] + // RGB: + // [r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 0 0 0 0 | R0 0 R1 0 G0 0 G1 0 B0 0 B1 0 0 0 0 0] + auto pix1 = _mm256_shuffle_epi8(source1, mask_low); + // Compute output value as C += w0 * C0 + w1 * C1 for each channel in 32-bit precision + sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk)); + + // Same as above for lines 2 and 3: + auto source2 = _mm256_inserti128_si256(_mm256_castsi128_si256( + _mm_loadl_epi64((__m128i *) (lineIn2_min + stride * i))), + _mm_loadl_epi64((__m128i *) (lineIn3_min + stride * i)), 1); + auto pix2 = _mm256_shuffle_epi8(source2, mask_low); + sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk)); + } + + // block 1 + const auto i32_aligned = num_channels == 4; + for (; i < ids_size - 1; i++) { + // Load 1 value from weight vector + // mmk = [wl_0 wh_0 0 0 wl_0 wh_0 0 0 ...] + const auto mmk = _mm256_set1_epi32(k[i]); + + // Load 2 pixels (one per line) from input lines 0 and 1: + // RGBA: pix1 = [ + // r0 0 0 0 g0 0 0 0 b0 0 0 0 a0 0 0 0 + // R0 0 0 0 G0 0 0 0 B0 0 0 0 A0 0 0 0 + // ] + // RGB: pix1 = [ + // r0 0 0 0 g0 0 0 0 b0 0 0 0 r1 0 0 0 + // R0 0 0 0 G0 0 0 0 B0 0 0 0 R1 0 0 0 + // ] + auto pix1 = _mm256_inserti128_si256(_mm256_castsi128_si256( + mm_cvtepu8_epi32(lineIn0_min + stride * i, i32_aligned)), + mm_cvtepu8_epi32(lineIn1_min + stride * i, i32_aligned), 1); + // Compute output value as C += w0 * C0 for each channel in 32-bit precision + sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk)); + + // Same as above for lines 2 and 3 + auto pix2 = _mm256_inserti128_si256(_mm256_castsi128_si256( + mm_cvtepu8_epi32(lineIn2_min + stride * i, i32_aligned)), + mm_cvtepu8_epi32(lineIn3_min + stride * i, i32_aligned), 1); + sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk)); + } + + if (i == ids_size - 1) { + // last element + auto mmk = _mm256_set1_epi32(k[i]); + // For num_channels == 3 (3 bytes = one pixel) we tolerate to read 4 bytes + // lines 0, 1 and 2 wont go out of allocated memory bounds + auto pix = _mm256_inserti128_si256(_mm256_castsi128_si256( + mm_cvtepu8_epi32(lineIn0_min + stride * i, i32_aligned)), + mm_cvtepu8_epi32(lineIn1_min + stride * i, i32_aligned), 1); + sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix, mmk)); + + auto p0 = mm_cvtepu8_epi32(lineIn2_min + stride * i, i32_aligned); + __m128i p1; + if (num_channels == 3 && C10_UNLIKELY(is_last_line && ids_min + stride * i + 4 >= max_in_x_strided)) { + uint8_t input[4]; + std::memcpy(input, lineIn3_min + stride * i, 3); + p1 = mm_cvtepu8_epi32(input, true); + } else { + p1 = mm_cvtepu8_epi32(lineIn3_min + stride * i, i32_aligned); + } + auto pix2 = _mm256_inserti128_si256(_mm256_castsi128_si256(p0), p1, 1); + sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk)); + } + + // Convert fixed point values back to integers (truncating) + sss0 = _mm256_srai_epi32(sss0, coefs_precision); + sss1 = _mm256_srai_epi32(sss1, coefs_precision); + // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation + // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d 0 0 0 0 0 0 0 0) + sss0 = _mm256_packs_epi32(sss0, zero); + sss1 = _mm256_packs_epi32(sss1, zero); + // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation + // (a a b b c c d d) -> (a b c d 0 0 0 0) + sss0 = _mm256_packus_epi16(sss0, zero); + sss1 = _mm256_packus_epi16(sss1, zero); + + // Write the output into single uint32 + // (a b c d) -> x_uint32 + auto o0 = _mm_cvtsi128_si32(_mm256_castsi256_si128(sss0)); + auto o1 = _mm_cvtsi128_si32(_mm256_extracti128_si256(sss0, 1)); + auto o2 = _mm_cvtsi128_si32(_mm256_castsi256_si128(sss1)); + auto o3 = _mm_cvtsi128_si32(_mm256_extracti128_si256(sss1, 1)); + + const auto out_x_strided = stride * out_x; + + if (num_channels == 3 && C10_UNLIKELY(out_x_strided + 4 >= max_out_x_strided)) { + // Memcpy 4-bytes is faster than 3-bytes and this is a boundary case when we want to write + // 4 bytes (R G B | X) to the output buffer (X1 X2 X3 | R1). + // The 4th byte in the register (X) has a garbage value and 4th byte in the output buffer (R1) has a correct + // value which was previously computed by another line. In other words, it means that we can not overwrite + // it by simply writing 4 bytes from the register to the output. We'll do the following: + // v----------| + // Output = [... X1 X2 X3 | R1 G1 B1 R2 ...] + // First, we write R1 value to the 4th byte of (R G B | X) -> (R G B | R1) + // Second, we write 4 bytes from the register to the output: (X1 X2 X3 | R1) -> (R G B | R1) + // Output = [... R G B | R1 G1 B1 R2 ...] + + _write_endline_rgb_as_uint32(lineOut0 + out_x_strided, o0); + _write_endline_rgb_as_uint32(lineOut1 + out_x_strided, o1); + _write_endline_rgb_as_uint32(lineOut2 + out_x_strided, o2); + + if (C10_UNLIKELY(is_last_line)) { + // When we handle the last line, we can not access the next 4 bytes + // as they are out of memory bounds. + std::memcpy(lineOut3 + out_x_strided, (uint8_t *) &o3, num_channels); + } else { + _write_endline_rgb_as_uint32(lineOut3 + out_x_strided, o3); + } + } else if (num_channels == 3) { + // Memcpy 4-bytes is faster than 3-bytes and here + // we simply write 4 bytes (... R G B X 0 0 0 0 0 ...) where X is a garbage value + // that we will overwrite on the next iteration: (... R G B R G B X 0 0 ...) + std::memcpy(lineOut0 + out_x_strided, (uint8_t *) &o0, 4); + std::memcpy(lineOut1 + out_x_strided, (uint8_t *) &o1, 4); + std::memcpy(lineOut2 + out_x_strided, (uint8_t *) &o2, 4); + std::memcpy(lineOut3 + out_x_strided, (uint8_t *) &o3, 4); + } else { + // num_channels = 4 -> lineOutX + out_x_strided should be uint32 aligned + *(uint32_t *)(lineOut0 + out_x_strided) = o0; + *(uint32_t *)(lineOut1 + out_x_strided) = o1; + *(uint32_t *)(lineOut2 + out_x_strided) = o2; + *(uint32_t *)(lineOut3 + out_x_strided) = o3; + } + } +} + +void ImagingResampleHorizontalConvolution8u( + uint8_t* C10_RESTRICT lineOut, + int64_t out_xsize, + const uint8_t* C10_RESTRICT lineIn, + int64_t in_xsize, + const int64_t* idx_ptr_xmin, + const int64_t* idx_ptr_size, + const int16_t* kk, + int kmax, + unsigned int coefs_precision, + int64_t num_channels, + bool is_last_line) { + + // Interpolation horizontal pass processing only one vertical line. + // - Input data format is RGBA or RGB with R,G,B,A being uint8. In case of RGBA + // we can encode 4 values as a single uint32 value. + // - We split the size of weight vector for a given output index as a sum: + // ids_size = num_blocks_8 * 8 + num_blocks_4 * 4 + num_blocks_2 * 2 + num_blocks_1 + // - We load and process 8 weights values in a loop ("block 8") then 4 weights and 2 weights values in + // in another loops ("block 4" and "block 2") and finally we process 1 weight value in the final loop ("block 1"). + + // Define various shuffling masks + const auto kmask_low = _mm256_set_epi8( + 11, 10, 9, 8, 11, 10, 9, 8, 11, 10, 9, 8, 11, 10, 9, 8, + 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0); + const auto kmask_high = _mm256_set_epi8( + 15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12, + 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4); + const auto kmask_hl = _mm256_set_epi8( + 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4, + 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0); + + const auto mask_low_c4 = _mm256_set_epi8( + -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0, + -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0); + const auto mask_high_c4 = _mm256_set_epi8( + -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8, + -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8); + const auto mask_low_c3 = _mm256_set_epi8( + -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0, + -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0); + const auto mask_high_c3 = _mm256_set_epi8( + -1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6, + -1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6); + const auto mask_hl_c3 = _mm256_set_epi8( + -1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6, + -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0); + const auto mask_hl_c4 = _mm256_set_epi8( + -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8, + -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0); + + const auto mask_low128_c3 = _mm_set_epi8( + -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0); + const auto mask_low128_c4 = _mm_set_epi8( + -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0); + + const auto mask_low = (num_channels == 3) ? mask_low_c3 : mask_low_c4; + const auto mask_high = (num_channels == 3) ? mask_high_c3 : mask_high_c4; + const auto mask_hl = (num_channels == 3) ? mask_hl_c3 : mask_hl_c4; + const auto mask_low128 = (num_channels == 3) ? mask_low128_c3 : mask_low128_c4; + + // out_xsize = output width, out_x = output x index + // ids_min is the input offset index corresponding to out_x + // ids_size is the interpolation size for out_x + + const auto stride = num_channels * sizeof(uint8_t); + const auto zero = _mm_setzero_si128(); + + TORCH_INTERNAL_ASSERT(stride == 3 || stride == 4); + + // Let's precompute ids_size limits for block 8, block 4 and block 2 + // + // In block 8 (8 means we process 8 weight values together), we read at + // most 32 bytes input data (16 + 16 bytes for RGBA and 12 + 16 bytes for RGB) + // lineIn + stride * (i + ids_min) + 32 <= lineIn + stride * (ids_size + ids_min) + // --> i <= ids_size - 32.0 / stride + // Strict boundary: + // --> i < ids_size + 1 - int(ceil(32.0 / stride)) = ids_size - b8_delta + // Soft boundary for reading inside the buffer except its boundaries: + // --> i < ids_size + 1 - int(32.0 / stride) = ids_size - b8_delta_soft + // RGBA: b8_delta = b8_delta_soft = 7 + // RGB : b8_delta = 10 + // RGB : b8_delta_soft = 9 + const auto b8_delta = (stride == 4) ? 7 : ((is_last_line) ? 10 : 9); + + // In block 4 (4 means we process 4 weight values together), we read + // 16 bytes of input data. + // lineIn + stride * (i + ids_min) + 16 <= lineIn0 + stride * (ids_size + ids_min) + // --> i <= ids_size - 16.0 / stride + // Strict boundary: + // --> i < ids_size + 1 - int(ceil(16.0 / stride)) = ids_size - b4_delta + // Soft boundary for reading inside the buffer except its boundaries: + // --> i < ids_size + 1 - int(16.0 / stride) = ids_size - b4_delta_soft + // RGBA: b4_delta = b4_delta_soft = 3 + // RGB : b4_delta = 5 + // RGB : b4_delta_soft = 4 + const auto b4_delta = (stride == 4) ? 3 : ((is_last_line) ? 5 : 4); + + // In block 2 (2 means we process 2 weight values together), we read + // 8 bytes of input data. + // lineIn0 + stride * (i + ids_min) + 8 <= lineIn0 + stride * (ids_size + ids_min) + // --> i <= ids_size - 8.0 / stride + // Strict boundary: + // --> i < ids_size + 1 - int(ceil(8.0 / stride)) = ids_size - b2_delta + // Soft boundary for reading inside the buffer except its boundaries: + // --> i < ids_size + 1 - int(8.0 / stride) = ids_size - b2_delta_soft + // RGBA: b2_delta = b2_delta_soft = 1 + // RGB : b2_delta = 2 + // RGB : b2_delta_soft = 1 + const auto b2_delta = (stride == 4) ? 1 : ((is_last_line) ? 2 : 1); + + const auto max_out_x_strided = out_xsize * stride; + const auto max_in_x_strided = in_xsize * stride; + + for (const auto out_x : c10::irange(out_xsize)) { + __m128i sss; + const auto ids_min = idx_ptr_xmin[out_x]; + const auto ids_size = idx_ptr_size[out_x]; + const auto * k = &kk[out_x * kmax]; + int64_t i = 0; + + const auto * lineIn_min = lineIn + ids_min; + + if (ids_size < 8) { + sss = _mm_set1_epi32(1 << (coefs_precision - 1)); + } else { + // Lower part will be added to higher, use only half of the error + auto sss256 = _mm256_set1_epi32(1 << (coefs_precision - 2)); + + // block 8 + for (; i < ids_size - b8_delta; i += 8) { + // Load 8 values from weight vector + auto tmp = _mm_loadu_si128((__m128i*)&k[i]); + // ksource = [ + // wl_0 wh_0 wl_1 wh_1 wl_2 wh_2 wl_3 wh_3 wl_4 wh_4 wl_5 wh_5 wl_6 wh_6 wl_7 wh_7 + // wl_0 wh_0 wl_1 wh_1 wl_2 wh_2 wl_3 wh_3 wl_4 wh_4 wl_5 wh_5 wl_6 wh_6 wl_7 wh_7 + // ] + auto ksource = _mm256_insertf128_si256(_mm256_castsi128_si256(tmp), tmp, 1); + + // RGBA: Load 8 pixels from input: + // source = [ + // r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3 + // r4 g4 b4 a4 r5 g5 b5 a5 r6 g6 b6 a6 r7 g7 b7 a7 + // ] + // RGB: Load 10 pixels from input (however we can process only 8 pixels): + // source = [ + // r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5 + // r4 g4 b4 r5 g5 b5 r6 g6 b6 r7 g7 b7 r8 g8 b8 r9 + // ] + auto source = _mm256_inserti128_si256(_mm256_castsi128_si256( + _mm_loadu_si128((__m128i *) (lineIn_min + stride * i))), + _mm_loadu_si128((__m128i *) (lineIn_min + stride * (i + 4))), 1); + + // Extract lower part of each lane, cast to epi16 and reoder RGBARGBA -> RRGGBBAA + // RGBA: pix1 = [ + // r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0 + // r4 0 r5 0 g4 0 g5 0 b4 0 b5 0 a4 0 a5 0 + // ] + // RGB: pix1 = [ + // r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 0 0 0 0 + // r4 0 r5 0 g4 0 g5 0 b4 0 b5 0 0 0 0 0 + // ] + auto pix1 = _mm256_shuffle_epi8(source, mask_low); + // mmk1 = [ + // wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ... ... + // wl_4 wh_4 wl_5 wh_5 wl_4 wh_4 wl_5 wh_5 ... ... + // ] + auto mmk1 = _mm256_shuffle_epi8(ksource, kmask_low); + // Compute output value as + // C += w0 * C0 + w1 * C1 + // C += w4 * C4 + w5 * C5 for each channel in 32-bit precision + sss256 = _mm256_add_epi32(sss256, _mm256_madd_epi16(pix1, mmk1)); + + // Same as above for higher part of each lane + auto pix2 = _mm256_shuffle_epi8(source, mask_high); + auto mmk2 = _mm256_shuffle_epi8(ksource, kmask_high); + // Compute output value as + // C += w2 * C2 + w3 * C3 + // C += w6 * C6 + w7 * C7 for each channel in 32-bit precision + sss256 = _mm256_add_epi32(sss256, _mm256_madd_epi16(pix2, mmk2)); + } + + // block 4 + for (; i < ids_size - b4_delta; i += 4) { + // Load 4 values from weight vector + auto tmp = _mm_loadl_epi64((__m128i *) &k[i]); + // ksource = [ + // wl_0 wh_0 wl_1 wh_1 wl_2 wh_2 wl_3 wh_3 0 0 0 0 0 0 0 0 + // wl_0 wh_0 wl_1 wh_1 wl_2 wh_2 wl_3 wh_3 0 0 0 0 0 0 0 0 + // ] + auto ksource = _mm256_insertf128_si256(_mm256_castsi128_si256(tmp), tmp, 1); + + // Load pixels from input line + tmp = _mm_loadu_si128((__m128i *) (lineIn_min + stride * i)); + // RGBA: source = [ + // r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3 + // r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3 + // ] + // RGB: source = [ + // r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5 + // r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5 + // ] + auto source = _mm256_insertf128_si256(_mm256_castsi128_si256(tmp), tmp, 1); + + // Cast source to epi16 and reorder RGBARGBA -> RRGGBBAA + // RGBA: pix = [ + // r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0 + // r2 0 r3 0 g2 0 g3 0 b2 0 b3 0 a2 0 a3 0 + // ] + // RGB: pix = [ + // r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 0 0 0 0 + // r2 0 r3 0 g2 0 g3 0 b2 0 b3 0 0 0 0 0 + // ] + auto pix = _mm256_shuffle_epi8(source, mask_hl); + // mmk = [ + // wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ... ... + // wl_2 wh_2 wl_3 wh_3 wl_2 wh_2 wl_3 wh_3 ... ... + // ] + auto mmk = _mm256_shuffle_epi8(ksource, kmask_hl); + // Compute output value as + // C += w0 * C0 + w1 * C1 + // C += w2 * C2 + w3 * C3 for each channel in 32-bit precision + sss256 = _mm256_add_epi32(sss256, _mm256_madd_epi16(pix, mmk)); + } + + // Sum results between the lanes + sss = _mm_add_epi32( + _mm256_extracti128_si256(sss256, 0), + _mm256_extracti128_si256(sss256, 1)); + } + + // block 2 + for (; i < ids_size - b2_delta; i += 2) { + // Load 2 values from weight vector + // mmk = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ...] + auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]); + // Load pixels from input line + // RGBA: source = [ + // r0 g0 b0 a0 r1 g1 b1 a1 0 0 0 0 0 0 0 0 + // ] + // RGB: source = [ + // r0 g0 b0 r1 g1 b1 r2 g2 0 0 0 0 0 0 0 0 + // ] + auto source = _mm_loadl_epi64((__m128i *) (lineIn_min + stride * i)); + // Cast source to epi16 and reorder RGBARGBA -> RRGGBBAA + auto pix = _mm_shuffle_epi8(source, mask_low128); + // Compute output value as C += w0 * C0 + w1 * C1 for each channel in 32-bit precision + sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk)); + } + + // block 1 + const auto i32_aligned = num_channels == 4; + for (; i < ids_size - 1; i++) { + // Load 1 value from weight vector + // mmk = [wl_0 wh_0 0 0 wl_0 wh_0 0 0 ...] + auto mmk = _mm_set1_epi32(k[i]); + // Load one pixel from input line + // RGBA: pix = [ + // r0 0 0 0 g0 0 0 0 b0 0 0 0 a0 0 0 0 + // ] + // RGB: pix = [ + // r0 0 0 0 g0 0 0 0 b0 0 0 0 r1 0 0 0 + // ] + auto pix = mm_cvtepu8_epi32(lineIn_min + stride * i, i32_aligned); + // Compute output value as C += w0 * C0 for each channel in 32-bit precision + sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk)); + } + + if (i == ids_size - 1) { + // last element + auto mmk = _mm_set1_epi32(k[i]); + __m128i pix; + auto p = lineIn_min + stride * i; + if (num_channels == 3 && C10_UNLIKELY(is_last_line && ids_min + stride * i + 4 >= max_in_x_strided)) { + uint8_t input[4]; + std::memcpy(input, p, 3); + pix = mm_cvtepu8_epi32(input, true); + } else { + pix = mm_cvtepu8_epi32(p, i32_aligned); + } + sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk)); + } + + // Convert fixed point values back to integers (truncating) + sss = _mm_srai_epi32(sss, coefs_precision); + // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation + // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d 0 0 0 0 0 0 0 0) + sss = _mm_packs_epi32(sss, zero); + // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation + // (a a b b c c d d) -> (a b c d 0 0 0 0) + sss = _mm_packus_epi16(sss, zero); + // Write the output into single uint32 + // (a b c d) -> x_uint32 + auto o = _mm_cvtsi128_si32(sss); + const auto out_x_strided = stride * out_x; + if (num_channels == 3 && C10_UNLIKELY(out_x_strided + 4 >= max_out_x_strided)) { + if (C10_UNLIKELY(is_last_line)) { + // When we handle the last line, we can not access the next 4 bytes + // as they are out of memory bounds. + std::memcpy(lineOut + out_x_strided, (uint8_t *) &o, 3); + } else { + // Memcpy 4-bytes is faster than 3-bytes and this is a boundary case when we want to write + // 4 bytes (R G B | X) to the output buffer (X1 X2 X3 | R1). + // The 4th byte in the register (X) has a garbage value and 4th byte in the output buffer (R1) has a correct + // value which was previously computed by another line. In other words, it means that we can not overwrite + // it by simply writing 4 bytes from the register to the output. We'll do the following: + // v----------| + // Output = [... X1 X2 X3 | R1 G1 B1 R2 ...] + // First, we write R1 value to the 4th byte of (R G B | X) -> (R G B | R1) + // Second, we write 4 bytes from the register to the output: (X1 X2 X3 | R1) -> (R G B | R1) + // Output = [... R G B | R1 G1 B1 R2 ...] + _write_endline_rgb_as_uint32(lineOut + out_x_strided, o); + } + } else if (num_channels == 3) { + // Memcpy 4-bytes is faster than 3-bytes and here + // we simply write 4 bytes (... R G B X 0 0 0 0 0 ...) where X is a garbage value + // that we will overwrite on the next iteration: (... R G B R G B X 0 0 ...) + std::memcpy(lineOut + out_x_strided, (uint8_t *) &o, 4); + } else { + // num_channels = 4 -> lineOut + out_x_strided should be uint32 aligned + *(uint32_t *)(lineOut + out_x_strided) = o; + } + } +} + +void ImagingResampleVerticalConvolution8u( + uint8_t* C10_RESTRICT lineOut, + const uint8_t* C10_RESTRICT lineIn, + int64_t xsize, + int64_t ids_min, + int64_t ids_size, + const int16_t* k, + unsigned int coefs_precision, + int64_t num_channels) { + + // Interpolation vertical pass processing one line. + // - We process x-axis data with blocks of 8, 2 and 1 + // - We split the size of weight vector for a given output index as a sum: K = n * 2 + m. + + // xsize = output width, also equals to input width + // ids_size = interpolation size + // ids_min = input y start index + const auto stride = num_channels * sizeof(uint8_t); + + TORCH_INTERNAL_ASSERT(stride == 3 || stride == 4); + + const int64_t data_size = xsize * stride; + const int64_t data_stride = stride; + constexpr auto vec_size = 256 / 8; + + const auto initial = _mm_set1_epi32(1 << (coefs_precision - 1)); + const auto initial_256 = _mm256_set1_epi32(1 << (coefs_precision - 1)); + const auto zero = _mm_setzero_si128(); + const auto zero_256 = _mm256_setzero_si256(); + + int64_t j = 0; + // block 8 + const auto b8_usable_vec_stride = (vec_size / data_stride) * data_stride; + for (; j < data_size - vec_size; j += b8_usable_vec_stride) { + auto sss0 = initial_256; + auto sss1 = initial_256; + auto sss2 = initial_256; + auto sss3 = initial_256; + int64_t i = 0; + const auto * lineIn_min = lineIn + j + ids_min; + + for (; i < ids_size - 1; i += 2) { + // Load 2 values from weight vector + auto mmk = _mm256_set1_epi32(*(int32_t*)&k[i]); + + // RGBA: Load 8 pixels per line + // source1 = [ + // r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3 + // r4 g4 b4 a4 r5 g5 b5 a5 r6 g6 b6 a6 r7 g7 b7 a7 + // ] + // RGB: Load 10 pixels per line (however we can process only 8 pixels): + // source1 = [ + // r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5 + // r4 g4 b4 r5 g5 b5 r6 g6 b6 r7 g7 b7 r8 g8 b8 r9 + // ] + auto source1 = + _mm256_loadu_si256((__m256i*)(lineIn_min + data_size * i)); + auto source2 = + _mm256_loadu_si256((__m256i*)(lineIn_min + data_size * (i + 1))); + + // Interleave source1 and source2 from the low half of each 128-bit lane + // and cast the result to epi16 + // RGBA: pix1 = [ + // r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 a0 0 A0 0 + // r1 0 R1 0 g1 0 G1 0 b1 0 B1 0 a1 0 A1 0 + // ] + // RGB: pix1 = [ + // r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 0 0 0 0 + // r1 0 R1 0 g1 0 G1 0 b1 0 B1 0 0 0 0 0 + // ] + auto source_lo = _mm256_unpacklo_epi8(source1, source2); + auto pix1 = _mm256_unpacklo_epi8(source_lo, zero_256); + // Compute output value as + // C += w0 * c0 + w1 * C0 + // C += w0 * c1 + w1 * C1 for each channel in 32-bit precision + sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk)); + + // RGBA: pix2 = [ + // r2 0 R2 0 g2 0 G2 0 b2 0 B2 0 a2 0 A2 0 + // r3 0 R3 0 g3 0 G3 0 b3 0 B3 0 a3 0 A3 0 + // ] + // RGB: pix2 = [ + // r2 0 R2 0 g2 0 G2 0 b2 0 B2 0 0 0 0 0 + // r3 0 R3 0 g3 0 G3 0 b3 0 B3 0 0 0 0 0 + // ] + auto pix2 = _mm256_unpackhi_epi8(source_lo, zero_256); + // Compute output value as + // C += w0 * c2 + w1 * C2 + // C += w0 * c3 + w1 * C3 for each channel in 32-bit precision + sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk)); + + // Same as above for the high half of each 128-bit lane + auto source_hi = _mm256_unpackhi_epi8(source1, source2); + auto pix3 = _mm256_unpacklo_epi8(source_hi, zero_256); + sss2 = _mm256_add_epi32(sss2, _mm256_madd_epi16(pix3, mmk)); + auto pix4 = _mm256_unpackhi_epi8(source_hi, zero_256); + sss3 = _mm256_add_epi32(sss3, _mm256_madd_epi16(pix4, mmk)); + } + // Same processing as above but with a single weight value + for (; i < ids_size; i += 1) { + auto mmk = _mm256_set1_epi32(k[i]); + + auto source1 = _mm256_loadu_si256((__m256i*)(lineIn_min + i * data_size)); + + auto source_lo = _mm256_unpacklo_epi8(source1, zero_256); + auto pix1 = _mm256_unpacklo_epi8(source_lo, zero_256); + sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk)); + auto pix2 = _mm256_unpackhi_epi8(source_lo, zero_256); + sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk)); + + auto source_hi = _mm256_unpackhi_epi8(source1, zero_256); + auto pix3 = _mm256_unpacklo_epi8(source_hi, _mm256_setzero_si256()); + sss2 = _mm256_add_epi32(sss2, _mm256_madd_epi16(pix3, mmk)); + auto pix4 = _mm256_unpackhi_epi8(source_hi, _mm256_setzero_si256()); + sss3 = _mm256_add_epi32(sss3, _mm256_madd_epi16(pix4, mmk)); + } + // Convert fixed point values back to integers (truncating) + sss0 = _mm256_srai_epi32(sss0, coefs_precision); + sss1 = _mm256_srai_epi32(sss1, coefs_precision); + sss2 = _mm256_srai_epi32(sss2, coefs_precision); + sss3 = _mm256_srai_epi32(sss3, coefs_precision); + // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation + // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d) + sss0 = _mm256_packs_epi32(sss0, sss1); + sss2 = _mm256_packs_epi32(sss2, sss3); + // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation + // (a a b b c c d d) -> (a b c d) + sss0 = _mm256_packus_epi16(sss0, sss2); + + // Stores 32 bytes + _mm256_storeu_si256((__m256i*)(lineOut + j), sss0); + } + + // TODO: Do we also need block 4 ??? + // block 2 + const auto b2_usable_vec_stride = (8 / data_stride) * data_stride; + for (; j < data_size - vec_size / 4; j += b2_usable_vec_stride) { + auto sss0 = initial; + auto sss1 = initial; + int64_t i = 0; + const auto * lineIn_min = lineIn + j + ids_min; + + for (; i < ids_size - 1; i += 2) { + // Load 2 values from weight vector + // mmk = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ... ] + auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]); + + // Load 2 pixels per line + // RGBA: source1 = [ + // r0 g0 b0 a0 r1 g1 b1 a1 0 0 0 0 0 0 0 0 + // ] + // RGB: source1 = [ + // r0 g0 b0 r1 g1 b1 r2 g2 0 0 0 0 0 0 0 0 + // ] + auto source1 = _mm_loadl_epi64((__m128i *) (lineIn_min + i * data_size)); + auto source2 = _mm_loadl_epi64((__m128i *) (lineIn_min + (i + 1) * data_size)); + // Interleave source1 and source2 and cast the result to epi16 + // RGBA: pix = [ + // r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 a0 0 A0 0 + // ] + // RGB: pix = [ + // r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 0 0 0 0 + // ] + auto source = _mm_unpacklo_epi8(source1, source2); + auto pix = _mm_unpacklo_epi8(source, zero); + // Compute output value as C += w0 * c0 + w1 * C0 for each channel in 32-bit precision + sss0 = _mm_add_epi32(sss0, _mm_madd_epi16(pix, mmk)); + // RGBA: pix = [ + // r1 0 R1 0 g1 0 G1 0 b1 0 B1 0 a1 0 A1 0 + // ] + // RGB: pix = [ + // r1 0 R1 0 g1 0 G1 0 b1 0 B1 0 0 0 0 0 + // ] + pix = _mm_unpackhi_epi8(source, zero); + // Compute output value as C += w0 * c1 + w1 * C1 for each channel in 32-bit precision + sss1 = _mm_add_epi32(sss1, _mm_madd_epi16(pix, mmk)); + } + // Same processing as above but with a single weight value + for (; i < ids_size; i += 1) { + auto mmk = _mm_set1_epi32(k[i]); + + auto source1 = _mm_loadl_epi64((__m128i*) (lineIn_min + i * data_size)); + + auto source = _mm_unpacklo_epi8(source1, zero); + auto pix1 = _mm_unpacklo_epi8(source, zero); + sss0 = _mm_add_epi32(sss0, _mm_madd_epi16(pix1, mmk)); + auto pix2 = _mm_unpackhi_epi8(source, zero); + sss1 = _mm_add_epi32(sss1, _mm_madd_epi16(pix2, mmk)); + } + // Convert fixed point values back to integers (truncating) + sss0 = _mm_srai_epi32(sss0, coefs_precision); + sss1 = _mm_srai_epi32(sss1, coefs_precision); + // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation + // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d) + sss0 = _mm_packs_epi32(sss0, sss1); + // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation + // (a a b b c c d d) -> (a b c d) + sss0 = _mm_packus_epi16(sss0, sss0); + // Store 2 pixels to the output + _mm_storel_epi64((__m128i*)(lineOut + j), sss0); + } + + // block 1 + const auto b1_usable_vec_stride = (4 / data_stride) * data_stride; + const auto i32_aligned = num_channels == 4; + for (; j < data_size - 4; j += b1_usable_vec_stride) { + auto sss = initial; + int64_t i = 0; + const auto * lineIn_min = lineIn + j + ids_min; + + for (; i < ids_size - 1; i += 2) { + // Load 2 values from weight vector + // mmk = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ... ] + auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]); + + // Load one pixel per line + // RGBA: source1 = [ + // r0 g0 b0 a0 0 0 0 0 0 0 0 0 0 0 0 0 + // ] + // RGB: source1 = [ + // r0 g0 b0 r1 0 0 0 0 0 0 0 0 0 0 0 0 + // ] + auto source1 = mm_cvtsi32_si128(lineIn_min + i * data_size, i32_aligned); + auto source2 = mm_cvtsi32_si128(lineIn_min + (i + 1) * data_size, i32_aligned); + + // Interleave source1 and source2 and cast the result to epi16 + // RGBA: pix = [ + // r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 a0 0 A0 0 + // ] + // RGB: pix = [ + // r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 0 0 0 0 + // ] + auto source = _mm_unpacklo_epi8(source1, source2); + auto pix = _mm_unpacklo_epi8(source, zero); + // Compute output value as C += w0 * c0 + w1 * C0 for each channel in 32-bit precision + sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk)); + } + + for (; i < ids_size; i++) { + auto mmk = _mm_set1_epi32(k[i]); + auto pix = mm_cvtepu8_epi32(lineIn_min + i * data_size, i32_aligned); + sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk)); + } + sss = _mm_srai_epi32(sss, coefs_precision); + sss = _mm_packs_epi32(sss, zero); + sss = _mm_packus_epi16(sss, zero); + + auto o = _mm_cvtsi128_si32(sss); + + // Here we write 4 bytes to the output even if num_channels < 4, e.g o = {r,g,b,X} for num_channels=3 + // It is OK to write 4th byte (e.g. X) as on the next step we will overwrite it with new data. + // We also wont go out of bounds of lineOut memory allocation + std::memcpy(lineOut + j, (uint8_t *) &o, 4); + } + + for (; j < data_size; j += data_stride) { + auto sss = initial; + int64_t i = 0; + const auto * lineIn_min = lineIn + j + ids_min; + // For RGBA we can use (ids_size - 1) as tighter limit but for RGB we can read outside memory boundary + // for the last remaining line + for (; i < ids_size - 2; i += 2) { + // Load two coefficients at once + auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]); + + // Load 2 lines + auto source1 = mm_cvtsi32_si128(lineIn_min + i * data_size, i32_aligned); + auto source2 = mm_cvtsi32_si128(lineIn_min + (i + 1) * data_size, i32_aligned); + + auto source = _mm_unpacklo_epi8(source1, source2); + auto pix = _mm_unpacklo_epi8(source, zero); + sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk)); + } + + // Same processing as above but with a single weight value + for (; i < ids_size; i++) { + auto mmk = _mm_set1_epi32(k[i]); + + const uint8_t * p = lineIn_min + i * data_size; + __m128i pix; + // There is no much perf gain using more detailed condition like + // num_channels == 3 && ids_min + j + data_size * i + 4 >= in_max_size + // const int64_t in_max_size = data_size * in_ysize; + if (num_channels == 3) { + uint8_t input[4]; + std::memcpy(input, p, 3); + pix = mm_cvtepu8_epi32(input, true); + } else { + pix = mm_cvtepu8_epi32(p, true); + } + sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk)); + } + + // Convert fixed point values back to integers (truncating) + sss = _mm_srai_epi32(sss, coefs_precision); + // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation + // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d) + sss = _mm_packs_epi32(sss, zero); + // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation + // (a a b b c c d d) -> (a b c d) + sss = _mm_packus_epi16(sss, zero); + // Store one pixel to the output + auto o = _mm_cvtsi128_si32(sss); + if (num_channels == 3 && C10_UNLIKELY(j + 4 >= data_size)) { + std::memcpy(lineOut + j, (uint8_t *) &o, 3); + } else { + std::memcpy(lineOut + j, (uint8_t *) &o, 4); + } + } +} + +} // anonymous namespace +#endif // CPU_CAPABILITY_AVX2 diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/WeightNormKernel.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/WeightNormKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..2fc0e15e5201df6d5f88689c70cb7f84b59fd99e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/WeightNormKernel.h @@ -0,0 +1,20 @@ +#pragma once +#include +#include + +namespace at { +class TensorBase; +} + +namespace at::native { + +using weight_norm_fn = void(*)( + TensorBase&, TensorBase&, const TensorBase&, const TensorBase&, int64_t); +using weight_norm_backward_fn = void(*)( + TensorBase&, TensorBase&, const TensorBase&, const TensorBase&, + const TensorBase&, const TensorBase&, int64_t); + +DECLARE_DISPATCH(weight_norm_fn, weight_norm_stub) +DECLARE_DISPATCH(weight_norm_backward_fn, weight_norm_backward_stub) + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/avx_mathfun.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/avx_mathfun.h new file mode 100644 index 0000000000000000000000000000000000000000..ce37f0aecb8cb88758703b3783c720d840a3d926 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/avx_mathfun.h @@ -0,0 +1,522 @@ +#pragma once +/* + AVX implementation of sin, cos, sincos, exp and log + + Based on "sse_mathfun.h", by Julien Pommier + http://gruntthepeon.free.fr/ssemath/ + + Copyright (C) 2012 Giovanni Garberoglio + Interdisciplinary Laboratory for Computational Science (LISC) + Fondazione Bruno Kessler and University of Trento + via Sommarive, 18 + I-38123 Trento (Italy) + + This software is provided 'as-is', without any express or implied + warranty. In no event will the authors be held liable for any damages + arising from the use of this software. + + Permission is granted to anyone to use this software for any purpose, + including commercial applications, and to alter it and redistribute it + freely, subject to the following restrictions: + + 1. The origin of this software must not be misrepresented; you must not + claim that you wrote the original software. If you use this software + in a product, an acknowledgment in the product documentation would be + appreciated but is not required. + 2. Altered source versions must be plainly marked as such, and must not be + misrepresented as being the original software. + 3. This notice may not be removed or altered from any source distribution. + + (this is the zlib license) +*/ + +#include + +/* The original source of this file has been modified. */ +#if defined(CPU_CAPABILITY_AVX2) + +#if defined(__GNUC__) +# define ALIGN32_BEG __attribute__((aligned(32))) +#elif defined(_WIN32) +# define ALIGN32_BEG __declspec(align(32)) +#endif + +typedef __m256 v8sf; // vector of 8 float (avx2) +typedef __m256i v8si; // vector of 8 int (avx2) + +/* declare some AVX constants -- why can't I figure a better way to do that? */ +#define _PS256_CONST(Name, Val) \ + static const ALIGN32_BEG float _ps256_##Name[8] = { Val, Val, Val, Val, Val, Val, Val, Val } +#define _PI32_CONST256(Name, Val) \ + static const ALIGN32_BEG int _pi32_256_##Name[8] = { Val, Val, Val, Val, Val, Val, Val, Val } +#define _PS256_CONST_TYPE(Name, Type, Val) \ + static const ALIGN32_BEG Type _ps256_##Name[8] = { Val, Val, Val, Val, Val, Val, Val, Val } + +_PS256_CONST(1 , 1.0f); +_PS256_CONST(0p5, 0.5f); +/* the smallest non denormalized float number */ +_PS256_CONST_TYPE(min_norm_pos, int, 0x00800000); +_PS256_CONST_TYPE(mant_mask, int, 0x7f800000); +_PS256_CONST_TYPE(inv_mant_mask, int, ~0x7f800000); + +_PS256_CONST_TYPE(sign_mask, int, (int)0x80000000); +_PS256_CONST_TYPE(inv_sign_mask, int, ~0x80000000); + +_PI32_CONST256(0, 0); +_PI32_CONST256(1, 1); +_PI32_CONST256(inv1, ~1); +_PI32_CONST256(2, 2); +_PI32_CONST256(4, 4); +_PI32_CONST256(0x7f, 0x7f); + +_PS256_CONST(cephes_SQRTHF, 0.707106781186547524); +_PS256_CONST(cephes_log_p0, 7.0376836292E-2); +_PS256_CONST(cephes_log_p1, - 1.1514610310E-1); +_PS256_CONST(cephes_log_p2, 1.1676998740E-1); +_PS256_CONST(cephes_log_p3, - 1.2420140846E-1); +_PS256_CONST(cephes_log_p4, + 1.4249322787E-1); +_PS256_CONST(cephes_log_p5, - 1.6668057665E-1); +_PS256_CONST(cephes_log_p6, + 2.0000714765E-1); +_PS256_CONST(cephes_log_p7, - 2.4999993993E-1); +_PS256_CONST(cephes_log_p8, + 3.3333331174E-1); +_PS256_CONST(cephes_log_q1, -2.12194440e-4); +_PS256_CONST(cephes_log_q2, 0.693359375); + + +/* natural logarithm computed for 8 simultaneous float + return NaN for x <= 0 +*/ +inline v8sf log256_ps(v8sf x) { + v8si imm0; + v8sf one = *(v8sf*)_ps256_1; + + //v8sf invalid_mask = _mm256_cmple_ps(x, _mm256_setzero_ps()); + v8sf invalid_mask = _mm256_cmp_ps(x, _mm256_setzero_ps(), _CMP_LE_OS); + + x = _mm256_max_ps(x, *(v8sf*)_ps256_min_norm_pos); /* cut off denormalized stuff */ + + // can be done with AVX2 + imm0 = _mm256_srli_epi32(_mm256_castps_si256(x), 23); + + /* keep only the fractional part */ + x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_mant_mask); + x = _mm256_or_ps(x, *(v8sf*)_ps256_0p5); + + // this is again another AVX2 instruction + imm0 = _mm256_sub_epi32(imm0, *(v8si*)_pi32_256_0x7f); + v8sf e = _mm256_cvtepi32_ps(imm0); + + e = _mm256_add_ps(e, one); + + /* part2: + if( x < SQRTHF ) { + e -= 1; + x = x + x - 1.0; + } else { x = x - 1.0; } + */ + //v8sf mask = _mm256_cmplt_ps(x, *(v8sf*)_ps256_cephes_SQRTHF); + v8sf mask = _mm256_cmp_ps(x, *(v8sf*)_ps256_cephes_SQRTHF, _CMP_LT_OS); + v8sf tmp = _mm256_and_ps(x, mask); + x = _mm256_sub_ps(x, one); + e = _mm256_sub_ps(e, _mm256_and_ps(one, mask)); + x = _mm256_add_ps(x, tmp); + + v8sf z = _mm256_mul_ps(x,x); + + v8sf y = *(v8sf*)_ps256_cephes_log_p0; + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p1); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p2); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p3); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p4); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p5); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p6); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p7); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p8); + y = _mm256_mul_ps(y, x); + + y = _mm256_mul_ps(y, z); + + tmp = _mm256_mul_ps(e, *(v8sf*)_ps256_cephes_log_q1); + y = _mm256_add_ps(y, tmp); + + + tmp = _mm256_mul_ps(z, *(v8sf*)_ps256_0p5); + y = _mm256_sub_ps(y, tmp); + + tmp = _mm256_mul_ps(e, *(v8sf*)_ps256_cephes_log_q2); + x = _mm256_add_ps(x, y); + x = _mm256_add_ps(x, tmp); + x = _mm256_or_ps(x, invalid_mask); // negative arg will be NAN + return x; +} + +_PS256_CONST(exp_hi, 88.3762626647949f); +_PS256_CONST(exp_lo, -88.3762626647949f); + +_PS256_CONST(cephes_LOG2EF, 1.44269504088896341); +_PS256_CONST(cephes_exp_C1, 0.693359375); +_PS256_CONST(cephes_exp_C2, -2.12194440e-4); + +_PS256_CONST(cephes_exp_p0, 1.9875691500E-4); +_PS256_CONST(cephes_exp_p1, 1.3981999507E-3); +_PS256_CONST(cephes_exp_p2, 8.3334519073E-3); +_PS256_CONST(cephes_exp_p3, 4.1665795894E-2); +_PS256_CONST(cephes_exp_p4, 1.6666665459E-1); +_PS256_CONST(cephes_exp_p5, 5.0000001201E-1); + +inline v8sf exp256_ps(v8sf x) { + v8sf tmp = _mm256_setzero_ps(), fx; + v8si imm0; + v8sf one = *(v8sf*)_ps256_1; + + x = _mm256_min_ps(x, *(v8sf*)_ps256_exp_hi); + x = _mm256_max_ps(x, *(v8sf*)_ps256_exp_lo); + + /* express exp(x) as exp(g + n*log(2)) */ + fx = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_LOG2EF); + fx = _mm256_add_ps(fx, *(v8sf*)_ps256_0p5); + + /* how to perform a floorf with SSE: just below */ + //imm0 = _mm256_cvttps_epi32(fx); + //tmp = _mm256_cvtepi32_ps(imm0); + + tmp = _mm256_floor_ps(fx); + + /* if greater, subtract 1 */ + //v8sf mask = _mm256_cmpgt_ps(tmp, fx); + v8sf mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS); + mask = _mm256_and_ps(mask, one); + fx = _mm256_sub_ps(tmp, mask); + + tmp = _mm256_mul_ps(fx, *(v8sf*)_ps256_cephes_exp_C1); + v8sf z = _mm256_mul_ps(fx, *(v8sf*)_ps256_cephes_exp_C2); + x = _mm256_sub_ps(x, tmp); + x = _mm256_sub_ps(x, z); + + z = _mm256_mul_ps(x,x); + + v8sf y = *(v8sf*)_ps256_cephes_exp_p0; + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p1); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p2); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p3); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p4); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p5); + y = _mm256_mul_ps(y, z); + y = _mm256_add_ps(y, x); + y = _mm256_add_ps(y, one); + + /* build 2^n */ + imm0 = _mm256_cvttps_epi32(fx); + // another two AVX2 instructions + imm0 = _mm256_add_epi32(imm0, *(v8si*)_pi32_256_0x7f); + imm0 = _mm256_slli_epi32(imm0, 23); + v8sf pow2n = _mm256_castsi256_ps(imm0); + y = _mm256_mul_ps(y, pow2n); + return y; +} + +_PS256_CONST(minus_cephes_DP1, -0.78515625); +_PS256_CONST(minus_cephes_DP2, -2.4187564849853515625e-4); +_PS256_CONST(minus_cephes_DP3, -3.77489497744594108e-8); +_PS256_CONST(sincof_p0, -1.9515295891E-4); +_PS256_CONST(sincof_p1, 8.3321608736E-3); +_PS256_CONST(sincof_p2, -1.6666654611E-1); +_PS256_CONST(coscof_p0, 2.443315711809948E-005); +_PS256_CONST(coscof_p1, -1.388731625493765E-003); +_PS256_CONST(coscof_p2, 4.166664568298827E-002); +_PS256_CONST(cephes_FOPI, 1.27323954473516); // 4 / M_PI + + +/* evaluation of 8 sines at onces using AVX intrinsics + + The code is the exact rewriting of the cephes sinf function. + Precision is excellent as long as x < 8192 (I did not bother to + take into account the special handling they have for greater values + -- it does not return garbage for arguments over 8192, though, but + the extra precision is missing). + + Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the + surprising but correct result. + +*/ +inline v8sf sin256_ps(v8sf x) { // any x + v8sf xmm1, xmm2 = _mm256_setzero_ps(), xmm3, sign_bit, y; + v8si imm0, imm2; + + sign_bit = x; + /* take the absolute value */ + x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask); + /* extract the sign bit (upper one) */ + sign_bit = _mm256_and_ps(sign_bit, *(v8sf*)_ps256_sign_mask); + + /* scale by 4/Pi */ + y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI); + + /* + Here we start a series of integer operations, which are in the + realm of AVX2. + If we don't have AVX, let's perform them using SSE2 directives + */ + + /* store the integer part of y in mm0 */ + imm2 = _mm256_cvttps_epi32(y); + /* j=(j+1) & (~1) (see the cephes sources) */ + // another two AVX2 instruction + imm2 = _mm256_add_epi32(imm2, *(v8si*)_pi32_256_1); + imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_inv1); + y = _mm256_cvtepi32_ps(imm2); + + /* get the swap sign flag */ + imm0 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_4); + imm0 = _mm256_slli_epi32(imm0, 29); + /* get the polynom selection mask + there is one polynom for 0 <= x <= Pi/4 + and another one for Pi/4 +#include + +namespace at::native { + +using weight_to_int4pack_fn = void (*)(const Tensor&, const Tensor&); +using int4pack_mm_fn = + void (*)(const Tensor&, const Tensor&, const Tensor&, int, const Tensor&); +using int8pack_mm_fn = + void (*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&); +using dyn_quant_pack_4bit_weight_fn = void (*)( + Tensor&, + const Tensor&, + const Tensor&, + const std::optional& bias, + const int64_t, + const int64_t, + const int64_t); +using dyn_quant_matmul_4bit_fn = void (*)( + const Tensor&, + const Tensor&, + const Tensor&, + const int64_t, + const int64_t, + const int64_t, + const int64_t); + +DECLARE_DISPATCH(weight_to_int4pack_fn, weight_to_int4pack_stub) +DECLARE_DISPATCH(int4pack_mm_fn, int4pack_mm_stub) +DECLARE_DISPATCH(int8pack_mm_fn, int8pack_mm_stub) +DECLARE_DISPATCH( + dyn_quant_pack_4bit_weight_fn, + dyn_quant_pack_4bit_weight_stub) +DECLARE_DISPATCH(dyn_quant_matmul_4bit_fn, dyn_quant_matmul_4bit_stub) + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/mixed_data_type.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/mixed_data_type.h new file mode 100644 index 0000000000000000000000000000000000000000..6bb14031e487ea1ea4838ba3dc6a8a52e8627f18 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/mixed_data_type.h @@ -0,0 +1,41 @@ +#pragma once + +#include + +namespace at::native { + +inline ScalarType first_type() { + return ScalarType::Undefined; +} + +template +inline ScalarType first_type(const Tensor& arg, const Args&... parameters) { + return arg.defined() ? arg.scalar_type() : first_type(parameters...); +} + +template +inline bool is_mixed_type(const Tensor& input, const Args&... parameters) { + const auto parameter_type = first_type(parameters...); + return ((parameter_type != ScalarType::Undefined) && + (parameter_type != input.scalar_type())); +} + +// currently on CPU, mixed data type is only supported +// when input is 'BFloat16' or 'Half' and parameters are 'Float' +inline void check_mixed_data_type(const Tensor& input) { + TORCH_CHECK(at::isReducedFloatingType(input.scalar_type()), + "mixed dtype (CPU): all inputs must share same datatype."); +} + +template +inline void check_mixed_data_type(const Tensor& input, const Tensor& parameter, const Args&... parameters) { + TORCH_CHECK(!parameter.defined() || parameter.scalar_type() == ScalarType::Float, + "mixed dtype (CPU): expect parameter to have scalar type of Float"); + check_mixed_data_type(input, parameters...); +} + +inline ScalarType param_scalar_type(const Tensor& t, bool is_mixed_type) { + return is_mixed_type ? ScalarType::Float : t.scalar_type(); +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/moments_utils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/moments_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..93d845a804440b5506e9d45a0f06956725db6d70 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/moments_utils.h @@ -0,0 +1,202 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace at::native { +inline namespace CPU_CAPABILITY { + +template using opmath_t = at::opmath_type; + +constexpr int64_t kChunkSize = 16; + +template +void AddMoments( + int64_t m0_add, + const T& m1_add, + const T& m2_add, + int64_t& m0, + T& m1, + T& m2) { + const int64_t n = m0 + m0_add; + const T c = n == 0 ? static_cast(0) : static_cast(m0_add) / static_cast(n); + const T delta = m1_add - m1; + m1 += c * delta; + m2 += m2_add + delta * delta * c * static_cast(m0); + m0 = n; +} + +template +C10_ALWAYS_INLINE void AddMomentsVec( + int64_t m0_add, + const vec::Vectorized& m1_add, + const vec::Vectorized& m2_add, + int64_t& m0, + vec::Vectorized& m1, + vec::Vectorized& m2) { + using Vec = vec::Vectorized; + const int64_t n = m0 + m0_add; + const T c = n == 0 ? static_cast(0) : static_cast(m0_add) / static_cast(n); + const Vec c_vec(c); + const Vec delta = m1_add - m1; + m1 += c_vec * delta; + m2 += m2_add + delta * delta * c_vec * Vec(static_cast(m0)); + m0 = n; +} + +template +inline std::enable_if_t>, void> +UpdateMomentsVec( + int64_t m0, + const T* X_ptr, + const std::array>, kChunkSize>& c_vecs, + int64_t& m0_stk0, + vec::Vectorized>& m1_stk0, + vec::Vectorized>& m2_stk0) { + using Vec = vec::Vectorized>; + Vec m1_vec(0); + Vec m2_vec(0); + for (const auto j : c10::irange(m0)) { + const Vec x_vec = Vec::loadu(X_ptr + j * Vec::size()); + const Vec delta_vec = x_vec - m1_vec; + m1_vec += delta_vec * c_vecs[j]; + m2_vec += delta_vec * (x_vec - m1_vec); + } + AddMomentsVec(m0, m1_vec, m2_vec, m0_stk0, m1_stk0, m2_stk0); +} + +// each bfloat16/half vector will be converted to two float vectors, +// and accumulated successively on m1_stk0/m2_stk0. +template +inline std::enable_if_t>, void> +UpdateMomentsVec( + int64_t m0, + const T* X_ptr, + const std::array>, kChunkSize>& c_vecs, + int64_t& m0_stk0, + vec::Vectorized>& m1_stk0, + vec::Vectorized>& m2_stk0) { + using Vec = vec::Vectorized; + using fVec = vec::Vectorized>; + fVec m1_fvec0(0), m1_fvec1(0); + fVec m2_fvec0(0), m2_fvec1(0); + for (const auto j : c10::irange(m0)) { + const Vec x_bvec = Vec::loadu(X_ptr + j * Vec::size()); + auto [x_fvec0, x_fvec1] = convert_to_float(x_bvec); + const fVec delta_fvec0 = x_fvec0 - m1_fvec0; + const fVec delta_fvec1 = x_fvec1 - m1_fvec1; + m1_fvec0 += delta_fvec0 * c_vecs[j]; + m1_fvec1 += delta_fvec1 * c_vecs[j]; + m2_fvec0 += delta_fvec0 * (x_fvec0 - m1_fvec0); + m2_fvec1 += delta_fvec1 * (x_fvec1 - m1_fvec1); + } + AddMomentsVec(m0, m1_fvec0, m2_fvec0, m0_stk0, m1_stk0, m2_stk0); + AddMomentsVec(m0, m1_fvec1, m2_fvec1, m0_stk0, m1_stk0, m2_stk0); +} + +// Compute rowwise moments by Welford algorithm and cascade sum to improve +// numerical stability. +// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance +// https://en.wikipedia.org/wiki/Pairwise_summation +template +std::pair, opmath_t> RowwiseMomentsImpl(const T* X, int64_t N, int64_t ddof = 0) { + using math_t = opmath_t; + + constexpr int64_t kVecSize = vec::Vectorized::size(); + constexpr int64_t kAccVecSize = vec::Vectorized::size(); + const int64_t n = N / kVecSize; + const int64_t m = divup(n, kChunkSize); + const int64_t depth = utils::CeilLog2(m); + + using Vec = vec::Vectorized; + const Vec kZeroVec(math_t(0)); + c10::SmallVector m0_stk(depth, 0); + c10::SmallVector m1_stk(depth, kZeroVec); + c10::SmallVector m2_stk(depth, kZeroVec); + + for (const auto i : c10::irange(m)) { + const T* X_ptr = X + i * kChunkSize * kVecSize; + const int64_t m0 = std::min(kChunkSize, n - i * kChunkSize); + static std::array c_vecs = ([]() { + std::array result; + for (const auto i : c10::irange(kChunkSize)) { + result[i] = Vec(math_t(1) / static_cast(i + 1)); + } + return result; + })(); + UpdateMomentsVec(m0, X_ptr, c_vecs, m0_stk[0], m1_stk[0], m2_stk[0]); + + int64_t mask = i + 1; + for (int64_t j = 1; j < depth && (mask & 1) == 0; ++j) { + AddMomentsVec( + m0_stk[j - 1], + m1_stk[j - 1], + m2_stk[j - 1], + m0_stk[j], + m1_stk[j], + m2_stk[j]); + m0_stk[j - 1] = 0; + m1_stk[j - 1] = kZeroVec; + m2_stk[j - 1] = kZeroVec; + mask >>= 1; + } + } + for (const auto i : c10::irange(1, depth)) { + AddMomentsVec( + m0_stk[i], m1_stk[i], m2_stk[i], m0_stk[0], m1_stk[0], m2_stk[0]); + } + + std::array m1_arr{}; + std::array m2_arr{}; + m1_stk[0].store(m1_arr.data()); + m2_stk[0].store(m2_arr.data()); + + int64_t m0 = 0; + math_t m1 = 0; + math_t m2 = 0; + for (int64_t i = n * kVecSize; i < N; ++i) { + math_t x = static_cast(X[i]); + const math_t delta = x - m1; + ++m0; + m1 += delta / static_cast(m0); + m2 += delta * (x - m1); + } + // for BFloat16, each vector in m1_arr/m2_arr holds 2*n accumulated result + int64_t m0_add = n * kVecSize / kAccVecSize; + for (const auto i : c10::irange(kAccVecSize)) { + AddMoments(m0_add, m1_arr[i], m2_arr[i], m0, m1, m2); + } + + return std::make_pair(m1, m2 / static_cast(N - ddof)); +} + +template +std::pair, opmath_t> RowwiseMoments(const T* X, int64_t N, int64_t ddof = 0) { + using Vec = vec::Vectorized; + constexpr int64_t kVecSize = Vec::size(); + const int64_t n = N / kVecSize; + const int64_t m = divup(n, kChunkSize); + const int64_t depth = utils::CeilLog2(m); + if (depth <= 4) { + return RowwiseMomentsImpl(X, N, ddof); + } else if (depth <= 8) { + return RowwiseMomentsImpl(X, N, ddof); + } else if (depth <= 16) { + return RowwiseMomentsImpl(X, N, ddof); + } else if (depth <= 32) { + return RowwiseMomentsImpl(X, N, ddof); + } else { + return RowwiseMomentsImpl(X, N, ddof); + } +} + +} // namespace CPU_CAPABILITY +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/utils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..3edbd80312070f3539afab503487d9da247e8470 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/utils.h @@ -0,0 +1,212 @@ +#pragma once + +#include +#include +#include +#include + +#ifdef USE_FBGEMM +#include +#endif + +namespace at::native { + +template +inline void _store(T* dst, at::vec::Vectorized src) { + src.store(dst); +} + +inline void _store(at::BFloat16* dst, at::vec::Vectorized src) { + auto res = at::vec::convert_float_bfloat16(src, src); + res.store(dst, at::vec::Vectorized::size()); +} + +inline void _store(at::Half* dst, at::vec::Vectorized src) { + auto res = at::vec::convert_float_half(src, src); + res.store(dst, at::vec::Vectorized::size()); +} + +inline namespace CPU_CAPABILITY { + +template +inline T data_index_init(T offset) { + return offset; +} + +template +inline T data_index_init(T offset, T& x, const T& X, Args&&... args) { + offset = data_index_init(offset, std::forward(args)...); + x = offset % X; + return offset / X; +} + +inline bool data_index_step() { + return true; +} + +template +inline bool data_index_step(T& x, const T& X, Args&&... args) { + if (data_index_step(std::forward(args)...)) { + x = ((x + 1) == X) ? 0 : (x + 1); + return x == 0; + } + return false; +} + +// Helper struct for bfloat16/float16 vectorization +// Useful when you need float as immediate dtype or accumulate dtype +using namespace vec; +struct Vec2 { + Vectorized val0, val1; + Vec2(Vectorized v0, Vectorized v1) : val0(v0), val1(v1) {} + Vec2(float v) : val0(v), val1(v) {} + static Vec2 loadu(const BFloat16* ptr) { + auto [v0, v1] = convert_bfloat16_float(Vectorized::loadu(ptr)); + return {v0, v1}; + } + static Vec2 loadu(const Half* ptr) { + auto [v0, v1] = convert_half_float(Vectorized::loadu(ptr)); + return {v0, v1}; + } + static Vec2 loadu(const float* ptr) { + return {Vectorized::loadu(ptr), Vectorized::loadu(ptr + Vectorized::size())}; + } + void store(BFloat16* ptr) const { + Vectorized val = convert_float_bfloat16(val0, val1); + val.store(ptr); + } + void store(Half* ptr) const { + Vectorized val = convert_float_half(val0, val1); + val.store(ptr); + } + void store(float* ptr) const { + val0.store(ptr); + val1.store(ptr + Vectorized::size()); + } +}; +inline Vec2 operator+(const Vec2& a, const Vec2& b) { return {a.val0 + b.val0, a.val1 + b.val1}; } +inline Vec2 operator*(const Vec2& a, const Vec2& b) { return {a.val0 * b.val0, a.val1 * b.val1}; } +inline Vec2 operator-(const Vec2& a, const Vec2& b) { return {a.val0 - b.val0, a.val1 - b.val1}; } +inline Vec2 operator/(const Vec2& a, const Vec2& b) { return {a.val0 / b.val0, a.val1 / b.val1}; } +inline Vec2 maximum(const Vec2& a, const Vec2& b) { return {vec::maximum(a.val0, b.val0), vec::maximum(a.val1, b.val1)}; } +inline Vec2 minimum(const Vec2& a, const Vec2& b) { return {vec::minimum(a.val0, b.val0), vec::minimum(a.val1, b.val1)}; } + +template struct VectorizedType { using type = Vectorized; }; +template <> struct VectorizedType { using type = Vec2; }; +template <> struct VectorizedType { using type = Vec2; }; +template using VecType = typename VectorizedType::type; + +// Helper for mixed data type parameter Vec::load +inline std::tuple, Vectorized> load2f(const BFloat16* ptr) { + return convert_bfloat16_float(Vectorized::loadu(ptr)); +} + +inline std::tuple, Vectorized> load2f(const Half* ptr) { + return convert_half_float(Vectorized::loadu(ptr)); +} + +inline std::tuple, Vectorized> load2f(const float* ptr) { + using Vec = Vectorized; + return std::make_tuple(Vec::loadu(ptr), Vec::loadu(ptr + Vec::size())); +} + +inline std::tuple, Vectorized> load2f(const BFloat16* ptr, int64_t count) { + return convert_bfloat16_float(Vectorized::loadu(ptr, count)); +} + +inline std::tuple, Vectorized> load2f(const Half* ptr, int64_t count) { + return convert_half_float(Vectorized::loadu(ptr, count)); +} + +inline std::tuple, Vectorized> load2f(const float* ptr, int64_t count) { + using Vec = Vectorized; + if (count > Vec::size()) { + return std::make_tuple(Vec::loadu(ptr), Vec::loadu(ptr + Vec::size(), count - Vec::size())); + } else { + return std::make_tuple(Vec::loadu(ptr, count), Vec(0)); + } +} + +} // namespace + +namespace utils { + +template +T CeilLog2(const T& x) { + if (x <= 2) { + return 1; + } + // Last set bit is floor(log2(x)), floor + 1 is ceil + // except when x is an exact powers of 2, so subtract 1 first + return static_cast(llvm::findLastSet(static_cast(x) - 1)) + 1; +} + +// matrix transpose: +// src has shape of M by N, with leading dimension of ld_src +// dst has shape of N by M, with leading dimension of ld_dst +template +inline void transpose(int64_t M, int64_t N, const T* src, int64_t ld_src, T* dst, int64_t ld_dst) { + for (int64_t j = 0; j < N; j++) { + for (int64_t i = 0; i < M; i++) { + dst[j * ld_dst + i] = c10::load(&(src[i * ld_src + j])); + } + } +} + +#ifdef USE_FBGEMM +template <> +inline void transpose(int64_t M, int64_t N, const float* src, int64_t ld_src, float* dst, int64_t ld_dst) { + TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM."); + fbgemm::transpose_simd(M, N, src, ld_src, dst, ld_dst); +} + +template <> +inline void transpose(int64_t M, int64_t N, const uint16_t* src, int64_t ld_src, uint16_t* dst, int64_t ld_dst) { + TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM."); + fbgemm::transpose_simd(M, N, src, ld_src, dst, ld_dst); +} +#endif + +template +inline void parallel_sparse_csr( + const TensorAccessor& crow_acc, + const int64_t M, + const int64_t nnz, + const F& f) { + TORCH_CHECK(crow_acc.size(0) == M + 1); + + // directly parallel on `M` may lead to load imbalance, + // statically determine thread partition here to average payload + // for each thread. + int num_threads = at::get_num_threads(); + std::vector thread_splits(num_threads + 1, M); + + int64_t thread_averge_payload = std::max((int64_t)1, divup(nnz, num_threads)); + + thread_splits[0] = 0; + int64_t sum = 0; + int64_t t = 1; + for (const auto m : c10::irange(M)) { + int64_t row_start = crow_acc[m]; + int64_t row_end = crow_acc[m + 1]; + sum += row_end - row_start; + if (sum > t * thread_averge_payload) { + thread_splits[t] = m; + t++; + } + } + // need to restore the last index, + // due to rounding error when calculating `thread_averge_payload`. + thread_splits[num_threads] = M; + + at::parallel_for(0, num_threads, 1, [&](int64_t cbegin, int64_t cend) { + int tid = at::get_thread_num(); + int64_t begin = thread_splits[tid]; + int64_t end = thread_splits[tid + 1]; + f(begin, end); + }); +} + +} // namespace utils + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/zmath.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/zmath.h new file mode 100644 index 0000000000000000000000000000000000000000..3ebef428705350d4dffe37b612b0f368ea87cb78 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cpu/zmath.h @@ -0,0 +1,250 @@ +#pragma once + +// Complex number math operations that act as no-ops for other dtypes. +#include +#include +#include + +namespace at::native { +inline namespace CPU_CAPABILITY { + +template +inline VALUE_TYPE zabs (SCALAR_TYPE z) { + return z; +} + +template<> +inline c10::complex zabs > (c10::complex z) { + return c10::complex(std::abs(z)); +} + +template<> +inline float zabs , float> (c10::complex z) { + return std::abs(z); +} + +template<> +inline c10::complex zabs > (c10::complex z) { + return c10::complex(std::abs(z)); +} + +template<> +inline double zabs , double> (c10::complex z) { + return std::abs(z); +} + +// This overload corresponds to non-complex dtypes. +// The function is consistent with its NumPy equivalent +// for non-complex dtypes where `pi` is returned for +// negative real numbers and `0` is returned for 0 or positive +// real numbers. +// Note: `nan` is propagated. +template +inline VALUE_TYPE angle_impl (SCALAR_TYPE z) { + if (at::_isnan(z)) { + return z; + } + return z < 0 ? c10::pi : 0; +} + +template<> +inline c10::complex angle_impl > (c10::complex z) { + return c10::complex(std::arg(z), 0.0); +} + +template<> +inline float angle_impl , float> (c10::complex z) { + return std::arg(z); +} + +template<> +inline c10::complex angle_impl > (c10::complex z) { + return c10::complex(std::arg(z), 0.0); +} + +template<> +inline double angle_impl , double> (c10::complex z) { + return std::arg(z); +} + +template +constexpr VALUE_TYPE real_impl (SCALAR_TYPE z) { + return z; //No-Op +} + +template<> +constexpr c10::complex real_impl > (c10::complex z) { + return c10::complex(z.real(), 0.0); +} + +template<> +constexpr float real_impl , float> (c10::complex z) { + return z.real(); +} + +template<> +constexpr c10::complex real_impl > (c10::complex z) { + return c10::complex(z.real(), 0.0); +} + +template<> +constexpr double real_impl , double> (c10::complex z) { + return z.real(); +} + +template +constexpr VALUE_TYPE imag_impl (SCALAR_TYPE /*z*/) { + return 0; +} + +template<> +constexpr c10::complex imag_impl > (c10::complex z) { + return c10::complex(z.imag(), 0.0); +} + +template<> +constexpr float imag_impl , float> (c10::complex z) { + return z.imag(); +} + +template<> +constexpr c10::complex imag_impl > (c10::complex z) { + return c10::complex(z.imag(), 0.0); +} + +template<> +constexpr double imag_impl , double> (c10::complex z) { + return z.imag(); +} + +template +inline TYPE conj_impl (TYPE z) { + return z; //No-Op +} + +template<> +inline c10::complex conj_impl > (c10::complex z) { + return c10::complex{z.real(), -z.imag()}; +} + +template<> +inline c10::complex conj_impl > (c10::complex z) { + return c10::complex(z.real(), -z.imag()); +} + +template<> +inline c10::complex conj_impl > (c10::complex z) { + return c10::complex(z.real(), -z.imag()); +} + +template +inline TYPE ceil_impl (TYPE z) { + return std::ceil(z); +} + +template <> +inline c10::complex ceil_impl (c10::complex z) { + return c10::complex(std::ceil(z.real()), std::ceil(z.imag())); +} + +template <> +inline c10::complex ceil_impl (c10::complex z) { + return c10::complex(std::ceil(z.real()), std::ceil(z.imag())); +} + +template +inline c10::complex sgn_impl (c10::complex z) { + if (z == c10::complex(0, 0)) { + return c10::complex(0, 0); + } else { + return z / zabs(z); + } +} + +template +inline TYPE floor_impl (TYPE z) { + return std::floor(z); +} + +template <> +inline c10::complex floor_impl (c10::complex z) { + return c10::complex(std::floor(z.real()), std::floor(z.imag())); +} + +template <> +inline c10::complex floor_impl (c10::complex z) { + return c10::complex(std::floor(z.real()), std::floor(z.imag())); +} + +template +inline TYPE round_impl (TYPE z) { + return std::nearbyint(z); +} + +template <> +inline c10::complex round_impl (c10::complex z) { + return c10::complex(std::nearbyint(z.real()), std::nearbyint(z.imag())); +} + +template <> +inline c10::complex round_impl (c10::complex z) { + return c10::complex(std::nearbyint(z.real()), std::nearbyint(z.imag())); +} + +template +inline TYPE trunc_impl (TYPE z) { + return std::trunc(z); +} + +template <> +inline c10::complex trunc_impl (c10::complex z) { + return c10::complex(std::trunc(z.real()), std::trunc(z.imag())); +} + +template <> +inline c10::complex trunc_impl (c10::complex z) { + return c10::complex(std::trunc(z.real()), std::trunc(z.imag())); +} + +template ::value, int> = 0> +inline TYPE max_impl (TYPE a, TYPE b) { + if (_isnan(a) || _isnan(b)) { + return std::numeric_limits::quiet_NaN(); + } else { + return std::max(a, b); + } +} + +template ::value, int> = 0> +inline TYPE max_impl (TYPE a, TYPE b) { + if (_isnan(a)) { + return a; + } else if (_isnan(b)) { + return b; + } else { + return std::abs(a) > std::abs(b) ? a : b; + } +} + +template ::value, int> = 0> +inline TYPE min_impl (TYPE a, TYPE b) { + if (_isnan(a) || _isnan(b)) { + return std::numeric_limits::quiet_NaN(); + } else { + return std::min(a, b); + } +} + +template ::value, int> = 0> +inline TYPE min_impl (TYPE a, TYPE b) { + if (_isnan(a)) { + return a; + } else if (_isnan(b)) { + return b; + } else { + return std::abs(a) < std::abs(b) ? a : b; + } +} + +} // end namespace +} //end at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/Activation.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/Activation.h new file mode 100644 index 0000000000000000000000000000000000000000..9823bae18f8be70120a442e26fd85165b80b5acc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/Activation.h @@ -0,0 +1,20 @@ +#pragma once +#include +#include + +namespace at { +struct TensorIteratorBase; +class TensorBase; +} + +namespace at::native { + +void launch_glu_backward_kernel(const TensorIteratorBase& iter, + int64_t gI_stride, int64_t I_stride); + +void launch_log_sigmoid_forward_kernel(TensorIteratorBase& iter); + +void GeluCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate); +void GeluBackwardCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate); + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/BinaryInternal.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/BinaryInternal.h new file mode 100644 index 0000000000000000000000000000000000000000..e288f7d9ad74529d111776bd1c0a16572bee57eb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/BinaryInternal.h @@ -0,0 +1,44 @@ +// DON'T include this except from Binary*.cu files. It should not leak into +// headers. +#pragma once +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace at::native::binary_internal { + +template +struct DivFunctor { + __device__ scalar_t operator()(scalar_t a, scalar_t b) const { + return a / b; + } +}; + +template +struct MulFunctor { + __device__ T operator()(T a, T b) const { + return a * b; + } +}; + +// Workaround for the error: '*' in boolean context, suggest '&&' instead +// [-Werror=int-in-bool-context] +template <> +struct MulFunctor { + __device__ bool operator()(bool a, bool b) const { + return a && b; + } +}; +void div_true_kernel_cuda(TensorIteratorBase& iter); +void div_trunc_kernel_cuda(TensorIteratorBase& iter); +} // namespace at::native::binary_internal diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/CUDAJitLoops.cuh b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/CUDAJitLoops.cuh new file mode 100644 index 0000000000000000000000000000000000000000..b068b4560acd6fd17d31de83003bd7385ed88b2b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/CUDAJitLoops.cuh @@ -0,0 +1,327 @@ +#pragma once +#include + +// Jiterator functions are guarded behind this macro +#if AT_USE_JITERATOR() + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace at::native { + +template +// warning : unused parameter when tuple is empty. +constexpr auto tuple_to_array_helper(const Tuple& t [[maybe_unused]], std::index_sequence seq) { + constexpr auto size = seq.size(); + return std::array{static_cast(&std::get(t))...}; +} + +// Helper function convert tuple to std::array +// for passing the arguments to CUDA Kernel +// NOTE: We capture tuple by reference, +// so the pointers in returned array are only valid +// till tuple is alive. +template +constexpr auto tuple_to_array(const std::tuple& extra_args) { + constexpr auto tuple_size = sizeof...(Args); + return tuple_to_array_helper(extra_args, std::make_index_sequence{}); +} + +struct JittedVecKernelCache { + // Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements) + at::cuda::jit::NvrtcFunction vec1; + at::cuda::jit::NvrtcFunction vec2; + at::cuda::jit::NvrtcFunction vec4; + at::cuda::jit::NvrtcFunction vec8; +#ifdef USE_ROCM + at::cuda::jit::NvrtcFunction vec16; +#endif + +}; + +struct JittedKernelVariantCache { + JittedVecKernelCache vec; + at::cuda::jit::NvrtcFunction noncontiguous; + at::cuda::jit::NvrtcFunction dynamic_contiguous; + at::cuda::jit::NvrtcFunction dynamic_noncontiguous; +}; + +inline c10::SmallBuffer pack_kernel_args( + std::initializer_list args, + c10::ArrayRef extra_args) { + c10::SmallBuffer ret(args.size() + extra_args.size()); + std::copy(args.begin(), args.end(), ret.data()); + std::copy(extra_args.begin(), extra_args.end(), ret.data() + args.size()); + return ret; +} + +template +void launch_jitted_unrolled_kernel( + std::mutex &jiterator_mutex, + at::cuda::jit::NvrtcFunction &fn_cache, + const at::cuda::jit::KernelDescriptor &desc, + int64_t N, + array_t data, + inp_calc_t ic, + out_calc_t oc, + loader_t l, + storer_t s, + bool contiguous, + at::cuda::jit::BinaryFuncVariant scalar_pos, + const void* scalar_val, + c10::ArrayRef extra_args) { + + TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); + + int tws = at::cuda::jit::calc_thread_work_size(desc.nInputs, desc.nOutputs, desc.f_inputs_type, desc.result_type); + int bws = tws * num_threads(); + //casting result to int is always safe, intermediate is int64 and won't overflow + const uint32_t grid = (N + bws - 1) / bws; + + if (!fn_cache.function) { + const std::lock_guard lock{jiterator_mutex}; + if (!fn_cache.function) { + constexpr bool dynamic_casting = !std::is_same() || + !std::is_same(); + auto code = at::cuda::jit::generate_code( + desc, contiguous, dynamic_casting, scalar_pos, tws); + fn_cache = at::cuda::jit::jit_pwise_function(code, desc.name); + } + } + + auto args = pack_kernel_args({&N, &data, &ic, &oc, &l, &s, scalar_val}, extra_args); + at::cuda::jit::launch_jitted_pwise_function(fn_cache, args.data(), {grid, 1u, 1u}, + {num_threads(), 1u, 1u}); +} + +template +void launch_jitted_vectorized_kernel( + std::mutex &jiterator_mutex, JittedVecKernelCache &fn_cache, + const at::cuda::jit::KernelDescriptor &desc, int64_t N, array_t data, + at::cuda::jit::BinaryFuncVariant scalar_pos, + const void *scalar_val, c10::ArrayRef extra_args) { + TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); + + int tws = at::cuda::jit::calc_thread_work_size(desc.nInputs, desc.nOutputs, desc.f_inputs_type, desc.result_type); + int bws = tws * num_threads(); + // N is still int64_t for the computation, but it's always safe to cast result to int + const uint32_t grid = (N + bws - 1) / bws; + + int vec_size = at::cuda::jit::can_vectorize_up_to( + desc, c10::ArrayRef(data.data(), data.size())); + +#ifndef USE_ROCM + const auto input_size = c10::scalarTypeToTypeMeta(desc.f_inputs_type).itemsize(); + const int optimal_vec_size = 16 / static_cast(input_size); + vec_size = std::min(optimal_vec_size, vec_size); + // Here we purposely omit vec8 for 1-byte data because of a bug in NVCC + // that causes some numerical mismatches with uint8 on sm80 and sm90. + // TODO: Revisit this after CUDA 12.8 update. + if (input_size < 2) { + vec_size = std::min(vec_size, 4); + } +#endif + + // Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements) + // fn_ptr is set to the appropriate function based on the vec size and GPU used + at::cuda::jit::NvrtcFunction* fn_ptr = nullptr; + +#ifdef USE_ROCM + if (vec_size == 16) { + fn_ptr = &fn_cache.vec16; + } else +#endif + if (vec_size == 8) { + fn_ptr = &fn_cache.vec8; + } else if (vec_size == 4) { + fn_ptr = &fn_cache.vec4; + } else if (vec_size == 2) { + fn_ptr = &fn_cache.vec2; + } else if (vec_size ==1) { + fn_ptr = &fn_cache.vec1; + } else { + TORCH_INTERNAL_ASSERT(false, "unexpected vec_size for jitter vectorized kernel"); + } + + bool vectorized = vec_size > 1; + + if (!fn_ptr->function) { + const std::lock_guard lock{jiterator_mutex}; + if (!fn_ptr->function) { // cache miss! + + // Generates program + auto code = at::cuda::jit::generate_code( + desc, /*contiguous=*/true, /*dynamic_casting=*/false, + scalar_pos, tws, vectorized, vec_size); + std::string kernel_name = vectorized ? desc.name + "_vectorized" + std::to_string(vec_size) : desc.name; + + // Acquires the program + *fn_ptr = at::cuda::jit::jit_pwise_function(code, kernel_name); + } + } + + if (vectorized) { + auto args = pack_kernel_args({&N, &data, scalar_val}, extra_args); + at::cuda::jit::launch_jitted_pwise_function( + *fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u}); + } else { +// NVCC complains about unused variables l and s. +// It should be false positive in most cases, so we suppress the warnings. +#pragma nv_diagnostic push +#pragma nv_diag_suppress 177 + auto ic = TrivialOffsetCalculator(); + auto oc = TrivialOffsetCalculator<1>(); + auto l = memory::LoadWithoutCast(); + auto s = memory::StoreWithoutCast(); + + auto args = pack_kernel_args( + {&N, &data, &ic, &oc, &l, &s, scalar_val}, extra_args); + at::cuda::jit::launch_jitted_pwise_function( + *fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u}); +#pragma nv_diagnostic pop + } +} + +template +void jitted_gpu_kernel_generic( + std::mutex &jiterator_mutex, + JittedKernelVariantCache &cache, + const at::cuda::jit::KernelDescriptor &desc, + at::cuda::jit::BinaryFuncVariant scalar_pos, + c10::ArrayRef extra_args, + TensorIteratorBase& iter, + const bool dynamic_casting, + const void *scalar_val) { + TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing()); + TORCH_INTERNAL_ASSERT(iter.ninputs() == arity); + TORCH_INTERNAL_ASSERT(iter.noutputs() == 1); + + constexpr int ntensors = arity + 1; + std::array data; + for (auto i : c10::irange(ntensors)) { + data[i] = (char*)iter.data_ptr(i); + } + + int64_t numel = iter.numel(); + bool contiguous = iter.is_contiguous(); + + // Decides which of 4 kernel types to launch + // Variations are: + // - Case 1: no dynamic casting and contiguous + // - Case 2: no dynamic casting and noncontiguous + // - Case 3: dynamic casting and contiguous + // - Case 4: dynamic casting and noncontiguous + // These cases align with the non-jitted CUDALoops.cuh cases in gpu_kernel_impl + + if (!dynamic_casting) { + if (contiguous) { + // Case 1: no dynamic casting and contiguous + launch_jitted_vectorized_kernel( + jiterator_mutex, cache.vec, desc, + numel, data, scalar_pos, scalar_val, extra_args); + return; + } + + // Case 2: no dynamic casting and noncontiguous + auto input_offset_calculator = make_input_offset_calculator(iter); + auto output_offset_calculator = make_output_offset_calculator(iter); + auto loader = memory::LoadWithoutCast(); + auto storer = memory::StoreWithoutCast(); + launch_jitted_unrolled_kernel( + jiterator_mutex, cache.noncontiguous, desc, numel, data, + input_offset_calculator, output_offset_calculator, loader, + storer, contiguous, scalar_pos, scalar_val, extra_args); + return; + } + + // Cases 3 and 4 are handled below + // Both require construction of a storer (this asserts 1 output) and one or more loaders + + // Creates store cast to output (the zeroth tensor in TensorIterator) + auto storer = memory::StoreWithCast<1>(iter); + + // Creates load casts from inputs (note offset indexing into the iterators 1...n tensors) + auto loader = memory::LoadWithCast(iter); + + if (contiguous) { + // Case 3: dynamic casting and contiguous + auto input_offset_calculator = TrivialOffsetCalculator(); + auto output_offset_calculator = TrivialOffsetCalculator<1>(); + launch_jitted_unrolled_kernel( + jiterator_mutex, cache.dynamic_contiguous, desc, numel, data, input_offset_calculator, + output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args); + return; + } + + // Case 4: dynamic casting and noncontiguous + auto input_offset_calculator = make_input_offset_calculator(iter); + auto output_offset_calculator = make_output_offset_calculator(iter); + launch_jitted_unrolled_kernel( + jiterator_mutex, cache.dynamic_noncontiguous, desc, numel, data, input_offset_calculator, + output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args); +} + +// NOTE: static to reduce chances of name collision. +template < + char const* name, + typename result_type, + typename f_inputs_type, + int arity, + at::cuda::jit::BinaryFuncVariant scalar_pos = + at::cuda::jit::BinaryFuncVariant::NoScalar, + typename... ExtraArgs> +static void jitted_gpu_kernel_impl( + TensorIteratorBase& iter, + const std::string &f, + const bool dynamic_casting, + at::opmath_type scalar_val, + const std::tuple& extra_args) { + + // TODO: Memory use can probably be optimized by re-using kernels across GPUs with + // the same compute capability + static std::mutex jiterator_mutex; + static std::vector device_caches(c10::cuda::device_count()); + + constexpr int nInputs = arity; + constexpr int nOutputs = 1; // TODO: Support more than 1 output + static const auto desc = at::cuda::jit::make_kernel_descriptor< + result_type, f_inputs_type, ExtraArgs...>(name, f, nInputs, nOutputs); + + auto &cache = device_caches[iter.device().index()]; + auto extra_args_array = tuple_to_array(extra_args); + return jitted_gpu_kernel_generic( + jiterator_mutex, + cache, + desc, + scalar_pos, + extra_args_array, + iter, + dynamic_casting, + &scalar_val + ); +} + +} // at::native + +#endif // AT_USE_JITERATOR() diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/CUDALoops.cuh b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/CUDALoops.cuh new file mode 100644 index 0000000000000000000000000000000000000000..a274f327499a6d9f7797fc1b42659c9d79cb1e31 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/CUDALoops.cuh @@ -0,0 +1,1010 @@ +#pragma once + +// This file provides two functions to help write GPU elementwise kernels: +// +// gpu_kernel(TensorIterator iter, ) +// gpu_kernel_with_scalars(TensorIterator iter, ) +// +// The gpu_kernel_with_scalars generates specializations that support a +// single scalar CPU argument, such as from `cuda_tensor + 5`. The CPU scalar +// is lifted to a kernel parameter instead of copying to device memory. +// This should be used in conjunction with TensorIterator::allow_cpu_scalars_, +// which is the default for TensorIterator::binary_op. Otherwise, all inputs +// and the output must be on the GPU. +// +// For example, to write a reciprocal kernel for GPU float Tensors: +// +// gpu_kernel(iter, []GPU_LAMBDA(float a) { +// return 1.0f / a; +// }); +// +// To write a multiplication kernel for GPU float Tensors where one argument +// may be a CPU scalar: +// +// gpu_kernel_with_scalars(iter, []GPU_LAMBDA(float a, float b) { +// return a * b; +// }); +// +// See BinaryOpsKernel.cu for the complete implementation +// + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#ifdef __NVCC__ +#define ASSERT_HOST_DEVICE_LAMBDA(type) \ + static_assert( \ + __nv_is_extended_host_device_lambda_closure_type(type), \ + #type " must be a __host__ __device__ lambda") +#else +#define ASSERT_HOST_DEVICE_LAMBDA(type) +#endif + +namespace at::native { + +#ifdef USE_ROCM +// Custom configuration for vectorized elementwise kernel +// with template instantiation. +namespace vectorized_templated_config { +constexpr int num_threads() { + return 512; +} + +constexpr int elems_per_thread() { + return 32; +} + +constexpr int block_work_size() { + return elems_per_thread() * num_threads(); +} +} // namespace vectorized_templated_config +#endif + +template +constexpr auto sum_of_sizes(args_t args, std::index_sequence) { + if constexpr (sizeof...(Is) == 0) { + return 0; + } else { + return (sizeof(std::tuple_element_t) + ...); + } +} + +#ifdef USE_ROCM +template +constexpr auto elems_per_thread(){ + if constexpr (io_sizes == 1) { + return 16; + } else if constexpr (io_sizes < 4) { + return 8; + } else { + return 4; + } +} +#else +template +constexpr auto elems_per_thread(){ + if constexpr (io_sizes == 1) { + return 16; + } else { + return 8; + } +} +#endif + + +//thread work size of 8 regresses the perf of elementwise kernel on cuda +//this doesn't change ROCm behavior as thread_work_size is already 4 on ROCm +constexpr int elementwise_thread_work_size() {return 4;} +constexpr int elementwise_block_work_size() { + return elementwise_thread_work_size() * num_threads(); +} + +template +constexpr auto io_block_work_size() { + return num_threads() * elems_per_thread(); +} + +#ifdef USE_ROCM +template +constexpr auto input_size(args_t args, std::index_sequence) { + if constexpr (sizeof...(Is) == 0) { + return 0; + } else { + return sizeof(std::tuple_element_t<0, args_t>); + } +} + +template +constexpr auto calc_optimal_vec_size() { + static_assert(vec_size != 0); + static_assert(io_size != 0); + if constexpr (io_size == 1 && vec_size >= 16) { + return 16; + } else if constexpr (io_size <= 2 && vec_size >= 8) { + return 8; + } else if constexpr (io_size <= 4 && vec_size >= 4) { + return 4; + } else if constexpr (vec_size >= 4) { + return 4; + } else if constexpr (vec_size >= 2) { + return 2; + } else { + return 1; + } +} +#endif + +template +constexpr auto calc_io_size(){ + using traits = function_traits; + using args_t = typename traits::ArgsTuple; +#ifdef USE_ROCM + constexpr auto input_size = at::native::input_size(args_t{}, std::make_index_sequence>{}); + constexpr auto output_size = sizeof(typename traits::result_type); + return (input_size > 0) ? ((input_size < output_size) ? input_size : output_size) : output_size; +#else + constexpr auto input_size = at::native::sum_of_sizes(args_t{}, std::make_index_sequence>{}); + constexpr auto output_size = sizeof(typename traits::result_type); + return input_size + output_size; +#endif +} + +#ifndef USE_ROCM +// To save on binary size of libtorch_cuda.so, we split the vectorized_elementwise_kernel +// into two: one for vec_size=8 and one for vec_size=[2, 4], since vec8 is going to be +// used on sm_90 and sm_100 exclusively. +template +C10_LAUNCH_BOUNDS_1(num_threads()) +__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) { + if constexpr (vec_size == 8) { +#if __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000 + using traits = function_traits; + constexpr auto io_size = calc_io_size(); + int remaining = N - io_block_work_size() * blockIdx.x; + + if (remaining < io_block_work_size()) { // if this block handles the reminder, + // just do a naive unrolled loop + auto input_calc = TrivialOffsetCalculator(); + auto output_calc = TrivialOffsetCalculator<1>(); + auto loader = memory::LoadWithoutCast(); + auto storer = memory::StoreWithoutCast(); + auto policy = memory::policies::unroll< + array_t, + decltype(input_calc), + decltype(output_calc), + memory::LoadWithoutCast, + memory::StoreWithoutCast, + elems_per_thread()>( + data, remaining, input_calc, output_calc, loader, storer); + elementwise_kernel_helper(f, policy); + } else { // if this block has a full `block_work_size` data to handle, use + // vectorized memory access + elementwise_kernel_helper( + f, memory::policies::vectorized()>(data)); + } +#endif // __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000 + } else { + using traits = function_traits; + constexpr auto io_size = calc_io_size(); + int remaining = N - io_block_work_size() * blockIdx.x; + + if (remaining < io_block_work_size()) { // if this block handles the reminder, + // just do a naive unrolled loop + auto input_calc = TrivialOffsetCalculator(); + auto output_calc = TrivialOffsetCalculator<1>(); + auto loader = memory::LoadWithoutCast(); + auto storer = memory::StoreWithoutCast(); + auto policy = memory::policies::unroll< + array_t, + decltype(input_calc), + decltype(output_calc), + memory::LoadWithoutCast, + memory::StoreWithoutCast, + elems_per_thread()>( + data, remaining, input_calc, output_calc, loader, storer); + elementwise_kernel_helper(f, policy); + } else { // if this block has a full `block_work_size` data to handle, use + // vectorized memory access + elementwise_kernel_helper( + f, memory::policies::vectorized()>(data)); + } + } +} + +#else // USE_ROCM +template +C10_LAUNCH_BOUNDS_1(num_threads()) +__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) { + using traits = function_traits; + constexpr auto io_size = calc_io_size(); +#ifdef __gfx942__ + constexpr int tws = (io_size >= 2) ? 8 : 16; +#else + constexpr int tws = elems_per_thread(); +#endif + constexpr int bws = tws * num_threads(); + int remaining = N - bws * blockIdx.x; + + if (remaining < bws) { // if this block handles the reminder, + // just do a naive unrolled loop + auto input_calc = TrivialOffsetCalculator(); + auto output_calc = TrivialOffsetCalculator<1>(); + auto loader = memory::LoadWithoutCast(); + auto storer = memory::StoreWithoutCast(); + auto policy = memory::policies::unroll< + array_t, + decltype(input_calc), + decltype(output_calc), + memory::LoadWithoutCast, + memory::StoreWithoutCast, + tws>( + data, remaining, input_calc, output_calc, loader, storer); + elementwise_kernel_helper(f, policy); + } else { // if this block has a full `block_work_size` data to handle, use + // vectorized memory access + constexpr auto optimal_vec_size = calc_optimal_vec_size(); + elementwise_kernel_helper( + f, memory::policies::vectorized(data)); + } +} +#endif // USE_ROCM + +template < + typename func_t, + typename array_t, + int elems_per_thread, + typename inp_calc_t, + typename out_calc_t, + typename loader_t, + typename storer_t> +C10_LAUNCH_BOUNDS_1(num_threads()) +__global__ void unrolled_elementwise_kernel( + int N, + func_t f, + array_t data, + inp_calc_t ic, + out_calc_t oc, + loader_t l, + storer_t s) { + int remaining = N - elems_per_thread * num_threads() * blockIdx.x; + auto policy = memory::policies:: + unroll( + data, remaining, ic, oc, l, s); + elementwise_kernel_helper(f, policy); +} + +// this function assume trivial 1d and no dynamic casting +template +static inline void launch_vectorized_kernel( + int64_t N, + const func_t& f, + array_t data) { + TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); + using traits = function_traits; + constexpr auto io_size = calc_io_size(); + auto stream = at::cuda::getCurrentCUDAStream(); +#ifdef USE_ROCM + int vec_size = memory::can_vectorize_up_to(data); + c10::DeviceIndex curDevice = -1; + AT_CUDA_CHECK(c10::cuda::GetDevice(&curDevice)); + int tws = at::detail::getCUDAHooks().isGPUArch({"gfx942"}, curDevice) ? ((io_size >= 2) ? 8 : 16) : elems_per_thread(); +#else + using cpp_type = typename function_traits::result_type; + const uint16_t max_vec_size = memory::can_vectorize_up_to(data); + uint16_t vec_size = 16 / static_cast(sizeof(cpp_type)); + vec_size = std::min(vec_size, max_vec_size); + // Here we purposely omit vec8 for 1-byte data because of a bug in NVCC + // that causes some numerical mismatches with uint8 on sm80 and sm90. + // TODO: Revisit this after CUDA 12.8 update. + cudaDeviceProp* p = at::cuda::getDeviceProperties(stream.device().index()); + const int computeCapability = p->major * 10 + p->minor; + if (computeCapability != 90 && computeCapability != 100) { + vec_size = std::min(vec_size, 4); + } + if constexpr (sizeof(cpp_type) < 2) { + vec_size = std::min(vec_size, 4); + } + int tws = elems_per_thread(); +#endif + int bws = tws * num_threads(); + int64_t grid = (N + bws - 1) / bws; + switch (vec_size) { +#ifdef USE_ROCM + case 16: + vectorized_elementwise_kernel<16, func_t, array_t> + <<>>(N, f, data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; +#endif + case 8: + vectorized_elementwise_kernel<8, func_t, array_t> + <<>>(N, f, data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + case 4: + vectorized_elementwise_kernel<4, func_t, array_t> + <<>>(N, f, data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + case 2: + vectorized_elementwise_kernel<2, func_t, array_t> + <<>>(N, f, data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + case 1: { + auto input_calc = TrivialOffsetCalculator(); + auto output_calc = TrivialOffsetCalculator<1>(); + auto loader = memory::LoadWithoutCast(); + auto storer = memory::StoreWithoutCast(); + int64_t grid_unrolled = (N + elementwise_block_work_size() - 1) / elementwise_block_work_size(); + unrolled_elementwise_kernel + <<>>( + N, f, data, input_calc, output_calc, loader, storer); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + } + default: + TORCH_INTERNAL_ASSERT(false, "Unexpected vectorization size"); + } +} + +#ifdef USE_ROCM +template < + int vec_size, + typename func_t, + typename array_t, + typename inp_calc_t, + typename out_calc_t, + typename loader_t, + typename storer_t, + typename OutputType, + typename... InputTypes> +C10_LAUNCH_BOUNDS_1(vectorized_templated_config::num_threads()) +__global__ void vectorized_templated_elementwise_kernel( + int N, + func_t f, + array_t data, + inp_calc_t inp_calc, + out_calc_t out_calc, + loader_t loader, + storer_t storer) { + int remaining = N - + vectorized_templated_config::block_work_size() * + (gridDim.x - blockIdx.x - 1); + constexpr bool reverted_idx = true; + + if (remaining < + vectorized_templated_config::block_work_size()) { // if this block handles + // the reminder, + // just do a naive unrolled loop + auto policy = memory::policies::unroll_base< + vectorized_templated_config::num_threads(), + array_t, + inp_calc_t, + out_calc_t, + loader_t, + storer_t, + vectorized_templated_config::elems_per_thread()>( + data, remaining, inp_calc, out_calc, loader, storer); + elementwise_kernel_helper(f, policy); + } else { // if this block has a full `block_work_size` data to handle, use + // vectorized memory access + auto policy = memory::policies::vectorized_templated< + vec_size, + array_t, + vectorized_templated_config::elems_per_thread(), + vectorized_templated_config::num_threads(), + OutputType, + InputTypes...>(data); + elementwise_kernel_helper(f, policy); + } +} + +// This function assume trivial 1d and supports template specialization +// to avoid dynamic casting. +// Input vectorization size is based on runtime information, i.e. +// the actual data types of the input and output tensor and cannot +// be determined using the functor type, as in regular non-templated +// vectorized kernels. The caller is in charge of selecting the correct input +// vectorization length. +template < + typename func_t, + typename array_t, + typename inp_calc_t, + typename out_calc_t, + typename loader_t, + typename storer_t, + typename OutputType, + typename... InputTypes> +static inline void launch_vectorized_templated_kernel( + int64_t N, + const func_t& f, + array_t data, + inp_calc_t ic, + out_calc_t oc, + loader_t l, + storer_t s) { + TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); + using traits = function_traits; + int64_t grid = (N + vectorized_templated_config::block_work_size() - 1) / + vectorized_templated_config::block_work_size(); + auto stream = at::cuda::getCurrentCUDAStream(); + int vec_size = memory::can_vectorize_up_to(data); + switch (vec_size) { + case 8: + vectorized_templated_elementwise_kernel< + 8, + func_t, + array_t, + inp_calc_t, + out_calc_t, + loader_t, + storer_t, + OutputType, + InputTypes...> + <<>>( + N, f, data, ic, oc, l, s); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + case 4: + vectorized_templated_elementwise_kernel< + 4, + func_t, + array_t, + inp_calc_t, + out_calc_t, + loader_t, + storer_t, + OutputType, + InputTypes...> + <<>>( + N, f, data, ic, oc, l, s); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + case 2: + vectorized_templated_elementwise_kernel< + 2, + func_t, + array_t, + inp_calc_t, + out_calc_t, + loader_t, + storer_t, + OutputType, + InputTypes...> + <<>>( + N, f, data, ic, oc, l, s); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + default: + // vector size 1 is not handled as part of vectorize_templated kernel + TORCH_INTERNAL_ASSERT(false, "Unexpected vectorization size"); + } +} +#endif + +template < + typename func_t, + typename array_t, + typename inp_calc_t, + typename out_calc_t, + typename loader_t, + typename storer_t> +static inline void launch_unrolled_kernel( + int64_t N, + const func_t& f, + array_t data, + inp_calc_t ic, + out_calc_t oc, + loader_t l, + storer_t s) { + TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); + + int64_t grid = (N + elementwise_block_work_size() - 1) / elementwise_block_work_size(); + auto stream = at::cuda::getCurrentCUDAStream(); + unrolled_elementwise_kernel + <<>>(N, f, data, ic, oc, l, s); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +C10_LAUNCH_BOUNDS_2(nt, 4) +__global__ void elementwise_kernel(int N, func_t f) { + int tid = threadIdx.x; + int nv = nt * vt; + int idx = nv * blockIdx.x + tid; +#pragma unroll + for (int i = 0; i < vt; i++) { + if (idx < N) { + f(idx); + idx += nt; + } + } +} + +template +static void launch_legacy_kernel(int64_t N, const func_t& f) { + TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits::max()); + if (N == 0) { + return; + } + dim3 block(nt); + dim3 grid((N + block.x * vt - 1) / (block.x * vt)); + auto stream = at::cuda::getCurrentCUDAStream(); + elementwise_kernel<<>>(N, f); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +#ifdef USE_ROCM +template +C10_LAUNCH_BOUNDS_2(nt, 4) +__global__ void elementwise_kernel_manual_unroll(int N, func_t f) { + int tid = threadIdx.x; + constexpr int nv = nt * vt; + int idx = nv * blockIdx.x + tid; + if ((idx + nt*(vt-1)) < N) { + f(idx, true); + } else { +#pragma unroll + for (int i = 0; i < vt; i++) { + if (idx < N) { + f(idx, false); + idx += nt; + } + } + } +} + +template +static void launch_legacy_kernel_manual_unroll(int64_t N, const func_t& f) { + TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits::max()); + if (N == 0) { + return; + } + dim3 block(nt); + dim3 grid((N + block.x * vt - 1) / (block.x * vt)); + auto stream = at::cuda::getCurrentCUDAStream(); + elementwise_kernel_manual_unroll<<>>(N, f); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} +#endif + +template +C10_HOST_DEVICE typename traits::result_type invoke_impl( + const func_t& f, + char* const C10_RESTRICT data[], + const index_t strides[], + int i, + std::index_sequence) { + (void)strides; + (void)i; + return f(c10::load::type>( + data[INDEX] + i * strides[INDEX])...); +} + +template < + typename func_t, + typename index_t, + typename traits = function_traits> +C10_HOST_DEVICE typename traits::result_type invoke( + const func_t& f, + char* const C10_RESTRICT data[], + const index_t strides[], + int i) { + using Indices = std::make_index_sequence; + return invoke_impl(f, data, strides, i, Indices{}); +} + +template +C10_HOST_DEVICE typename traits::result_type invoke_impl( + const func_t& f, + char* const C10_RESTRICT data[], + const index_t strides[], + const ScalarType dtypes[], + int i, + std::index_sequence) { + (void)strides; + (void)i; + return f(c10::fetch_and_cast::type>( + dtypes[I], data[I] + i * strides[I])...); +} + +template < + typename func_t, + typename index_t, + typename traits = function_traits> +C10_HOST_DEVICE typename traits::result_type invoke( + const func_t& f, + char* const C10_RESTRICT data[], + const index_t strides[], + const ScalarType dtypes[], + int i) { + using Indices = std::make_index_sequence; + return invoke_impl(f, data, strides, dtypes, i, Indices{}); +} + +template +void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) { + using traits = function_traits; + using arg0_t = typename traits::result_type; + constexpr int ntensors = traits::arity + 1; + + TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing()); + TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity); + TORCH_INTERNAL_ASSERT(iter.noutputs() == 1); + TORCH_INTERNAL_ASSERT(!needs_dynamic_casting::check(iter)); + + std::array data; + for (int i = 0; i < ntensors; i++) { + data[i] = (char*)iter.data_ptr(i); + } + + int64_t numel = iter.numel(); + + bool contiguous = iter.is_contiguous(); + + if (contiguous) { + return launch_vectorized_kernel(numel, f, data); + } + auto offset_calc = ::make_offset_calculator(iter); +#ifndef USE_ROCM + constexpr int unroll_factor = sizeof(arg0_t) >= 4 ? 2 : 4; + launch_legacy_kernel<128, unroll_factor>(numel, [=] GPU_LAMBDA(int idx) { + auto offsets = offset_calc.get(idx); + arg0_t* out = (arg0_t*)(data[0] + offsets[0]); + *out = invoke(f, &data[1], &offsets[1], 1); + }); +#else + constexpr int unroll_factor = sizeof(arg0_t) >= 4 ? 4 : 8; + constexpr int grp_sz = 128; + launch_legacy_kernel_manual_unroll(numel, [=] GPU_LAMBDA(int idx, bool unrl) { + if (unrl) { + if constexpr (unroll_factor == 4) { + auto offsets0 = offset_calc.get(idx); + auto offsets1 = offset_calc.get(idx+grp_sz); + auto offsets2 = offset_calc.get(idx+grp_sz*2); + auto offsets3 = offset_calc.get(idx+grp_sz*3); + arg0_t* out0 = (arg0_t*)(data[0] + offsets0[0]); + arg0_t* out1 = (arg0_t*)(data[0] + offsets1[0]); + arg0_t* out2 = (arg0_t*)(data[0] + offsets2[0]); + arg0_t* out3 = (arg0_t*)(data[0] + offsets3[0]); + auto tmp0 = invoke(f, &data[1], &offsets0[1], 1); + auto tmp1 = invoke(f, &data[1], &offsets1[1], 1); + auto tmp2 = invoke(f, &data[1], &offsets2[1], 1); + auto tmp3 = invoke(f, &data[1], &offsets3[1], 1); + *out0 = tmp0; + *out1 = tmp1; + *out2 = tmp2; + *out3 = tmp3; + } else { + auto offsets0 = offset_calc.get(idx); + auto offsets1 = offset_calc.get(idx+grp_sz); + auto offsets2 = offset_calc.get(idx+grp_sz*2); + auto offsets3 = offset_calc.get(idx+grp_sz*3); + auto offsets4 = offset_calc.get(idx+grp_sz*4); + auto offsets5 = offset_calc.get(idx+grp_sz*5); + auto offsets6 = offset_calc.get(idx+grp_sz*6); + auto offsets7 = offset_calc.get(idx+grp_sz*7); + arg0_t* out0 = (arg0_t*)(data[0] + offsets0[0]); + arg0_t* out1 = (arg0_t*)(data[0] + offsets1[0]); + arg0_t* out2 = (arg0_t*)(data[0] + offsets2[0]); + arg0_t* out3 = (arg0_t*)(data[0] + offsets3[0]); + arg0_t* out4 = (arg0_t*)(data[0] + offsets4[0]); + arg0_t* out5 = (arg0_t*)(data[0] + offsets5[0]); + arg0_t* out6 = (arg0_t*)(data[0] + offsets6[0]); + arg0_t* out7 = (arg0_t*)(data[0] + offsets7[0]); + auto tmp0 = invoke(f, &data[1], &offsets0[1], 1); + auto tmp1 = invoke(f, &data[1], &offsets1[1], 1); + auto tmp2 = invoke(f, &data[1], &offsets2[1], 1); + auto tmp3 = invoke(f, &data[1], &offsets3[1], 1); + auto tmp4 = invoke(f, &data[1], &offsets4[1], 1); + auto tmp5 = invoke(f, &data[1], &offsets5[1], 1); + auto tmp6 = invoke(f, &data[1], &offsets6[1], 1); + auto tmp7 = invoke(f, &data[1], &offsets7[1], 1); + *out0 = tmp0; + *out1 = tmp1; + *out2 = tmp2; + *out3 = tmp3; + *out4 = tmp4; + *out5 = tmp5; + *out6 = tmp6; + *out7 = tmp7; + } + } else { + auto offsets = offset_calc.get(idx); + arg0_t* out = (arg0_t*)(data[0] + offsets[0]); + *out = invoke(f, &data[1], &offsets[1], 1); + } + }); +#endif +} + +#ifdef USE_ROCM +namespace { +template < + typename TupleLike, + typename FirstParamTy, + typename SecondParamTy, + size_t arity, + size_t arg_num = 0> +struct check_binary_functor_types_for_specialization { + constexpr static inline bool check() { + if constexpr (arity != 2) + return false; + if constexpr (arg_num == 0) { + using SelectedType = std::tuple_element_t; + if constexpr (std::is_same_v) + return check_binary_functor_types_for_specialization< + TupleLike, + FirstParamTy, + SecondParamTy, + arity, + arg_num + 1>::check(); + } else if constexpr (arg_num == 1) { + using SelectedType2 = std::tuple_element_t; + if constexpr (std::is_same_v) + return check_binary_functor_types_for_specialization< + TupleLike, + FirstParamTy, + SecondParamTy, + arity, + arg_num + 1>::check(); + } + return false; + } +}; + +// Bottom case: if we got this far, assume correct type matching except +// when there are no arguments (arity == 0). +template < + typename TupleLike, + typename FirstParamTy, + typename SecondParamTy, + size_t arity> +struct check_binary_functor_types_for_specialization< + TupleLike, + FirstParamTy, + SecondParamTy, + arity, + arity> { + constexpr static inline bool check() { + if constexpr (arity != 0) + return true; + return false; + } +}; + +template +struct check_binary_functor_types_for_specialization< + TupleLike, + FirstParamTy, + SecondParamTy, + 0, + 0> { + constexpr static inline bool check() { + return false; + } +}; + +// The following is a list of type specializations for vectorized_templated +// elementwise kernel. The three types refer to runtime types of the output +// tensor, first tensor argument, and the second tensor argument used for a +// binary functor. +constexpr std::array rt_binary_specializations = { + std::array( + {c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value}), + std::array( + {c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value}), + std::array( + {c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value}), + std::array( + {c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value}), + std::array( + {c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value}), + std::array( + {c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value})}; + +bool check_binary_rt_types_for_specialization(TensorIteratorBase& iter) { + if (iter.ninputs() != 2) + return false; + for (auto spec : rt_binary_specializations) + if (iter.dtype(0) == spec[0] && iter.input_dtype(0) == spec[1] && + iter.input_dtype(1) == spec[2]) + return true; + return false; +} + +template +struct type_specialized_kernel_launcher { + template < + typename func_t, + typename array_t, + typename inp_calc_t, + typename out_calc_t, + typename loader_t, + typename storer_t> + static void apply( + ScalarType ret_t, + ScalarType arg0_t, + ScalarType arg1_t, + int64_t numel, + func_t f, + array_t data, + inp_calc_t input_offset_calculator, + out_calc_t output_offset_calculator, + loader_t loader, + storer_t storer) { + if (ret_t == rt_binary_specializations[arg_index][0] && + arg0_t == rt_binary_specializations[arg_index][1] && + arg1_t == rt_binary_specializations[arg_index][2]) + launch_vectorized_templated_kernel< + func_t, + array_t, + inp_calc_t, + out_calc_t, + loader_t, + storer_t, + decltype(c10::impl::ScalarTypeToCPPType< + rt_binary_specializations[arg_index][0]>::t), + decltype(c10::impl::ScalarTypeToCPPType< + rt_binary_specializations[arg_index][1]>::t), + decltype(c10::impl::ScalarTypeToCPPType< + rt_binary_specializations[arg_index][2]>::t)>( + numel, + f, + data, + input_offset_calculator, + output_offset_calculator, + loader, + storer); + } +}; + +} // namespace +#endif + +template +void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) { + if (!needs_dynamic_casting::check(iter)) { + return gpu_kernel_impl_nocast(iter, f); + } + using traits = function_traits; + using arg0_t = typename traits::result_type; + constexpr int ntensors = traits::arity + 1; + + TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing()); + TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity); + TORCH_INTERNAL_ASSERT(iter.noutputs() == 1); + + std::array data; + for (int i = 0; i < ntensors; i++) { + data[i] = (char*)iter.data_ptr(i); + } + + int64_t numel = iter.numel(); + + bool contiguous = iter.is_contiguous(); + + if (contiguous) { +#ifdef USE_ROCM + // Attempt to call specialized vectorized elementwise kernel + // that enables interleaving. + if (check_binary_rt_types_for_specialization(iter) && + memory::can_vectorize_up_to(data) > 1) { + // constexpr to reduce the amount of kernels generated for + // vectorized templated elementwise and limit which functors are actually + // applied to the load and store at compile time. + using func_tuple = typename traits::ArgsTuple; + if constexpr ( + std::is_same_v && traits::arity == 2 && + check_binary_functor_types_for_specialization< + func_tuple, + float, + float, + traits::arity, + /*arg_num=*/0>::check()) { + // If we got here, we know we are in one of the specialized cases. We + // need to translate the runtime type to a statically known type. This + // is effectively hoisting to the host the switch over runtime type in + // the kernel in fetch_and_cast. Loader, storer, offset calculators are + // only needed for the reminder loop. + auto input_offset_calculator = TrivialOffsetCalculator(); + auto output_offset_calculator = TrivialOffsetCalculator<1>(); + auto loader = memory::LoadWithCast(iter); + auto storer = memory::StoreWithCast<1>(iter); + memory::detail::static_unroll< + type_specialized_kernel_launcher, + rt_binary_specializations.size()>:: + with_args( + iter.dtype(0), + iter.input_dtype(0), + iter.input_dtype(1), + numel, + f, + data, + input_offset_calculator, + output_offset_calculator, + loader, + storer); + return; + } + } + std::array dtypes; + auto inner_strides = iter.get_inner_strides(); + std::array strides; + for (int i = 0; i < ntensors; i++) { + dtypes[i] = iter.dtype(i); + strides[i] = inner_strides[i]; + } + constexpr int grp_sz = 128; + launch_legacy_kernel_manual_unroll(numel, [=] GPU_LAMBDA(int idx, bool unrl) { + if (unrl) { + void* out0 = data[0] + strides[0] * idx; + void* out1 = data[0] + strides[0] * (idx + grp_sz); + void* out2 = data[0] + strides[0] * (idx + grp_sz * 2); + void* out3 = data[0] + strides[0] * (idx + grp_sz * 3); + arg0_t result0 = invoke(f, &data[1], &strides[1], &dtypes[1], idx); + arg0_t result1 = invoke(f, &data[1], &strides[1], &dtypes[1], (idx + grp_sz)); + arg0_t result2 = invoke(f, &data[1], &strides[1], &dtypes[1], (idx + grp_sz * 2)); + arg0_t result3 = invoke(f, &data[1], &strides[1], &dtypes[1], (idx + grp_sz * 3)); + c10::cast_and_store(dtypes[0], out0, result0); + c10::cast_and_store(dtypes[0], out1, result1); + c10::cast_and_store(dtypes[0], out2, result2); + c10::cast_and_store(dtypes[0], out3, result3); + } else { + void* out = data[0] + strides[0] * idx; + arg0_t result = invoke(f, &data[1], &strides[1], &dtypes[1], idx); + c10::cast_and_store(dtypes[0], out, result); + } + }); +#else + auto loader = memory::LoadWithCast(iter); + auto storer = memory::StoreWithCast<1>(iter); + auto input_offset_calculator = TrivialOffsetCalculator(); + auto output_offset_calculator = TrivialOffsetCalculator<1>(); + launch_unrolled_kernel( + numel, + f, + data, + input_offset_calculator, + output_offset_calculator, + loader, + storer); +#endif + } else { + std::array dtypes; + for (int i = 0; i < ntensors; i++) { + dtypes[i] = iter.dtype(i); + } + auto offset_calc = ::make_offset_calculator(iter); + launch_legacy_kernel<128, 4>(numel, [=] GPU_LAMBDA(int idx) { + auto offsets = offset_calc.get(idx); + void* out = data[0] + offsets[0]; + arg0_t result = invoke(f, &data[1], &offsets[1], &dtypes[1], 1); + c10::cast_and_store(dtypes[0], out, result); + }); + } +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/CompositeRandomAccessor.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/CompositeRandomAccessor.h new file mode 100644 index 0000000000000000000000000000000000000000..f0dc24872e6157de677146db592fe0fed86d51b9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/CompositeRandomAccessor.h @@ -0,0 +1,35 @@ +#pragma once + +#include +#include + +namespace at { namespace native { + +struct TupleInfoCPU { + template + using tuple = thrust::tuple; + + template + static constexpr auto tie(Types&... args) noexcept { + return thrust::tie(args...); + } +}; + +template +using CompositeRandomAccessorCPU = + CompositeRandomAccessor; + +template +void swap( + references_holder rh1, + references_holder rh2 +) { + return thrust::swap(rh1.data(), rh2.data()); +} + +template +auto get(references_holder rh) -> decltype(thrust::get(rh.data())) { + return thrust::get(rh.data()); +} + +}} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/Copy.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/Copy.h new file mode 100644 index 0000000000000000000000000000000000000000..285b888db77863d664caf710095d9dca70eb50a6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/Copy.h @@ -0,0 +1,11 @@ +#pragma once + +namespace at { +struct TensorIteratorBase; + +namespace native { + +void direct_copy_kernel_cuda(TensorIteratorBase& iter); + +} +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/CuFFTPlanCache.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/CuFFTPlanCache.h new file mode 100644 index 0000000000000000000000000000000000000000..3c5c40d021a59c51c1e0000243061946de378421 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/CuFFTPlanCache.h @@ -0,0 +1,494 @@ +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace at::native::detail { + +// Enum representing the FFT type +enum class CuFFTTransformType : int8_t { + C2C, // Complex-to-complex + R2C, // Real-to-complex + C2R, // Complex-to-real +}; + +// This struct is used to let us easily compute hashes of the +// parameters. +// It will be the **key** to the plan cache. +struct CuFFTParams +{ + int64_t signal_ndim_; // between 1 and max_rank, i.e., 1 <= signal_ndim <= 3 + // These include additional batch dimension as well. + int64_t sizes_[max_rank + 1]; + int64_t input_strides_[max_rank + 1]; + int64_t output_strides_[max_rank + 1]; + CuFFTTransformType fft_type_; + ScalarType value_type_; + + CuFFTParams() = default; + + CuFFTParams(IntArrayRef in_strides, IntArrayRef out_strides, + IntArrayRef signal_sizes, CuFFTTransformType fft_type, ScalarType value_type) { + // Padding bits must be zeroed for hashing + memset(this, 0, sizeof(*this)); + signal_ndim_ = signal_sizes.size() - 1; + fft_type_ = fft_type; + value_type_ = value_type; + + TORCH_INTERNAL_ASSERT(in_strides.size() == signal_sizes.size()); + TORCH_INTERNAL_ASSERT(out_strides.size() == signal_sizes.size()); + TORCH_INTERNAL_ASSERT(1 <= signal_ndim_ && signal_ndim_ <= max_rank); + + std::copy(signal_sizes.cbegin(), signal_sizes.cend(), sizes_); + std::copy(in_strides.cbegin(), in_strides.cend(), input_strides_); + std::copy(out_strides.cbegin(), out_strides.cend(), output_strides_); + } +}; + +static_assert(std::is_trivial_v ); + +// Returns true if the transform type has complex input +inline bool cufft_complex_input(CuFFTTransformType type) { + switch (type) { + case CuFFTTransformType::C2C: + case CuFFTTransformType::C2R: + return true; + + case CuFFTTransformType::R2C: + return false; + } + TORCH_INTERNAL_ASSERT(false); +} + +// Returns true if the transform type has complex output +inline bool cufft_complex_output(CuFFTTransformType type) { + switch (type) { + case CuFFTTransformType::C2C: + case CuFFTTransformType::R2C: + return true; + + case CuFFTTransformType::C2R: + return false; + } + TORCH_INTERNAL_ASSERT(false); +} + +// Create transform type enum from bools representing if input and output are complex +inline CuFFTTransformType GetCuFFTTransformType(bool complex_input, bool complex_output) { + if (complex_input && complex_output) { + return CuFFTTransformType::C2C; + } else if (complex_input && !complex_output) { + return CuFFTTransformType::C2R; + } else if (!complex_input && complex_output) { + return CuFFTTransformType::R2C; + } + TORCH_INTERNAL_ASSERT(false, "Real to real FFTs are not supported"); +} + + +class CuFFTHandle { + ::cufftHandle handle_; +public: + + CuFFTHandle() { + CUFFT_CHECK(cufftCreate(&handle_)); + } + + ::cufftHandle & get() { return handle_; } + const ::cufftHandle & get() const { return handle_; } + + ~CuFFTHandle() { +// Not using fftDestroy() for rocFFT to work around double freeing of handles +#if !defined(USE_ROCM) + cufftDestroy(handle_); +#endif + } +}; + +__forceinline__ +static bool is_pow_of_two(int64_t x) { + return (x & (x - 1)) == 0; +} + +using cufft_size_type = long long int; + +using CuFFTDimVector = c10::SmallVector; + +// Struct representing a tensor in CuFFT's data layout for planning transforms +// See NOTE [ cuFFT Embedded Strides ]. +struct CuFFTDataLayout { + CuFFTDimVector embed; + cufft_size_type stride, dist; + bool must_clone, simple; +}; + +// Returns a cufft embedding for a contiguous signal of the given size. +// e.g. if the input is cloned, this will be the resulting data layout +// See NOTE [ cuFFT Embedded Strides ]. +inline CuFFTDataLayout cufft_simple_embed(IntArrayRef sizes, bool onesided) { + CuFFTDataLayout layout; + layout.simple = true; + layout.must_clone = false; + layout.embed.assign(sizes.cbegin() + 1, sizes.cend()); + if (onesided) { + layout.embed.back() = sizes.back() / 2 + 1; + } + layout.stride = 1; + layout.dist = 1; + for (const auto& len : layout.embed) { + layout.dist *= len; + } + return layout; +} + +// Convert strides to a CuFFT embedded representation. +// If strides cannot be embedded, returns a simple layout and sets must_clone flag +// See NOTE [ cuFFT Embedded Strides ]. +inline CuFFTDataLayout as_cufft_embed(IntArrayRef strides, IntArrayRef sizes, bool onesided) { + const auto signal_ndim = strides.size() - 1; + CuFFTDataLayout layout; + auto last_stride = strides[signal_ndim]; + layout.must_clone = (last_stride <= 0); + + const auto last_dim_size = onesided ? + sizes[signal_ndim] / 2 + 1 : sizes[signal_ndim]; + const auto signal_numel = c10::multiply_integers(sizes.slice(1, sizes.size() - 2)) * last_dim_size; + + // Zero stides are not allowed, even if the batch size is one. + // If that happens just set a dummy case + if (sizes[0] == 1) { + layout.dist = signal_numel; + } else if (strides[0] == 0) { + layout.must_clone = true; + } else { + layout.dist = strides[0]; + } + + // Calculate the embedding shape, or set must_clone if the strides cannot be embedded + layout.embed.resize(signal_ndim); + for (auto i = signal_ndim - 1; !layout.must_clone && i > 0; i--) { + auto stride = strides[i]; + if (sizes[i] == 1) { + layout.embed[i] = 1; + } else if (stride > 0 && stride % last_stride == 0) { + layout.embed[i] = stride / last_stride; + last_stride = stride; + } else { + layout.must_clone = true; + } + } + + if (layout.must_clone) { + // If the input needs to be cloned, assume it will be contiguous + layout = cufft_simple_embed(sizes, onesided); + layout.must_clone = true; + } else { + layout.embed[0] = sizes[1]; + layout.stride = strides[signal_ndim]; + // Determine if layout represents a simple embedding (contiguous data) + layout.simple = [&] { + for (const auto i : c10::irange(1, signal_ndim - 1)) { + if (layout.embed[i] != sizes[i + 1]) { + return false; + } + } + + return (layout.stride == 1 && layout.dist == signal_numel && + layout.embed.back() == last_dim_size); + }(); + } + return layout; +} + +// This class contains all the information needed to execute a cuFFT plan: +// 1. the plan +// 2. whether to clone input before executing the plan +// 3. the workspace size needed +// +// This class will be the **value** in the plan cache. +// It **owns** the raw plan via a unique_ptr. +class CuFFTConfig { +public: + + // Only move semantics is enought for this class. Although we already use + // unique_ptr for the plan, still remove copy constructor and assignment op so + // we don't accidentally copy and take perf hit. + CuFFTConfig(const CuFFTConfig&) = delete; + CuFFTConfig& operator=(CuFFTConfig const&) = delete; + + explicit CuFFTConfig(const CuFFTParams& params): + CuFFTConfig( + IntArrayRef(params.input_strides_, params.signal_ndim_ + 1), + IntArrayRef(params.output_strides_, params.signal_ndim_ + 1), + IntArrayRef(params.sizes_, params.signal_ndim_ + 1), + params.fft_type_, + params.value_type_) {} + + // For complex types, strides are in units of 2 * element_size(dtype) + // sizes are for the full signal, including batch size and always two-sided + CuFFTConfig(IntArrayRef in_strides, IntArrayRef out_strides, + IntArrayRef sizes, CuFFTTransformType fft_type, ScalarType dtype): + fft_type_(fft_type), value_type_(dtype) { + + // signal sizes (excluding batch dim) + CuFFTDimVector signal_sizes(sizes.begin() + 1, sizes.end()); + + // input batch size + const int64_t batch = sizes[0]; + const int64_t signal_ndim = sizes.size() - 1; + + // Since cuFFT has limited non-unit stride support and various constraints, we + // use a flag to keep track throughout this function to see if we need to + // input = input.clone(); + +#if defined(USE_ROCM) + // clone input to avoid issues with hipfft clobering the input and failing tests + clone_input = true; +#else + clone_input = false; +#endif + + // For half, base strides on the real part of real-to-complex and + // complex-to-real transforms are not supported. Since our output is always + // contiguous, only need to check real-to-complex case. + if (dtype == ScalarType::Half) { + // cuFFT on half requires compute capability of at least SM_53 + auto dev_prop = at::cuda::getCurrentDeviceProperties(); + TORCH_CHECK(dev_prop->major >= 5 && !(dev_prop->major == 5 && dev_prop->minor < 3), + "cuFFT doesn't support signals of half type with compute " + "capability less than SM_53, but the device containing input half " + "tensor only has SM_", dev_prop->major, dev_prop->minor); + for (const auto i : c10::irange(signal_ndim)) { + TORCH_CHECK(is_pow_of_two(sizes[i + 1]), + "cuFFT only supports dimensions whose sizes are powers of two when" + " computing in half precision, but got a signal size of", + sizes.slice(1)); + } + clone_input |= in_strides.back() != 1; + } + + CuFFTDataLayout in_layout; + if (clone_input) { + in_layout = cufft_simple_embed(sizes, fft_type == CuFFTTransformType::C2R); + } else { + in_layout = as_cufft_embed(in_strides, sizes, fft_type == CuFFTTransformType::C2R); + } + auto out_layout = as_cufft_embed(out_strides, sizes, fft_type == CuFFTTransformType::R2C); + TORCH_INTERNAL_ASSERT(!out_layout.must_clone, "Out strides cannot be represented as CuFFT embedding"); + clone_input |= in_layout.must_clone; + + // Check if we can take advantage of simple data layout. + // + // See NOTE [ cuFFT Embedded Strides ] in native/cuda/SpectralOps.cu. + + const bool simple_layout = in_layout.simple && out_layout.simple; + cudaDataType itype, otype, exec_type; + const auto complex_input = cufft_complex_input(fft_type); + const auto complex_output = cufft_complex_output(fft_type); + if (dtype == ScalarType::Float) { + itype = complex_input ? CUDA_C_32F : CUDA_R_32F; + otype = complex_output ? CUDA_C_32F : CUDA_R_32F; + exec_type = CUDA_C_32F; + } else if (dtype == ScalarType::Double) { + itype = complex_input ? CUDA_C_64F : CUDA_R_64F; + otype = complex_output ? CUDA_C_64F : CUDA_R_64F; + exec_type = CUDA_C_64F; + } else if (dtype == ScalarType::Half) { + itype = complex_input ? CUDA_C_16F : CUDA_R_16F; + otype = complex_output ? CUDA_C_16F : CUDA_R_16F; + exec_type = CUDA_C_16F; + } else { + TORCH_CHECK(false, "cuFFT doesn't support tensor of type: ", dtype); + } + + // disable auto allocation of workspace to use THC allocator + CUFFT_CHECK(cufftSetAutoAllocation(plan(), /* autoAllocate */ 0)); + + size_t ws_size_t; + + // make plan + if (simple_layout) { + // If with unit-stride, we tell cuFFT by setting inembed == onembed == NULL. + // In such case, cuFFT ignores istride, ostride, idist, and odist + // by assuming istride = ostride = 1. + // + // See NOTE [ cuFFT Embedded Strides ] in native/cuda/SpectralOps.cu. + CUFFT_CHECK(cufftXtMakePlanMany(plan(), signal_ndim, signal_sizes.data(), + /* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1, itype, + /* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, otype, + batch, &ws_size_t, exec_type)); + } else { + CUFFT_CHECK(cufftXtMakePlanMany(plan(), signal_ndim, signal_sizes.data(), + in_layout.embed.data(), in_layout.stride, in_layout.dist, itype, + out_layout.embed.data(), out_layout.stride, out_layout.dist, otype, + batch, &ws_size_t, exec_type)); + } + ws_size = static_cast(ws_size_t); + } + + const cufftHandle &plan() const { return plan_ptr.get(); } + + CuFFTTransformType transform_type() const { return fft_type_; } + ScalarType data_type() const { return value_type_; } + bool should_clone_input() const { return clone_input; } + int64_t workspace_size() const { return ws_size; } + +private: + CuFFTHandle plan_ptr; + bool clone_input; + int64_t ws_size; + CuFFTTransformType fft_type_; + ScalarType value_type_; +}; + +#if defined(USE_ROCM) + // Note that the max plan number for CUDA version < 10 has to be 1023 + // due to a bug that fails on the 1024th plan + constexpr int64_t CUFFT_MAX_PLAN_NUM = 1023; + constexpr int64_t CUFFT_DEFAULT_CACHE_SIZE = CUFFT_MAX_PLAN_NUM; +#else + constexpr int64_t CUFFT_MAX_PLAN_NUM = std::numeric_limits::max(); + // The default max cache size chosen for CUDA version > 10 is arbitrary. + // This number puts a limit on how big of a plan cache should we maintain by + // default. Users can always configure it via cufft_set_plan_cache_max_size. + constexpr int64_t CUFFT_DEFAULT_CACHE_SIZE = 4096; +#endif +static_assert(0 <= CUFFT_MAX_PLAN_NUM && CUFFT_MAX_PLAN_NUM <= std::numeric_limits::max(), + "CUFFT_MAX_PLAN_NUM not in size_t range"); +static_assert(CUFFT_DEFAULT_CACHE_SIZE >= 0 && CUFFT_DEFAULT_CACHE_SIZE <= CUFFT_MAX_PLAN_NUM, + "CUFFT_DEFAULT_CACHE_SIZE not in [0, CUFFT_MAX_PLAN_NUM] range"); + +// This cache assumes that the mapping from key to value never changes. +// This is **NOT** thread-safe. Please use a mutex when using it **AND** the +// value returned from try_emplace_value. +// The contract of using this cache is that try_emplace_value should only be +// used when the max_size is positive. +class CuFFTParamsLRUCache { +public: + using kv_t = typename std::pair; + using map_t = typename std::unordered_map, + typename std::list::iterator, + ParamsHash, + ParamsEqual>; + using map_kkv_iter_t = typename map_t::iterator; + + + CuFFTParamsLRUCache() : CuFFTParamsLRUCache(CUFFT_DEFAULT_CACHE_SIZE) {} + + CuFFTParamsLRUCache(int64_t max_size) { + _set_max_size(max_size); + } + + CuFFTParamsLRUCache(CuFFTParamsLRUCache&& other) noexcept : + _usage_list(std::move(other._usage_list)), + _cache_map(std::move(other._cache_map)), + _max_size(other._max_size) {} + + CuFFTParamsLRUCache& operator=(CuFFTParamsLRUCache&& other) noexcept { + _usage_list = std::move(other._usage_list); + _cache_map = std::move(other._cache_map); + _max_size = other._max_size; + return *this; + } + + // If key is in this cache, return the cached config. Otherwise, emplace the + // config in this cache and return it. + // Return const reference because CuFFTConfig shouldn't be tampered with once + // created. + const CuFFTConfig &lookup(CuFFTParams params) { + AT_ASSERT(_max_size > 0); + + map_kkv_iter_t map_it = _cache_map.find(params); + // Hit, put to list front + if (map_it != _cache_map.end()) { + _usage_list.splice(_usage_list.begin(), _usage_list, map_it->second); + return map_it->second->second; + } + + // Miss + // remove if needed + if (_usage_list.size() >= _max_size) { + auto last = _usage_list.end(); + last--; + _cache_map.erase(last->first); + _usage_list.pop_back(); + } + + // construct new plan at list front, then insert into _cache_map + _usage_list.emplace_front(std::piecewise_construct, + std::forward_as_tuple(params), + std::forward_as_tuple(params)); + auto kv_it = _usage_list.begin(); + _cache_map.emplace(std::piecewise_construct, + std::forward_as_tuple(kv_it->first), + std::forward_as_tuple(kv_it)); + return kv_it->second; + } + + void clear() { + _cache_map.clear(); + _usage_list.clear(); + } + + void resize(int64_t new_size) { + _set_max_size(new_size); + auto cur_size = _usage_list.size(); + if (cur_size > _max_size) { + auto delete_it = _usage_list.end(); + for (size_t i = 0; i < cur_size - _max_size; i++) { + delete_it--; + _cache_map.erase(delete_it->first); + } + _usage_list.erase(delete_it, _usage_list.end()); + } + } + + size_t size() const { return _cache_map.size(); } + + size_t max_size() const noexcept { return _max_size; } + + std::mutex mutex; + +private: + // Only sets size and does value check. Does not resize the data structures. + void _set_max_size(int64_t new_size) { + // We check that 0 <= new_size <= CUFFT_MAX_PLAN_NUM here. Since + // CUFFT_MAX_PLAN_NUM is of type size_t, we need to do non-negativity check + // first. + TORCH_CHECK(new_size >= 0, + "cuFFT plan cache size must be non-negative, but got ", new_size); + TORCH_CHECK(new_size <= CUFFT_MAX_PLAN_NUM, + "cuFFT plan cache size can not be larger than ", CUFFT_MAX_PLAN_NUM, ", but got ", new_size); + _max_size = static_cast(new_size); + } + + std::list _usage_list; + map_t _cache_map; + size_t _max_size; +}; + +// Since ATen is separated into CPU build and CUDA build, we need a way to call +// these functions only when CUDA is loaded. We use CUDA hooks for this purpose +// (at cuda/detail/CUDAHooks.cpp), and call the hooked functions from the actual +// native function counterparts (at native/SpectralOps.cpp), i.e., +// _cufft_get_plan_cache_max_size, _cufft_set_plan_cache_max_size +// _cufft_get_plan_cache_size, and _cufft_clear_plan_cache. +int64_t cufft_get_plan_cache_max_size_impl(DeviceIndex device_index); +void cufft_set_plan_cache_max_size_impl(DeviceIndex device_index, int64_t max_size); +int64_t cufft_get_plan_cache_size_impl(DeviceIndex device_index); +void cufft_clear_plan_cache_impl(DeviceIndex device_index); + +} // namespace at::native::detail diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/CuFFTUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/CuFFTUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..45e523d04b84aa26ad03d604d226e662e0f6982c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/CuFFTUtils.h @@ -0,0 +1,73 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace at { namespace native { + +// This means that max dim is 3 + 2 = 5 with batch dimension and possible +// complex dimension +constexpr int max_rank = 3; + +static inline std::string _cudaGetErrorEnum(cufftResult error) +{ + switch (error) + { + case CUFFT_SUCCESS: + return "CUFFT_SUCCESS"; + case CUFFT_INVALID_PLAN: + return "CUFFT_INVALID_PLAN"; + case CUFFT_ALLOC_FAILED: + return "CUFFT_ALLOC_FAILED"; + case CUFFT_INVALID_TYPE: + return "CUFFT_INVALID_TYPE"; + case CUFFT_INVALID_VALUE: + return "CUFFT_INVALID_VALUE"; + case CUFFT_INTERNAL_ERROR: + return "CUFFT_INTERNAL_ERROR"; + case CUFFT_EXEC_FAILED: + return "CUFFT_EXEC_FAILED"; + case CUFFT_SETUP_FAILED: + return "CUFFT_SETUP_FAILED"; + case CUFFT_INVALID_SIZE: + return "CUFFT_INVALID_SIZE"; + case CUFFT_UNALIGNED_DATA: + return "CUFFT_UNALIGNED_DATA"; + case CUFFT_INCOMPLETE_PARAMETER_LIST: + return "CUFFT_INCOMPLETE_PARAMETER_LIST"; + case CUFFT_INVALID_DEVICE: + return "CUFFT_INVALID_DEVICE"; + case CUFFT_PARSE_ERROR: + return "CUFFT_PARSE_ERROR"; + case CUFFT_NO_WORKSPACE: + return "CUFFT_NO_WORKSPACE"; + case CUFFT_NOT_IMPLEMENTED: + return "CUFFT_NOT_IMPLEMENTED"; +#if !defined(USE_ROCM) + case CUFFT_LICENSE_ERROR: + return "CUFFT_LICENSE_ERROR"; +#endif + case CUFFT_NOT_SUPPORTED: + return "CUFFT_NOT_SUPPORTED"; + default: + std::ostringstream ss; + ss << "unknown error " << error; + return ss.str(); + } +} + +static inline void CUFFT_CHECK(cufftResult error) +{ + if (error != CUFFT_SUCCESS) { + std::ostringstream ss; + ss << "cuFFT error: " << _cudaGetErrorEnum(error); + TORCH_CHECK(false, ss.str()); + } +} + +}} // at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/DeviceSqrt.cuh b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/DeviceSqrt.cuh new file mode 100644 index 0000000000000000000000000000000000000000..6dad8efbca88aa7144a6ef08646940402aea3521 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/DeviceSqrt.cuh @@ -0,0 +1,25 @@ +#pragma once + +namespace at::native { +#if defined(USE_ROCM) +// take these out when ROCm implements std:: math functions +#include +template +static __forceinline__ __device__ scalar_t device_sqrt(scalar_t val); + +template <> +__forceinline__ __device__ float device_sqrt(float val) { + return ::sqrtf(val); +} + +template <> +__forceinline__ __device__ double device_sqrt(double val) { + return ::sqrt(val); +} +#else +template +__forceinline__ __device__ double device_sqrt(scalar_t val) { + return std::sqrt(val); +} +#endif +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/DistributionTemplates.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/DistributionTemplates.h new file mode 100644 index 0000000000000000000000000000000000000000..729a9f12c712435a5c2cbba6a9ba3d4e126beffe --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/DistributionTemplates.h @@ -0,0 +1,697 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace { + +// launch bounds used for kernels utilizing TensorIterator +const uint32_t block_size_bound = 256; +const uint32_t grid_size_bound = 4; +// At the time of writing, there is no curand_* call that increments the offset by more than 4. +// See: https://docs.nvidia.com/cuda/archive/11.8.0/curand/group__DEVICE.html +const uint32_t max_generator_offsets_per_curand_call = 4; + +// utility function that calculates proper philox_offset +// for distributions utilizing TensorIterator. For distributions using +// TensorIterator, we are using a grid-stride loop with each +// thread yielding one element per thread. For the edge of the grid-stride +// loop, if the tensor size is large, the unroll loop will kick in and the float4 +// from curand4 will start getting utilized (for common tensor sizes, we end up +// using rand.x from each thread). The philox_offset calculation was changed to +// (number of elements per thread * maximum generator increment per "curand_*" call), which makes +// sure that philox offset increment is not less than the number of randoms used +// in each thread. +std::tuple calc_execution_policy(const int64_t total_elements, const uint32_t unroll_factor) { + const uint64_t numel = static_cast(total_elements); + const uint32_t block_size = block_size_bound; + dim3 dim_block(block_size); + dim3 grid((numel + block_size - 1) / block_size); + uint32_t blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size; + grid.x = std::min( + static_cast(at::cuda::getCurrentDeviceProperties()->multiProcessorCount) * blocks_per_sm, + grid.x); + //number of times random will be generated per thread, to offset philox counter in thc random state + uint64_t counter_offset = ((numel - 1) / (block_size * grid.x * unroll_factor) + 1) * max_generator_offsets_per_curand_call; + return std::make_tuple(counter_offset, grid, dim_block); +} + +// grid stride loop kernel for distributions +template +C10_LAUNCH_BOUNDS_2(block_size_bound, grid_size_bound) +__global__ void distribution_elementwise_grid_stride_kernel(int64_t numel, + PhiloxCudaState philox_args, + const dist_t dist_func, + const transform_t transform_func) { + auto [seed, offset] = at::cuda::philox::unpack(philox_args); + int64_t idx = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x; + curandStatePhilox4_32_10_t state; + curand_init(seed, idx, offset, &state); + + int64_t rounded_size = ((numel - 1)/(blockDim.x * gridDim.x * unroll_factor)+1) * + blockDim.x * gridDim.x * unroll_factor; + for(int64_t linear_index = idx; linear_index < rounded_size; linear_index += blockDim.x * gridDim.x * unroll_factor) { + auto rand = dist_func(&state); + #pragma unroll + for (int ii = 0; ii < unroll_factor; ii++) { + int64_t li = linear_index + blockDim.x * gridDim.x * ii; + if (li < numel) { + transform_func(li, static_cast((&rand.x)[ii])); + } + } + __syncthreads(); + } +} + +/** + * distribution_nullary_kernel is analogous to gpu_kernel in + * ATen/native/cuda/Loops.cuh. Like gpu_kernel, it uses + * TensorIterator to launch a kernel. However, the differences are + * - it launches a grid-stride loop based kernel. The kernel is not + * generic like elementwise_kernel in Loops.cuh and is specialized + * for the distribution kernels here. + * - For big size tensors, we can launch multiple kernels recursively + * (i.e. if (!iter.can_use_32bit_indexing())) and hence, the philox + * offset calculation is done in this function. + * + * FIXME: Can we specialize elementwise_kernel and launch_kernel in Loops.cuh + * to have grid-stride loop kernel and then use that to launch our distribution + * kernels? Note that we need a grid-stride loop kernel because, we found by testing + * that it achieves peak effective bandwidth. + */ +template +void distribution_nullary_kernel(at::TensorIteratorBase& iter, + RNG gen, + const dist_t& dist_func, + const transform_t transform_func) { + const int unroll_factor = sizeof(dist_func_return_t) / sizeof(accscalar_t); + TORCH_CHECK(unroll_factor >= 1, "unroll_factor must be >= 1."); + int64_t numel = iter.numel(); + if (numel == 0) { + return; + } + + auto [counter_offset, grid, block] = calc_execution_policy(numel, unroll_factor); + PhiloxCudaState rng_engine_inputs; + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + rng_engine_inputs = gen->philox_cuda_state(counter_offset); + } + + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + distribution_nullary_kernel(sub_iter, + gen, dist_func, transform_func); + } + return; + } + + char* out_data = (char*)iter.data_ptr(0); + + auto stream = at::cuda::getCurrentCUDAStream(); + if (iter.is_trivial_1d()) { + auto strides = iter.get_inner_strides(); + int stride0 = strides[0]; + distribution_elementwise_grid_stride_kernel<<>>( + numel, + rng_engine_inputs, + dist_func, + [=]__device__(int idx, accscalar_t rand) { + scalar_t* out = (scalar_t*)&out_data[stride0 * idx]; + *out = transform_func(rand); + } + ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + auto offset_calc = make_offset_calculator<1>(iter); + distribution_elementwise_grid_stride_kernel<<>>( + numel, + rng_engine_inputs, + dist_func, + [=]__device__(int idx, accscalar_t rand) { + auto offsets = offset_calc.get(idx); + scalar_t* out = (scalar_t*)&out_data[offsets[0]]; + *out = transform_func(rand); + } + ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } +} + +// Binary kernel +template +__global__ void distribution_binary_elementwise_kernel( + int numel, + func_t f, + PhiloxCudaState philox_args, + typename function_traits::result_type *output_data, + const typename function_traits::template arg<1>::type *input_data_1, + const typename function_traits::template arg<2>::type *input_data_2, + inp_offset_calc_t inp_calc, + out_offset_calc_t out_calc) { + auto seeds = at::cuda::philox::unpack(philox_args); + + using input_t_1 = typename function_traits::template arg<1>::type; + using input_t_2 = typename function_traits::template arg<2>::type; + + input_t_1 inputs_1[thread_work_size()]; + input_t_2 inputs_2[thread_work_size()]; + + int base_index = block_work_size() * blockIdx.x; + int remaining = std::min(numel - base_index, block_work_size()); + + curandStatePhilox4_32_10_t state; + curand_init(std::get<0>(seeds), + blockIdx.x * blockDim.x + threadIdx.x, + std::get<1>(seeds), + &state); + + // load data into registers + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < thread_work_size(); i++) { + if (thread_idx >= remaining) { + break; + } + int input_idx = thread_idx + base_index; + auto offsets = inp_calc.get(input_idx); + inputs_1[i] = input_data_1[offsets[0]]; + inputs_2[i] = input_data_2[offsets[1]]; + + thread_idx += num_threads(); + } + + // compute and store + thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < thread_work_size(); i++) { + if (thread_idx >= remaining) { + break; + } + int input_idx = thread_idx + base_index; + auto offsets = out_calc.get(input_idx); + output_data[offsets[0]] = f(state, inputs_1[i], inputs_2[i]); + thread_idx += num_threads(); + } +} + +template +void distribution_binary_kernel(TensorIteratorBase &iter, PhiloxCudaState philox_args, const func_t &f) { + static_assert(std::is_same_v::template arg<0>::type, curandStatePhilox4_32_10_t&>, "the first argument of functor must be curandStatePhilox4_32_10_t"); + using input_t_1 = typename function_traits::template arg<1>::type; + using input_t_2 = typename function_traits::template arg<2>::type; + using output_t = typename function_traits::result_type; + + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + distribution_binary_kernel(sub_iter, philox_args, f); + } + return; + } + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(iter.can_use_32bit_indexing()); + + int64_t numel = iter.numel(); + if (numel == 0) { + return; + } + + output_t *output_data = static_cast(iter.data_ptr(0)); + const input_t_1 *input_data_1 = static_cast(iter.data_ptr(1)); + const input_t_2 *input_data_2 = static_cast(iter.data_ptr(2)); + + int64_t grid = (numel + block_work_size() - 1) / block_work_size(); + auto stream = at::cuda::getCurrentCUDAStream(); + + if (iter.is_contiguous()) { + distribution_binary_elementwise_kernel<<>>( + numel, f, philox_args, output_data, input_data_1, input_data_2, + TrivialOffsetCalculator<2>(), TrivialOffsetCalculator<1>()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + distribution_binary_elementwise_kernel<<>>( + numel, f, philox_args, output_data, input_data_1, input_data_2, + make_input_offset_calculator<2>(iter), make_output_offset_calculator(iter)); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } +} + +} // namespace +}} // namespace at::native + + +namespace at { +namespace native { +namespace templates { +namespace cuda { + +// ==================================================== Random ======================================================== + +template +void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG gen) { +#ifdef FBCODE_CAFFE2 + AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cuda", AT_WRAP([&] { + if (( + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) && range >= 1ULL << 32) + { + // define lambda to mod with range and add base + auto random_func = [range, base] __device__ (uint64_t rand) { + return transformation::uniform_int_from_to(rand, range, base); + }; + distribution_nullary_kernel(iter, + gen, + [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 { + ulonglong2 ret; + uint4 rand_val = curand4(state); + ret.x = (static_cast(rand_val.x) << 32) | rand_val.y; + ret.y = (static_cast(rand_val.z) << 32) | rand_val.w; + return ret; + }, + random_func); + } else { + auto random_func = [range, base] __device__ (uint32_t rand) { + return transformation::uniform_int_from_to(rand, range, base); + }; + distribution_nullary_kernel(iter, + gen, + [] __device__ (curandStatePhilox4_32_10_t* state) -> uint4 { + return curand4(state); + }, + random_func); + } + }), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); +#else + AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cuda", AT_WRAP([&] { + if (range >= 1ULL << 28) // allow approx 5% skew in uniform int generation using % + { + // define lambda to mod with range and add base + auto random_func = [range, base] __device__ (uint64_t rand) { + return transformation::uniform_int_from_to(rand, range, base); + }; + distribution_nullary_kernel(iter, + gen, + [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 { + ulonglong2 ret; + uint4 rand_val = curand4(state); + ret.x = (static_cast(rand_val.x) << 32) | rand_val.y; + ret.y = (static_cast(rand_val.z) << 32) | rand_val.w; + return ret; + }, + random_func); + } else { + auto random_func = [range, base] __device__ (uint32_t rand) { + return transformation::uniform_int_from_to(rand, range, base); + }; + distribution_nullary_kernel(iter, + gen, + [] __device__ (curandStatePhilox4_32_10_t* state) -> uint4 { + return curand4(state); + }, + random_func); + } + }), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); +#endif +} + +// This is the special kernel to handle single specific case: +// from(inclusive) = std::numeric_limits::lowest() +// to(exclusive) = None (= std::numeric_limits::max() + 1) +template +void random_full_64_bits_range_kernel(TensorIteratorBase& iter, RNG gen) { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::BFloat16, iter.dtype(), "random_full_64_bits_range_kernel_cuda", [&] { + if (std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) { + auto random_func = [] __device__ (uint64_t rand) { + return transformation::uniform_int_full_range(rand); + }; + distribution_nullary_kernel(iter, + gen, + [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 { + ulonglong2 ret; + uint4 rand_val = curand4(state); + ret.x = (static_cast(rand_val.x) << 32) | rand_val.y; + ret.y = (static_cast(rand_val.z) << 32) | rand_val.w; + return ret; + }, + random_func); + } else { + TORCH_CHECK(false, "random_full_64_bits_range_kernel_cuda handles only int64, double, float and bfloat16"); + } + }); +} + +template +struct RandomFromToKernel { + void operator()(TensorIteratorBase& iter, uint64_t range, int64_t base, std::optional gen) { + random_from_to_kernel(iter, range, base, check_generator(gen)); + } + void operator()(TensorIteratorBase& iter, std::optional gen) { + random_full_64_bits_range_kernel(iter, check_generator(gen)); + } +}; + +template +void random_kernel(TensorIteratorBase& iter, RNG gen) { + AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "random_kernel_cuda", [&] { + if (std::is_same_v || std::is_same_v) { + auto random_func = [] __device__ (uint64_t rand) { + return transformation::uniform_int(rand); + }; + distribution_nullary_kernel(iter, gen, + [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 { + ulonglong2 ret; + uint4 rand_val = curand4(state); + ret.x = (static_cast(rand_val.x) << 32) | rand_val.y; + ret.y = (static_cast(rand_val.z) << 32) | rand_val.w; + return ret; + }, + random_func); + } else { + auto random_func = [] __device__ (uint32_t rand) { + return transformation::uniform_int(rand); + }; + distribution_nullary_kernel(iter, + gen, + [] __device__ (curandStatePhilox4_32_10_t* state) -> uint4 { + return curand4(state); + }, + random_func); + } + }); +} + +template +struct RandomKernel { + void operator()(TensorIteratorBase& iter, RNG gen) { + random_kernel(iter, gen); + } +}; + +// ==================================================================================================================== + +template +void uniform_and_transform(TensorIteratorBase& iter, RNG gen, transform_t transform) { + if (std::is_same_v) { + distribution_nullary_kernel(iter, + gen, + [] __device__ (curandStatePhilox4_32_10_t* state) -> double2 { return curand_uniform2_double(state); }, + transform); + } else { + distribution_nullary_kernel(iter, + gen, + [] __device__ (curandStatePhilox4_32_10_t* state) -> float4 { return curand_uniform4(state); }, + transform); + } +} + +template +void normal_and_transform(TensorIteratorBase& iter, RNG gen, transform_t transform) { + if (std::is_same_v) { + distribution_nullary_kernel(iter, + gen, + [] __device__ (curandStatePhilox4_32_10_t* state) -> double2 { return curand_normal2_double(state); }, + transform); + } else { + distribution_nullary_kernel(iter, + gen, + [] __device__ (curandStatePhilox4_32_10_t* state) -> float4 { return curand_normal4(state); }, + transform); + } +} + +// ==================================================== Normal ======================================================== + +template +void normal_kernel(const TensorBase &self, double mean_, double std_, RNG gen) { + auto iter = TensorIterator::borrowing_nullary_op(self); + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "normal_kernel_cuda", [&] { + using accscalar_t = at::acc_type; + auto mean = static_cast(mean_); + auto std = static_cast(std_); + // define lambda to multiply std and add mean + auto normal_func = [mean, std] __device__ (accscalar_t rand) { + return static_cast(transformation::normal(rand, mean, std)); + }; + normal_and_transform(iter, gen, normal_func); + }); +} + +template +struct NormalKernel { + void operator()(const TensorBase &self, double mean, double std, std::optional gen) { + normal_kernel(self, mean, std, check_generator(gen)); + } +}; + +// ==================================================== Uniform ======================================================== + +template +void uniform_kernel(TensorIteratorBase& iter, double from_, double to_, RNG gen) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "uniform_kernel_cuda", [&] { + auto from = static_cast(from_); + auto to = static_cast(to_); + using opmath_t = at::opmath_type; + auto range = static_cast(to-from); + // define lambda to reverse bounds, multiply 'range' and add 'from_' + auto uniform_func = [range, from, to] __device__ (opmath_t rand) { + // Compute output value before reversing the bounds + // BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/96947 + auto value = static_cast(rand * range + from); + // reverse the bounds of curand4 from (0, 1] to [0, 1) + // Note that this method is from legacy THCTensorRandom and is likely to give + // you more 0-s, since, the probability of gettings 1-s is higher than 0-s and + // by reversing the bounds, we are flipping the probabilities of 1-s and 0-s. + // BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/16706 + auto reverse_bound_value = value == to ? from : value; + return reverse_bound_value; + }; + uniform_and_transform(iter, gen, uniform_func); + }); +} + +template +struct UniformKernel { + void operator()(TensorIteratorBase& iter, double from, double to, std::optional gen) { + uniform_kernel(iter, from, to, check_generator(gen)); + } +}; + +// ================================================== LogNormal ======================================================= + +template +void log_normal_kernel(TensorIteratorBase& iter, double mean_, double std_, RNG gen) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "log_normal_cuda", [&] { + using accscalar_t = at::acc_type; + auto mean = static_cast(mean_); + auto std = static_cast(std_); + // define lambda for log_normal transformation + auto log_normal_func = [mean, std] __device__ (accscalar_t rand) { + return static_cast(transformation::log_normal(transformation::normal(rand, mean, std))); + }; + normal_and_transform(iter, gen, log_normal_func); + }); +} + +template +struct LogNormalKernel { + void operator()(TensorIteratorBase& iter, double mean, double std, std::optional gen) { + log_normal_kernel(iter, mean, std, check_generator(gen)); + } +}; + +// =================================================== Geometric ====================================================== + +template +void geometric_kernel(TensorIteratorBase& iter, double p, RNG gen) { + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "geometric_cuda", [&] { + using accscalar_t = at::DiscreteDistributionType::type; + // define lambda for geometric transformation + auto geometric_func = [p] __device__ (accscalar_t rand) { + return static_cast(transformation::geometric(rand, p)); + }; + uniform_and_transform(iter, gen, geometric_func); + }); +} + +template +struct GeometricKernel { + void operator()(TensorIteratorBase& iter, double p, std::optional gen) { + geometric_kernel(iter, p, check_generator(gen)); + } +}; + +// ================================================== Exponential ===================================================== + +template +void exponential_kernel(TensorIteratorBase& iter, double lambda_, RNG gen) { + TORCH_CHECK(isFloatingType(iter.dtype()), "Exponential distribution is a continuous probability distribution. dtype must be a floating point but you specified ", iter.dtype()); + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exponential_cuda", [&] { + using accscalar_t = at::acc_type; + auto lambda = static_cast(lambda_); + // define lambda for exponential transformation + auto exponential_func = [lambda] __device__ (accscalar_t rand) { + return static_cast(transformation::exponential(rand, lambda)); + }; + uniform_and_transform(iter, gen, exponential_func); + }); +} + +template +struct ExponentialKernel { + void operator()(TensorIteratorBase& iter, double lambda, std::optional gen) { + exponential_kernel(iter, lambda, check_generator(gen)); + } +}; + +// ==================================================== Cauchy ======================================================== + +template +void cauchy_kernel(TensorIteratorBase& iter, double median_, double sigma_, RNG gen) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "cauchy_cuda", [&] { + using accscalar_t = at::acc_type; + auto median = static_cast(median_); + auto sigma = static_cast(sigma_); + // define lambda for cauchy transformation + auto cauchy_func = [median, sigma] __device__ (accscalar_t rand) { + return static_cast(transformation::cauchy(rand, median, sigma)); + }; + uniform_and_transform(iter, gen, cauchy_func); + }); +} + +template +struct CauchyKernel { + void operator()(TensorIteratorBase& iter, double median, double sigma, std::optional gen) { + cauchy_kernel(iter, median, sigma, check_generator(gen)); + } +}; + +// ==================================================== Bernoulli ===================================================== + +template +void bernoulli_tensor_cuda_kernel( + const TensorBase &ret, const at::TensorBase &p, + PhiloxCudaState philox_args) { + auto functor = [philox_args] __device__( + int n, scalar_t& v1, scalar_t& v2, scalar_t& v3, scalar_t& v4, + const prob_t& p1, const prob_t& p2, const prob_t& p3, const prob_t& p4) { + auto seeds = at::cuda::philox::unpack(philox_args); + curandStatePhilox4_32_10_t state; + curand_init(std::get<0>(seeds), + blockIdx.x * blockDim.x + threadIdx.x, + std::get<1>(seeds), + &state); + + // See Note [Register spilling in curand call for CUDA < 10] + float4 rand = curand_uniform4(&state); + switch (n) { + case 4: { + CUDA_KERNEL_ASSERT(0 <= p4 && p4 <= 1); + v4 = static_cast(rand.w <= p4); + [[fallthrough]]; + } + case 3: { + CUDA_KERNEL_ASSERT(0 <= p3 && p3 <= 1); + v3 = static_cast(rand.z <= p3); + [[fallthrough]]; + } + case 2: { + CUDA_KERNEL_ASSERT(0 <= p2 && p2 <= 1); + v2 = static_cast(rand.y <= p2); + [[fallthrough]]; + } + case 1: { + CUDA_KERNEL_ASSERT(0 <= p1 && p1 <= 1); + v1 = static_cast(rand.x <= p1); + } + } + }; + // The template argument `4` below indicates that we want to operate on four + // element at each time. See NOTE [ CUDA_tensor_applyN helpers ] for details. + at::cuda::CUDA_tensor_apply2(ret, p, functor); +} + +template +void bernoulli_kernel(const TensorBase &self, const TensorBase &p_, RNG gen) { + PhiloxCudaState rng_engine_inputs; + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + rng_engine_inputs = gen->philox_cuda_state(10); + } + TORCH_CHECK(at::isFloatingType(p_.scalar_type()), "expected probabilities tensor to have floating type, got ", p_.scalar_type()); + // cast probabilities tensor to double for double `self` tensor, and to `float` for everything else + const auto p_type = self.dtype() == at::kDouble ? at::kDouble : at::kFloat; + auto p_cuda = p_.to(TensorOptions().device(self.device()).dtype(p_type)); + auto p = expand_inplace(self, p_cuda); + AT_DISPATCH_ALL_TYPES_AND3( + at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, self.scalar_type(), "bernoulli_tensor_cuda_self_", [&] { + if (std::is_same_v) { + return bernoulli_tensor_cuda_kernel(self, *p, rng_engine_inputs); + } else { + return bernoulli_tensor_cuda_kernel(self, *p, rng_engine_inputs); + } + }); +} + +template +void bernoulli_kernel(TensorIteratorBase& iter, double p, RNG gen) { + AT_DISPATCH_ALL_TYPES_AND3( + at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "bernoulli_scalar_cuda_", [&] { + using accscalar_t = at::DiscreteDistributionType::type; + // define lambda for bernoulli transformation + auto bernoulli_func = [p] __device__ (accscalar_t rand) { + return static_cast(transformation::bernoulli(rand, p)); + }; + uniform_and_transform(iter, gen, bernoulli_func); + }); +} + +template +struct BernoulliKernel { + void operator()(TensorIteratorBase& iter, double p, std::optional gen) { + bernoulli_kernel(iter, p, check_generator(gen)); + } + void operator()(const TensorBase &self, const TensorBase &p_, std::optional gen) { + bernoulli_kernel(self, p_, check_generator(gen)); + } +}; + +}}}} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/Distributions.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/Distributions.h new file mode 100644 index 0000000000000000000000000000000000000000..053eff0c7d7a5a84db1601bf17fd19dc2cc35382 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/Distributions.h @@ -0,0 +1,25 @@ +#pragma once + +namespace at { +struct CUDAGeneratorImpl; +struct TensorIteratorBase; +class TensorBase; + +namespace native { + +void launch_poisson_cuda_kernel( + const TensorBase &ret, const TensorBase &lambda, CUDAGeneratorImpl *gen); + +void launch_gamma_kernel( + const TensorBase &ret, const TensorBase &alpha, CUDAGeneratorImpl *gen); + +void launch_binomial_cuda_kernel( + TensorIteratorBase &iter, CUDAGeneratorImpl *gen); + +void launch_dirichlet_kernel(TensorIteratorBase &iter); + +void launch_standard_gamma_grad_kernel(TensorIteratorBase &iter); + +void launch_dirichlet_grad_kernel(TensorIteratorBase &iter); + +}} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/EmbeddingBackwardKernel.cuh b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/EmbeddingBackwardKernel.cuh new file mode 100644 index 0000000000000000000000000000000000000000..6fce8d6eb6d2ab28547ee321bc74bb71a774687f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/EmbeddingBackwardKernel.cuh @@ -0,0 +1,21 @@ +#pragma once +#include +#include +#include +#include + +namespace at::native { + +Tensor embedding_backward_cuda_kernel( + const Tensor &grad, + const Tensor &orig_indices, + const Tensor &sorted_indices, + const Tensor &count, + int64_t num_weights, + int padding_idx = -1, + bool mode_mean = false, + const Tensor &offset2bag = Tensor(), + const Tensor &bag_size = Tensor(), + const Tensor &per_sample_weights = Tensor()); + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/ForeachFunctors.cuh b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/ForeachFunctors.cuh new file mode 100644 index 0000000000000000000000000000000000000000..6888dde20982fbfb0162e1143fcb7b8906cbb15f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/ForeachFunctors.cuh @@ -0,0 +1,738 @@ +#pragma once +#include +#include +#include +#include + +namespace at::native { + +namespace { + +// TODO(crcrpar): Handle version bump in codegen. +// rel: +// https://github.com/pytorch/pytorch/blob/9cf84347767c8abb8feba18a9a1baba321eeb8b9/tools/autograd/gen_inplace_or_view_type.py#L481-L482 +inline void increment_version(TensorList tensors) { + for (const auto& t : tensors) { + t.unsafeGetTensorImpl()->bump_version(); + } +} + +// Initializes args and checks if all args are aligned +template +__device__ bool init_args( + T** args, + TensorListMetadata& tl, + const int64_t chunk_idx, + const int64_t chunk_size, + const int64_t tensor_loc) { + bool all_aligned = true; + for (int i = 0; i < depth; i++) { + args[i] = (T*)tl.addresses[i][tensor_loc]; + args[i] += chunk_idx * chunk_size; + + if (!is_aligned(args[i])) { + all_aligned = false; + } + } + return all_aligned; +} + +// Initializes args and checks if all args are aligned +template +__device__ bool init_args( + T** args, + TensorListScalarListMetadata& tl, + const int64_t chunk_idx, + const int64_t chunk_size, + const int64_t tensor_loc) { + bool all_aligned = true; + for (int i = 0; i < depth; i++) { + args[i] = (T*)tl.addresses[i][tensor_loc]; + args[i] += chunk_idx * chunk_size; + + if (!is_aligned(args[i])) { + all_aligned = false; + } + } + return all_aligned; +} + +template +__device__ bool init_args( + T** args, + FusedOptimizerTensorListMetadata& tl, + const int64_t chunk_idx, + const int64_t chunk_size, + const int64_t tensor_loc) { + bool all_aligned = true; + for (int i = 0; i < depth; i++) { + args[i] = (T*)tl.addresses[i][tensor_loc]; + args[i] += chunk_idx * chunk_size; + + if (!is_aligned(args[i])) { + all_aligned = false; + } + } + return all_aligned; +} + +template +__device__ void load_args( + T r_args[][kILP], + T** args, + const int64_t i_start, + const int64_t chunk_size, + const int64_t n) { +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + const auto i = i_start + threadIdx.x + ii * blockDim.x; + for (int r_index = 0; r_index < depth; r_index++) { + r_args[r_index][ii] = 0; + if (i < n && i < chunk_size) { + r_args[r_index][ii] = args[r_index][i]; + } + } + } +} + +template +__device__ void store_args( + T* dst, + T* src, + const int64_t i_start, + const int64_t chunk_size, + const int64_t n) { +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + const int64_t i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) + dst[i] = src[ii]; + } +} + +template +__device__ __forceinline__ void binary_op_scalar( + T r_args[][kILP], + T** args, + opmath_t scalar, + const int64_t n, + const int64_t chunk_size, + const bool all_aligned, + Op op) { + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_args[0], args[0], 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast( + op(static_cast(r_args[0][ii]), + static_cast(scalar))); + } + // store + load_store(args[res_arg_index], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + // Regardless if depth is 1 (for inplace) or 2 (for out of place), r_args + // has depth 1 + load_args<1>(r_args, args, i_start, chunk_size, n); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast( + op(static_cast(r_args[0][ii]), + static_cast(scalar))); + } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); + } + } +} + +template +__device__ __forceinline__ void pointwise_op_scalar( + T r_args[][kILP], + T** args, + opmath_t scalar, + const int64_t n, + const int64_t chunk_size, + const bool all_aligned, + Op op) { + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_args[0], args[0], 0, i_start); + load_store(r_args[1], args[1], 0, i_start); + load_store(r_args[2], args[2], 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast( + static_cast(r_args[0][ii]) + + scalar * + op(static_cast(r_args[1][ii]), + static_cast(r_args[2][ii]))); + } + // store + load_store(args[res_arg_index], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + // Regardless if depth is 3 (for inplace) or 4 (for out of place), r_args + // has depth 3 + load_args<3>(r_args, args, i_start, chunk_size, n); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast( + static_cast(r_args[0][ii]) + + scalar * + op(static_cast(r_args[1][ii]), + static_cast(r_args[2][ii]))); + } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); + } + } +} + +// +// Binary Functors +// +template +struct BinaryOpScalarFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + Op op, + opmath_t scalar) { + const int tensor_loc = tl.block_to_tensor[blockIdx.x]; + const int chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + binary_op_scalar( + r_args, args, scalar, n, chunk_size, all_aligned, op); + } +}; + +template +struct BinaryOpScalarListFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListScalarListMetadata& tl, + Op op) { + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + opmath_t scalar = tl.scalar_vals[tensor_loc]; + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + binary_op_scalar( + r_args, args, scalar, n, chunk_size, all_aligned, op); + } +}; + +template +struct BinaryOpListAlphaFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + Op op, + opmath_t alpha) { + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_args[0], args[0], 0, i_start); + load_store(r_args[1], args[1], 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast( + op(static_cast(r_args[0][ii]), + alpha * static_cast(r_args[1][ii]))); + } + // store + load_store(args[res_arg_index], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + load_args(r_args, args, i_start, chunk_size, n); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast( + op(static_cast(r_args[0][ii]), + alpha * static_cast(r_args[1][ii]))); + } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); + } + } + } +}; + +template +struct BinaryOpScalarTensorFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + Op op, + T* scalar, + opmath_t alpha) { + const int tensor_loc = tl.block_to_tensor[blockIdx.x]; + const int chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_args[0], args[0], 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast(op( + static_cast(r_args[0][ii]), + static_cast(alpha) * static_cast(*scalar))); + } + // store + load_store(args[res_arg_index], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + // Regardless if depth is 1 (for inplace) or 2 (for out of place), + // r_args has depth 1 + load_args<1>(r_args, args, i_start, chunk_size, n); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast(op( + static_cast(r_args[0][ii]), + static_cast(alpha) * static_cast(*scalar))); + } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); + } + } + } +}; + +// +// Unary Functors +// + +template +struct ZeroFunctor { + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata<1>& tl) { + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const auto all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = 0; + } + // store + load_store(args[0], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = 0; + } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); + } + } + } +}; + +template +struct UnaryOpFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + Op op) { + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_args[0], args[0], 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = + static_cast(op(static_cast(r_args[0][ii]))); + } + // store + load_store(args[res_arg_index], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + load_args(r_args, args, i_start, chunk_size, n); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = + static_cast(op(static_cast(r_args[0][ii]))); + } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); + } + } + } +}; + +// +// Pointwise Functors +// + +template +struct PointwiseOpScalarFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + Op op, + opmath_t scalar) { + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + pointwise_op_scalar( + r_args, args, scalar, n, chunk_size, all_aligned, op); + } +}; + +template +struct PointwiseOpScalarListFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListScalarListMetadata& tl, + Op op) { + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + opmath_t scalar = tl.scalar_vals[tensor_loc]; + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + pointwise_op_scalar( + r_args, args, scalar, n, chunk_size, all_aligned, op); + } +}; + +template +struct PointwiseOpListFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + Op op) { + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[depth - 1][kILP]; + + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_args[0], args[0], 0, i_start); + load_store(r_args[1], args[1], 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast( + op(static_cast(r_args[0][ii]), + static_cast(r_args[1][ii]))); + } + // store + load_store(args[2], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + load_args(r_args, args, i_start, chunk_size, n); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast( + op(static_cast(r_args[0][ii]), + static_cast(r_args[1][ii]))); + } + store_args(args[2], r_args[0], i_start, chunk_size, n); + } + } + } +}; + +template +struct TernaryOpListFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + Op op) { + static_assert(depth == 3 || depth == 4, ""); + static_assert(depth >= r_args_depth, ""); + static_assert(res_arg_index == depth - 1 || res_arg_index == 0, ""); + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + load_store(r_args[0], args[0], 0, i_start); + load_store(r_args[1], args[1], 0, i_start); + load_store(r_args[2], args[2], 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = + op(static_cast(r_args[0][ii]), + static_cast(r_args[1][ii]), + static_cast(r_args[2][ii])); + } + load_store(args[res_arg_index], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + load_args(r_args, args, i_start, chunk_size, n); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = + op(static_cast(r_args[0][ii]), + static_cast(r_args[1][ii]), + static_cast(r_args[2][ii])); + } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); + } + } + } +}; + +template +struct TernaryOpScalarFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + Op op, + opmath_t alpha) { + static_assert(depth == 2 || depth == 3, ""); + static_assert(depth >= r_args_depth, ""); + static_assert(res_arg_index == depth - 1 || res_arg_index == 0, ""); + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_args[0], args[0], 0, i_start); + load_store(r_args[1], args[1], 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = + op(static_cast(r_args[0][ii]), + static_cast(r_args[1][ii]), + alpha); + } + // store + load_store(args[res_arg_index], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + load_args(r_args, args, i_start, chunk_size, n); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = + op(static_cast(r_args[0][ii]), + static_cast(r_args[1][ii]), + alpha); + } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); + } + } + } +}; + +template +struct TernaryOpScalarListFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListScalarListMetadata& tl, + Op op) { + static_assert(depth == 2 || depth == 3, ""); + static_assert(depth >= r_args_depth, ""); + static_assert(res_arg_index == depth - 1 || res_arg_index == 0, ""); + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + const opmath_t scalar = tl.scalar_vals[tensor_loc]; + + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_args[0], args[0], 0, i_start); + load_store(r_args[1], args[1], 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = + op(static_cast(r_args[0][ii]), + static_cast(r_args[1][ii]), + scalar); + } + // store + load_store(args[res_arg_index], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + load_args(r_args, args, i_start, chunk_size, n); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = + op(static_cast(r_args[0][ii]), + static_cast(r_args[1][ii]), + scalar); + } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); + } + } + } +}; + +template +struct power_functor { + C10_DEVICE T operator()(const T& a, const T& b) const { + return at::native::pow_(a, b); + } +}; + +template +struct reverse_power_functor { + C10_DEVICE T operator()(const T& a, const T& b) const { + return at::native::pow_(b, a); + } +}; + +} // namespace +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/ForeachMinMaxFunctors.cuh b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/ForeachMinMaxFunctors.cuh new file mode 100644 index 0000000000000000000000000000000000000000..32421ef305a9905cd6d54805429fa58bc78b0825 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/ForeachMinMaxFunctors.cuh @@ -0,0 +1,22 @@ +#pragma once + +#include + +namespace at::native { + +// std:: does not have clamp functors +template +struct minimum { + __device__ T operator()(const T& a, const T& b) const { + return (_isnan(a) || a < b) ? a : b; + } +}; + +template +struct maximum { + __device__ T operator()(const T& a, const T& b) const { + return (_isnan(a) || a > b) ? a : b; + } +}; + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/GridSampler.cuh b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/GridSampler.cuh new file mode 100644 index 0000000000000000000000000000000000000000..5041d03ed424d45a4e0d0be84b6f346e076ac701 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/GridSampler.cuh @@ -0,0 +1,321 @@ +#pragma once +#include +#include + +namespace at::native { + +using detail::GridSamplerInterpolation; +using detail::GridSamplerPadding; + +// Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value, +// where we view each pixel as an area between (idx - 0.5) and (idx + 0.5). +// if align_corners: -1 and +1 get sent to the centers of the corner pixels +// -1 --> 0 +// +1 --> (size - 1) +// scale_factor = (size - 1) / 2 +// if not align_corners: -1 and +1 get sent to the image edges +// -1 --> -0.5 +// +1 --> (size - 1) + 0.5 == size - 0.5 +// scale_factor = size / 2 +template +__forceinline__ __device__ +scalar_t grid_sampler_unnormalize(scalar_t coord, int size, bool align_corners) { + if (align_corners) { + // unnormalize coord from [-1, 1] to [0, size - 1] + return ((coord + 1.f) / 2) * (size - 1); + } else { + // unnormalize coord from [-1, 1] to [-0.5, size - 0.5] + return ((coord + 1.f) * size - 1) / 2; + } +} + +// grid_sampler_unnormalize_set_grad works the same as grid_sampler_unnormalize +// except that it also returns the `d output / d input` via pointer argument +// `grad_in`. +// This is useful in the backward pass of grid_sampler. +template +__forceinline__ __device__ +scalar_t grid_sampler_unnormalize_set_grad(scalar_t coord, int size, + bool align_corners, scalar_t *grad_in) { + if (align_corners) { + // unnormalize coord from [-1, 1] to [0, size - 1] + *grad_in = static_cast(size - 1) / 2; + return ((coord + 1.f) / 2) * (size - 1); + } else { + // unnormalize coord from [-1, 1] to [-0.5, size - 0.5] + *grad_in = static_cast(size) / 2; + return ((coord + 1.f) * size - 1) / 2; + } +} + +// Clips coordinates to between 0 and clip_limit - 1 +template +__forceinline__ __device__ +scalar_t clip_coordinates(scalar_t in, int clip_limit) { + return ::min(static_cast(clip_limit - 1), ::max(in, static_cast(0))); +} + +// clip_coordinates_set_grad works similarly to clip_coordinates except that +// it also returns the `d output / d input` via pointer argument `grad_in`. +// This is useful in the backward pass of grid_sampler. +template +__forceinline__ __device__ +scalar_t clip_coordinates_set_grad(scalar_t in, int clip_limit, scalar_t *grad_in) { + // Note that it is important for the gradient calculation that borders + // are considered out of bounds. + if (in <= static_cast(0)) { + *grad_in = static_cast(0); + return static_cast(0); + } else { + scalar_t max = static_cast(clip_limit - 1); + if (in >= max) { + *grad_in = static_cast(0); + return max; + } else { + *grad_in = static_cast(1); + return in; + } + } +} + +// Reflects coordinates until they fall between low and high (inclusive). +// The bounds are passed as twice their value so that half-integer values +// can be represented as ints. +template +__forceinline__ __device__ +scalar_t reflect_coordinates(scalar_t in, int twice_low, int twice_high) { + if (twice_low == twice_high) { + return static_cast(0); + } + scalar_t min = static_cast(twice_low) / 2; + scalar_t span = static_cast(twice_high - twice_low) / 2; + in = ::fabs(in - min); + // `fmod` returns same sign as `in`, which is positive after the `fabs` above. + scalar_t extra = ::fmod(in, span); + int flips = static_cast(::floor(in / span)); + if (flips % 2 == 0) { + return extra + min; + } else { + return span - extra + min; + } +} + +// reflect_coordinates_set_grad works similarly to reflect_coordinates except +// that it also returns the `d output / d input` via pointer argument +// `grad_in`. +// This is useful in the backward pass of grid_sampler. +template +__forceinline__ __device__ +scalar_t reflect_coordinates_set_grad(scalar_t in, int twice_low, int twice_high, + scalar_t *grad_in) { + if (twice_low == twice_high) { + *grad_in = static_cast(0); + return static_cast(0); + } + int grad_in_mult_; + scalar_t min = static_cast(twice_low) / 2; + scalar_t span = static_cast(twice_high - twice_low) / 2; + in = in - min; + if (in < static_cast(0)) { + grad_in_mult_ = -1; + in = -in; + } else { + grad_in_mult_ = 1; + } + // `fmod` returns same sign as `in`, which is positive after the `if` above. + scalar_t extra = ::fmod(in, span); + int flips = static_cast(::floor(in / span)); + if (flips % 2 == 0) { + *grad_in = static_cast(grad_in_mult_); + return extra + min; + } else { + *grad_in = static_cast(-grad_in_mult_); + return span - extra + min; + } +} + +template +__forceinline__ __device__ +scalar_t safe_downgrade_to_int_range(scalar_t x){ + // -100.0 does not have special meaning. This is just to make sure + // it's not within_bounds_2d or within_bounds_3d, and does not cause + // undefined behavior. See #35506. + if (x > INT_MAX-1 || x < INT_MIN || !::isfinite(static_cast(x))) + return static_cast(-100.0); + return x; +} + +template +__forceinline__ __device__ +scalar_t compute_coordinates(scalar_t coord, int size, + GridSamplerPadding padding_mode, + bool align_corners) { + if (padding_mode == GridSamplerPadding::Border) { + // clip coordinates to image borders + coord = clip_coordinates(coord, size); + } else if (padding_mode == GridSamplerPadding::Reflection) { + // reflect coordinates by image borders + if (align_corners) { + coord = reflect_coordinates(coord, 0, 2*(size - 1)); + } else { + coord = reflect_coordinates(coord, -1, 2*size - 1); + } + // clip coordinates to image borders + coord = clip_coordinates(coord, size); + } + + coord = safe_downgrade_to_int_range(coord); + return coord; +} + +// Computes the pixel source index value for a grid coordinate +template +__forceinline__ __device__ +scalar_t grid_sampler_compute_source_index( + scalar_t coord, + int size, + GridSamplerPadding padding_mode, + bool align_corners) { + coord = grid_sampler_unnormalize(coord, size, align_corners); + coord = compute_coordinates(coord, size, padding_mode, align_corners); + return coord; +} + +// grid_sampler_compute_source_index_set_grad works similarly to +// grid_sampler_compute_source_index except that it also returns the +// `d output / d input` via pointer argument `grad_in`. +// This is useful in the backward pass of grid_sampler. +template +__forceinline__ __device__ +scalar_t grid_sampler_compute_source_index_set_grad( + scalar_t coord, + int size, + GridSamplerPadding padding_mode, + bool align_corners, + scalar_t *grad_in) { + scalar_t grad_clip, grad_refl; + coord = grid_sampler_unnormalize_set_grad(coord, size, align_corners, grad_in); + if (padding_mode == GridSamplerPadding::Border) { + // clip coordinates to image borders + coord = clip_coordinates_set_grad(coord, size, &grad_clip); + *grad_in = (*grad_in) * grad_clip; + } else if (padding_mode == GridSamplerPadding::Reflection) { + // reflect coordinates by image borders + if (align_corners) { + coord = reflect_coordinates_set_grad(coord, 0, 2*(size - 1), &grad_refl); + } else { + coord = reflect_coordinates_set_grad(coord, -1, 2*size - 1, &grad_refl); + } + // clip coordinates to image borders + coord = clip_coordinates_set_grad(coord, size, &grad_clip); + *grad_in = (*grad_in) * grad_refl * grad_clip; + } + + coord = safe_downgrade_to_int_range(coord); + return coord; +} + +__forceinline__ __device__ +bool within_bounds_2d(int h, int w, int H, int W) { + return h >= 0 && h < H && w >= 0 && w < W; +} + +__forceinline__ __device__ +bool within_bounds_3d(int d, int h, int w, int D, int H, int W) { + return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W; +} + +template +__forceinline__ __device__ +scalar_t get_value_bounded( + const scalar_t *data, scalar_t x, scalar_t y, int W, int H, int sW, int sH, + GridSamplerPadding padding_mode, + bool align_corners) { + + x = compute_coordinates(x, W, padding_mode, align_corners); + y = compute_coordinates(y, H, padding_mode, align_corners); + + int ix = static_cast(x); + int iy = static_cast(y); + + if (within_bounds_2d(iy, ix, H, W)) { + return data[iy * sH + ix * sW]; + } + return static_cast(0); +} + +template +__forceinline__ __device__ +void safe_add_2d(scalar_t *data, int h, int w, + int sH, int sW, int H, int W, + scalar_t delta, + const index_t NC_offset, + const index_t memory_span) { + if (within_bounds_2d(h, w, H, W)) { + fastAtomicAdd(data, + NC_offset + h * sH + w * sW, + memory_span, + delta, + true); + } +} + +template +__forceinline__ __device__ +void safe_add_3d(scalar_t *data, int d, int h, int w, + int sD, int sH, int sW, int D, int H, int W, + scalar_t delta, + const index_t NC_offset, + const index_t memory_span) { + if (within_bounds_3d(d, h, w, D, H, W)) { + fastAtomicAdd(data, + NC_offset + d * sD + h * sH + w * sW, + memory_span, + delta, + true); + } +} + +template +__forceinline__ __device__ +void add_value_bounded( + scalar_t* data, scalar_t x, scalar_t y, int W, int H, int sW, int sH, + scalar_t delta, + GridSamplerPadding padding_mode, + bool align_corners, + const index_t NC_offset, + const index_t memory_span) { + + x = compute_coordinates(x, W, padding_mode, align_corners); + y = compute_coordinates(y, H, padding_mode, align_corners); + + int ix = static_cast(x); + int iy = static_cast(y); + + safe_add_2d(data, iy, ix, sH, sW, H, W, delta, NC_offset, memory_span); +} + +// Calculate the differential of the cubic convolution, i.e. `d coeff / d x` +template +__forceinline__ __device__ +void get_cubic_coefficients_grad( + scalar_t coeffs[4], + scalar_t t) { + + // Must be the same as forward calculation in + // aten/src/ATen/native/cuda/UpSample.cuh:get_cubic_upsample_coefficients + scalar_t A = -0.75; + + scalar_t x; + x = -1 - t; // 1 < x = |-1 - tx| < 2 + coeffs[0] = (-3 * A * x - 10 * A ) * x - 8 * A; + x = -t; // x = |0 - tx| <= 1 + coeffs[1] = (-3 * (A + 2) * x - 2 * (A + 3)) * x; + x = 1 - t; // x = |1 - tx| <= 1 + coeffs[2] = (3 * (A + 2) * x - 2 * (A + 3)) * x; + x = 2 - t; // 1 < x = |2 - tx| < 2 + coeffs[3] = (3 * A * x - 10 * A) * x + 8 * A; +} + + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/GridSampler.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/GridSampler.h new file mode 100644 index 0000000000000000000000000000000000000000..3d2ff62706840d18984bfef20adba911ae4293e3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/GridSampler.h @@ -0,0 +1,31 @@ +#pragma once +#include +#include + +namespace at { +class TensorBase; +} + +namespace at::native { + +void launch_grid_sampler_2d_forward_kernel( + const TensorBase &output, const TensorBase &input, const TensorBase &grid, + int64_t interpolation_mode, int64_t padding_mode, bool align_corners); + +void launch_grid_sampler_3d_forward_kernel( + const TensorBase &output, const TensorBase &input, const TensorBase &grid, + int64_t interpolation_mode, int64_t padding_mode, bool align_corners); + +void launch_grid_sampler_2d_backward_kernel( + const TensorBase &grad_input, const TensorBase &grad_grid, + const TensorBase &grad_output, const TensorBase &input, + const TensorBase &grid, int64_t interpolation_mode, int64_t padding_mode, + bool align_corners, std::array output_mask); + +void launch_grid_sampler_3d_backward_kernel( + const TensorBase &grad_input, const TensorBase &grad_grid, + const TensorBase &grad_output, const TensorBase &input, + const TensorBase &grid, int64_t interpolation_mode, int64_t padding_mode, + bool align_corners, std::array output_mask); + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/GroupMM.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/GroupMM.h new file mode 100644 index 0000000000000000000000000000000000000000..e425619a8bead2e9ff04714870cb7938129758c4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/GroupMM.h @@ -0,0 +1,12 @@ +#pragma once +#include +#include + +namespace at::cuda::detail { +TORCH_API void bf16bf16_grouped_mm( + at::Tensor mat_a, // bf16 + at::Tensor mat_b, // bf16 + std::optional offs, + std::optional bias, // BF16 + at::Tensor& out); +} // namespace at::cuda::detail diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/GroupMMCommon.cuh b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/GroupMMCommon.cuh new file mode 100644 index 0000000000000000000000000000000000000000..ebfc2cc667a4e127ef570131151535ef7089ec9a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/GroupMMCommon.cuh @@ -0,0 +1,156 @@ +#pragma once +#include + +namespace at::cuda::detail { + +using Strides = std::array; + +template < + typename DtypeA, + typename DtypeB, + typename DtypeOutput, + typename DtypeScale, + typename ProblemShape, + typename StrideA, + typename StrideB, + typename StrideOutput> +__global__ void prepare_grouped_gemm_data( + DtypeA* A, + DtypeB* B, + DtypeOutput* output, + DtypeScale* scale_A, + DtypeScale* scale_B, + DtypeA** A_ptrs, + DtypeB** B_ptrs, + DtypeOutput** output_ptrs, + DtypeScale** inputA_scale_ptrs, + DtypeScale** inputB_scale_ptrs, + ProblemShape* problem_sizes, + // Strides for cutlass, cute::Stride + StrideA* stride_A, + StrideB* stride_B, + StrideOutput* stride_output, + const int32_t* offs, + int32_t M, + int32_t N, + int32_t K, + // Original strides of the input tensors + Strides tensor_StrideA, + Strides tensor_StrideB, + Strides tensor_StrideOutput, + int64_t a_scale_stride, + int64_t b_scale_stride, + bool a_row_major = true, + bool b_row_major = false) { + int32_t tid = threadIdx.x; + int32_t delta = 0; + if (offs != nullptr) { + int32_t start = tid == 0 ? 0 : offs[tid - 1]; + delta = offs[tid] - start; + if (K < 0) { + if (!a_row_major && b_row_major) { + CUDA_KERNEL_ASSERT(delta >=0 && "expected ofsets to be greater or equal 0\n"); + } else { + // CUTLASS cannot handle delta=0 here. + CUDA_KERNEL_ASSERT(delta >0 && "expected ofsets to be greater than 0\n"); + } + } + + // TMA transfers require global memory tensor addresses to be + // aligned to 16 bytes. + if (tid < blockDim.x - 1) { + // Check this requirement for input tensors, in case group + // addresses are increased along the dynamic dimension. + if ((K < 0 && a_row_major) || // 2D/2D: check along K dimension + (M < 0 && !a_row_major)) { // 3D/2D: check along N dimension + int align = 128 / cutlass::sizeof_bits::value; + CUDA_KERNEL_ASSERT( + delta % align == 0 && + "expected input tensor dynamic dimension byte size to be non-negative multiple of 16\n"); + } + if ((K < 0 && !b_row_major) || // 2D/2D: check along K dimension + (N < 0 && b_row_major)) { // 3D/2D: check along N dimension + int align = 128 / cutlass::sizeof_bits::value; + CUDA_KERNEL_ASSERT( + delta % align == 0 && + "expected input tensor dynamic dimension byte size to be non-negative multiple of 16\n"); + } + + // Check the same requirement for output tensor (that is always + // contiguous, and in row-major layout). + if (N < 0) { + int align = 128 / cutlass::sizeof_bits::value; + CUDA_KERNEL_ASSERT( + delta % align == 0 && + "expected output tensor dynamic dimension byte size to be non-negative multiple of 16\n"); + } + } + } + int64_t lda, ldb, ldoutput; + if (M < 0) { + // A and output is 2d + M = delta; + lda = a_row_major ? tensor_StrideA[0] : tensor_StrideA[1]; + ldb = b_row_major ? tensor_StrideB[1] : tensor_StrideB[2]; + ldoutput = tensor_StrideOutput[0]; + A_ptrs[tid] = tid == 0 ? A : A + offs[tid - 1] * tensor_StrideA[0]; + if (scale_A != nullptr) { + inputA_scale_ptrs[tid] = tid == 0 ? scale_A : scale_A + offs[tid - 1]; + inputB_scale_ptrs[tid] = scale_B + tid * b_scale_stride; + } + output_ptrs[tid] = tid == 0 ? output : output + offs[tid - 1] * ldoutput; + B_ptrs[tid] = B + tid * tensor_StrideB[0]; + } else if (N < 0) { + N = delta; + lda = a_row_major ? tensor_StrideA[1] : tensor_StrideA[2]; + ldb = b_row_major ? tensor_StrideB[0] : tensor_StrideB[1]; // B is transposed + ldoutput = tensor_StrideOutput[0]; + A_ptrs[tid] = A + tid * tensor_StrideA[0]; + output_ptrs[tid] = tid == 0 ? output : output + offs[tid - 1]; + B_ptrs[tid] = tid == 0 ? B : B + offs[tid - 1] * tensor_StrideB[1]; + if (scale_A != nullptr) { + inputA_scale_ptrs[tid] = scale_A + tid * a_scale_stride; + inputB_scale_ptrs[tid] = tid == 0 ? scale_B : scale_B + offs[tid - 1]; + } + } else if (K < 0) { + // A, B is 2d, output is 3d + K = delta; + lda = a_row_major ? tensor_StrideA[0] : tensor_StrideA[1]; + ldb = b_row_major ? tensor_StrideB[0] : tensor_StrideB[1]; + ldoutput = tensor_StrideOutput[1]; + A_ptrs[tid] = tid == 0 ? A : A + offs[tid - 1] * tensor_StrideA[1]; + B_ptrs[tid] = tid == 0 ? B : B + offs[tid - 1] * tensor_StrideB[0]; + output_ptrs[tid] = output + tid * tensor_StrideOutput[0]; + if (scale_A != nullptr) { + inputA_scale_ptrs[tid] = scale_A + tid * M; + inputB_scale_ptrs[tid] = scale_B + tid * N; + } + } else { + // A, B, output are 3D + lda = a_row_major ? tensor_StrideA[1] : tensor_StrideA[2]; + ldb = b_row_major ? tensor_StrideB[1] : tensor_StrideB[2]; + ldoutput = tensor_StrideOutput[1]; + A_ptrs[tid] = A + tid * tensor_StrideA[0]; + B_ptrs[tid] = B + tid * tensor_StrideB[0]; + output_ptrs[tid] = output + tid * tensor_StrideOutput[0]; + if (scale_A != nullptr) { + inputA_scale_ptrs[tid] = scale_A + tid * a_scale_stride; + inputB_scale_ptrs[tid] = scale_B + tid * b_scale_stride; + } + } + problem_sizes[tid] = ProblemShape(M, N, K); + + // make_cute_packed_stride only replaces one of the stride elements with + // one the provided values in the shape arguments + // the indices of the src/dst depend on whether A/B are row-major + // so constructing shape argument with two similar lda values + // while it looks non-sensical (and it is a nonsensical shape) + // is fine for these stride construction purposes - the one that will be used + // for replacement is correct, the other one is ignored, and we don't have to + // branch on whether A/B are row-major + stride_A[tid] = cutlass::make_cute_packed_stride(StrideA{}, {lda, lda, 1}); + stride_B[tid] = cutlass::make_cute_packed_stride(StrideB{}, {ldb, ldb, 1}); + stride_output[tid] = + cutlass::make_cute_packed_stride(StrideOutput{}, {M, ldoutput, 1}); +} +} // namespace at::cuda::detail diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/IndexKernel.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/IndexKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..3f3c3f36c11d730a0a3bf2dcda3d6f6edbc8a1e0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/IndexKernel.h @@ -0,0 +1,15 @@ +#pragma once +#include +#include + +namespace at { +struct TensorIteratorBase; +class TensorBase; +} + +namespace at::native { +/// @param maskPrefixSum[in,out] +void launch_masked_scatter_kernel( + const TensorBase &self, const TensorBase &mask, + const TensorBase &maskPrefixSum, const TensorBase &source); +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/IndexKernelUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/IndexKernelUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..eafbdaa1bb754409d9946c9f3a5bc7ad8a747f7c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/IndexKernelUtils.h @@ -0,0 +1,35 @@ + +#include +#include +#include + +namespace at::native { + +template +inline bool fast_gather_kernel_eligible(const TensorIterator& iter, char * const out_ptr, char * const in_ptr, const size_t index_stride_bytes, const size_t element_size) { + using at::native::memory::get_alignment; + const auto index_element_size = iter.element_size(2); + //TensorIterator strides and sizes are ordered fastest moving to slowest moving, + //in contrast to regular sizes + // we need contiguous source and dst slices and aligned pointers and strides and slice size to do vectorized loads + // also we need idx to be expanded in the last dimension so we can copy entire slices + // and we need the src tensor to keep 0 stride from restriding + // (it could have been deleted by dimension collapse, in this case iterator would still be 2d + // but we cannot use fast path) + + return iter.ndim() == 2 && iter.strides(2)[0]==0 && iter.strides(2)[1]==index_element_size && + static_cast(iter.strides(0)[0])==element_size && + static_cast(iter.strides(1)[0])==element_size && static_cast(iter.strides(1)[1] == 0) && + get_alignment(out_ptr) == alignment && get_alignment(in_ptr) == alignment && + get_alignment(static_cast(iter.shape()[0] * element_size)) == alignment && + get_alignment(static_cast(index_stride_bytes)) == alignment && + get_alignment(static_cast(iter.strides(0)[1])) == alignment; +} + +template +void vectorized_gather_kernel_launch(char * out, char * inp, index_t * idx, int num_ind, + int64_t slice_size_in_bytes, int64_t ind_dim_size, int64_t inp_stride_bytes, int64_t out_stride_bytes, + bool allow_neg_indices=false); + + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/JitLoops.cuh b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/JitLoops.cuh new file mode 100644 index 0000000000000000000000000000000000000000..9bc143a522dceabe527ac1ab3e60f1dabbca7a08 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/JitLoops.cuh @@ -0,0 +1,186 @@ +#pragma once + +#include + +#if AT_USE_JITERATOR() + +#include + +#include +#include +#include + +#include + +#include + +namespace at::native { + +/* Note [Jiterator] +The "jiterator" simply just-in-time compiles the same kernels that +Loops.cuh (and CUDALoops.cuh) usually build. This reduces build time, +build size, and initial CUDA context size. + +By default on non-Windows systems, it also caches compiled kernels in ~/.cache/torch/kernels. +This behavior is controlled with two environment variables: + - USE_PYTORCH_KERNEL_CACHE, if set to zero then this will disable all cache use + - PYTORCH_KERNEL_CACHE_PATH, if set specifies the folder to use for cached kernels + +The jiterator currently has some limitations, however. It cannot: + - handle math on complex datatypes + - handle kernels with scalar parameters + +These improvements will likely come soon. + +For examples of how to use the jiterator see the i1 and gcd kernel +implementations, which pass jittable strings implementing their +operations instead of the typical CUDA functors. + +To pass a runtime argument (similar to lambda captures in non-JIT kernels), +we need to pass to additional arguments to `jitted_gpu_kernel` by value. +Currently only primitive C++ types used for computation are valid. +The order of these extra arguments should be same as the order they appear +in kernel's function signature. (look at polygamma for example) + +NOTE: One big restriction being that these arguments should be after the +arguments provided by TensorIterator. Eg. While capturing `n`, where +`scalar_t x` and `scalar_t y` are provided by TensorIterator, +* foo(scalar_t x, scalar_t y, int n) works! +* foo(int n, scalar_t x, scalar_y) doesn't work +* foo(scalar_t x, int n, scalar_y) doesn't work + +*/ + +// Entrypoint for jitted GPU kernels. +// Only handles elementwise unary and binary kernels with a +// common dtype and a single output. +// NOTE: this assumes the op's iterator has a common_dtype. +// NOTE: We use std::tuple instead of parameter pack +// for `extra_args` due to following +// bug on older versions of clang +// https://bugs.llvm.org/show_bug.cgi?id=23029 +template < + char const* name, + typename return_type, + typename f_inputs_type, + int arity, + typename... Args> +void jitted_gpu_kernel( + TensorIteratorBase& iter, + const std::string& f, + at::cuda::jit::BinaryFuncVariant scalar_pos = + at::cuda::jit::BinaryFuncVariant::NoScalar, + at::opmath_type scalar_val = 0, + std::tuple extra_args = std::make_tuple()) { + // TODO: much of preamble is common to both jitted_gpu_kernel and gpu_kernel + // Maybe it could be refactored? + for (int arg = 0; arg < iter.ntensors(); arg++) { + TORCH_INTERNAL_ASSERT( + iter.device(arg).is_cuda(), + "argument ", arg, ": expected a CUDA device but found ", iter.device(arg)); + } + + if (iter.numel() == 0) { + return; + } + + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + jitted_gpu_kernel( + sub_iter, f, scalar_pos, scalar_val, extra_args); + } + + return; + } + + // Computes if dynamic casting is needed + // Dynamic casting is needed if an input's dtype differs from the common dtype + // or if the result dtype differs from the output's dtype + // Note: this is intentionally divergent from calling needs_dynamic_casting, + // which is more general and inspects a lambda to determine if dynamic + // casting is needed. + bool needs_dynamic_casting = false; + + // Checks output + const ScalarType return_scalar_type = c10::CppTypeToScalarType::value; + const auto dtype0 = iter.dtype(0); + if (dtype0 != return_scalar_type) { + needs_dynamic_casting = true; + } + + // Checks input(s) + const ScalarType inputs_scalar_type = c10::CppTypeToScalarType::value; + for (auto i = decltype(arity){1}; i < (arity + 1); ++i) { + const auto dtypei = iter.dtype(i); + if (dtypei != inputs_scalar_type) { + needs_dynamic_casting = true; + break; + } + } + if (scalar_pos == at::cuda::jit::BinaryFuncVariant::NoScalar) { + // NOTE: With `scalar_pos=NoScalar`,`scalar_val` is not used + // for computation in the generated code and hence we pass a dummy + // value of `0`. + jitted_gpu_kernel_impl< + /*name*/ name, + /*return_type=*/return_type, + /*f_inputs_type=*/f_inputs_type, + arity, + at::cuda::jit::BinaryFuncVariant::NoScalar>( + iter, f, needs_dynamic_casting, /*scalar_val=*/scalar_val, extra_args); + } else if (scalar_pos == at::cuda::jit::BinaryFuncVariant::RhsScalar) { + jitted_gpu_kernel_impl< + /*name*/ name, + /*return_type=*/return_type, + /*f_inputs_type=*/f_inputs_type, + arity, + at::cuda::jit::BinaryFuncVariant::RhsScalar>( + iter, + f, + needs_dynamic_casting, + scalar_val, + extra_args); + + } else { + jitted_gpu_kernel_impl< + /*name*/ name, + /*return_type=*/return_type, + /*f_inputs_type=*/f_inputs_type, + arity, + at::cuda::jit::BinaryFuncVariant::LhsScalar>( + iter, + f, + needs_dynamic_casting, + scalar_val, + extra_args); + } +} + +// TODO: support runtime state capture similar to `jitted_gpu_kernel`. +template +void opmath_jitted_gpu_kernel_with_scalars(TensorIteratorBase& iter, const std::string& f) { + TORCH_INTERNAL_ASSERT(iter.ntensors() == 3); + //currently jiterator only handles binary functions where both inputs are of the same type (f_inputs_type) + using opmath_t = at::opmath_type; + if (iter.is_cpu_scalar(1)) { + auto scalar_val = iter.scalar_value(1); + iter.remove_operand(1); + // TODO: When all kernels that use gpu_kernel_with_scalars are + // ported to structured, this device guard can be deleted. This + // works around incorrect device guard generation for pre-structured + // kernels device guards, but structured kernels do it right and + // we can assume the device is already set correctly + const OptionalDeviceGuard device_guard(iter.device(1)); + jitted_gpu_kernel(iter, f, at::cuda::jit::BinaryFuncVariant::LhsScalar, scalar_val); + } else if (iter.is_cpu_scalar(2)) { + auto scalar_val = iter.scalar_value(2); + iter.remove_operand(2); + jitted_gpu_kernel(iter, f, at::cuda::jit::BinaryFuncVariant::RhsScalar, scalar_val); + } else { + jitted_gpu_kernel(iter, f); + } +} + +} // namespace at::native + +#endif // AT_USE_JITERATOR() diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/KernelUtils.cuh b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/KernelUtils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..6eb18e74474335cd5363633dce7077064cb3c8a1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/KernelUtils.cuh @@ -0,0 +1,367 @@ +#pragma once +#include + +#if !(defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) +#include +#endif + +// ROCm 6.3 is planned to have these functions, but until then here they are. +#if defined(USE_ROCM) && ROCM_VERSION >= 60201 +#include +#include +#include + +__device__ inline __hip_bfloat162 preview_unsafeAtomicAdd(__hip_bfloat162* address, __hip_bfloat162 value) { +#if (defined(__gfx942__)) && \ + __has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2bf16) + typedef unsigned short __attribute__((ext_vector_type(2))) vec_short2; + static_assert(sizeof(vec_short2) == sizeof(__hip_bfloat162_raw)); + union { + __hip_bfloat162_raw bf162_raw; + vec_short2 vs2; + } u{static_cast<__hip_bfloat162_raw>(value)}; + u.vs2 = __builtin_amdgcn_flat_atomic_fadd_v2bf16((vec_short2*)address, u.vs2); + return static_cast<__hip_bfloat162>(u.bf162_raw); +#else + static_assert(sizeof(unsigned int) == sizeof(__hip_bfloat162_raw)); + union u_hold { + __hip_bfloat162_raw h2r; + unsigned int u32; + }; + u_hold old_val, new_val; + old_val.u32 = __hip_atomic_load((unsigned int*)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + do { + new_val.h2r = __hadd2(old_val.h2r, value); + } while (!__hip_atomic_compare_exchange_strong( + (unsigned int*)address, &old_val.u32, new_val.u32, + __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT)); + return old_val.h2r; +#endif +} + +__device__ inline __half2 preview_unsafeAtomicAdd(__half2* address, __half2 value) { +#if (defined(__gfx942__)) && \ + __has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2f16) + // The api expects an ext_vector_type of half + typedef _Float16 __attribute__((ext_vector_type(2))) vec_fp162; + static_assert(sizeof(vec_fp162) == sizeof(__half2_raw)); + union { + __half2_raw h2r; + vec_fp162 fp16; + } u {static_cast<__half2_raw>(value)}; + u.fp16 = __builtin_amdgcn_flat_atomic_fadd_v2f16((vec_fp162*)address, u.fp16); + return static_cast<__half2>(u.h2r); +#else + static_assert(sizeof(__half2_raw) == sizeof(unsigned int)); + union u_hold { + __half2_raw h2r; + unsigned int u32; + }; + u_hold old_val, new_val; + old_val.u32 = __hip_atomic_load((unsigned int*)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + do { + new_val.h2r = __hadd2(old_val.h2r, value); + } while (!__hip_atomic_compare_exchange_strong( + (unsigned int*)address, &old_val.u32, new_val.u32, + __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT)); + return old_val.h2r; +#endif +} +#define ATOMICADD preview_unsafeAtomicAdd +#define NATIVE_ZERO_BF16 __float2bfloat16(0.0f) +#else +#define ATOMICADD atomicAdd +#define NATIVE_ZERO_BF16 __int2bfloat16_rz(0) +#endif + +namespace at:: native { + +__device__ __forceinline__ size_t +idx(const size_t nc, + const size_t height, + const size_t width, + const size_t h, + const size_t w) { + return (nc * height + h) * width + w; +} + +// for channels-last +__device__ __forceinline__ size_t +idx_cl( + const size_t n, const size_t h, const size_t w, const size_t c, + const size_t height, const size_t width, const size_t channel +) { + return ((n * height + h) * width + w) * channel + c; +} + +// fastSpecializedAtomicAdd (and fastAtomicAdd) are an optimization +// that speed up half-precision atomics. The situation with half +// precision atomics is that we have a slow __half atomic, and +// a fast vectored __half2 atomic (this can be worth up to a 6x +// speedup, see https://github.com/pytorch/pytorch/pull/21879). +// We can convert a __half atomic into a __half2 atomic by simply +// pairing the __half with a zero entry on the left/right depending +// on alignment... but only if this wouldn't cause an out of bounds +// access! Thus, you must specify tensor and numel so we can check +// if you would be out-of-bounds and use a plain __half atomic if +// you would be. +template < + typename scalar_t, + typename index_t, + typename std::enable_if_t>* = + nullptr> +__device__ __forceinline__ void fastSpecializedAtomicAdd( + scalar_t* tensor, + index_t index, + const index_t numel, + scalar_t value) { +#if ( \ + (defined(USE_ROCM) && ROCM_VERSION < 60201) || \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700))) + gpuAtomicAddNoReturn( + reinterpret_cast(tensor) + index, + static_cast(value)); +#else + // Accounts for the chance tensor falls on an odd 16 bit alignment (ie, not 32 bit aligned) + __half* target_addr = reinterpret_cast<__half*>(tensor + index); + bool low_byte = (reinterpret_cast(target_addr) % sizeof(__half2) == 0); + + if (low_byte && index < (numel - 1)) { + __half2 value2; + value2.x = static_cast<__half>(value); + value2.y = __int2half_rz(0); + ATOMICADD(reinterpret_cast<__half2*>(target_addr), value2); + + } else if (!low_byte && index > 0) { + __half2 value2; + value2.x = __int2half_rz(0); + value2.y = static_cast<__half>(value); + ATOMICADD(reinterpret_cast<__half2*>(target_addr - 1), value2); + + } else { +#ifdef USE_ROCM + gpuAtomicAddNoReturn( + reinterpret_cast(tensor) + index, static_cast(value)); +#else + atomicAdd( + reinterpret_cast<__half*>(tensor) + index, static_cast<__half>(value)); +#endif + } +#endif +} + +template < + typename scalar_t, + typename index_t, + typename std::enable_if_t>* = + nullptr> +__device__ __forceinline__ void fastSpecializedAtomicAdd( + scalar_t* tensor, + index_t index, + const index_t numel, + scalar_t value) { +#if ( \ + (defined(USE_ROCM) && ROCM_VERSION < 60201) || \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))) + gpuAtomicAddNoReturn( + reinterpret_cast(tensor) + index, + static_cast(value)); +#else + // Accounts for the chance tensor falls on an odd 16 bit alignment (ie, not 32 bit aligned) + __nv_bfloat16* target_addr = reinterpret_cast<__nv_bfloat16*>(tensor + index); + bool low_byte = (reinterpret_cast(target_addr) % sizeof(__nv_bfloat162) == 0); + + if (low_byte && index < (numel - 1)) { + __nv_bfloat162 value2; + value2.x = *reinterpret_cast<__nv_bfloat16*>(&value); + value2.y = NATIVE_ZERO_BF16; + ATOMICADD(reinterpret_cast<__nv_bfloat162*>(target_addr), value2); + + } else if (!low_byte && index > 0) { + __nv_bfloat162 value2; + value2.x = NATIVE_ZERO_BF16; + value2.y = *reinterpret_cast<__nv_bfloat16*>(&value); + ATOMICADD(reinterpret_cast<__nv_bfloat162*>(target_addr - 1), value2); + + } else { +#ifdef USE_ROCM + gpuAtomicAddNoReturn( + reinterpret_cast(tensor) + index, static_cast(value)); +#else + atomicAdd( + reinterpret_cast<__nv_bfloat16*>(tensor) + index, *reinterpret_cast<__nv_bfloat16*>(&value)); +#endif + } +#endif +} + + +template < + typename scalar_t, + typename index_t, + typename std::enable_if_t && !std::is_same_v>* = + nullptr> +__device__ __forceinline__ void fastSpecializedAtomicAdd( + scalar_t* tensor, + index_t index, + const index_t numel, + scalar_t value) { + gpuAtomicAddNoReturn(tensor + index, value); +} + +template +__device__ __forceinline__ void fastAtomicAdd( + scalar_t* tensor, + index_t index, + const index_t numel, + scalar_t value, + bool fast_atomics) { + if (fast_atomics) { + fastSpecializedAtomicAdd(tensor, index, numel, value); + } else { + gpuAtomicAddNoReturn(tensor + index, value); + } +} + +#if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__)) +// This function implements warp-level opportunistic fastatomics +// To reduce contention on an atomicAdd, this replaces per-thread atomicAdd with a per-warp atomicAdd. +// We identify all the threads within a warp that will perform an atomicAdd on the same destination +// address and perform the addition on the CU. Each warp elects a leader thread which does the +// atomicAdd to the destination address. +template +__device__ __forceinline__ void opportunistic_fastAtomicAdd( + scalar_t* self_ptr, + index_t index, + const index_t numel, + scalar_t value) { + + scalar_t* dst = self_ptr + index; + + //pack coalseced bf16 and fp16 + if constexpr (std::is_same::value || std::is_same::value) + { + typedef unsigned short __attribute__((ext_vector_type(2))) vec_short2; + union ill { unsigned int i[2]; int64_t il; }; + ill iil_, ill_oneUpDst = {}; + iil_.il = (int64_t)dst; + ill_oneUpDst.i[0] = __builtin_amdgcn_mov_dpp(iil_.i[0], 0x130, 0xf, 0xf, 0); + ill_oneUpDst.i[1] = __builtin_amdgcn_mov_dpp(iil_.i[1], 0x130, 0xf, 0xf, 0); + union bfi {scalar_t bf; short s; } bfi_ = { .bf = value }; bfi bfi_oneUpVal; + + bfi_oneUpVal.s = __builtin_amdgcn_mov_dpp(bfi_.s, 0x130, 0xf, 0xf, 0); + auto oneUpVal = bfi_oneUpVal.bf; + + __half* target_addr = reinterpret_cast<__half*>(self_ptr + index); + bool low_byte = (reinterpret_cast(target_addr) % sizeof(__half2) == 0); + bool canCombnUp = (bool)(__activemask()&(1<<(threadIdx.x+1))) && + (low_byte && index < (numel - 1)) && + (ill_oneUpDst.il - reinterpret_cast(dst) == sizeof(scalar_t)); + bool canCombnDn = (__builtin_amdgcn_mov_dpp(canCombnUp, 0x138, 0xf, 0xf, 0)); + + if (__lane_id()%2==0) + { + if (canCombnUp) { + typedef _Float16 __attribute__((ext_vector_type(2))) vec_fp162; + union bfvs { scalar_t bf[2]; vec_short2 vs2; vec_fp162 df16; }; + bfvs bfvs_ = {}; + bfvs_.bf[0] = value; + bfvs_.bf[1] = oneUpVal; + if constexpr (std::is_same::value) + __builtin_amdgcn_flat_atomic_fadd_v2bf16((vec_short2*)dst, bfvs_.vs2); + else + __builtin_amdgcn_flat_atomic_fadd_v2f16((__half2*)dst, bfvs_.df16); + return; + } + } + else + { + if (canCombnDn) + return; + } + } + + // not coalsced, so now let try to capture lane-matches... + // __activemask() -- finds the set of threads in the warp that are about to perform atomicAdd + // __match_any_sync() -- returns bit mask of the threads that have same dest addr + auto mask = __match_any_sync(__activemask(), (int64_t)dst); + + // select a leader thread + int leader = __ffsll(mask) - 1; + + scalar_t crnt_val = (scalar_t)0; + auto crnt_msk = mask >> (leader); + int crnt_idx = leader; + + // __shfl is limited in the dtypes it accepts + // That's why, we need these if/else to correctly do the addition on the CU + if constexpr(sizeof(scalar_t) <= sizeof(int)) { + union punner { int l; scalar_t s; }; + punner pnr = {}; + pnr.s = value; + while (crnt_msk != 0) { + if (crnt_msk & 1) { + punner add_val = {}; + add_val.l = __shfl(pnr.l ,crnt_idx); + crnt_val += add_val.s; + } + crnt_idx++; + crnt_msk = crnt_msk >> 1; + } + } + else if constexpr(sizeof(scalar_t) <= sizeof(long)) { + union punner { long l; scalar_t s; }; + punner pnr = {}; + pnr.s = value; + while (crnt_msk != 0) { + if (crnt_msk & 1) { + punner add_val = {}; + add_val.l = __shfl(pnr.l ,crnt_idx); + crnt_val += add_val.s; + } + crnt_idx++; + crnt_msk = crnt_msk >> 1; + } + } + else if constexpr(sizeof(scalar_t) <= sizeof(long long)) { + union punner { long long l; scalar_t s; }; + punner pnr = {}; + pnr.s = value; + while (crnt_msk != 0) { + if (crnt_msk & 1) { + punner add_val = {}; + add_val.l = __shfl(pnr.l ,crnt_idx); + crnt_val += add_val.s; + } + crnt_idx++; + crnt_msk = crnt_msk >> 1; + } + } + else { + union punner { long long l[2]; scalar_t s; }; + punner pnr = {}; + pnr.s = value; + while (crnt_msk != 0) { + if (crnt_msk & 1) { + punner add_val = {}; + add_val.l[0] = __shfl(pnr.l[0] ,crnt_idx); + add_val.l[1] = __shfl(pnr.l[1] ,crnt_idx); + crnt_val += add_val.s; + } + crnt_idx++; + crnt_msk = crnt_msk >> 1; + } + } + + + //Once the correct crnt_val is determined, only the leader thread does the update to the dest addr + if (__lane_id() == leader) { + fastAtomicAdd(self_ptr, index, numel, crnt_val, true); + } +} +#endif + +#undef ATOMICADD +#undef NATIVE_ZERO_BF16 + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/LaunchUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/LaunchUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..027eb5877c59ca674d151cce67602b0e3a52bb43 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/LaunchUtils.h @@ -0,0 +1,16 @@ +#pragma once +#include + +namespace at::native { + +// returns 2**floor(log2(n)) +static int lastPow2(unsigned int n) { + n |= (n >> 1); + n |= (n >> 2); + n |= (n >> 4); + n |= (n >> 8); + n |= (n >> 16); + return std::max(1, n - (n >> 1)); +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/Loops.cuh b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/Loops.cuh new file mode 100644 index 0000000000000000000000000000000000000000..e6a03252fa5eb6c505be3a44650e154c59668b85 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/Loops.cuh @@ -0,0 +1,330 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +#include + +#include + +namespace at::native { + +template +static OffsetCalculator make_input_offset_calculator(const TensorIteratorBase& iter) { + // array size can not be 0, this happens when N == 0 + constexpr int array_size = std::max(N, 1); + TORCH_INTERNAL_ASSERT(N == iter.ntensors() - iter.noutputs()); + std::array strides; + int64_t element_sizes[array_size]; + for (int i = 0; i < N; i++) { + strides[i] = iter.strides(i + iter.noutputs()).data(); + element_sizes[i] = iter.element_size(i + iter.noutputs()); + } + return OffsetCalculator(iter.ndim(), iter.shape().data(), strides.data(), element_sizes); +} + +template +static OffsetCalculator make_output_offset_calculator(const TensorIteratorBase& iter) { + TORCH_INTERNAL_ASSERT(num_outputs == iter.noutputs()); + std::array strides; + int64_t element_sizes[num_outputs]; + for (int i = 0; i < num_outputs; i++) { + strides[i] = iter.strides(i).data(); + element_sizes[i] = iter.element_size(i); + } + return OffsetCalculator(iter.ndim(), iter.shape().data(), strides.data(), element_sizes); +} + +template +__device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) { + using traits = function_traits; + using return_t = typename traits::result_type; + using args_t = typename traits::ArgsTuple; + constexpr int elems_per_thread = policy_t::tws; + + int idx = blockIdx.x; + if constexpr (reverted_idx) + idx = gridDim.x - blockIdx.x - 1; + + return_t results[elems_per_thread]; + args_t args[elems_per_thread]; + + // load + policy.load(args, idx); + + // compute + #pragma unroll + for (int i = 0; i < elems_per_thread; i++) { + if (policy.check_inbounds(i)) { + results[i] = c10::guts::apply(f, args[i]); + } + } + + // store + policy.store(results, idx); +} + +} // namespace at::native + +#include + +namespace at:: native { + +template +void gpu_kernel_nocast(TensorIteratorBase& iter, const func_t& f) { + + for (int arg = 0; arg < iter.ntensors(); arg++) { + TORCH_INTERNAL_ASSERT( + iter.device(arg).is_cuda(), + "argument ", arg, ": expected a CUDA device but found ", iter.device(arg)); + } + + if (iter.numel() == 0) { + return; + } + + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + gpu_kernel_nocast(sub_iter, f); + } + return; + } + + gpu_kernel_impl_nocast(iter, f); +} + +template +void gpu_kernel(TensorIteratorBase& iter, const func_t& f) { + + for (int arg = 0; arg < iter.ntensors(); arg++) { + TORCH_INTERNAL_ASSERT( + iter.device(arg).is_cuda(), + "argument ", arg, ": expected a CUDA device but found ", iter.device(arg)); + } + + if (iter.numel() == 0) { + return; + } + + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + gpu_kernel(sub_iter, f); + } + return; + } + + gpu_kernel_impl(iter, f); +} + +template +struct AUnaryFunctor { + using traits = function_traits; + using opmath_arg1_t = typename traits::template arg<0>::type; + __device__ return_t operator()(arg2_t b) const { + return f(a, b); + } + // NB: scalar is stored in higher precision! + AUnaryFunctor(func_t f_, opmath_arg1_t a_): f(f_), a(a_) {} + private: + func_t f; + opmath_arg1_t a; +}; + +template +struct BUnaryFunctor { + using traits = function_traits; + using opmath_arg2_t = typename traits::template arg<1>::type; + __device__ return_t operator()(arg1_t a) const { + return f(a, b); + } + // NB: scalar is stored in higher precision! + BUnaryFunctor(func_t f_, opmath_arg2_t b_): f(f_), b(b_) {} + private: + func_t f; + opmath_arg2_t b; +}; + +// Though seemingly noop, this inserts casts from arg1_t to func_t's type +// (which may be higher precision), as well as casts to return_t +template +struct BinaryFunctor { + __device__ return_t operator()(arg1_t a, arg2_t b) const { + return f(a, b); + } + BinaryFunctor(func_t f_): f(f_) {} + private: + func_t f; +}; + +// Unlike gpu_kernel_with_scalars, this allows you to pass a func_t which +// accepts inputs at higher precision (typically opmath_t), but then +// ensure that we load from memory at the correct precision (scalar_t) +// to avoid expensive loads. For the whole sordid story see +// https://dev-discuss.pytorch.org/t/cuda-loops-case-study-code-generation-vs-templates/302 +template +void opmath_gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) { + TORCH_INTERNAL_ASSERT(iter.ntensors() == 3); + + using traits = function_traits; + using opmath_arg1_t = typename traits::template arg<0>::type; + using opmath_arg2_t = typename traits::template arg<1>::type; + static_assert( + traits::arity == 2, + "gpu_kernel_with_scalars only supports two input arguments"); + + if (iter.is_cpu_scalar(1)) { + AUnaryFunctor af(f, iter.scalar_value(1)); + iter.remove_operand(1); + // TODO: When all kernels that use gpu_kernel_with_scalars are + // ported to structured, this device guard can be deleted. This + // works around incorrect device guard generation for pre-structured + // kernels device guards, but structured kernels do it right and + // we can assume the device is already set correctly + const OptionalDeviceGuard device_guard(iter.device(1)); + gpu_kernel(iter, af); + } else if (iter.is_cpu_scalar(2)) { + BUnaryFunctor bf(f, iter.scalar_value(2)); + iter.remove_operand(2); + gpu_kernel(iter, bf); + } else { + gpu_kernel(iter, BinaryFunctor(f)); + } +} + +template +void opmath_symmetric_gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) { + // Use symmetric property of the functor to reduce number of kernels, + // requires f(a, b) == f(b, a) + TORCH_INTERNAL_ASSERT(iter.ntensors() == 3); + + using traits = function_traits; + using opmath_arg_t = typename traits::template arg<0>::type; + static_assert( + traits::arity == 2, + "gpu_kernel_with_scalars only supports two input arguments"); + static_assert(std::is_same_v::type>, + "f is not symmetric"); + + OptionalDeviceGuard device_guard; + opmath_arg_t scalar_val{}; + + if (iter.is_cpu_scalar(1)) { + scalar_val = iter.scalar_value(1); + iter.remove_operand(1); + + // TODO: When all kernels that use gpu_kernel_with_scalars are + // ported to structured, this device guard can be deleted. This + // works around incorrect device guard generation for pre-structured + // kernels device guards, but structured kernels do it right and + // we can assume the device is already set correctly + device_guard.reset_device(iter.device(1)); + } else if (iter.is_cpu_scalar(2)) { + scalar_val = iter.scalar_value(2); + iter.remove_operand(2); + } + + if (iter.ninputs() == 2) { + gpu_kernel(iter, BinaryFunctor(f)); + } else { + AUnaryFunctor unary_f(f, scalar_val); + gpu_kernel(iter, unary_f); + } +} + +// Legacy variant that assumes that func_t has the correct types +// that we expect to load from memory +template +void gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) { + using traits = function_traits; + static_assert( + traits::arity == 2, + "gpu_kernel_with_scalars only supports two input arguments"); + using arg1_t = typename traits::template arg<0>::type; + using arg2_t = typename traits::template arg<1>::type; + using return_t = typename traits::result_type; + opmath_gpu_kernel_with_scalars(iter, f); +} + +namespace { // functions for `gpu_kernel_multiple_outputs`. + +// check the return type is `thrust::tuple`, not `std::tuple`. +template struct is_tuple: std::false_type {}; + +template struct is_tuple>: std::true_type {}; + +template +C10_LAUNCH_BOUNDS_1(num_threads()) +__global__ void unrolled_elementwise_kernel_for_multi_outputs(int N, func_t f, array_t data, inp_calc_t ic, out_calc_t oc) { + int remaining = N - block_work_size() * blockIdx.x; + elementwise_kernel_helper(f, memory::policies::multi_outputs_unroll(data, remaining, ic, oc)); +} + +template +static inline void launch_unrolled_kernel_for_multi_outputs(int64_t N, const func_t& f, array_t data, inp_calc_t ic, out_calc_t oc) { + TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); + int64_t grid = (N + block_work_size() - 1) / block_work_size(); + auto stream = at::cuda::getCurrentCUDAStream(); + unrolled_elementwise_kernel_for_multi_outputs<<>>(N, f, data, ic, oc); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void gpu_kernel_multiple_outputs_impl(TensorIteratorBase& iter, const func_t& f) { + using traits = function_traits; + using output_t = typename traits::result_type; + static_assert(is_tuple::value, "f's return type must be `thrust::tuple`"); + constexpr int num_outputs = thrust::tuple_size::value; + constexpr int num_inputs = traits::arity; + constexpr int ntensors = num_outputs + num_inputs; + + TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing()); + TORCH_INTERNAL_ASSERT(iter.ntensors() == ntensors); + + std::array data; + for (int i = 0; i < ntensors; i++) { + data[i] = (char*)iter.data_ptr(i); + } + + int64_t numel = iter.numel(); + + if (iter.is_contiguous()) { + auto input_calc = TrivialOffsetCalculator(); + auto output_calc = TrivialOffsetCalculator(); + launch_unrolled_kernel_for_multi_outputs(numel, f, data, input_calc, output_calc); + } else { + auto input_calc = make_input_offset_calculator(iter); + auto output_calc = make_output_offset_calculator(iter); + launch_unrolled_kernel_for_multi_outputs(numel, f, data, input_calc, output_calc); + } +} +} // namespace + +template +void gpu_kernel_multiple_outputs(TensorIteratorBase& iter, const func_t& f) { + ASSERT_HOST_DEVICE_LAMBDA(func_t); + + for (int arg = 0; arg < iter.ntensors(); arg++) { + TORCH_INTERNAL_ASSERT(iter.device(arg).is_cuda()); + } + + if (iter.numel() == 0) { + return; + } + + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + gpu_kernel_multiple_outputs(sub_iter, f); + } + return; + } + + gpu_kernel_multiple_outputs_impl(iter, f); +} + +} //namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/Math.cuh b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/Math.cuh new file mode 100644 index 0000000000000000000000000000000000000000..ba26226dd646ec8ba4c9865c44b6dc42e6944a30 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/Math.cuh @@ -0,0 +1,3390 @@ +#pragma once + +#include +#include +#include +#include + +namespace at::native { +// See note [Jiterator] +// TODO: elaborate in this comment on the structure of math.cuh +#if AT_USE_JITERATOR() + +const auto ndtri_string = jiterator_stringify( + /* + * This function is derived from the implementation of the digamma function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + * + * Evaluates polynomial of degree N: + * + * 2 N + * y = C + C x + C x +...+ C x + * 0 1 2 N + * + * Coefficients are stored in reverse order: + * + * coef[0] = C , ..., coef[N] = C . + * N 0 + */ + template + T polevl(const T x, const T A[], const int len) { + // NOTE: This `polevl` is different from other `polevl` + // implementation (in PyTorch) which expect the `len` to be + // `len(A) - 1` instead of `len(A)`. + T result = 0; + for (int i = 0; i < len; ++i) { + result = result * x + A[i]; + } + return result; + } + + /* + * This function is derived from the implementation of the i1e function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + * + * Computes the argument, x, for which the area under the Gaussian probability density function + * (integrated from minus infinity to x) is equal to y. + */ + template + T ndtri(T y0) { + + constexpr T zero = 0; + constexpr T one = 1; + + // Handles special cases + if (y0 == zero) { + return NEG_INFINITY; + } + if (y0 == one) { + return POS_INFINITY; + } + if (y0 < zero || y0 > one) { + return NAN; + } + + bool code = true; + T y = y0; + // Note: the constant 0.135... is equal to exp(-2) + if (y > one - T{0.13533528323661269189}) { + y = one - y; + code = false; + } + + if (y > T{0.13533528323661269189}) { + /* approximation for 0 <= |y - 0.5| <= 3/8 */ + static const T P0[5] = { + -5.99633501014107895267E1, + 9.80010754185999661536E1, + -5.66762857469070293439E1, + 1.39312609387279679503E1, + -1.23916583867381258016E0, + }; + + static const T Q0[9] = { + 1.00000000000000000000E0, + 1.95448858338141759834E0, + 4.67627912898881538453E0, + 8.63602421390890590575E1, + -2.25462687854119370527E2, + 2.00260212380060660359E2, + -8.20372256168333339912E1, + 1.59056225126211695515E1, + -1.18331621121330003142E0, + }; + + /* sqrt(2pi) */ + constexpr T s2pi = 2.50662827463100050242E0; + + y = y - T{0.5}; + const T y2 = y * y; + T x = y + y * (y2 * polevl(y2, P0, int{5}) / polevl(y2, Q0, int{9})); + return x * s2pi; + } + + T x = sqrt(T{-2.} * log(y)); + const T x0 = x - (log(x) / x); + + const T z = one / x; + T x1; + + /* y > exp(-32) = 1.2664165549e-14 */ + if (x < T{8.0}) { + /* Approximation for interval z = sqrt(-2 log y ) between 2 and 8 + * i.e., y between exp(-2) = .135 and exp(-32) = 1.27e-14. + */ + static const T P1[9] = { + 4.05544892305962419923E0, + 3.15251094599893866154E1, + 5.71628192246421288162E1, + 4.40805073893200834700E1, + 1.46849561928858024014E1, + 2.18663306850790267539E0, + -1.40256079171354495875E-1, + -3.50424626827848203418E-2, + -8.57456785154685413611E-4, + }; + + static const T Q1[9] = { + 1.00000000000000000000E0, + 1.57799883256466749731E1, + 4.53907635128879210584E1, + 4.13172038254672030440E1, + 1.50425385692907503408E1, + 2.50464946208309415979E0, + -1.42182922854787788574E-1, + -3.80806407691578277194E-2, + -9.33259480895457427372E-4, + }; + + x1 = z * polevl(z, P1, int{9}) / polevl(z, Q1, int{9}); + } else { + /* Approximation for interval z = sqrt(-2 log y ) between 8 and 64 + * i.e., y between exp(-32) = 1.27e-14 and exp(-2048) = 3.67e-890. + */ + static const T P2[9] = { + 3.23774891776946035970E0, + 6.91522889068984211695E0, + 3.93881025292474443415E0, + 1.33303460815807542389E0, + 2.01485389549179081538E-1, + 1.23716634817820021358E-2, + 3.01581553508235416007E-4, + 2.65806974686737550832E-6, + 6.23974539184983293730E-9, + }; + + static const T Q2[9] = { + 1.00000000000000000000E0, + 6.02427039364742014255E0, + 3.67983563856160859403E0, + 1.37702099489081330271E0, + 2.16236993594496635890E-1, + 1.34204006088543189037E-2, + 3.28014464682127739104E-4, + 2.89247864745380683936E-6, + 6.79019408009981274425E-9, + }; + + x1 = z * polevl(z, P2, int{9}) / polevl(z, Q2, int{9}); + } + + x = x0 - x1; + return (!code) ? x : -x; + } +); // ndtri_string + +const auto log_ndtr_string = jiterator_stringify( + template + T log_ndtr(T x) { + constexpr T SQRT1_2{0.707106781186547524400844362104849039}; // 1/sqrt(2) + T t = x * SQRT1_2; + if (x < T{-1.0}) { + return log(erfcx(-t) / 2) - t * t; + } else { + return log1p(-erfc(t) / 2); + } + } +); // log_ndtr_string + +const auto gcd_string = jiterator_stringify( + template + T gcd(const T a_in, const T b_in) { + T a = abs(a_in); + T b = abs(b_in); + + while (a != T{0}) { + T c = a; + a = b % a; + b = c; + } + + return b; + } +); // gcd_string + +const auto lcm_string = jiterator_stringify( + template + T gcd(const T a_in, const T b_in) { + T a = abs(a_in); + T b = abs(b_in); + + while (a != T{0}) { + T c = a; + a = b % a; + b = c; + } + + return b; + } + + template + T lcm(const T a, const T b) { + T g = gcd(a, b); + return (g == T{0}) ? T{0} : abs(a / g * b); + } +); // lcm_string + +/* + * For licensing information, please refer to the cpu implementation located in "ATen/native/Math.h". + */ +// [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma +const auto digamma_string = jiterator_stringify( + template + T digamma(T x) { + static const double PI_f64 = 3.14159265358979323846; + + // Short-circuits if x is +/- 0 and returns -/+ ∞ per the C++ standard + if (x == 0) { + return copysign(POS_INFINITY, -x); + } + + T result = 0; + if (x < 0) { + // Short-circuits if x is a negative integer and returns NaN + // per the C++ standard + const bool x_is_integer = (x == trunc(x)); + if (x_is_integer) { + return NAN; + } + + // Extracts the fractional part of x as r, since tan(pi * r) is more numerically + // accurate than tan(pi * x). While these operations are mathematically equivalent + // since both x and r are in radians and tan() has a periodicity of pi, in practice + // the computation of pi * x is a source of error (when |x| > 1). + double q, r; + r = modf(static_cast(x), &q); + result = - PI_f64 / tan(PI_f64 * r); + x = 1 - x; + } + + while (x < T{10}) { + result -= T{1} / x; + x += T{1}; + } + + if (x == T{10}) { + return result + T{2.25175258906672110764}; + } + + T y = 0; + if (x < T{1.0e17}) { + const T A[] = { + 8.33333333333333333333E-2, + -2.10927960927960927961E-2, + 7.57575757575757575758E-3, + -4.16666666666666666667E-3, + 3.96825396825396825397E-3, + -8.33333333333333333333E-3, + 8.33333333333333333333E-2, + }; + + + T z = T{1} / (x * x); + + T polevl_result = 0; + for (int i = 0; i <= 6; i++) { + polevl_result = polevl_result * z + A[i]; + } + y = z * polevl_result; + } + + return log(x) - (T{0.5} / x) - y + result; + } +); // digamma_string + +/* + * This function is derived from the implementation of the zeta function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + */ +const auto zeta_string = jiterator_stringify( + template + T zeta(T x, T q) { + const T MACHEP{1.11022302462515654042E-16}; + constexpr T zero{0}; + constexpr T half{0.5}; + constexpr T one{1}; + static const T A[] = { + 12.0, + -720.0, + 30240.0, + -1209600.0, + 47900160.0, + -1.8924375803183791606e9, /*1.307674368e12/691*/ + 7.47242496e10, + -2.950130727918164224e12, /*1.067062284288e16/3617*/ + 1.1646782814350067249e14, /*5.109094217170944e18/43867*/ + -4.5979787224074726105e15, /*8.028576626982912e20/174611*/ + 1.8152105401943546773e17, /*1.5511210043330985984e23/854513*/ + -7.1661652561756670113e18 /*1.6938241367317436694528e27/236364091*/ + }; + + int i = 0; + T a, b, k, s, t, w; + + // Short-circuits x -> +infty + if (x == one) { + return POS_INFINITY; + } + + // Short-circuits x < 1 -> NaN + if (x < one) { + return NAN; + } + + // Short-circuits negative q integers map to +infty, + // negative q non-integers map to NaN + if (q <= zero) { + if (q == floor(q)) { + return POS_INFINITY; + } + if (x != floor(x)) { + return NAN; + } + } + + s = pow(q, -x); + a = q; + i = 0; + b = zero; + while ((i < 9) || (a <= T{9.0})) { + i += 1; + a += one; + b = pow(a, -x); + s += b; + if ((-MACHEP * s < b) && (b < MACHEP * s)) { + return s; + } + }; + + w = a; + s += b * w / (x - one); + s -= half * b; + a = one; + k = zero; + for (int i = 0; i < 12; i++) { + a *= x + k; + b /= w; + t = a * b / A[i]; + s = s + t; + t = fabs(t / s); + + if (t < MACHEP) { + return s; + } + + k += one; + a *= x + k; + b /= w; + k += one; + } + + return s; + } +); // zeta_string + +const auto trigamma_string = jiterator_stringify( + template + T trigamma(T x) { + const T PI{3.14159265358979323846}; + T sign = 1; + T result = 0; + + if (x < T{0.5}) { + sign = -1; + T sin_pi_x = sin(PI * x); + result -= (PI * PI) / (sin_pi_x * sin_pi_x); + x = 1 - x; + } + + for (int i = 0; i < 6; ++i) { + result += T{1} / (x * x); + x += 1; + } + + const T one{1}; + const T ixx = one / (x*x); + result += (one + one / (T{2}*x) + ixx * (one/T{6} - ixx * (one/T{30} - ixx * (one/T{42})))) / x; + return sign * result; +} +); // trigamma_string + +const auto lgamma_string = jiterator_stringify( + template + T lgamma_kernel(T a) { + return lgamma(a); + } +); // lgamma_string + +const auto polygamma_string = zeta_string + jiterator_stringify( + template + T polygamma(T x, int n) { + // already blocked if n <= 1 + const auto one = T{1}; + return ((n % 2) ? one : -one) * exp(lgamma(static_cast(n) + one)) * + zeta(static_cast(n + 1), x); + } +); // polygamma_string + +const auto exp2_string = jiterator_stringify( + template + T exp2_impl(T a) { + return exp2(a); + } + + namespace std { template class complex; } + template + std::complex exp2_impl(std::complex x) { + // There is no std::exp2 overload for complex, so instead + // use the identity 2^x = e^(ln(2) * x) + const auto ln_2 = static_cast(0.693147180559945309417232121458176); + return exp(ln_2 * x); + } + + template + T exp2_kernel(T a) { + return exp2_impl(a); + } +); // exp2_string + +const auto erfc_string = jiterator_stringify( + template + T erfc_kernel(T a) { + return erfc(a); + } +); // erfc_string + +const auto erfinv_string = jiterator_stringify( + template + T erfinv_kernel(T a) { + return erfinv(a); + } +); // erfinv_string + +const auto entr_string = jiterator_stringify( + template + T entr(T a) { + if (a != a) { + return a; + } + + if (a > 0) { + return -a * log(a); + } + + if (a == 0) { + return 0; + } + + return NEG_INFINITY; + } +); // entr_string + +// NOTE: `kaiser_window_string` depends on `i0_string` +// for its implementation. +const auto i0_string = jiterator_stringify( + template + T chbevl(T x, const T array[], const int len) { + + T b0, b1, b2; + + b0 = array[0]; + b1 = 0; + + for (int i = 1; i < len; ++i) { + b2 = b1; + b1 = b0; + b0 = x * b1 - b2 + array[i]; + } + + return T{0.5} * (b0 - b2); + } + + template + T i0(T _x) { + T x = fabs(_x); + + if (x <= T{8.0}) { + /* Chebyshev coefficients for exp(-x) I0(x) + * in the interval [0,8]. + * + * lim(x->0){ exp(-x) I0(x) } = 1. + */ + static const T A[] = { + -4.41534164647933937950E-18, 3.33079451882223809783E-17, + -2.43127984654795469359E-16, 1.71539128555513303061E-15, + -1.16853328779934516808E-14, 7.67618549860493561688E-14, + -4.85644678311192946090E-13, 2.95505266312963983461E-12, + -1.72682629144155570723E-11, 9.67580903537323691224E-11, + -5.18979560163526290666E-10, 2.65982372468238665035E-9, + -1.30002500998624804212E-8, 6.04699502254191894932E-8, + -2.67079385394061173391E-7, 1.11738753912010371815E-6, + -4.41673835845875056359E-6, 1.64484480707288970893E-5, + -5.75419501008210370398E-5, 1.88502885095841655729E-4, + -5.76375574538582365885E-4, 1.63947561694133579842E-3, + -4.32430999505057594430E-3, 1.05464603945949983183E-2, + -2.37374148058994688156E-2, 4.93052842396707084878E-2, + -9.49010970480476444210E-2, 1.71620901522208775349E-1, + -3.04682672343198398683E-1, 6.76795274409476084995E-1}; + + T y = (x / T{2.0}) - T{2.0}; + return exp(x) * chbevl(y, A, int{30}); + } + + // Handles x > 8 case + /* Chebyshev coefficients for exp(-x) sqrt(x) I0(x) + * in the inverted interval [8,infinity]. + * + * lim(x->inf){ exp(-x) sqrt(x) I0(x) } = 1/sqrt(2pi). + */ + const T B[] = { + -7.23318048787475395456E-18, -4.83050448594418207126E-18, + 4.46562142029675999901E-17, 3.46122286769746109310E-17, + -2.82762398051658348494E-16, -3.42548561967721913462E-16, + 1.77256013305652638360E-15, 3.81168066935262242075E-15, + -9.55484669882830764870E-15, -4.15056934728722208663E-14, + 1.54008621752140982691E-14, 3.85277838274214270114E-13, + 7.18012445138366623367E-13, -1.79417853150680611778E-12, + -1.32158118404477131188E-11, -3.14991652796324136454E-11, + 1.18891471078464383424E-11, 4.94060238822496958910E-10, + 3.39623202570838634515E-9, 2.26666899049817806459E-8, + 2.04891858946906374183E-7, 2.89137052083475648297E-6, + 6.88975834691682398426E-5, 3.36911647825569408990E-3, + 8.04490411014108831608E-1}; + + return (exp(x) * chbevl(T{32.0} / x - T{2.0}, B, int{25})) / sqrt(x); + } +); // i0_string + +const auto i1_string = jiterator_stringify( + template + T chbevl(const T x, const T array[], const int len) { + T b0, b1, b2; + + b0 = array[0]; + b1 = 0; + + for (int i = 1; i < len; ++i) { + b2 = b1; + b1 = b0; + b0 = x * b1 - b2 + array[i]; + } + + return T{0.5} * (b0 - b2); + } + + template + T i1(T _x) { + const T x = fabs(_x); + + if (x <= T{8.0}) { + // Chebyshev coefficients for exp(-x) i1(x) in the internal [0, 8] + // lim(x->0){ exp(-x) i1(x) / x } = 1/2 + static const T coefficients[] = { + 2.77791411276104639959E-18, -2.11142121435816608115E-17, + 1.55363195773620046921E-16, -1.10559694773538630805E-15, + 7.60068429473540693410E-15, -5.04218550472791168711E-14, + 3.22379336594557470981E-13, -1.98397439776494371520E-12, + 1.17361862988909016308E-11, -6.66348972350202774223E-11, + 3.62559028155211703701E-10, -1.88724975172282928790E-9, + 9.38153738649577178388E-9, -4.44505912879632808065E-8, + 2.00329475355213526229E-7, -8.56872026469545474066E-7, + 3.47025130813767847674E-6, -1.32731636560394358279E-5, + 4.78156510755005422638E-5, -1.61760815825896745588E-4, + 5.12285956168575772895E-4, -1.51357245063125314899E-3, + 4.15642294431288815669E-3, -1.05640848946261981558E-2, + 2.47264490306265168283E-2, -5.29459812080949914269E-2, + 1.02643658689847095384E-1, -1.76416518357834055153E-1, + 2.52587186443633654823E-1}; + const T y = x / T{2.0} - T{2.0}; + const T out = exp(x) * x * chbevl(y, coefficients, int{29}); + return (_x < T{0.0}) ? -out : out; + } + + // Chebyshev coefficients for exp(-x) sqrt(x) i1(x) + // in the inverted interval [8, infinity] + // lim(x->inf){ exp(-x) sqrt(x) i1(x) } = 1/sqrt(2pi) + static const T coefficients[] = { + 7.51729631084210481353E-18, 4.41434832307170791151E-18, + -4.65030536848935832153E-17, -3.20952592199342395980E-17, + 2.96262899764595013876E-16, 3.30820231092092828324E-16, + -1.88035477551078244854E-15, -3.81440307243700780478E-15, + 1.04202769841288027642E-14, 4.27244001671195135429E-14, + -2.10154184277266431302E-14, -4.08355111109219731823E-13, + -7.19855177624590851209E-13, 2.03562854414708950722E-12, + 1.41258074366137813316E-11, 3.25260358301548823856E-11, + -1.89749581235054123450E-11, -5.58974346219658380687E-10, + -3.83538038596423702205E-9, -2.63146884688951950684E-8, + -2.51223623787020892529E-7, -3.88256480887769039346E-6, + -1.10588938762623716291E-4, -9.76109749136146840777E-3, + 7.78576235018280120474E-1}; + const T out = (exp(x) * chbevl(T{32.} / x - T{2.}, coefficients, int{25})) / sqrt(x); + return (_x < T{0.}) ? -out : out; + } +); // i1_string + +const auto i1e_string = jiterator_stringify( + template + T chbevl(const T x, const T array[], const int len) { + T b0, b1, b2; + + b0 = array[0]; + b1 = 0; + + for (int i = 1; i < len; ++i) { + b2 = b1; + b1 = b0; + b0 = x * b1 - b2 + array[i]; + } + + return T{0.5} * (b0 - b2); + } + + // See double and float instantiations below + template + T i1e(T _x) { } + + // Double specialization (uses different coefficients than the float version) + template<> + double i1e(double _x) { + const double x = fabs(_x); + if (x <= double{8.}) { + // Chebyshev double coefficients for exp(-x) i1(x) in the interval [0,8]. + // Note: lim(x->0){ exp(-x) i1(x) / x } = 1/2. + static const double coefficients[] = { + 2.77791411276104639959E-18, -2.11142121435816608115E-17, + 1.55363195773620046921E-16, -1.10559694773538630805E-15, + 7.60068429473540693410E-15, -5.04218550472791168711E-14, + 3.22379336594557470981E-13, -1.98397439776494371520E-12, + 1.17361862988909016308E-11, -6.66348972350202774223E-11, + 3.62559028155211703701E-10, -1.88724975172282928790E-9, + 9.38153738649577178388E-9, -4.44505912879632808065E-8, + 2.00329475355213526229E-7, -8.56872026469545474066E-7, + 3.47025130813767847674E-6, -1.32731636560394358279E-5, + 4.78156510755005422638E-5, -1.61760815825896745588E-4, + 5.12285956168575772895E-4, -1.51357245063125314899E-3, + 4.15642294431288815669E-3, -1.05640848946261981558E-2, + 2.47264490306265168283E-2, -5.29459812080949914269E-2, + 1.02643658689847095384E-1, -1.76416518357834055153E-1, + 2.52587186443633654823E-1}; + const double y = x / double{2.} - double{2.}; + const double out = chbevl(y, coefficients, int{29}) * x; + return (_x < 0.) ? -out : out; + } + + // Chebyshev coefficients for exp(-x) sqrt(x) i1(x) + // in the inverted interval (8, infinity]. + // Note: lim(x->inf){ exp(-x) sqrt(x) i1(x) } = 1/sqrt(2pi). + // TODO: what's an "inverted interval"? Open on the left + // and closed on the right? + static const double coefficients[] = { + 7.51729631084210481353E-18, 4.41434832307170791151E-18, + -4.65030536848935832153E-17, -3.20952592199342395980E-17, + 2.96262899764595013876E-16, 3.30820231092092828324E-16, + -1.88035477551078244854E-15, -3.81440307243700780478E-15, + 1.04202769841288027642E-14, 4.27244001671195135429E-14, + -2.10154184277266431302E-14, -4.08355111109219731823E-13, + -7.19855177624590851209E-13, 2.03562854414708950722E-12, + 1.41258074366137813316E-11, 3.25260358301548823856E-11, + -1.89749581235054123450E-11, -5.58974346219658380687E-10, + -3.83538038596423702205E-9, -2.63146884688951950684E-8, + -2.51223623787020892529E-7, -3.88256480887769039346E-6, + -1.10588938762623716291E-4, -9.76109749136146840777E-3, + 7.78576235018280120474E-1}; + + const double out = chbevl(double{32.} / x - double{2.}, coefficients, int{25}) / sqrt(x); + return (_x < double{0.}) ? -out : out; + } + + // Float specialization (uses different coefficients than the double version) + template<> + float i1e(float _x) { + const float x = fabsf(_x); + if (x <= float{8.}) { + // Chebyshev double coefficients for exp(-x) i1(x) in the interval [0,8]. + // Note: lim(x->0){ exp(-x) i1(x) / x } = 1/2. + static const float coefficients[] = { + 9.38153738649577178388E-9f, + -4.44505912879632808065E-8f, + 2.00329475355213526229E-7f, + -8.56872026469545474066E-7f, + 3.47025130813767847674E-6f, + -1.32731636560394358279E-5f, + 4.78156510755005422638E-5f, + -1.61760815825896745588E-4f, + 5.12285956168575772895E-4f, + -1.51357245063125314899E-3f, + 4.15642294431288815669E-3f, + -1.05640848946261981558E-2f, + 2.47264490306265168283E-2f, + -5.29459812080949914269E-2f, + 1.02643658689847095384E-1f, + -1.76416518357834055153E-1f, + 2.52587186443633654823E-1f}; + const float y = x / float{2.} - float{2.}; + const float out = chbevl(y, coefficients, int{17}) * x; + return (_x < 0.) ? -out : out; + } + + // Chebyshev coefficients for exp(-x) sqrt(x) i1(x) + // in the inverted interval (8, infinity]. + // Note: lim(x->inf){ exp(-x) sqrt(x) i1(x) } = 1/sqrt(2pi). + // TODO: what's an "inverted interval"? Open on the left + // and closed on the right? + static const float coefficients[] = { + -3.83538038596423702205E-9f, + -2.63146884688951950684E-8f, + -2.51223623787020892529E-7f, + -3.88256480887769039346E-6f, + -1.10588938762623716291E-4f, + -9.76109749136146840777E-3f, + 7.78576235018280120474E-1f}; + + const float out = chbevl(float{32.} / x - float{2.}, coefficients, int{7}) / sqrt(x); + return (_x < float{0.}) ? -out : out; + } +); // i1e_string + +const auto kaiser_window_string = i0_string + jiterator_stringify( + template + T kaiser_window(T a, T inv_alpha, T beta, T inv_i0_beta) { + T x = a * inv_alpha - T{1}; + T y = max(T{0}, T{1} - x * x); + return i0(beta * sqrt(y)) * inv_i0_beta; + } +); // kaiser_window_string + +const auto sinc_string = jiterator_stringify( + template + T sinc(T a) { + if (a == T(0)) { + return T(1); + } + constexpr T pi = T(3.14159265358979323846L); + T product = pi * a; + return std::sin(product) / product; + } +); // sinc_string + +const auto erfcx_string = jiterator_stringify( + /* The next function is taken from http://ab-initio.mit.edu/faddeeva */ + + /* Copyright (c) 2012 Massachusetts Institute of Technology + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files (the + * "Software"), to deal in the Software without restriction, including + * without limitation the rights to use, copy, modify, merge, publish, + * distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to + * the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE + * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION + * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION + * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + + /* erfcx(x) = exp(x^2) erfc(x) function, for real x, written by + Steven G. Johnson, October 2012. + + This function combines a few different ideas. + + First, for x > 50, it uses a continued-fraction expansion (same as + for the Faddeeva function, but with algebraic simplifications for z=i*x). + + Second, for 0 <= x <= 50, it uses Chebyshev polynomial approximations, + but with two twists: + + a) It maps x to y = 4 / (4+x) in [0,1]. This simple transformation, + inspired by a similar transformation in the octave-forge/specfun + erfcx by Soren Hauberg, results in much faster Chebyshev convergence + than other simple transformations I have examined. + + b) Instead of using a single Chebyshev polynomial for the entire + [0,1] y interval, we break the interval up into 100 equal + subintervals, with a switch/lookup table, and use much lower + degree Chebyshev polynomials in each subinterval. This greatly + improves performance in my tests. + + For x < 0, we use the relationship erfcx(-x) = 2 exp(x^2) - erfc(x), + with the usual checks for overflow etcetera. + + Performance-wise, it seems to be substantially faster than either + the SLATEC DERFC function [or an erfcx function derived therefrom] + or Cody's CALERF function (from netlib.org/specfun), while + retaining near machine precision in accuracy. + */ + + /* Given y100 = 100 * y, where y = 4 / (4 + x) for x >= 0, compute erfc(x). + + Uses a look-up table of 100 different Chebyshev polynomials + for y intervals [0,0.01], [0.01,0.02], ...., [0.99,1], generated + with the help of Maple and a little shell script. This allows + the Chebyshev polynomials to be of significantly lower degree (about 1/4) + compared to fitting the whole [0,1] interval with a single polynomial. + */ + + // TODO: review if this is computing in double when given a float input + template + T erfcx_y100(T y100) { + switch (static_cast(y100)) { + case 0: { + T t = 2*y100 - 1; + return 0.70878032454106438663e-3 + (0.71234091047026302958e-3 + (0.35779077297597742384e-5 + (0.17403143962587937815e-7 + (0.81710660047307788845e-10 + (0.36885022360434957634e-12 + 0.15917038551111111111e-14 * t) * t) * t) * t) * t) * t; + } + case 1: { + T t = 2*y100 - 3; + return 0.21479143208285144230e-2 + (0.72686402367379996033e-3 + (0.36843175430938995552e-5 + (0.18071841272149201685e-7 + (0.85496449296040325555e-10 + (0.38852037518534291510e-12 + 0.16868473576888888889e-14 * t) * t) * t) * t) * t) * t; + } + case 2: { + T t = 2*y100 - 5; + return 0.36165255935630175090e-2 + (0.74182092323555510862e-3 + (0.37948319957528242260e-5 + (0.18771627021793087350e-7 + (0.89484715122415089123e-10 + (0.40935858517772440862e-12 + 0.17872061464888888889e-14 * t) * t) * t) * t) * t) * t; + } + case 3: { + T t = 2*y100 - 7; + return 0.51154983860031979264e-2 + (0.75722840734791660540e-3 + (0.39096425726735703941e-5 + (0.19504168704300468210e-7 + (0.93687503063178993915e-10 + (0.43143925959079664747e-12 + 0.18939926435555555556e-14 * t) * t) * t) * t) * t) * t; + } + case 4: { + T t = 2*y100 - 9; + return 0.66457513172673049824e-2 + (0.77310406054447454920e-3 + (0.40289510589399439385e-5 + (0.20271233238288381092e-7 + (0.98117631321709100264e-10 + (0.45484207406017752971e-12 + 0.20076352213333333333e-14 * t) * t) * t) * t) * t) * t; + } + case 5: { + T t = 2*y100 - 11; + return 0.82082389970241207883e-2 + (0.78946629611881710721e-3 + (0.41529701552622656574e-5 + (0.21074693344544655714e-7 + (0.10278874108587317989e-9 + (0.47965201390613339638e-12 + 0.21285907413333333333e-14 * t) * t) * t) * t) * t) * t; + } + case 6: { + T t = 2*y100 - 13; + return 0.98039537275352193165e-2 + (0.80633440108342840956e-3 + (0.42819241329736982942e-5 + (0.21916534346907168612e-7 + (0.10771535136565470914e-9 + (0.50595972623692822410e-12 + 0.22573462684444444444e-14 * t) * t) * t) * t) * t) * t; + } + case 7: { + T t = 2*y100 - 15; + return 0.11433927298290302370e-1 + (0.82372858383196561209e-3 + (0.44160495311765438816e-5 + (0.22798861426211986056e-7 + (0.11291291745879239736e-9 + (0.53386189365816880454e-12 + 0.23944209546666666667e-14 * t) * t) * t) * t) * t) * t; + } + case 8: { + T t = 2*y100 - 17; + return 0.13099232878814653979e-1 + (0.84167002467906968214e-3 + (0.45555958988457506002e-5 + (0.23723907357214175198e-7 + (0.11839789326602695603e-9 + (0.56346163067550237877e-12 + 0.25403679644444444444e-14 * t) * t) * t) * t) * t) * t; + } + case 9: { + T t = 2*y100 - 19; + return 0.14800987015587535621e-1 + (0.86018092946345943214e-3 + (0.47008265848816866105e-5 + (0.24694040760197315333e-7 + (0.12418779768752299093e-9 + (0.59486890370320261949e-12 + 0.26957764568888888889e-14 * t) * t) * t) * t) * t) * t; + } + case 10: { + T t = 2*y100 - 21; + return 0.16540351739394069380e-1 + (0.87928458641241463952e-3 + (0.48520195793001753903e-5 + (0.25711774900881709176e-7 + (0.13030128534230822419e-9 + (0.62820097586874779402e-12 + 0.28612737351111111111e-14 * t) * t) * t) * t) * t) * t; + } + case 11: { + T t = 2*y100 - 23; + return 0.18318536789842392647e-1 + (0.89900542647891721692e-3 + (0.50094684089553365810e-5 + (0.26779777074218070482e-7 + (0.13675822186304615566e-9 + (0.66358287745352705725e-12 + 0.30375273884444444444e-14 * t) * t) * t) * t) * t) * t; + } + case 12: { + T t = 2*y100 - 25; + return 0.20136801964214276775e-1 + (0.91936908737673676012e-3 + (0.51734830914104276820e-5 + (0.27900878609710432673e-7 + (0.14357976402809042257e-9 + (0.70114790311043728387e-12 + 0.32252476000000000000e-14 * t) * t) * t) * t) * t) * t; + } + case 13: { + T t = 2*y100 - 27; + return 0.21996459598282740954e-1 + (0.94040248155366777784e-3 + (0.53443911508041164739e-5 + (0.29078085538049374673e-7 + (0.15078844500329731137e-9 + (0.74103813647499204269e-12 + 0.34251892320000000000e-14 * t) * t) * t) * t) * t) * t; + } + case 14: { + T t = 2*y100 - 29; + return 0.23898877187226319502e-1 + (0.96213386835900177540e-3 + (0.55225386998049012752e-5 + (0.30314589961047687059e-7 + (0.15840826497296335264e-9 + (0.78340500472414454395e-12 + 0.36381553564444444445e-14 * t) * t) * t) * t) * t) * t; + } + case 15: { + T t = 2*y100 - 31; + return 0.25845480155298518485e-1 + (0.98459293067820123389e-3 + (0.57082915920051843672e-5 + (0.31613782169164830118e-7 + (0.16646478745529630813e-9 + (0.82840985928785407942e-12 + 0.38649975768888888890e-14 * t) * t) * t) * t) * t) * t; + } + case 16: { + T t = 2*y100 - 33; + return 0.27837754783474696598e-1 + (0.10078108563256892757e-2 + (0.59020366493792212221e-5 + (0.32979263553246520417e-7 + (0.17498524159268458073e-9 + (0.87622459124842525110e-12 + 0.41066206488888888890e-14 * t) * t) * t) * t) * t) * t; + } + case 17: { + T t = 2*y100 - 35; + return 0.29877251304899307550e-1 + (0.10318204245057349310e-2 + (0.61041829697162055093e-5 + (0.34414860359542720579e-7 + (0.18399863072934089607e-9 + (0.92703227366365046533e-12 + 0.43639844053333333334e-14 * t) * t) * t) * t) * t) * t; + } + case 18: { + T t = 2*y100 - 37; + return 0.31965587178596443475e-1 + (0.10566560976716574401e-2 + (0.63151633192414586770e-5 + (0.35924638339521924242e-7 + (0.19353584758781174038e-9 + (0.98102783859889264382e-12 + 0.46381060817777777779e-14 * t) * t) * t) * t) * t) * t; + } + case 19: { + T t = 2*y100 - 39; + return 0.34104450552588334840e-1 + (0.10823541191350532574e-2 + (0.65354356159553934436e-5 + (0.37512918348533521149e-7 + (0.20362979635817883229e-9 + (0.10384187833037282363e-11 + 0.49300625262222222221e-14 * t) * t) * t) * t) * t) * t; + } + case 20: { + T t = 2*y100 - 41; + return 0.36295603928292425716e-1 + (0.11089526167995268200e-2 + (0.67654845095518363577e-5 + (0.39184292949913591646e-7 + (0.21431552202133775150e-9 + (0.10994259106646731797e-11 + 0.52409949102222222221e-14 * t) * t) * t) * t) * t) * t; + } + case 21: { + T t = 2*y100 - 43; + return 0.38540888038840509795e-1 + (0.11364917134175420009e-2 + (0.70058230641246312003e-5 + (0.40943644083718586939e-7 + (0.22563034723692881631e-9 + (0.11642841011361992885e-11 + 0.55721092871111111110e-14 * t) * t) * t) * t) * t) * t; + } + case 22: { + T t = 2*y100 - 45; + return 0.40842225954785960651e-1 + (0.11650136437945673891e-2 + (0.72569945502343006619e-5 + (0.42796161861855042273e-7 + (0.23761401711005024162e-9 + (0.12332431172381557035e-11 + 0.59246802364444444445e-14 * t) * t) * t) * t) * t) * t; + } + case 23: { + T t = 2*y100 - 47; + return 0.43201627431540222422e-1 + (0.11945628793917272199e-2 + (0.75195743532849206263e-5 + (0.44747364553960993492e-7 + (0.25030885216472953674e-9 + (0.13065684400300476484e-11 + 0.63000532853333333334e-14 * t) * t) * t) * t) * t) * t; + } + case 24: { + T t = 2*y100 - 49; + return 0.45621193513810471438e-1 + (0.12251862608067529503e-2 + (0.77941720055551920319e-5 + (0.46803119830954460212e-7 + (0.26375990983978426273e-9 + (0.13845421370977119765e-11 + 0.66996477404444444445e-14 * t) * t) * t) * t) * t) * t; + } + case 25: { + T t = 2*y100 - 51; + return 0.48103121413299865517e-1 + (0.12569331386432195113e-2 + (0.80814333496367673980e-5 + (0.48969667335682018324e-7 + (0.27801515481905748484e-9 + (0.14674637611609884208e-11 + 0.71249589351111111110e-14 * t) * t) * t) * t) * t) * t; + } + case 26: { + T t = 2*y100 - 53; + return 0.50649709676983338501e-1 + (0.12898555233099055810e-2 + (0.83820428414568799654e-5 + (0.51253642652551838659e-7 + (0.29312563849675507232e-9 + (0.15556512782814827846e-11 + 0.75775607822222222221e-14 * t) * t) * t) * t) * t) * t; + } + case 27: { + T t = 2*y100 - 55; + return 0.53263363664388864181e-1 + (0.13240082443256975769e-2 + (0.86967260015007658418e-5 + (0.53662102750396795566e-7 + (0.30914568786634796807e-9 + (0.16494420240828493176e-11 + 0.80591079644444444445e-14 * t) * t) * t) * t) * t) * t; + } + case 28: { + T t = 2*y100 - 57; + return 0.55946601353500013794e-1 + (0.13594491197408190706e-2 + (0.90262520233016380987e-5 + (0.56202552975056695376e-7 + (0.32613310410503135996e-9 + (0.17491936862246367398e-11 + 0.85713381688888888890e-14 * t) * t) * t) * t) * t) * t; + } + case 29: { + T t = 2*y100 - 59; + return 0.58702059496154081813e-1 + (0.13962391363223647892e-2 + (0.93714365487312784270e-5 + (0.58882975670265286526e-7 + (0.34414937110591753387e-9 + (0.18552853109751857859e-11 + 0.91160736711111111110e-14 * t) * t) * t) * t) * t) * t; + } + case 30: { + T t = 2*y100 - 61; + return 0.61532500145144778048e-1 + (0.14344426411912015247e-2 + (0.97331446201016809696e-5 + (0.61711860507347175097e-7 + (0.36325987418295300221e-9 + (0.19681183310134518232e-11 + 0.96952238400000000000e-14 * t) * t) * t) * t) * t) * t; + } + case 31: { + T t = 2*y100 - 63; + return 0.64440817576653297993e-1 + (0.14741275456383131151e-2 + (0.10112293819576437838e-4 + (0.64698236605933246196e-7 + (0.38353412915303665586e-9 + (0.20881176114385120186e-11 + 0.10310784480000000000e-13 * t) * t) * t) * t) * t) * t; + } + case 32: { + T t = 2*y100 - 65; + return 0.67430045633130393282e-1 + (0.15153655418916540370e-2 + (0.10509857606888328667e-4 + (0.67851706529363332855e-7 + (0.40504602194811140006e-9 + (0.22157325110542534469e-11 + 0.10964842115555555556e-13 * t) * t) * t) * t) * t) * t; + } + case 33: { + T t = 2*y100 - 67; + return 0.70503365513338850709e-1 + (0.15582323336495709827e-2 + (0.10926868866865231089e-4 + (0.71182482239613507542e-7 + (0.42787405890153386710e-9 + (0.23514379522274416437e-11 + 0.11659571751111111111e-13 * t) * t) * t) * t) * t) * t; + } + case 34: { + T t = 2*y100 - 69; + return 0.73664114037944596353e-1 + (0.16028078812438820413e-2 + (0.11364423678778207991e-4 + (0.74701423097423182009e-7 + (0.45210162777476488324e-9 + (0.24957355004088569134e-11 + 0.12397238257777777778e-13 * t) * t) * t) * t) * t) * t; + } + case 35: { + T t = 2*y100 - 71; + return 0.76915792420819562379e-1 + (0.16491766623447889354e-2 + (0.11823685320041302169e-4 + (0.78420075993781544386e-7 + (0.47781726956916478925e-9 + (0.26491544403815724749e-11 + 0.13180196462222222222e-13 * t) * t) * t) * t) * t) * t; + } + case 36: { + T t = 2*y100 - 73; + return 0.80262075578094612819e-1 + (0.16974279491709504117e-2 + (0.12305888517309891674e-4 + (0.82350717698979042290e-7 + (0.50511496109857113929e-9 + (0.28122528497626897696e-11 + 0.14010889635555555556e-13 * t) * t) * t) * t) * t) * t; + } + case 37: { + T t = 2*y100 - 75; + return 0.83706822008980357446e-1 + (0.17476561032212656962e-2 + (0.12812343958540763368e-4 + (0.86506399515036435592e-7 + (0.53409440823869467453e-9 + (0.29856186620887555043e-11 + 0.14891851591111111111e-13 * t) * t) * t) * t) * t) * t; + } + case 38: { + T t = 2*y100 - 77; + return 0.87254084284461718231e-1 + (0.17999608886001962327e-2 + (0.13344443080089492218e-4 + (0.90900994316429008631e-7 + (0.56486134972616465316e-9 + (0.31698707080033956934e-11 + 0.15825697795555555556e-13 * t) * t) * t) * t) * t) * t; + } + case 39: { + T t = 2*y100 - 79; + return 0.90908120182172748487e-1 + (0.18544478050657699758e-2 + (0.13903663143426120077e-4 + (0.95549246062549906177e-7 + (0.59752787125242054315e-9 + (0.33656597366099099413e-11 + 0.16815130613333333333e-13 * t) * t) * t) * t) * t) * t; + } + case 40: { + T t = 2*y100 - 81; + return 0.94673404508075481121e-1 + (0.19112284419887303347e-2 + (0.14491572616545004930e-4 + (0.10046682186333613697e-6 + (0.63221272959791000515e-9 + (0.35736693975589130818e-11 + 0.17862931591111111111e-13 * t) * t) * t) * t) * t) * t; + } + case 41: { + T t = 2*y100 - 83; + return 0.98554641648004456555e-1 + (0.19704208544725622126e-2 + (0.15109836875625443935e-4 + (0.10567036667675984067e-6 + (0.66904168640019354565e-9 + (0.37946171850824333014e-11 + 0.18971959040000000000e-13 * t) * t) * t) * t) * t) * t; + } + case 42: { + T t = 2*y100 - 85; + return 0.10255677889470089531e0 + (0.20321499629472857418e-2 + (0.15760224242962179564e-4 + (0.11117756071353507391e-6 + (0.70814785110097658502e-9 + (0.40292553276632563925e-11 + 0.20145143075555555556e-13 * t) * t) * t) * t) * t) * t; + } + case 43: { + T t = 2*y100 - 87; + return 0.10668502059865093318e0 + (0.20965479776148731610e-2 + (0.16444612377624983565e-4 + (0.11700717962026152749e-6 + (0.74967203250938418991e-9 + (0.42783716186085922176e-11 + 0.21385479360000000000e-13 * t) * t) * t) * t) * t) * t; + } + case 44: { + T t = 2*y100 - 89; + return 0.11094484319386444474e0 + (0.21637548491908170841e-2 + (0.17164995035719657111e-4 + (0.12317915750735938089e-6 + (0.79376309831499633734e-9 + (0.45427901763106353914e-11 + 0.22696025653333333333e-13 * t) * t) * t) * t) * t) * t; + } + case 45: { + T t = 2*y100 - 91; + return 0.11534201115268804714e0 + (0.22339187474546420375e-2 + (0.17923489217504226813e-4 + (0.12971465288245997681e-6 + (0.84057834180389073587e-9 + (0.48233721206418027227e-11 + 0.24079890062222222222e-13 * t) * t) * t) * t) * t) * t; + } + case 46: { + T t = 2*y100 - 93; + return 0.11988259392684094740e0 + (0.23071965691918689601e-2 + (0.18722342718958935446e-4 + (0.13663611754337957520e-6 + (0.89028385488493287005e-9 + (0.51210161569225846701e-11 + 0.25540227111111111111e-13 * t) * t) * t) * t) * t) * t; + } + case 47: { + T t = 2*y100 - 95; + return 0.12457298393509812907e0 + (0.23837544771809575380e-2 + (0.19563942105711612475e-4 + (0.14396736847739470782e-6 + (0.94305490646459247016e-9 + (0.54366590583134218096e-11 + 0.27080225920000000000e-13 * t) * t) * t) * t) * t) * t; + } + case 48: { + T t = 2*y100 - 97; + return 0.12941991566142438816e0 + (0.24637684719508859484e-2 + (0.20450821127475879816e-4 + (0.15173366280523906622e-6 + (0.99907632506389027739e-9 + (0.57712760311351625221e-11 + 0.28703099555555555556e-13 * t) * t) * t) * t) * t) * t; + } + case 49: { + T t = 2*y100 - 99; + return 0.13443048593088696613e0 + (0.25474249981080823877e-2 + (0.21385669591362915223e-4 + (0.15996177579900443030e-6 + (0.10585428844575134013e-8 + (0.61258809536787882989e-11 + 0.30412080142222222222e-13 * t) * t) * t) * t) * t) * t; + } + case 50: { + T t = 2*y100 - 101; + return 0.13961217543434561353e0 + (0.26349215871051761416e-2 + (0.22371342712572567744e-4 + (0.16868008199296822247e-6 + (0.11216596910444996246e-8 + (0.65015264753090890662e-11 + 0.32210394506666666666e-13 * t) * t) * t) * t) * t) * t; + } + case 51: { + T t = 2*y100 - 103; + return 0.14497287157673800690e0 + (0.27264675383982439814e-2 + (0.23410870961050950197e-4 + (0.17791863939526376477e-6 + (0.11886425714330958106e-8 + (0.68993039665054288034e-11 + 0.34101266222222222221e-13 * t) * t) * t) * t) * t) * t; + } + case 52: { + T t = 2*y100 - 105; + return 0.15052089272774618151e0 + (0.28222846410136238008e-2 + (0.24507470422713397006e-4 + (0.18770927679626136909e-6 + (0.12597184587583370712e-8 + (0.73203433049229821618e-11 + 0.36087889048888888890e-13 * t) * t) * t) * t) * t) * t; + } + case 53: { + T t = 2*y100 - 107; + return 0.15626501395774612325e0 + (0.29226079376196624949e-2 + (0.25664553693768450545e-4 + (0.19808568415654461964e-6 + (0.13351257759815557897e-8 + (0.77658124891046760667e-11 + 0.38173420035555555555e-13 * t) * t) * t) * t) * t) * t; + } + case 54: { + T t = 2*y100 - 109; + return 0.16221449434620737567e0 + (0.30276865332726475672e-2 + (0.26885741326534564336e-4 + (0.20908350604346384143e-6 + (0.14151148144240728728e-8 + (0.82369170665974313027e-11 + 0.40360957457777777779e-13 * t) * t) * t) * t) * t) * t; + } + case 55: { + T t = 2*y100 - 111; + return 0.16837910595412130659e0 + (0.31377844510793082301e-2 + (0.28174873844911175026e-4 + (0.22074043807045782387e-6 + (0.14999481055996090039e-8 + (0.87348993661930809254e-11 + 0.42653528977777777779e-13 * t) * t) * t) * t) * t) * t; + } + case 56: { + T t = 2*y100 - 113; + return 0.17476916455659369953e0 + (0.32531815370903068316e-2 + (0.29536024347344364074e-4 + (0.23309632627767074202e-6 + (0.15899007843582444846e-8 + (0.92610375235427359475e-11 + 0.45054073102222222221e-13 * t) * t) * t) * t) * t) * t; + } + case 57: { + T t = 2*y100 - 115; + return 0.18139556223643701364e0 + (0.33741744168096996041e-2 + (0.30973511714709500836e-4 + (0.24619326937592290996e-6 + (0.16852609412267750744e-8 + (0.98166442942854895573e-11 + 0.47565418097777777779e-13 * t) * t) * t) * t) * t) * t; + } + case 58: { + T t = 2*y100 - 117; + return 0.18826980194443664549e0 + (0.35010775057740317997e-2 + (0.32491914440014267480e-4 + (0.26007572375886319028e-6 + (0.17863299617388376116e-8 + (0.10403065638343878679e-10 + 0.50190265831111111110e-13 * t) * t) * t) * t) * t) * t; + } + case 59: { + T t = 2*y100 - 119; + return 0.19540403413693967350e0 + (0.36342240767211326315e-2 + (0.34096085096200907289e-4 + (0.27479061117017637474e-6 + (0.18934228504790032826e-8 + (0.11021679075323598664e-10 + 0.52931171733333333334e-13 * t) * t) * t) * t) * t) * t; + } + case 60: { + T t = 2*y100 - 121; + return 0.20281109560651886959e0 + (0.37739673859323597060e-2 + (0.35791165457592409054e-4 + (0.29038742889416172404e-6 + (0.20068685374849001770e-8 + (0.11673891799578381999e-10 + 0.55790523093333333334e-13 * t) * t) * t) * t) * t) * t; + } + case 61: { + T t = 2*y100 - 123; + return 0.21050455062669334978e0 + (0.39206818613925652425e-2 + (0.37582602289680101704e-4 + (0.30691836231886877385e-6 + (0.21270101645763677824e-8 + (0.12361138551062899455e-10 + 0.58770520160000000000e-13 * t) * t) * t) * t) * t) * t; + } + case 62: { + T t = 2*y100 - 125; + return 0.21849873453703332479e0 + (0.40747643554689586041e-2 + (0.39476163820986711501e-4 + (0.32443839970139918836e-6 + (0.22542053491518680200e-8 + (0.13084879235290858490e-10 + 0.61873153262222222221e-13 * t) * t) * t) * t) * t) * t; + } + case 63: { + T t = 2*y100 - 127; + return 0.22680879990043229327e0 + (0.42366354648628516935e-2 + (0.41477956909656896779e-4 + (0.34300544894502810002e-6 + (0.23888264229264067658e-8 + (0.13846596292818514601e-10 + 0.65100183751111111110e-13 * t) * t) * t) * t) * t) * t; + } + case 64: { + T t = 2*y100 - 129; + return 0.23545076536988703937e0 + (0.44067409206365170888e-2 + (0.43594444916224700881e-4 + (0.36268045617760415178e-6 + (0.25312606430853202748e-8 + (0.14647791812837903061e-10 + 0.68453122631111111110e-13 * t) * t) * t) * t) * t) * t; + } + case 65: { + T t = 2*y100 - 131; + return 0.24444156740777432838e0 + (0.45855530511605787178e-2 + (0.45832466292683085475e-4 + (0.38352752590033030472e-6 + (0.26819103733055603460e-8 + (0.15489984390884756993e-10 + 0.71933206364444444445e-13 * t) * t) * t) * t) * t) * t; + } + case 66: { + T t = 2*y100 - 133; + return 0.25379911500634264643e0 + (0.47735723208650032167e-2 + (0.48199253896534185372e-4 + (0.40561404245564732314e-6 + (0.28411932320871165585e-8 + (0.16374705736458320149e-10 + 0.75541379822222222221e-13 * t) * t) * t) * t) * t) * t; + } + case 67: { + T t = 2*y100 - 135; + return 0.26354234756393613032e0 + (0.49713289477083781266e-2 + (0.50702455036930367504e-4 + (0.42901079254268185722e-6 + (0.30095422058900481753e-8 + (0.17303497025347342498e-10 + 0.79278273368888888890e-13 * t) * t) * t) * t) * t) * t; + } + case 68: { + T t = 2*y100 - 137; + return 0.27369129607732343398e0 + (0.51793846023052643767e-2 + (0.53350152258326602629e-4 + (0.45379208848865015485e-6 + (0.31874057245814381257e-8 + (0.18277905010245111046e-10 + 0.83144182364444444445e-13 * t) * t) * t) * t) * t) * t; + } + case 69: { + T t = 2*y100 - 139; + return 0.28426714781640316172e0 + (0.53983341916695141966e-2 + (0.56150884865255810638e-4 + (0.48003589196494734238e-6 + (0.33752476967570796349e-8 + (0.19299477888083469086e-10 + 0.87139049137777777779e-13 * t) * t) * t) * t) * t) * t; + } + case 70: { + T t = 2*y100 - 141; + return 0.29529231465348519920e0 + (0.56288077305420795663e-2 + (0.59113671189913307427e-4 + (0.50782393781744840482e-6 + (0.35735475025851713168e-8 + (0.20369760937017070382e-10 + 0.91262442613333333334e-13 * t) * t) * t) * t) * t) * t; + } + case 71: { + T t = 2*y100 - 143; + return 0.30679050522528838613e0 + (0.58714723032745403331e-2 + (0.62248031602197686791e-4 + (0.53724185766200945789e-6 + (0.37827999418960232678e-8 + (0.21490291930444538307e-10 + 0.95513539182222222221e-13 * t) * t) * t) * t) * t) * t; + } + case 72: { + T t = 2*y100 - 145; + return 0.31878680111173319425e0 + (0.61270341192339103514e-2 + (0.65564012259707640976e-4 + (0.56837930287837738996e-6 + (0.40035151353392378882e-8 + (0.22662596341239294792e-10 + 0.99891109760000000000e-13 * t) * t) * t) * t) * t) * t; + } + case 73: { + T t = 2*y100 - 147; + return 0.33130773722152622027e0 + (0.63962406646798080903e-2 + (0.69072209592942396666e-4 + (0.60133006661885941812e-6 + (0.42362183765883466691e-8 + (0.23888182347073698382e-10 + 0.10439349811555555556e-12 * t) * t) * t) * t) * t) * t; + } + case 74: { + T t = 2*y100 - 149; + return 0.34438138658041336523e0 + (0.66798829540414007258e-2 + (0.72783795518603561144e-4 + (0.63619220443228800680e-6 + (0.44814499336514453364e-8 + (0.25168535651285475274e-10 + 0.10901861383111111111e-12 * t) * t) * t) * t) * t) * t; + } + case 75: { + T t = 2*y100 - 151; + return 0.35803744972380175583e0 + (0.69787978834882685031e-2 + (0.76710543371454822497e-4 + (0.67306815308917386747e-6 + (0.47397647975845228205e-8 + (0.26505114141143050509e-10 + 0.11376390933333333333e-12 * t) * t) * t) * t) * t) * t; + } + case 76: { + T t = 2*y100 - 153; + return 0.37230734890119724188e0 + (0.72938706896461381003e-2 + (0.80864854542670714092e-4 + (0.71206484718062688779e-6 + (0.50117323769745883805e-8 + (0.27899342394100074165e-10 + 0.11862637614222222222e-12 * t) * t) * t) * t) * t) * t; + } + case 77: { + T t = 2*y100 - 155; + return 0.38722432730555448223e0 + (0.76260375162549802745e-2 + (0.85259785810004603848e-4 + (0.75329383305171327677e-6 + (0.52979361368388119355e-8 + (0.29352606054164086709e-10 + 0.12360253370666666667e-12 * t) * t) * t) * t) * t) * t; + } + case 78: { + T t = 2*y100 - 157; + return 0.40282355354616940667e0 + (0.79762880915029728079e-2 + (0.89909077342438246452e-4 + (0.79687137961956194579e-6 + (0.55989731807360403195e-8 + (0.30866246101464869050e-10 + 0.12868841946666666667e-12 * t) * t) * t) * t) * t) * t; + } + case 79: { + T t = 2*y100 - 159; + return 0.41914223158913787649e0 + (0.83456685186950463538e-2 + (0.94827181359250161335e-4 + (0.84291858561783141014e-6 + (0.59154537751083485684e-8 + (0.32441553034347469291e-10 + 0.13387957943111111111e-12 * t) * t) * t) * t) * t) * t; + } + case 80: { + T t = 2*y100 - 161; + return 0.43621971639463786896e0 + (0.87352841828289495773e-2 + (0.10002929142066799966e-3 + (0.89156148280219880024e-6 + (0.62480008150788597147e-8 + (0.34079760983458878910e-10 + 0.13917107176888888889e-12 * t) * t) * t) * t) * t) * t; + } + case 81: { + T t = 2*y100 - 163; + return 0.45409763548534330981e0 + (0.91463027755548240654e-2 + (0.10553137232446167258e-3 + (0.94293113464638623798e-6 + (0.65972492312219959885e-8 + (0.35782041795476563662e-10 + 0.14455745872000000000e-12 * t) * t) * t) * t) * t) * t; + } + case 82: { + T t = 2*y100 - 165; + return 0.47282001668512331468e0 + (0.95799574408860463394e-2 + (0.11135019058000067469e-3 + (0.99716373005509038080e-6 + (0.69638453369956970347e-8 + (0.37549499088161345850e-10 + 0.15003280712888888889e-12 * t) * t) * t) * t) * t) * t; + } + case 83: { + T t = 2*y100 - 167; + return 0.49243342227179841649e0 + (0.10037550043909497071e-1 + (0.11750334542845234952e-3 + (0.10544006716188967172e-5 + (0.73484461168242224872e-8 + (0.39383162326435752965e-10 + 0.15559069118222222222e-12 * t) * t) * t) * t) * t) * t; + } + case 84: { + T t = 2*y100 - 169; + return 0.51298708979209258326e0 + (0.10520454564612427224e-1 + (0.12400930037494996655e-3 + (0.11147886579371265246e-5 + (0.77517184550568711454e-8 + (0.41283980931872622611e-10 + 0.16122419680000000000e-12 * t) * t) * t) * t) * t) * t; + } + case 85: { + T t = 2*y100 - 171; + return 0.53453307979101369843e0 + (0.11030120618800726938e-1 + (0.13088741519572269581e-3 + (0.11784797595374515432e-5 + (0.81743383063044825400e-8 + (0.43252818449517081051e-10 + 0.16692592640000000000e-12 * t) * t) * t) * t) * t) * t; + } + case 86: { + T t = 2*y100 - 173; + return 0.55712643071169299478e0 + (0.11568077107929735233e-1 + (0.13815797838036651289e-3 + (0.12456314879260904558e-5 + (0.86169898078969313597e-8 + (0.45290446811539652525e-10 + 0.17268801084444444444e-12 * t) * t) * t) * t) * t) * t; + } + case 87: { + T t = 2*y100 - 175; + return 0.58082532122519320968e0 + (0.12135935999503877077e-1 + (0.14584223996665838559e-3 + (0.13164068573095710742e-5 + (0.90803643355106020163e-8 + (0.47397540713124619155e-10 + 0.17850211608888888889e-12 * t) * t) * t) * t) * t) * t; + } + case 88: { + T t = 2*y100 - 177; + return 0.60569124025293375554e0 + (0.12735396239525550361e-1 + (0.15396244472258863344e-3 + (0.13909744385382818253e-5 + (0.95651595032306228245e-8 + (0.49574672127669041550e-10 + 0.18435945564444444444e-12 * t) * t) * t) * t) * t) * t; + } + case 89: { + T t = 2*y100 - 179; + return 0.63178916494715716894e0 + (0.13368247798287030927e-1 + (0.16254186562762076141e-3 + (0.14695084048334056083e-5 + (0.10072078109604152350e-7 + (0.51822304995680707483e-10 + 0.19025081422222222222e-12 * t) * t) * t) * t) * t) * t; + } + case 90: { + T t = 2*y100 - 181; + return 0.65918774689725319200e0 + (0.14036375850601992063e-1 + (0.17160483760259706354e-3 + (0.15521885688723188371e-5 + (0.10601827031535280590e-7 + (0.54140790105837520499e-10 + 0.19616655146666666667e-12 * t) * t) * t) * t) * t) * t; + } + case 91: { + T t = 2*y100 - 183; + return 0.68795950683174433822e0 + (0.14741765091365869084e-1 + (0.18117679143520433835e-3 + (0.16392004108230585213e-5 + (0.11155116068018043001e-7 + (0.56530360194925690374e-10 + 0.20209663662222222222e-12 * t) * t) * t) * t) * t) * t; + } + case 92: { + T t = 2*y100 - 185; + return 0.71818103808729967036e0 + (0.15486504187117112279e-1 + (0.19128428784550923217e-3 + (0.17307350969359975848e-5 + (0.11732656736113607751e-7 + (0.58991125287563833603e-10 + 0.20803065333333333333e-12 * t) * t) * t) * t) * t) * t; + } + case 93: { + T t = 2*y100 - 187; + return 0.74993321911726254661e0 + (0.16272790364044783382e-1 + (0.20195505163377912645e-3 + (0.18269894883203346953e-5 + (0.12335161021630225535e-7 + (0.61523068312169087227e-10 + 0.21395783431111111111e-12 * t) * t) * t) * t) * t) * t; + } + case 94: { + T t = 2*y100 - 189; + return 0.78330143531283492729e0 + (0.17102934132652429240e-1 + (0.21321800585063327041e-3 + (0.19281661395543913713e-5 + (0.12963340087354341574e-7 + (0.64126040998066348872e-10 + 0.21986708942222222222e-12 * t) * t) * t) * t) * t) * t; + } + case 95: { + T t = 2*y100 - 191; + return 0.81837581041023811832e0 + (0.17979364149044223802e-1 + (0.22510330592753129006e-3 + (0.20344732868018175389e-5 + (0.13617902941839949718e-7 + (0.66799760083972474642e-10 + 0.22574701262222222222e-12 * t) * t) * t) * t) * t) * t; + } + case 96: { + T t = 2*y100 - 193; + return 0.85525144775685126237e0 + (0.18904632212547561026e-1 + (0.23764237370371255638e-3 + (0.21461248251306387979e-5 + (0.14299555071870523786e-7 + (0.69543803864694171934e-10 + 0.23158593688888888889e-12 * t) * t) * t) * t) * t) * t; + } + case 97: { + T t = 2*y100 - 195; + return 0.89402868170849933734e0 + (0.19881418399127202569e-1 + (0.25086793128395995798e-3 + (0.22633402747585233180e-5 + (0.15008997042116532283e-7 + (0.72357609075043941261e-10 + 0.23737194737777777778e-12 * t) * t) * t) * t) * t) * t; + } + case 98: { + T t = 2*y100 - 197; + return 0.93481333942870796363e0 + (0.20912536329780368893e-1 + (0.26481403465998477969e-3 + (0.23863447359754921676e-5 + (0.15746923065472184451e-7 + (0.75240468141720143653e-10 + 0.24309291271111111111e-12 * t) * t) * t) * t) * t) * t; + } + case 99: { + T t = 2*y100 - 199; + return 0.97771701335885035464e0 + (0.22000938572830479551e-1 + (0.27951610702682383001e-3 + (0.25153688325245314530e-5 + (0.16514019547822821453e-7 + (0.78191526829368231251e-10 + 0.24873652355555555556e-12 * t) * t) * t) * t) * t) * t; + } + } + + // we only get here if y = 1, i.e. |x| < 4*eps, in which case + // erfcx is within 1e-15 of 1.. + return 1.; + } + + template + T erfcx(T x) { + // Short-circuits on NaN (returning NaN) + if (x != x) { + return x; + } + + if (x >= 0) { + if (x > T{50}) { // continued-fraction expansion is faster + const T ispi = 0.56418958354775628694807945156; // 1 / sqrt(pi) + + if (x > T{5e7}) { // 1-term expansion, important to avoid overflow + return ispi / x; + } + + /* 5-term expansion (rely on compiler for CSE), simplified from: + ispi / (x+0.5/(x+1/(x+1.5/(x+2/x)))) */ + return ispi * ((x*x) * (x*x+T{4.5}) + T{2}) / (x * ((x*x) * (x*x+T{5}) + T{3.75})); + } + + // x >= 0 x <= 50 + return erfcx_y100(T{400} / (T{4} + x)); + } + + // x < 0 + if (x < T{-26.7}) { + return POS_INFINITY; + } else if (x < T{-6.1}) { + return T{2} * exp(x * x); + } + + // x < 0 and x >= -6.1 + return T{2} * exp(x * x) - erfcx_y100(T{400} / (T{4} - x)); + } +); // erfcx_string + +const auto airy_ai_string = jiterator_stringify( + template + T airy_ai_forward(T x) { + static const T AN[] = { + +3.46538101525629032477e-01, + +1.20075952739645805542e+01, + +7.62796053615234516538e+01, + +1.68089224934630576269e+02, + +1.59756391350164413639e+02, + +7.05360906840444183113e+01, + +1.40264691163389668864e+01, + +9.99999999999999995305e-01, + }; + + static const T AD[] = { + +5.67594532638770212846e-01, + +1.47562562584847203173e+01, + +8.45138970141474626562e+01, + +1.77318088145400459522e+02, + +1.64234692871529701831e+02, + +7.14778400825575695274e+01, + +1.40959135607834029598e+01, + +1.00000000000000000470e+00, + }; + + static const T AFN[] = { + -1.31696323418331795333e-01, + -6.26456544431912369773e-01, + -6.93158036036933542233e-01, + -2.79779981545119124951e-01, + -4.91900132609500318020e-02, + -4.06265923594885404393e-03, + -1.59276496239262096340e-04, + -2.77649108155232920844e-06, + -1.67787698489114633780e-08, + }; + + static const T AFD[] = { + +1.33560420706553243746e+01, + +3.26825032795224613948e+01, + +2.67367040941499554804e+01, + +9.18707402907259625840e+00, + +1.47529146771666414581e+00, + +1.15687173795188044134e-01, + +4.40291641615211203805e-03, + +7.54720348287414296618e-05, + +4.51850092970580378464e-07, + }; + + static const T AGN[] = { + +1.97339932091685679179e-02, + +3.91103029615688277255e-01, + +1.06579897599595591108e+00, + +9.39169229816650230044e-01, + +3.51465656105547619242e-01, + +6.33888919628925490927e-02, + +5.85804113048388458567e-03, + +2.82851600836737019778e-04, + +6.98793669997260967291e-06, + +8.11789239554389293311e-08, + +3.41551784765923618484e-10, + }; + + static const T AGD[] = { + +9.30892908077441974853e+00, + +1.98352928718312140417e+01, + +1.55646628932864612953e+01, + +5.47686069422975497931e+00, + +9.54293611618961883998e-01, + +8.64580826352392193095e-02, + +4.12656523824222607191e-03, + +1.01259085116509135510e-04, + +1.17166733214413521882e-06, + +4.91834570062930015649e-09, + }; + + int domain_flag = 0; + + T ai; + + if (isinf(x)) { + return NAN; + } + + if (x > T(103.892)) { + return T(0.0); + } + + T f; + T g; + T k; + + if (x < T(-2.09)) { + T z = T(1.0) / (T(-2.0) * x * sqrt(-x) / T(3.0)); + + T afn = 0.0; + + for (uint8_t index = 0; index <= 8; index++) { + afn = afn * (z * z) + AFN[index]; + } + + T afd = 0.0; + + for (uint8_t index = 0; index <= 8; index++) { + afd = afd * (z * z) + AFD[index]; + } + + T agn = 0.0; + + for (uint8_t index = 0; index <= 10 + 0; index++) { + agn = agn * (z * z) + AGN[index]; + } + + T agd = 0.0; + + for (uint8_t index = 0; index <= 10 - 1; index++) { + agd = agd * (z * z) + AGD[index]; + } + + T t = T(-2.0) * x * sqrt(-x) / T(3.0) + T(0.25) * T(3.14159265358979323846); + + return T(5.64189583547756286948e-01) / sqrt(sqrt(-x)) * (sin(t) * (T(1.0) + z * z * afn / afd) - cos(t) * (z * agn / agd)); + } + + if (x >= T(2.09)) { + domain_flag = 5; + + T zeta = T(2.0) * x * sqrt(x) / T(3.0); + + T an = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + an = an * (T(1.0) / zeta) + AN[index]; + } + + T ad = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + ad = ad * (T(1.0) / zeta) + AD[index]; + } + + ai = T(5.64189583547756286948e-01) * (an / ad) / (T(2.0) * sqrt(sqrt(x)) * exp(zeta)); + + if (x > T(8.3203353)) { + return ai; + } + } + + f = 1.0; + g = x; + k = 1.0; + + T m = 1.0; + T n = x; + T t = 1.0; + T z = x * x * x; + + while (t > T(1.11022302462515654042e-16)) { + m *= z; + k += T(1.0); + m /= k; + n *= z; + k += T(1.0); + n /= k; + m /= k; + f += m; + k += T(1.0); + n /= k; + g += n; + + t = abs(m / f); + } + + if ((domain_flag & 1) == 0) { + return T(0.355028053887817239260) * f - T(0.258819403792806798405) * g; + } + + return ai; + } // T airy_ai(T x) +); // airy_ai_string + +const auto bessel_j0_string = jiterator_stringify( + template + T bessel_j0_forward(T x) { + static const T PP[] = { + +7.96936729297347051624e-04, + +8.28352392107440799803e-02, + +1.23953371646414299388e+00, + +5.44725003058768775090e+00, + +8.74716500199817011941e+00, + +5.30324038235394892183e+00, + +9.99999999999999997821e-01, + }; + + static const T PQ[] = { + +9.24408810558863637013e-04, + +8.56288474354474431428e-02, + +1.25352743901058953537e+00, + +5.47097740330417105182e+00, + +8.76190883237069594232e+00, + +5.30605288235394617618e+00, + +1.00000000000000000218e+00, + }; + + static const T QP[] = { + -1.13663838898469149931e-02, + -1.28252718670509318512e+00, + -1.95539544257735972385e+01, + -9.32060152123768231369e+01, + -1.77681167980488050595e+02, + -1.47077505154951170175e+02, + -5.14105326766599330220e+01, + -6.05014350600728481186e+00, + }; + + static const T QQ[] = { + +6.43178256118178023184e+01, + +8.56430025976980587198e+02, + +3.88240183605401609683e+03, + +7.24046774195652478189e+03, + +5.93072701187316984827e+03, + +2.06209331660327847417e+03, + +2.42005740240291393179e+02, + }; + + static const T RP[] = { + -4.79443220978201773821e+09, + +1.95617491946556577543e+12, + -2.49248344360967716204e+14, + +9.70862251047306323952e+15, + }; + + static const T RQ[] = { + +4.99563147152651017219e+02, + +1.73785401676374683123e+05, + +4.84409658339962045305e+07, + +1.11855537045356834862e+10, + +2.11277520115489217587e+12, + +3.10518229857422583814e+14, + +3.18121955943204943306e+16, + +1.71086294081043136091e+18, + }; + + if (x < T(0)) { + x = -x; + } + + if (x <= T(5.0)) { + if (x < T(0.00001)) { + return T(1.0) - x * x / T(4.0); + } + + T rp = 0.0; + + for (uint8_t index = 0; index <= 3; index++) { + rp = rp * (x * x) + RP[index]; + } + + T rq = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + rq = rq * (x * x) + RQ[index]; + } + + return (x * x - T(5.78318596294678452118e+00)) * (x * x - T(3.04712623436620863991e+01)) * rp / rq; + } + + T pp = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pp = pp * (T(25.0) / (x * x)) + PP[index]; + } + + T pq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pq = pq * (T(25.0) / (x * x)) + PQ[index]; + } + + T qp = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + qp = qp * (T(25.0) / (x * x)) + QP[index]; + } + + T qq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + qq = qq * (T(25.0) / (x * x)) + QQ[index]; + } + + return (pp / pq * cos(x - T(0.785398163397448309615660845819875721)) - T(5.0) / x * (qp / qq) * sin(x - T(0.785398163397448309615660845819875721))) * T(0.797884560802865355879892119868763737) / sqrt(x); + } // bessel_j0_forward(T x) +); // bessel_j0_string + +const auto bessel_y0_string = bessel_j0_string + jiterator_stringify( + template + T bessel_y0_forward(T x) { + static const T PP[] = { + +7.96936729297347051624e-04, + +8.28352392107440799803e-02, + +1.23953371646414299388e+00, + +5.44725003058768775090e+00, + +8.74716500199817011941e+00, + +5.30324038235394892183e+00, + +9.99999999999999997821e-01, + }; + + static const T PQ[] = { + +9.24408810558863637013e-04, + +8.56288474354474431428e-02, + +1.25352743901058953537e+00, + +5.47097740330417105182e+00, + +8.76190883237069594232e+00, + +5.30605288235394617618e+00, + +1.00000000000000000218e+00, + }; + + static const T QP[] = { + -1.13663838898469149931e-02, + -1.28252718670509318512e+00, + -1.95539544257735972385e+01, + -9.32060152123768231369e+01, + -1.77681167980488050595e+02, + -1.47077505154951170175e+02, + -5.14105326766599330220e+01, + -6.05014350600728481186e+00, + }; + + static const T QQ[] = { + +6.43178256118178023184e+01, + +8.56430025976980587198e+02, + +3.88240183605401609683e+03, + +7.24046774195652478189e+03, + +5.93072701187316984827e+03, + +2.06209331660327847417e+03, + +2.42005740240291393179e+02, + }; + + static const T YP[] = { + +1.55924367855235737965e+04, + -1.46639295903971606143e+07, + +5.43526477051876500413e+09, + -9.82136065717911466409e+11, + +8.75906394395366999549e+13, + -3.46628303384729719441e+15, + +4.42733268572569800351e+16, + -1.84950800436986690637e+16, + }; + + static const T YQ[] = { + +1.04128353664259848412e+03, + +6.26107330137134956842e+05, + +2.68919633393814121987e+08, + +8.64002487103935000337e+10, + +2.02979612750105546709e+13, + +3.17157752842975028269e+15, + +2.50596256172653059228e+17, + }; + + if (x <= T(5.0)) { + if (x == T(0.0)) { + return NEG_INFINITY; + } + + if (x < T(0.0)) { + NAN; + } + + T yp = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + yp = yp * (x * x) + YP[index]; + } + + T yq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + yq = yq * (x * x) + YQ[index]; + } + + return yp / yq + (T(0.636619772367581343075535053490057448) * log(x) * bessel_j0_forward(x)); + } + + T pp = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pp = pp * (T(25.0) / (x * x)) + PP[index]; + } + + T pq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pq = pq * (T(25.0) / (x * x)) + PQ[index]; + } + + T qp = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + qp = qp * (T(25.0) / (x * x)) + QP[index]; + } + + T qq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + qq = qq * (T(25.0) / (x * x)) + QQ[index]; + } + + return (pp / pq * sin(x - T(0.785398163397448309615660845819875721)) + T(5.0) / x * (qp / qq) * cos(x - T(0.785398163397448309615660845819875721))) * T(0.797884560802865355879892119868763737) / sqrt(x); + } // bessel_y0_forward(T x) +); // bessel_y0_string + +const auto bessel_j1_string = jiterator_stringify( + template + T bessel_j1_forward(T x) { + static const T PP[] = { + +7.62125616208173112003e-04, + +7.31397056940917570436e-02, + +1.12719608129684925192e+00, + +5.11207951146807644818e+00, + +8.42404590141772420927e+00, + +5.21451598682361504063e+00, + +1.00000000000000000254e+00, + }; + + static const T PQ[] = { + +5.71323128072548699714e-04, + +6.88455908754495404082e-02, + +1.10514232634061696926e+00, + +5.07386386128601488557e+00, + +8.39985554327604159757e+00, + +5.20982848682361821619e+00, + +9.99999999999999997461e-01, + }; + + static const T QP[] = { + +5.10862594750176621635e-02, + +4.98213872951233449420e+00, + +7.58238284132545283818e+01, + +3.66779609360150777800e+02, + +7.10856304998926107277e+02, + +5.97489612400613639965e+02, + +2.11688757100572135698e+02, + +2.52070205858023719784e+01, + }; + + static const T QQ[] = { + +7.42373277035675149943e+01, + +1.05644886038262816351e+03, + +4.98641058337653607651e+03, + +9.56231892404756170795e+03, + +7.99704160447350683650e+03, + +2.82619278517639096600e+03, + +3.36093607810698293419e+02, + }; + + static const T RP[] = { + -8.99971225705559398224e+08, + +4.52228297998194034323e+11, + -7.27494245221818276015e+13, + +3.68295732863852883286e+15, + }; + + static const T RQ[] = { + +6.20836478118054335476e+02, + +2.56987256757748830383e+05, + +8.35146791431949253037e+07, + +2.21511595479792499675e+10, + +4.74914122079991414898e+12, + +7.84369607876235854894e+14, + +8.95222336184627338078e+16, + +5.32278620332680085395e+18, + }; + + if (x < T(0.0)) { + return -bessel_j1_forward(-x); + } + + if (x <= T(5.0)) { + T rp = 0.0; + + for (uint8_t index = 0; index <= 3; index++) { + rp = rp * (x * x) + RP[index]; + } + + T rq = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + rq = rq * (x * x) + RQ[index]; + } + + return rp / rq * x * (x * x - T(1.46819706421238932572e+01)) * (x * x - T(4.92184563216946036703e+01)); + } + + T pp = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pp = pp * (T(5.0) / x * (T(5.0) / x)) + PP[index]; + } + + T pq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pq = pq * (T(5.0) / x * (T(5.0) / x)) + PQ[index]; + } + + T qp = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + qp = qp * (T(5.0) / x * (T(5.0) / x)) + QP[index]; + } + + T qq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + qq = qq * (T(5.0) / x * (T(5.0) / x)) + QQ[index]; + } + + return (pp / pq * cos(x - T(2.356194490192344928846982537459627163)) - T(5.0) / x * (qp / qq) * sin(x - T(2.356194490192344928846982537459627163))) * T(0.797884560802865355879892119868763737) / sqrt(x); + } // bessel_j1_forward(T x) +); // bessel_j1_string + +const auto bessel_y1_string = bessel_j1_string + jiterator_stringify( + template + T bessel_y1_forward(T x) { + static const T PP[] = { + +7.62125616208173112003e-04, + +7.31397056940917570436e-02, + +1.12719608129684925192e+00, + +5.11207951146807644818e+00, + +8.42404590141772420927e+00, + +5.21451598682361504063e+00, + +1.00000000000000000254e+00, + }; + + static const T PQ[] = { + +5.71323128072548699714e-04, + +6.88455908754495404082e-02, + +1.10514232634061696926e+00, + +5.07386386128601488557e+00, + +8.39985554327604159757e+00, + +5.20982848682361821619e+00, + +9.99999999999999997461e-01, + }; + + static const T QP[] = { + +5.10862594750176621635e-02, + +4.98213872951233449420e+00, + +7.58238284132545283818e+01, + +3.66779609360150777800e+02, + +7.10856304998926107277e+02, + +5.97489612400613639965e+02, + +2.11688757100572135698e+02, + +2.52070205858023719784e+01, + }; + + static const T QQ[] = { + +7.42373277035675149943e+01, + +1.05644886038262816351e+03, + +4.98641058337653607651e+03, + +9.56231892404756170795e+03, + +7.99704160447350683650e+03, + +2.82619278517639096600e+03, + +3.36093607810698293419e+02, + }; + + static const T YP[] = { + +1.26320474790178026440e+09, + -6.47355876379160291031e+11, + +1.14509511541823727583e+14, + -8.12770255501325109621e+15, + +2.02439475713594898196e+17, + -7.78877196265950026825e+17, + }; + + static const T YQ[] = { + +5.94301592346128195359e+02, + +2.35564092943068577943e+05, + +7.34811944459721705660e+07, + +1.87601316108706159478e+10, + +3.88231277496238566008e+12, + +6.20557727146953693363e+14, + +6.87141087355300489866e+16, + +3.97270608116560655612e+18, + }; + + if (x <= T(5.0)) { + if (x == T(0.0)) { + return NEG_INFINITY; + } + + if (x <= T(0.0)) { + return NAN; + } + + T yp = 0.0; + + for (uint8_t index = 0; index <= 5; index++) { + yp = yp * (x * x) + YP[index]; + } + + T yq = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + yq = yq * (x * x) + YQ[index]; + } + + return x * (yp / yq) + (T(0.636619772367581343075535053490057448) * (bessel_j1_forward(x) * log(x) - T(1.0) / x)); + } + + T pp = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pp = pp * (T(5.0) / x * (T(5.0) / x)) + PP[index]; + } + + T pq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pq = pq * (T(5.0) / x * (T(5.0) / x)) + PQ[index]; + } + + T qp = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + qp = qp * (T(5.0) / x * (T(5.0) / x)) + QP[index]; + } + + T qq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + qq = qq * (T(5.0) / x * (T(5.0) / x)) + QQ[index]; + } + + return (pp / pq * sin(x - T(2.356194490192344928846982537459627163)) + T(5.0) / x * (qp / qq) * cos(x - T(2.356194490192344928846982537459627163))) * T(0.797884560802865355879892119868763737) / sqrt(x); + } // bessel_y1_forward(T x) +); // bessel_y1_string + +const auto chebyshev_polynomial_t_string = jiterator_stringify( + template + T chebyshev_polynomial_t_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (abs(x) == T(1.0)) { + if (x > T(0.0) || n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if ((n > 6) && (abs(x) < T(1.0))) { + return cos(n * acos(x)); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x; + } + + T p = T(1.0); + T q = x; + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x) * q - p; + p = q; + q = r; + } + + return r; + } // chebyshev_polynomial_t_forward(T x, int64_t n) + + template + T chebyshev_polynomial_t_forward(T x, T n) { + return chebyshev_polynomial_t_forward(x, static_cast(n)); + } // chebyshev_polynomial_t_forward(T x, T n) +); // chebyshev_polynomial_t_string + +const auto chebyshev_polynomial_u_string = jiterator_stringify( + template + T chebyshev_polynomial_u_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (abs(x) == T(1.0)) { + if (x > T(0.0) || n % 2 == 0) { + return n + 1; + } + + return -(n + 1); + } + + if ((n > 8) && (abs(x) < T(1.0))) { + if (sin(acos(x)) != T(0.0)) { + return sin((n + 1) * acos(x)) / sin(acos(x)); + } + + return (n + 1) * cos((n + 1) * acos(x)) / x; + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x; + } + + T p = T(1.0); + T q = x + x; + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x) * q - p; + p = q; + q = r; + } + + return r; + } // chebyshev_polynomial_u_forward(T x, int64_t n) + + template + T chebyshev_polynomial_u_forward(T x, T n) { + return chebyshev_polynomial_u_forward(x, static_cast(n)); + } // chebyshev_polynomial_u_forward(T x, T n) +); // chebyshev_polynomial_u_string + +const auto chebyshev_polynomial_v_string = jiterator_stringify( + template + T chebyshev_polynomial_v_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (abs(x) == T(1.0)) { + if (x > T(0.0)) { + return T(1.0); + } + + if (n % 2 == 0) { + return n + n + 1; + } + + return -(n + n + 1); + } + + if ((n > 8) && (abs(x) < T(1.0))) { + if (sin(acos(x) / T(2.0)) != T(1.0)) { + return cos((n + T(0.5)) * acos(x)) / cos(acos(x) / T(2.0)); + } + + if (n % 2 == 0) { + return n + n + 1; + } + + return -(n + n + 1); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x - T(1.0); + } + + T p = T(1.0); + T q = x + x - T(1.0); + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x) * q - p; + p = q; + q = r; + } + + return r; + } // chebyshev_polynomial_v_forward(T x, int64_t n) + + template + T chebyshev_polynomial_v_forward(T x, T n) { + return chebyshev_polynomial_v_forward(x, static_cast(n)); + } // chebyshev_polynomial_v_forward(T x, T n) +); // chebyshev_polynomial_v_string + +const auto chebyshev_polynomial_w_string = jiterator_stringify( + template + T chebyshev_polynomial_w_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (abs(x) == T(1.0)) { + if (x > T(0.0)) { + return n + n + 1; + } + + if (n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if ((n > 8) && (abs(x) < T(1.0))) { + if (cos(acos(x) / T(2.0)) != T(1.0)) { + return sin((n + T(0.5)) * acos(x)) / sin(acos(x) / T(2.0)); + } + + if (x > T(0.0)) { + return n + n + 1; + } + + if (n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x + T(1.0); + } + + T p = T(1.0); + T q = x + x + T(1.0); + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x) * q - p; + p = q; + q = r; + } + + return r; + } // chebyshev_polynomial_w_forward(T x, int64_t n) + + template + T chebyshev_polynomial_w_forward(T x, T n) { + return chebyshev_polynomial_w_forward(x, static_cast(n)); + } // chebyshev_polynomial_w_forward(T x, T n) +); // chebyshev_polynomial_w_string + +const auto hermite_polynomial_h_string = jiterator_stringify( + template + unsigned short getHermitianLimit() { + if (sizeof(T) <= sizeof(float)) { + return 128; + } else if (sizeof(T) <= sizeof(double)) { + return 512; + } else { + return 1024; + } + } + + template + T hermite_polynomial_h_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x; + } + + if (n > getHermitianLimit()) { + return NAN; + } + + T p = T(1.0); + T q = x + x; + T r = T(0.0); + + for (int64_t k = 2; k < n + n; k += 2) { + r = (x + x) * q - k * p; + p = q; + q = r; + } + + return r; + } // hermite_polynomial_h_forward(T x, int64_t n) + + template + T hermite_polynomial_h_forward(T x, T n) { + return hermite_polynomial_h_forward(x, static_cast(n)); + } // hermite_polynomial_h_forward(T x, T n) +); // hermite_polynomial_h_string + +const auto hermite_polynomial_he_string = jiterator_stringify( + template + unsigned short getHermitianLimit() { + if (sizeof(T) <= sizeof(float)) { + return 128; + } else if (sizeof(T) <= sizeof(double)) { + return 512; + } else { + return 1024; + } + } + + template + T hermite_polynomial_he_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x; + } + + if (n > getHermitianLimit()) { + return NAN; + } + + T p = T(1.0); + T q = x; + T r; + + for (int64_t k = 1; k < n; k++) { + r = x * q - k * p; + p = q; + q = r; + } + + return r; + } // hermite_polynomial_he_forward(T x, int64_t n) + + template + T hermite_polynomial_he_forward(T x, T n) { + return hermite_polynomial_he_forward(x, static_cast(n)); + } // hermite_polynomial_he_forward(T x, T n) +); // hermite_polynomial_he_string + +const auto laguerre_polynomial_l_string = jiterator_stringify( + template + T laguerre_polynomial_l_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (abs(x) == T(0.0)) { + return T(1.0); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return T(1.0) - x; + } + + T p = T(1.0); + T q = T(1.0) - x; + T r; + + for (int64_t k = 1; k < n; k++) { + r = (((k + k) + (T(1.0) - x)) * q - k * p) / (k + 1); + p = q; + q = r; + } + + return r; + } // laguerre_polynomial_l_forward(T x, int64_t n) + + template + T laguerre_polynomial_l_forward(T x, T n) { + return laguerre_polynomial_l_forward(x, static_cast(n)); + } // laguerre_polynomial_l_forward(T x, T n) +); // laguerre_polynomial_l_string + +const auto legendre_polynomial_p_string = jiterator_stringify( + template + T legendre_polynomial_p_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (abs(x) == T(1.0)) { + if (x > T(0.0) || n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x; + } + + T p = T(1.0); + T q = x; + T r; + + for (int64_t k = 1; k < n; k++) { + r = ((k + k + 1) * x * q - k * p) / (k + 1); + p = q; + q = r; + } + + return r; + } // legendre_polynomial_p_forward(T x, int64_t n) + + template + T legendre_polynomial_p_forward(T x, T n) { + return legendre_polynomial_p_forward(x, static_cast(n)); + } // legendre_polynomial_p_forward(T x, T n) +); // legendre_polynomial_p_string + +const auto modified_bessel_i0_string = jiterator_stringify( + template + T modified_bessel_i0_forward(T x) { + static const T A[] = { + -4.41534164647933937950e-18, + +3.33079451882223809783e-17, + -2.43127984654795469359e-16, + +1.71539128555513303061e-15, + -1.16853328779934516808e-14, + +7.67618549860493561688e-14, + -4.85644678311192946090e-13, + +2.95505266312963983461e-12, + -1.72682629144155570723e-11, + +9.67580903537323691224e-11, + -5.18979560163526290666e-10, + +2.65982372468238665035e-09, + -1.30002500998624804212e-08, + +6.04699502254191894932e-08, + -2.67079385394061173391e-07, + +1.11738753912010371815e-06, + -4.41673835845875056359e-06, + +1.64484480707288970893e-05, + -5.75419501008210370398e-05, + +1.88502885095841655729e-04, + -5.76375574538582365885e-04, + +1.63947561694133579842e-03, + -4.32430999505057594430e-03, + +1.05464603945949983183e-02, + -2.37374148058994688156e-02, + +4.93052842396707084878e-02, + -9.49010970480476444210e-02, + +1.71620901522208775349e-01, + -3.04682672343198398683e-01, + +6.76795274409476084995e-01, + }; + + static const T B[] = { + -7.23318048787475395456e-18, + -4.83050448594418207126e-18, + +4.46562142029675999901e-17, + +3.46122286769746109310e-17, + -2.82762398051658348494e-16, + -3.42548561967721913462e-16, + +1.77256013305652638360e-15, + +3.81168066935262242075e-15, + -9.55484669882830764870e-15, + -4.15056934728722208663e-14, + +1.54008621752140982691e-14, + +3.85277838274214270114e-13, + +7.18012445138366623367e-13, + -1.79417853150680611778e-12, + -1.32158118404477131188e-11, + -3.14991652796324136454e-11, + +1.18891471078464383424e-11, + +4.94060238822496958910e-10, + +3.39623202570838634515e-09, + +2.26666899049817806459e-08, + +2.04891858946906374183e-07, + +2.89137052083475648297e-06, + +6.88975834691682398426e-05, + +3.36911647825569408990e-03, + +8.04490411014108831608e-01, + }; + + T p; + T q = 0.0; + + if (abs(x) <= T(8.0)) { + T a = A[0]; + + for (uint8_t index = 1; index < 30; index++) { + p = q; + q = a; + a = ((abs(x) / T(2.0)) - T(2.0)) * q - p + A[index]; + } + + return exp(abs(x)) * (T(0.5) * (a - p)); + } + + T b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(32.0) / abs(x) - T(2.0)) * q - p + B[index]; + } + + return exp(abs(x)) * (T(0.5) * (b - p)) / sqrt(abs(x)); + } // modified_bessel_i0_forward(T x) +); // modified_bessel_i0_string + +const auto modified_bessel_i1_string = jiterator_stringify( + template + T modified_bessel_i1_forward(T x) { + static const T A[] = { + +2.77791411276104639959e-18, + -2.11142121435816608115e-17, + +1.55363195773620046921e-16, + -1.10559694773538630805e-15, + +7.60068429473540693410e-15, + -5.04218550472791168711e-14, + +3.22379336594557470981e-13, + -1.98397439776494371520e-12, + +1.17361862988909016308e-11, + -6.66348972350202774223e-11, + +3.62559028155211703701e-10, + -1.88724975172282928790e-09, + +9.38153738649577178388e-09, + -4.44505912879632808065e-08, + +2.00329475355213526229e-07, + -8.56872026469545474066e-07, + +3.47025130813767847674e-06, + -1.32731636560394358279e-05, + +4.78156510755005422638e-05, + -1.61760815825896745588e-04, + +5.12285956168575772895e-04, + -1.51357245063125314899e-03, + +4.15642294431288815669e-03, + -1.05640848946261981558e-02, + +2.47264490306265168283e-02, + -5.29459812080949914269e-02, + +1.02643658689847095384e-01, + -1.76416518357834055153e-01, + +2.52587186443633654823e-01, + }; + + static const T B[] = { + +7.51729631084210481353e-18, + +4.41434832307170791151e-18, + -4.65030536848935832153e-17, + -3.20952592199342395980e-17, + +2.96262899764595013876e-16, + +3.30820231092092828324e-16, + -1.88035477551078244854e-15, + -3.81440307243700780478e-15, + +1.04202769841288027642e-14, + +4.27244001671195135429e-14, + -2.10154184277266431302e-14, + -4.08355111109219731823e-13, + -7.19855177624590851209e-13, + +2.03562854414708950722e-12, + +1.41258074366137813316e-11, + +3.25260358301548823856e-11, + -1.89749581235054123450e-11, + -5.58974346219658380687e-10, + -3.83538038596423702205e-09, + -2.63146884688951950684e-08, + -2.51223623787020892529e-07, + -3.88256480887769039346e-06, + -1.10588938762623716291e-04, + -9.76109749136146840777e-03, + +7.78576235018280120474e-01, + }; + + T p; + T q = 0.0; + + if (abs(x) <= T(8.0)) { + T a = A[0]; + + for (uint8_t index = 1; index < 29; index++) { + p = q; + q = a; + a = ((abs(x) / T(2.0)) - T(2.0)) * q - p + A[index]; + } + + if (x < T(0.0)) { + return -(T(0.5) * (a - p) * abs(x) * exp(abs(x))); + } + + return T(0.5) * (a - p) * abs(x) * exp(abs(x)); + } + + T b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(32.0) / abs(x) - T(2.0)) * q - p + B[index]; + } + + if (x < T(0.0)) { + return -(exp(abs(x)) * (T(0.5) * (b - p)) / sqrt(abs(x))); + } + + return exp(abs(x)) * (T(0.5) * (b - p)) / sqrt(abs(x)); + } // modified_bessel_i1_forward(T x) +); // modified_bessel_i1_string + +const auto modified_bessel_k0_string = modified_bessel_i0_string + jiterator_stringify( + template + T modified_bessel_k0_forward(T x) { + static const T A[] = { + +1.37446543561352307156e-16, + +4.25981614279661018399e-14, + +1.03496952576338420167e-11, + +1.90451637722020886025e-09, + +2.53479107902614945675e-07, + +2.28621210311945178607e-05, + +1.26461541144692592338e-03, + +3.59799365153615016266e-02, + +3.44289899924628486886e-01, + -5.35327393233902768720e-01, + }; + + static const T B[] = { + +5.30043377268626276149e-18, + -1.64758043015242134646e-17, + +5.21039150503902756861e-17, + -1.67823109680541210385e-16, + +5.51205597852431940784e-16, + -1.84859337734377901440e-15, + +6.34007647740507060557e-15, + -2.22751332699166985548e-14, + +8.03289077536357521100e-14, + -2.98009692317273043925e-13, + +1.14034058820847496303e-12, + -4.51459788337394416547e-12, + +1.85594911495471785253e-11, + -7.95748924447710747776e-11, + +3.57739728140030116597e-10, + -1.69753450938905987466e-09, + +8.57403401741422608519e-09, + -4.66048989768794782956e-08, + +2.76681363944501510342e-07, + -1.83175552271911948767e-06, + +1.39498137188764993662e-05, + -1.28495495816278026384e-04, + +1.56988388573005337491e-03, + -3.14481013119645005427e-02, + +2.44030308206595545468e+00, + }; + + if (x == T(0.0)) { + return INFINITY; + } + + if (x < T(0.0)) { + return NAN; + } + + T p; + T q = 0.0; + + if (x <= T(2.0)) { + T a = A[0]; + + for (uint8_t index = 1; index < 10; index++) { + p = q; + q = a; + a = (x * x - T(2.0)) * q - p + A[index]; + } + + return T(0.5) * (a - p) - log(0.5 * x) * modified_bessel_i0_forward(x); + } + + T b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(8.0) / x - T(2.0)) * q - p + B[index]; + } + + return exp(-x) * (T(0.5) * (b - p)) / sqrt(x); + } // modified_bessel_k0_forward(T x) +); // modified_bessel_k0_string + +const auto scaled_modified_bessel_k0_string = modified_bessel_i0_string + jiterator_stringify( + template + T scaled_modified_bessel_k0_forward(T x) { + static const T A[] = { + +1.37446543561352307156e-16, + +4.25981614279661018399e-14, + +1.03496952576338420167e-11, + +1.90451637722020886025e-09, + +2.53479107902614945675e-07, + +2.28621210311945178607e-05, + +1.26461541144692592338e-03, + +3.59799365153615016266e-02, + +3.44289899924628486886e-01, + -5.35327393233902768720e-01, + }; + + static const T B[] = { + +5.30043377268626276149e-18, + -1.64758043015242134646e-17, + +5.21039150503902756861e-17, + -1.67823109680541210385e-16, + +5.51205597852431940784e-16, + -1.84859337734377901440e-15, + +6.34007647740507060557e-15, + -2.22751332699166985548e-14, + +8.03289077536357521100e-14, + -2.98009692317273043925e-13, + +1.14034058820847496303e-12, + -4.51459788337394416547e-12, + +1.85594911495471785253e-11, + -7.95748924447710747776e-11, + +3.57739728140030116597e-10, + -1.69753450938905987466e-09, + +8.57403401741422608519e-09, + -4.66048989768794782956e-08, + +2.76681363944501510342e-07, + -1.83175552271911948767e-06, + +1.39498137188764993662e-05, + -1.28495495816278026384e-04, + +1.56988388573005337491e-03, + -3.14481013119645005427e-02, + +2.44030308206595545468e+00, + }; + + if (x == T(0.0)) { + return INFINITY; + } + + if (x < T(0.0)) { + return NAN; + } + + T p; + T q = 0.0; + + if (x <= T(2.0)) { + T a = A[0]; + + for (uint8_t index = 1; index < 10; index++) { + p = q; + q = a; + a = (x * x - T(2.0)) * q - p + A[index]; + } + + return (T(0.5) * (a - p) - log(T(0.5) * x) * modified_bessel_i0_forward(x)) * exp(x); + } + + T b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(8.0) / x - T(2.0)) * q - p + B[index]; + } + + return T(0.5) * (b - p) / sqrt(x); + } // T scaled_modified_bessel_k0_forward(T x) +); // scaled_modified_bessel_k0_string + +const auto modified_bessel_k1_string = modified_bessel_i1_string + jiterator_stringify( + template + T modified_bessel_k1_forward(T x) { + static const T A[] = { + -7.02386347938628759343e-18, + -2.42744985051936593393e-15, + -6.66690169419932900609e-13, + -1.41148839263352776110e-10, + -2.21338763073472585583e-08, + -2.43340614156596823496e-06, + -1.73028895751305206302e-04, + -6.97572385963986435018e-03, + -1.22611180822657148235e-01, + -3.53155960776544875667e-01, + +1.52530022733894777053e+00, + }; + + static const T B[] = { + -5.75674448366501715755e-18, + +1.79405087314755922667e-17, + -5.68946255844285935196e-17, + +1.83809354436663880070e-16, + -6.05704724837331885336e-16, + +2.03870316562433424052e-15, + -7.01983709041831346144e-15, + +2.47715442448130437068e-14, + -8.97670518232499435011e-14, + +3.34841966607842919884e-13, + -1.28917396095102890680e-12, + +5.13963967348173025100e-12, + -2.12996783842756842877e-11, + +9.21831518760500529508e-11, + -4.19035475934189648750e-10, + +2.01504975519703286596e-09, + -1.03457624656780970260e-08, + +5.74108412545004946722e-08, + -3.50196060308781257119e-07, + +2.40648494783721712015e-06, + -1.93619797416608296024e-05, + +1.95215518471351631108e-04, + -2.85781685962277938680e-03, + +1.03923736576817238437e-01, + +2.72062619048444266945e+00, + }; + + if (x == T(0.0)) { + return INFINITY; + } + + if (x < T(0.0)) { + return NAN; + } + + T p; + T q = 0.0; + + if (x <= T(2.0)) { + T a = A[0]; + + for (uint8_t index = 1; index < 11; index++) { + p = q; + q = a; + a = (x * x - T(2.0)) * q - p + A[index]; + } + + return log(T(0.5) * x) * modified_bessel_i1_forward(x) + T(0.5) * (a - p) / x; + } + + T b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(8.0) / x - T(2.0)) * q - p + B[index]; + } + + return exp(-x) * (T(0.5) * (b - p)) / sqrt(x); + } // modified_bessel_k1_forward(T x) +); // modified_bessel_k1_string + +const auto scaled_modified_bessel_k1_string = modified_bessel_i1_string + jiterator_stringify( + template + T scaled_modified_bessel_k1_forward(T x) { + static const T A[] = { + -7.02386347938628759343e-18, + -2.42744985051936593393e-15, + -6.66690169419932900609e-13, + -1.41148839263352776110e-10, + -2.21338763073472585583e-08, + -2.43340614156596823496e-06, + -1.73028895751305206302e-04, + -6.97572385963986435018e-03, + -1.22611180822657148235e-01, + -3.53155960776544875667e-01, + +1.52530022733894777053e+00, + }; + + static const T B[] = { + -5.75674448366501715755e-18, + +1.79405087314755922667e-17, + -5.68946255844285935196e-17, + +1.83809354436663880070e-16, + -6.05704724837331885336e-16, + +2.03870316562433424052e-15, + -7.01983709041831346144e-15, + +2.47715442448130437068e-14, + -8.97670518232499435011e-14, + +3.34841966607842919884e-13, + -1.28917396095102890680e-12, + +5.13963967348173025100e-12, + -2.12996783842756842877e-11, + +9.21831518760500529508e-11, + -4.19035475934189648750e-10, + +2.01504975519703286596e-09, + -1.03457624656780970260e-08, + +5.74108412545004946722e-08, + -3.50196060308781257119e-07, + +2.40648494783721712015e-06, + -1.93619797416608296024e-05, + +1.95215518471351631108e-04, + -2.85781685962277938680e-03, + +1.03923736576817238437e-01, + +2.72062619048444266945e+00, + }; + + if (x == T(0.0)) { + return INFINITY; + } + + if (x < T(0.0)) { + return NAN; + } + + T p; + T q = 0.0; + + if (x <= T(2.0)) { + T a = A[0]; + + for (uint8_t index = 1; index < 11; index++) { + p = q; + q = a; + a = (x * x - T(2.0)) * q - p + A[index]; + } + + return (log(T(0.5) * x) * modified_bessel_i1_forward(x) + T(0.5) * (a - p) / x) * exp(x); + } + + T b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(8.0) / x - T(2.0)) * q - p + B[index]; + } + + return (T(0.5) * (b - p) / sqrt(x)); + } // T scaled_modified_bessel_k1_forward(T x) +); // scaled_modified_bessel_k1_string + +const auto shifted_chebyshev_polynomial_t_string = jiterator_stringify( + template + T shifted_chebyshev_polynomial_t_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (x == T(1.0)) { + return T(1.0); + } + + if (x == T(0.0)) { + if (n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if ((n > 6) && (abs(x + x - T(1.0)) < T(1.0))) { + return cos(n * acos(x + x - T(1.0))); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x - T(1.0); + } + + T p = T(1.0); + T q = x + x - T(1.0); + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; + p = q; + q = r; + } + + return r; + } // shifted_chebyshev_polynomial_t_forward(T x, int64_t n) + + template + T shifted_chebyshev_polynomial_t_forward(T x, T n) { + return shifted_chebyshev_polynomial_t_forward(x, static_cast(n)); + } // shifted_chebyshev_polynomial_t_forward(T x, T n) +); // shifted_chebyshev_polynomial_t_string + +const auto shifted_chebyshev_polynomial_u_string = jiterator_stringify( + template + T shifted_chebyshev_polynomial_u_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (x == T(1.0)) { + return n + 1; + } + + if (x == T(0.0)) { + if (n % 2 == 0) { + return n + 1; + } + + return -(n + 1); + } + + if ((n > 6) && (abs(x + x - T(1.0)) < T(1.0))) { + if (sin(acos(x + x - T(1.0))) != T(0.0)) { + return sin((n + 1) * acos(x + x - T(1.0))) / sin(acos(x + x - T(1.0))); + } + + return (n + 1) * cos((n + 1) * acos(x + x - T(1.0))) / (x + x - T(1.0)); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x - T(1.0) + (x + x - T(1.0)); + } + + T p = T(1.0); + T q = x + x - T(1.0) + (x + x - T(1.0)); + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; + p = q; + q = r; + } + + return r; + } // shifted_chebyshev_polynomial_u_forward(T x, int64_t n) + + template + T shifted_chebyshev_polynomial_u_forward(T x, T n) { + return shifted_chebyshev_polynomial_u_forward(x, static_cast(n)); + } // shifted_chebyshev_polynomial_u_forward(T x, T n) +); // shifted_chebyshev_polynomial_u_string + +const auto shifted_chebyshev_polynomial_v_string = jiterator_stringify( + template + T shifted_chebyshev_polynomial_v_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (x == T(1.0)) { + return T(1.0); + } + + if (x == T(0.0)) { + if (n % 2 == 0) { + return (n + n + 1); + } + + return -(n + n + 1); + } + + if ((n > 6) && (abs(x + x - T(1.0)) < T(1.0))) { + if (sin(acos(x + x - T(1.0)) / T(2.0)) != T(1.0)) { + return cos(((n) + T(0.5)) * acos(x + x - T(1.0))) / cos(acos(x + x - T(1.0)) / T(2.0)); + } + + if (n % 2 == 0) { + return n + n + 1; + } + + return -(n + n + 1); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x - T(1.0) + (x + x - T(1.0)) - T(1.0); + } + + T p = T(1.0); + T q = x + x - T(1.0) + (x + x - T(1.0)) - T(1.0); + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; + p = q; + q = r; + } + + return r; + } // shifted_chebyshev_polynomial_v_forward(T x, int64_t n) + + template + T shifted_chebyshev_polynomial_v_forward(T x, T n) { + return shifted_chebyshev_polynomial_v_forward(x, static_cast(n)); + } // shifted_chebyshev_polynomial_v_forward(T x, T n) +); // shifted_chebyshev_polynomial_v_string + +const auto shifted_chebyshev_polynomial_w_string = jiterator_stringify( + template + T shifted_chebyshev_polynomial_w_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (x == T(1.0)) { + return n + n + 1; + } + + if (x == T(0.0)) { + if (n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if ((n > 4) && (abs(x + x - T(1.0)) < T(1.0))) { + if (cos(acos(x + x - T(1.0)) / T(2.0)) != T(1.0)) { + return sin((n + T(0.5)) * acos(x + x - T(1.0))) / sin(acos(x + x - T(1.0)) / T(2.0)); + } + + if (n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x - T(1.0) + (x + x - T(1.0)) + T(1.0); + } + + T p = T(1.0); + T q = x + x - T(1.0) + (x + x - T(1.0)) + T(1.0); + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; + p = q; + q = r; + } + + return r; + } // shifted_chebyshev_polynomial_w_forward(T x, int64_t n) + + template + T shifted_chebyshev_polynomial_w_forward(T x, T n) { + return shifted_chebyshev_polynomial_w_forward(x, static_cast(n)); + } // shifted_chebyshev_polynomial_w_forward(T x, T n) +); // shifted_chebyshev_polynomial_w_string + +const auto spherical_bessel_j0_string = jiterator_stringify( + template + T spherical_bessel_j0_forward(T x) { + if (isinf(x)) { + return T(0.0); + } + + if (abs(x) < T(0.5)) { + return T(1.0) + x * x * (T(-1.0) / T(6.0) + x * x * (T(1.0) / T(120.0) + x * x * (T(-1.0) / T(5040.0) + x * x * (T(1.0) / T(362880.0) + x * x * (T(-1.0) / T(39916800.0) + x * x * (T(1.0) / T(6227020800.0))))))); + } + + return sin(x) / x; + } // T spherical_bessel_j0_forward(T x) +); // spherical_bessel_j0_string + +#else // !AT_USE_JITERATOR() -- kernels must be precompiled + +template +static inline C10_HOST_DEVICE scalar_t calc_gcd(scalar_t a_in, scalar_t b_in) { + scalar_t a = ::abs(a_in); + scalar_t b = ::abs(b_in); + while (a != 0) { + scalar_t c = a; + a = b % a; + b = c; + } + return b; +} + +/* + * For licensing information, please refer to the cpu implementation located in "ATen/native/Math.h". + */ +template +static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) { + // [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma + using accscalar_t = at::acc_type; + static const double PI_f64 = 3.14159265358979323846; + const accscalar_t PSI_10 = 2.25175258906672110764; + const accscalar_t A[] = { + 8.33333333333333333333E-2, + -2.10927960927960927961E-2, + 7.57575757575757575758E-3, + -4.16666666666666666667E-3, + 3.96825396825396825397E-3, + -8.33333333333333333333E-3, + 8.33333333333333333333E-2, + }; + + accscalar_t x = static_cast(in); + if (x == 0) { + // As per C++ standard for gamma related functions and SciPy, + // If the argument is ±0, ±∞ is returned + return std::copysign(static_cast(INFINITY), -x); + } + + bool x_is_integer = x == ::trunc(x); + accscalar_t result = 0; + if (x < 0) { + if (x_is_integer) { + // As per C++ standard for gamma related functions and SciPy, + // If the argument is a negative integer, NaN is returned + return static_cast(NAN); + } + // Extracts the fractional part of x as r, since tan(pi * r) is more numerically + // accurate than tan(pi * x). While these operations are mathematically equivalent + // since both x and r are in radians and tan() has a periodicity of pi, in practice + // the computation of pi * x is a source of error (when |x| > 1). + double q, r; + r = ::modf(static_cast(x), &q); + result = static_cast(- PI_f64 / ::tan(PI_f64 * r)); + x = 1 - x; + } + + while (x < 10) { + result -= 1 / x; + x += 1; + } + if (x == 10) { + return static_cast(result + PSI_10); + } + + accscalar_t y = 0; + if (x < 1.0e17) { + accscalar_t z = 1 / (x * x); + + accscalar_t polevl_result = 0; + for (int i = 0; i <= 6; i++) { + polevl_result = polevl_result * z + A[i]; + } + y = z * polevl_result; + } + + return static_cast(::log(x) - (static_cast(0.5) / x) - y + result); +} + +template +static inline C10_HOST_DEVICE scalar_t calc_trigamma(scalar_t in) { + using accscalar_t = at::acc_type; + const accscalar_t PI = 3.14159265358979323846; + accscalar_t x = static_cast(in); + accscalar_t sign = +1; + accscalar_t result = 0; + if (x < 0.5f) { + sign = -1; + accscalar_t sin_pi_x = ::sin(PI * x); + result -= (PI * PI) / (sin_pi_x * sin_pi_x); + x = 1 - x; + } + for (int i = 0; i < 6; ++i) { + result += 1 / (x * x); + x += 1; + } + const accscalar_t one = static_cast(1); + const accscalar_t ixx = 1 / (x*x); + result += (1 + 1 / (2*x) + ixx * (one/6 - ixx * (one/30 - ixx * (one/42)))) / x; + return static_cast(sign * result); +} + +/* + * For licensing information and documentation, please refer to the cpu implementation located in "ATen/native/Math.h". + */ +template +static inline C10_HOST_DEVICE scalar_t +chbevl(scalar_t _x, const scalar_t array[], size_t len) { + static_assert(!std::is_same() && !std::is_same(), "don't instantiate with low precision type"); + + scalar_t b0, b1, b2; + + b0 = array[0]; + b1 = 0; + + for (size_t i = 1; i < len; ++i) { + b2 = b1; + b1 = b0; + b0 = _x * b1 - b2 + array[i]; + } + + return (0.5 * (b0 - b2)); +} + +/* + * For licensing information and documentation, please refer to the cpu implementation located in "ATen/native/Math.h". + */ +template +C10_HOST_DEVICE inline std::tuple chebyshev_coefficients_i0e_A() { + /* Chebyshev coefficients for exp(-x) I0(x) + * in the interval [0,8]. + * + * lim(x->0){ exp(-x) I0(x) } = 1. + */ + static const T coefficients[] = { + -4.41534164647933937950E-18, 3.33079451882223809783E-17, + -2.43127984654795469359E-16, 1.71539128555513303061E-15, + -1.16853328779934516808E-14, 7.67618549860493561688E-14, + -4.85644678311192946090E-13, 2.95505266312963983461E-12, + -1.72682629144155570723E-11, 9.67580903537323691224E-11, + -5.18979560163526290666E-10, 2.65982372468238665035E-9, + -1.30002500998624804212E-8, 6.04699502254191894932E-8, + -2.67079385394061173391E-7, 1.11738753912010371815E-6, + -4.41673835845875056359E-6, 1.64484480707288970893E-5, + -5.75419501008210370398E-5, 1.88502885095841655729E-4, + -5.76375574538582365885E-4, 1.63947561694133579842E-3, + -4.32430999505057594430E-3, 1.05464603945949983183E-2, + -2.37374148058994688156E-2, 4.93052842396707084878E-2, + -9.49010970480476444210E-2, 1.71620901522208775349E-1, + -3.04682672343198398683E-1, 6.76795274409476084995E-1}; + + return std::make_tuple(coefficients, 30); +} + +template +C10_HOST_DEVICE inline std::tuple chebyshev_coefficients_i0e_B() { + /* Chebyshev coefficients for exp(-x) sqrt(x) I0(x) + * in the inverted interval [8,infinity]. + * + * lim(x->inf){ exp(-x) sqrt(x) I0(x) } = 1/sqrt(2pi). + */ + static const T coefficients[] = { + -7.23318048787475395456E-18, -4.83050448594418207126E-18, + 4.46562142029675999901E-17, 3.46122286769746109310E-17, + -2.82762398051658348494E-16, -3.42548561967721913462E-16, + 1.77256013305652638360E-15, 3.81168066935262242075E-15, + -9.55484669882830764870E-15, -4.15056934728722208663E-14, + 1.54008621752140982691E-14, 3.85277838274214270114E-13, + 7.18012445138366623367E-13, -1.79417853150680611778E-12, + -1.32158118404477131188E-11, -3.14991652796324136454E-11, + 1.18891471078464383424E-11, 4.94060238822496958910E-10, + 3.39623202570838634515E-9, 2.26666899049817806459E-8, + 2.04891858946906374183E-7, 2.89137052083475648297E-6, + 6.88975834691682398426E-5, 3.36911647825569408990E-3, + 8.04490411014108831608E-1}; + + return std::make_tuple(coefficients, 25); +} + +template +static inline C10_HOST_DEVICE scalar_t calc_i0(scalar_t _x) { + static_assert(!std::is_same() && !std::is_same(), "don't instantiate with low precision type"); + // Upcast input for numerical accuracy purposes + // Needed for accurate results if input is bfloat16 or float16 + scalar_t x = ::abs(_x); + + if (x <= scalar_t{8.0}) { + auto [A, len] = chebyshev_coefficients_i0e_A(); + scalar_t y = (x / scalar_t{2.0}) - scalar_t{2.0}; + return (::exp(x) * chbevl(y, A, len)); + } + + auto [B, len] = chebyshev_coefficients_i0e_B(); + return (::exp(x) * chbevl(scalar_t{32.0} / x - scalar_t{2.0}, B, len) / ::sqrt(x)); +} + +template +C10_HOST_DEVICE inline + typename std::enable_if_t, std::tuple> + chebyshev_coefficients_i1e_A() { + /* Chebyshev coefficients for exp(-x) I1(x) + * in the interval [0,8]. + * + * lim(x->0){ exp(-x) I1(x) / x } = 1/2. + */ + static const T coefficients[] = { + 2.77791411276104639959E-18, -2.11142121435816608115E-17, + 1.55363195773620046921E-16, -1.10559694773538630805E-15, + 7.60068429473540693410E-15, -5.04218550472791168711E-14, + 3.22379336594557470981E-13, -1.98397439776494371520E-12, + 1.17361862988909016308E-11, -6.66348972350202774223E-11, + 3.62559028155211703701E-10, -1.88724975172282928790E-9, + 9.38153738649577178388E-9, -4.44505912879632808065E-8, + 2.00329475355213526229E-7, -8.56872026469545474066E-7, + 3.47025130813767847674E-6, -1.32731636560394358279E-5, + 4.78156510755005422638E-5, -1.61760815825896745588E-4, + 5.12285956168575772895E-4, -1.51357245063125314899E-3, + 4.15642294431288815669E-3, -1.05640848946261981558E-2, + 2.47264490306265168283E-2, -5.29459812080949914269E-2, + 1.02643658689847095384E-1, -1.76416518357834055153E-1, + 2.52587186443633654823E-1}; + + return std::make_tuple(coefficients, 29); +} + +template +C10_HOST_DEVICE inline + typename std::enable_if_t, std::tuple> + chebyshev_coefficients_i1e_A() { + /* Chebyshev coefficients for exp(-x) I1(x) + * in the interval [0,8]. + * + * lim(x->0){ exp(-x) I1(x) / x } = 1/2. + */ + static const T coeff[] = { + 9.38153738649577178388E-9f, + -4.44505912879632808065E-8f, + 2.00329475355213526229E-7f, + -8.56872026469545474066E-7f, + 3.47025130813767847674E-6f, + -1.32731636560394358279E-5f, + 4.78156510755005422638E-5f, + -1.61760815825896745588E-4f, + 5.12285956168575772895E-4f, + -1.51357245063125314899E-3f, + 4.15642294431288815669E-3f, + -1.05640848946261981558E-2f, + 2.47264490306265168283E-2f, + -5.29459812080949914269E-2f, + 1.02643658689847095384E-1f, + -1.76416518357834055153E-1f, + 2.52587186443633654823E-1f}; + return std::make_tuple(coeff, 17); +}; + +template +C10_HOST_DEVICE inline + typename std::enable_if_t, std::tuple> + chebyshev_coefficients_i1e_B() { + /* Chebyshev coefficients for exp(-x) sqrt(x) I1(x) + * in the inverted interval [8,infinity]. + * + * lim(x->inf){ exp(-x) sqrt(x) I1(x) } = 1/sqrt(2pi). + */ + static const T coefficients[] = { + 7.51729631084210481353E-18, 4.41434832307170791151E-18, + -4.65030536848935832153E-17, -3.20952592199342395980E-17, + 2.96262899764595013876E-16, 3.30820231092092828324E-16, + -1.88035477551078244854E-15, -3.81440307243700780478E-15, + 1.04202769841288027642E-14, 4.27244001671195135429E-14, + -2.10154184277266431302E-14, -4.08355111109219731823E-13, + -7.19855177624590851209E-13, 2.03562854414708950722E-12, + 1.41258074366137813316E-11, 3.25260358301548823856E-11, + -1.89749581235054123450E-11, -5.58974346219658380687E-10, + -3.83538038596423702205E-9, -2.63146884688951950684E-8, + -2.51223623787020892529E-7, -3.88256480887769039346E-6, + -1.10588938762623716291E-4, -9.76109749136146840777E-3, + 7.78576235018280120474E-1}; + + return std::make_tuple(coefficients, 25); +} + +template +C10_HOST_DEVICE inline + typename std::enable_if_t, std::tuple> + chebyshev_coefficients_i1e_B() { + /* Chebyshev coefficients for exp(-x) sqrt(x) I1(x) + * in the inverted interval [8,infinity]. + * + * lim(x->inf){ exp(-x) sqrt(x) I1(x) } = 1/sqrt(2pi). + */ + static const T coeff[] = { + -3.83538038596423702205E-9f, + -2.63146884688951950684E-8f, + -2.51223623787020892529E-7f, + -3.88256480887769039346E-6f, + -1.10588938762623716291E-4f, + -9.76109749136146840777E-3f, + 7.78576235018280120474E-1f}; + + return std::make_tuple(coeff, 7); +}; + +template +static inline C10_HOST_DEVICE scalar_t calc_i1(scalar_t _x) { + const auto x = ::abs(_x); + if (x <= scalar_t{8.0}) { + auto [A, len] = chebyshev_coefficients_i1e_A(); + scalar_t y = x / scalar_t{2.0} - scalar_t{2.0}; + const scalar_t out = ::exp(x) * x * chbevl(y, A, len); + return (_x < scalar_t{0.0}) ? -out : out; + } + + auto [B, len] = chebyshev_coefficients_i1e_B(); + const scalar_t out = (::exp(x) * chbevl(scalar_t{32.0} / x - scalar_t{2.0}, B, len)) / ::sqrt(x); + return (_x < scalar_t{0.0}) ? -out : out; +} + +template +static inline C10_HOST_DEVICE scalar_t calc_i1e(scalar_t _x) { + const auto x = ::abs(_x); + if (x <= scalar_t{8.0}) { + auto [A, len] = chebyshev_coefficients_i1e_A(); + const scalar_t y = x / scalar_t{2.0} - scalar_t{2.0}; + const scalar_t out = chbevl(y, A, len) * x; + return (_x < scalar_t{0.0}) ? -out : out; + } + + auto [B, len] = chebyshev_coefficients_i1e_B(); + const scalar_t out = chbevl(scalar_t{32.0} / x - scalar_t{2.0}, B, len) / ::sqrt(x); + return (_x < scalar_t{0.0}) ? -out : out; +} + +#endif // AT_USE_JITERATOR() (this closes the "else" branch of a if/else preprocessor directive) + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/MemoryAccess.cuh b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/MemoryAccess.cuh new file mode 100644 index 0000000000000000000000000000000000000000..c0575223ded558ba4e2abb5d264b4705682b5dac --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/MemoryAccess.cuh @@ -0,0 +1,682 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +// References: +// https://devblogs.nvidia.com/cuda-pro-tip-increase-performance-with-vectorized-memory-access/ + +namespace at::native::memory { + +namespace detail { + +// What does the `static_unroll` do? +// +// We want to do something like: +// +// using args_t = typename traits::ArgsTuple; +// args_t args; +// #pragma unroll +// for (int i = 0; i < traits::arity; i++) { +// std::get(args) = .... +// } +// +// but unfortunately the above code does not work because +// the template argument has to be a compile time constant +// so `static_unroll` is created to simulate `#pragma unroll` +// using template metaprogramming. + +template typename func, int end, int current=0> +struct static_unroll { + template + static inline C10_HOST_DEVICE void with_args(Args&&... args) { + func::apply(std::forward(args)...); + static_unroll::with_args(args...); + } +}; + +template typename func, int end> +struct static_unroll { + template + static inline C10_HOST_DEVICE void with_args(Args... /*args*/) {} +}; + +// helper structs to be used with static_unroll to load arguments +// one by one + +template +struct vectorized_load_helper { + template + static __device__ void apply(policy_t &self, args_t *args, int idx, int block_work_size) { + using arg_t = std::tuple_element_t; + // `data` hold the data_ptr for tensors [output, input0, input1, ...], so we + // need a +1 offset to get the input + auto ptr = reinterpret_cast(self.data[arg_index + 1]) + block_work_size * idx; + auto args_accessor = [&args] __device__ (int thread_unroll_idx) -> arg_t & { return std::get(args[thread_unroll_idx]); }; + self.load_single_arg(args_accessor, ptr); + } +}; + +#ifdef USE_ROCM +// Templated version of vectorized load helper. +// It can be used on heterogeneous input tensor element types. +template +struct vectorized_templated_load_helper { + template + static __device__ void apply(policy_t& self, args_t* args, int idx) { + using arg_t = std::tuple_element_t; + // `data` hold the data_ptr for tensors [output, input0, input1, ...], so we + // need a +1 offset to get the input + + // Delay pointer arithmetic to the policy loader where we know the actual + // type of the current argument. + char* ptr = (self.data[arg_index + 1]); + auto args_accessor = [&args] __device__(int thread_unroll_idx) -> arg_t& { + return std::get(args[thread_unroll_idx]); + }; + self.template load_single_arg(args_accessor, ptr, idx); + } +}; +#endif + +template +struct unroll_load_helper { + template + static __device__ void apply(policy_t &self, args_t *args, offset_t offset, loader_t loader, int j, int num_outputs) { + using arg_t = std::tuple_element_t; + // `data` hold the data_ptr for tensors [output, input0, input1, ...], so we + // need a +1 offset to get the input + std::get(args[j]) = loader.template load(self.data[arg_index + num_outputs], offset[arg_index], arg_index); + } +}; + +template +struct multi_outputs_store_helper { + template + C10_HOST_DEVICE static void apply( + const data_t& data, + const offsets_t& offsets, + thrust::tuple ret) { + using T = typename thrust::tuple_element>::type; + T *to = reinterpret_cast(data[current]) + offsets[current]; + *to = thrust::get(ret); + } +}; + +} // namespace detail + +struct LoadWithoutCast { + template + __device__ scalar_t load(char *base_ptr, uint32_t offset, int arg) { + return c10::load(reinterpret_cast(base_ptr) + offset); + } +}; + +template +struct LoadWithCast { + using array_t = std::array(N, 1)>; + using size_array_t = std::array(N, 1)>; + + array_t dtypes; + size_array_t element_sizes; + + LoadWithCast(const TensorIteratorBase& iter) { + CUDA_KERNEL_ASSERT(iter.ninputs() == N); + #pragma unroll + for (auto i = 0; i < N; ++i) { + this->dtypes[i] = iter.dtype(i + iter.noutputs()); + element_sizes[i] = c10::elementSize(iter.dtype(i + iter.noutputs())); + } + } + + template + __device__ scalar_t load(char *base_ptr, uint32_t offset, int arg) { + void *ptr = base_ptr + element_sizes[arg] * offset; + return c10::fetch_and_cast(dtypes[arg], ptr); + } +}; + +struct StoreWithoutCast { + template + __device__ void store(scalar_t value, char *base_ptr, uint32_t offset, int arg = 0) { + *(reinterpret_cast(base_ptr) + offset) = value; + } +}; + +template +struct StoreWithCast { + using array_t = std::array(N, 1)>; + using size_array_t = std::array(N, 1)>; + + array_t dtypes; + size_array_t element_sizes; + + StoreWithCast(const TensorIteratorBase& iter) { + CUDA_KERNEL_ASSERT(iter.noutputs() == N); + #pragma unroll + for (auto i = 0; i < N; ++i) { + this->dtypes[i] = iter.dtype(i); + element_sizes[i] = c10::elementSize(iter.dtype(i)); + } + } + + template + __device__ void store(scalar_t value, char *base_ptr, uint32_t offset, int arg = 0) { + void *ptr = base_ptr + element_sizes[arg] * offset; + c10::cast_and_store(dtypes[arg], ptr, value); + } +}; + +// aligned vector generates vectorized load/store on CUDA +template +struct alignas(sizeof(scalar_t) * vec_size) aligned_vector { + scalar_t val[vec_size]; +}; + +template +__device__ aligned_vector load_vector(const scalar_t *base_ptr, uint32_t offset) { + using vec_t = aligned_vector; + auto *from = reinterpret_cast(base_ptr); +#if defined(USE_ROCM) && defined(__gfx942__) + using longx2 = __attribute__((__vector_size__(4*sizeof(int)))) int; + if constexpr (sizeof(vec_t) == sizeof(int)) { + union { + vec_t v; + int i; + } tmpt = { .i = __builtin_nontemporal_load(reinterpret_cast(&(from[offset]))) }; + return tmpt.v; + } + else if constexpr (sizeof(vec_t) == sizeof(long)) { + union { + vec_t v; + long i; + } tmpt = { .i = __builtin_nontemporal_load(reinterpret_cast(&(from[offset]))) }; + return tmpt.v; + } + else if constexpr (sizeof(vec_t) == sizeof(longx2)) { + union { + vec_t v; + longx2 i; + } tmpt = { .i = __builtin_nontemporal_load(reinterpret_cast(&(from[offset]))) }; + return tmpt.v; + } +#endif + return from[offset]; +} + +template +__device__ aligned_vector load_vector(const bool *base_ptr, uint32_t offset) { + // See NOTE [Loading boolean values] + auto tmp = load_vector(reinterpret_cast(base_ptr), offset); + aligned_vector ret; + for (int i = 0; i < vec_size; ++i) { + ret.val[i] = bool(tmp.val[i]); + } + return ret; +} + +namespace policies { + +template < + int num_threads, + typename data_t, + typename inp_calc_t, + typename out_calc_t, + typename loader_t, + typename storer_t, + int elems_per_thread, + int num_outputs = 1> +struct unroll_base { + data_t data; + int remaining; + inp_calc_t input_offset_calculator; + out_calc_t output_offset_calculator; + loader_t loader; + storer_t storer; + static constexpr int tws = elems_per_thread; + static constexpr int block_work_size = elems_per_thread * num_threads; + + __device__ unroll_base( + data_t data, + int remaining, + inp_calc_t ic, + out_calc_t oc, + loader_t l, + storer_t s) + : data(data), + remaining(remaining), + input_offset_calculator(ic), + output_offset_calculator(oc), + loader(l), + storer(s) {} + + __device__ inline bool check_inbounds(int thread_work_elem) { + return ((int)(threadIdx.x + thread_work_elem * num_threads) < remaining); + } + + template + __device__ inline void load(args_t *args, int idx) { + constexpr int arity = std::tuple_size_v; + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < elems_per_thread; i++) { + if (thread_idx < remaining) { + int linear_idx = thread_idx + block_work_size * idx; + auto offset = input_offset_calculator.get(linear_idx); + detail::static_unroll::with_args( + *this, args, offset, loader, i, num_outputs); + thread_idx += num_threads; + } + } + } + + template + __device__ inline void store(scalar_t *from, int idx) { + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < elems_per_thread; i++) { + if (thread_idx < remaining) { + int linear_idx = thread_idx + block_work_size * idx; + int offset = output_offset_calculator.get(linear_idx)[0]; + storer.store(from[i], data[0], offset); + thread_idx += num_threads; + } + } + } +}; + +// Utility type for all users of unroll that extract the num_threads value from +// the caller scope. +template < + typename data_t, + typename inp_calc_t, + typename out_calc_t, + typename loader_t, + typename storer_t, + int elems_per_thread, + int num_outputs = 1> +using unroll = unroll_base< + num_threads(), + data_t, + inp_calc_t, + out_calc_t, + loader_t, + storer_t, + elems_per_thread, + num_outputs>; + +template // vec_size: number of scalars, can be 1, 2, or 4. +struct vectorized { + + static_assert(elems_per_thread % vec_size == 0, "The workload per thread must be a multiple of vec_size"); + static constexpr int loop_size = elems_per_thread / vec_size; + static constexpr int tws = elems_per_thread; + + data_t data; + + __device__ vectorized(data_t data) : data(data) {} + + __device__ inline constexpr bool check_inbounds(int thread_work_elem) { + return true; + } + + template + __device__ inline void load_single_arg(accessor_t to, scalar_t *from) { + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < loop_size; i++) { + int index = thread_idx + i * num_threads(); + auto v = load_vector(from, index); + #pragma unroll + for (int j = 0; j < vec_size; j++) { + to(vec_size * i + j) = v.val[j]; + } + } + } + + template + __device__ inline void load(args_t *args, int idx) { + constexpr int arity = std::tuple_size_v; + detail::static_unroll::with_args(*this, args, idx, elems_per_thread * num_threads()); + } + + template + __device__ inline void store(scalar_t *from, int idx) { + using vec_t = aligned_vector; + scalar_t *to = reinterpret_cast(data[0]) + elems_per_thread * num_threads() * idx; + vec_t *to_ = reinterpret_cast(to); + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < loop_size; i++) { + int index = thread_idx + i * num_threads(); + vec_t v; + for (int j = 0; j < vec_size; j++) { + v.val[j] = from[vec_size * i + j]; + } + to_[index] = v; + } + } +}; + +#ifdef USE_ROCM +// This is similar to vectorized policy above, but this one supports +// heterogenous input tensor types as templated parameters. +// Its use should be limited to frequently used heterogeneous data types +// as each instantiation will generate a separate kernel, leading to code +// bloating if applied to all combinations supported in PyTorch. Assumption: all +// tensors are contiguous, that is: stride == sizeof(type) for all tensors. +template < + int vec_size, + typename data_t, + int elems_per_thread, + int num_threads, + typename CastToT, + typename... CastFromTs> // vec_size: number of scalars, can be 1, 2, or 4. +struct vectorized_templated { + static_assert( + elems_per_thread % vec_size == 0, + "The workload per thread must be a multiple of vec_size"); + static constexpr int loop_size = elems_per_thread / vec_size; + static constexpr int tws = elems_per_thread; + static constexpr int block_work_size = elems_per_thread * num_threads; + data_t data; + + __device__ vectorized_templated(data_t data) : data(data) {} + + __device__ inline constexpr bool check_inbounds(int thread_work_elem) { + return true; + } + + template + __device__ inline void load_single_arg(accessor_t to, char* ptr, int idx) { + // extract the arg_index-th input tensor element type from the + // variadic template argument. + using CastFromT = + std::tuple_element_t>; + // Delayed pointer arithmetic from the caller: this is the place + // where we know the type of the argument. + CastFromT* block_ptr = + reinterpret_cast(ptr) + block_work_size * idx; + int thread_idx = threadIdx.x; +#pragma unroll + for (int i = 0; i < loop_size; i++) { + int index = thread_idx + i * num_threads; + auto v = load_vector(block_ptr, index); +#pragma unroll + for (int j = 0; j < vec_size; j++) { + to(vec_size * i + j) = c10::convert(v.val[j]); + } + } + } + + template + __device__ inline void load(args_t* args, int idx) { + constexpr int arity = std::tuple_size::value; + detail::static_unroll:: + with_args(*this, args, idx); + } + + // Assume for now that from (temporary array per thread) is of the same + // type as to (destination tensor), which is the case for + // float(float,bfloat16) and functor add on float(float,float). + template + __device__ inline void store(scalar_t* from, int idx) { + using vec_t = aligned_vector; + CastToT* to = reinterpret_cast(data[0]) + block_work_size * idx; + vec_t* to_ = reinterpret_cast(to); + int thread_idx = threadIdx.x; +#pragma unroll + for (int i = 0; i < loop_size; i++) { + int index = thread_idx + i * num_threads; + vec_t v; + for (int j = 0; j < vec_size; j++) { + v.val[j] = from[vec_size * i + j]; + } + to_[index] = v; + } + } +}; +#endif + +template +struct multi_outputs_unroll { + //multi_outputs_unroll struct members and check_inbounds and load methods are copypasted from unroll struct + //we don't use inheritance because of compiler bug in cuda 10.2+ + data_t data; + int remaining; + inp_calc_t input_offset_calculator; + out_calc_t output_offset_calculator; + LoadWithoutCast loader; + StoreWithoutCast storer; + static constexpr int tws = thread_work_size(); + + __device__ multi_outputs_unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc): + data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc) {} + + __device__ inline bool check_inbounds(int thread_work_elem) { + return ((int)(threadIdx.x + thread_work_elem*num_threads()) < remaining); + } + + template + __device__ inline void load(args_t *args, int idx) { + constexpr int arity = std::tuple_size_v; + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < thread_work_size(); i++) { + if (thread_idx >= remaining) { + return; + } + int linear_idx = thread_idx + block_work_size() * idx; + auto offset = input_offset_calculator.get(linear_idx); + detail::static_unroll::with_args(*this, args, offset, loader, i, num_outputs); + thread_idx += num_threads(); + } + } + + + template + __device__ inline void store(return_t *from, int idx) { + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < thread_work_size(); i++) { + if (thread_idx >= this->remaining) { + return; + } + int linear_idx = thread_idx + block_work_size() * idx; + auto offsets = this->output_offset_calculator.get(linear_idx); + memory::detail::static_unroll::with_args(this->data, offsets, from[i]); + thread_idx += num_threads(); + } + } +}; + +} // namespace policies + +// This is only used in host, but we will wrap this into some templates +// which is C10_HOST_DEVICE, so we have to make this C10_HOST_DEVICE +// in order to compile +template +inline C10_HOST_DEVICE int can_vectorize_up_to(const char *pointer) { + uint64_t address = reinterpret_cast(pointer); + constexpr int vec2_alignment = std::alignment_of_v>; + constexpr int vec4_alignment = std::alignment_of_v>; + constexpr int vec8_alignment = std::alignment_of_v>; +#ifdef USE_ROCM + constexpr int vec16_alignment = std::alignment_of_v>; + constexpr int type_size = sizeof(scalar_t); + if (type_size == 1 && (address % vec16_alignment == 0)) { + return 16; + } else if (type_size <= 2 && (address % vec8_alignment == 0)) { + return 8; + } else +#else + if (address % vec8_alignment == 0) { + return 8; + } else +#endif + if (address % vec4_alignment == 0) { + return 4; + } else if (address % vec2_alignment == 0) { + return 2; + } + return 1; +} + +template +inline C10_HOST_DEVICE int can_vectorize_up_to(char *pointer) { + return can_vectorize_up_to(static_cast(pointer)); +} + +template +struct can_vectorize_up_to_helper { + template + static C10_HOST_DEVICE void apply(int &result, array_t pointers, traits /*_*/) { + using arg_t = typename traits::template arg::type; + // `pointers` hold the data_ptr for tensors [output, input0, input1, ...], so we + // need a +1 offset to get the input + result = std::min(result, can_vectorize_up_to(pointers[i + 1])); + } +}; + +template +inline int can_vectorize_up_to(array_t pointers) { + using traits = function_traits; + using return_t = typename traits::result_type; + constexpr int arity = traits::arity; + int result = can_vectorize_up_to(pointers[0]); + // We need to get the type for each argument of `func_t`, this can only + // be done at compile time. + detail::static_unroll::with_args(result, pointers, traits()); + return result; +} + + + +template +__inline__ size_t get_alignment(T ptr_or_size) { + auto val = reinterpret_cast(ptr_or_size); + if (val % 16 == 0) { + return 16; + } else if (val % 8 == 0) { + return 8; + } else if (val % 4 == 0) { + return 4; + } else if (val % 2 == 0) { + return 2; + } else { + return 1; + } +} + +template <> +__inline__ size_t get_alignment(size_t size) { + return get_alignment(reinterpret_cast(size)); +} + +template +inline constexpr bool dependent_bool_value = Value; + +template +inline constexpr bool dependent_false = dependent_bool_value; + +template +union Vec; + +template <> +union Vec<4> { + uint16_t u16[2]; + uint32_t u32, as_scalar; + float f32; +}; + +template <> +union Vec<8> { + uint16_t u16[4]; + uint32_t u32[2]; + uint64_t u64, as_scalar; + float f32[2]; +}; + +template <> +union alignas(16) Vec<16> { + uint16_t u16[8]; + uint32_t u32[4]; + uint64_t u64[2]; + uint4 u128, as_scalar; + float f32[4]; +}; + +template +__device__ __inline__ Vec ld_vec(const T* addr) { + Vec vec; + if constexpr (Alignment == 16) { +#if defined(USE_ROCM) + vec.u128 = *reinterpret_cast(addr); + } else if constexpr (Alignment == 8) { + vec.u64 = *reinterpret_cast(addr); + } else if constexpr (Alignment == 4) { + vec.u32 = *reinterpret_cast(addr); +#else + asm("ld.global.v4.u32 {%0,%1,%2,%3}, [%4];" + : "=r"(vec.u32[0]), "=r"(vec.u32[1]), "=r"(vec.u32[2]), "=r"(vec.u32[3]) + : "l"(addr) + : "memory"); + } else if constexpr (Alignment == 8) { + asm("ld.global.v2.u32 {%0,%1}, [%2];" + : "=r"(vec.u32[0]), "=r"(vec.u32[1]) + : "l"(addr) + : "memory"); + } else if constexpr (Alignment == 4) { + asm("ld.global.u32 %0, [%1];" : "=r"(vec.u32) : "l"(addr) : "memory"); +#endif + } else { + static_assert(dependent_false); + } + return vec; +} + +template +__device__ __inline__ void st_vec(T* addr, const Vec& vec) { + if constexpr (Alignment == 16) { +#if defined(USE_ROCM) + reinterpret_cast(addr)[0] = vec.u64[0]; + reinterpret_cast(addr)[1] = vec.u64[1]; + } else if constexpr (Alignment == 8) { + *reinterpret_cast(addr) = vec.u64; + } else if constexpr (Alignment == 4) { + *reinterpret_cast(addr) = vec.u32; +#else + asm("st.global.v4.u32 [%0], {%1,%2,%3,%4};" + : + : "l"(addr), + "r"(vec.u32[0]), + "r"(vec.u32[1]), + "r"(vec.u32[2]), + "r"(vec.u32[3]) + : "memory"); + } else if constexpr (Alignment == 8) { + asm("st.global.v2.u32 [%0], {%1,%2};" + : + : "l"(addr), "r"(vec.u32[0]), "r"(vec.u32[1]) + : "memory"); + } else if constexpr (Alignment == 4) { + asm("st.global.u32 [%0], %1;" : : "l"(addr), "r"(vec.u32) : "memory"); +#endif + } else { + static_assert(dependent_false); + } +} + + + +} // namespace at::native::memory diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/MiscUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/MiscUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..7ab20b6418c7b758348d38a11874eea35760c753 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/MiscUtils.h @@ -0,0 +1,31 @@ +#pragma once +#include +#include +#include +#include + + +namespace at::native { + +static inline int cuda_int_cast(int64_t value, const char* varname) { + auto result = static_cast(value); + TORCH_CHECK(static_cast(result) == value, + "cuda_int_cast: The value of ", varname, "(", (long long)value, + ") is too large to fit into a int (", sizeof(int), " bytes)"); + return result; +} + +// Creates an array of size elements of type T, backed by pinned memory +// wrapped in a Storage +template +static inline Storage pin_memory(int64_t size) { + auto* allocator = cuda::getPinnedMemoryAllocator(); + int64_t adjusted_size = size * sizeof(T); + return Storage( + Storage::use_byte_size_t(), + adjusted_size, + allocator, + /*resizable=*/false); +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/MultiTensorApply.cuh b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/MultiTensorApply.cuh new file mode 100644 index 0000000000000000000000000000000000000000..3271c7eab444bde5c2ce0279d8eabbd58ba695b6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/MultiTensorApply.cuh @@ -0,0 +1,382 @@ +#pragma once +#include +#include +#include +#include +#include +#include + +namespace at::native { + +namespace { + +static constexpr int64_t kILP = 4; +static constexpr int64_t kChunkSize = 65536; +static constexpr int64_t kBlockSize = 512; + +// TODO(crcrpar): Add `n>5` for `low prec params & their higher prec copy` +// TensorListMetadata has to be < 4KB - the limit for kernel launch argument +static constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30}; +static constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320}; +static constexpr int depth_to_max_tensors_scalarlist[5] = {96, 64, 48, 36, 30}; +static constexpr int depth_to_max_tensors_scalarlist_of_complex_double[2] = { + 72, + 60}; + +template +__device__ __forceinline__ bool is_aligned(T* p) { + return ((uint64_t)p) % (kILP * sizeof(T)) == 0; +} + +template +__device__ __forceinline__ void load_store( + T* dst, + T* src, + int64_t dst_offset, + int64_t src_offset) { + using LT = at::native::memory::aligned_vector; + ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset]; +} + +template +struct TensorListMetadata { + const void* addresses[n][depth_to_max_tensors[n - 1]]; + int64_t numel_for_tensor[depth_to_max_tensors[n - 1]]; + unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; + int block_to_chunk[depth_to_max_blocks[n - 1]]; + int start_tensor_this_launch; +}; + +template +struct TensorListScalarListMetadata { + const void* addresses[n][depth_to_max_tensors_scalarlist[n - 1]]; + int64_t numel_for_tensor[depth_to_max_tensors_scalarlist[n - 1]]; + scalar_vals_t scalar_vals[depth_to_max_tensors_scalarlist[n - 1]]; + unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; + int block_to_chunk[depth_to_max_blocks[n - 1]]; +}; + +// note(mkozuki): `n` of 1&2 violate the limit of cuda kernel argument size of +// 4kb with `c10::complex` +template <> +struct TensorListScalarListMetadata, 1> { + const void* addresses[1] + [depth_to_max_tensors_scalarlist_of_complex_double[0]]; + int64_t + numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[0]]; + c10::complex + scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[0]]; + unsigned char block_to_tensor[depth_to_max_blocks[1 - 1]]; + int block_to_chunk[depth_to_max_blocks[1 - 1]]; +}; + +template <> +struct TensorListScalarListMetadata, 2> { + const void* addresses[2] + [depth_to_max_tensors_scalarlist_of_complex_double[1]]; + int64_t + numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[1]]; + c10::complex + scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[1]]; + unsigned char block_to_tensor[depth_to_max_blocks[2 - 1]]; + int block_to_chunk[depth_to_max_blocks[2 - 1]]; +}; + +// NOTE(crcrpar): This is a conservative resolution to handle `state_steps` +// whose each element is `at::Tensor` of 1 element representing the number of +// `step`s called so far. +template +struct FusedOptimizerTensorListMetadata { + const void* addresses[n][depth_to_max_tensors[n - 1]]; + int64_t numel_for_tensor[depth_to_max_tensors[n - 1]]; + const void* state_steps_addresses[depth_to_max_tensors_scalarlist[n - 1]]; + unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; + int block_to_chunk[depth_to_max_blocks[n - 1]]; + int start_tensor_this_launch; +}; + +template +C10_LAUNCH_BOUNDS_1(kBlockSize) +__global__ void multi_tensor_apply_kernel( + T tensorListMeta, + U callable, + ArgTypes... args) { + // Hand the chunk information to the user-supplied functor to process however + // it likes. + callable(kChunkSize, tensorListMeta, args...); +} + +} // namespace + +// multi_tensor_apply enables horizontal fusion across lists of tensors. +// For example, whereas you once had a for-loop of a + b = c, where a, b, +// and c are individual tensors in lists as, bs, and cs, you can now with +// fewer kernel launches compute as + bs = cs. +// +// You can also imagine bs to be a scalar list vs a tensor list. +// +// The function below takes in tensor lists, scalars, and a callable and +// chunks up the computation to launch as few kernels as possible by iterating +// through every "chunk" in every tensor (thus the nested for loops). In the +// simplest case, everything gets bundled into just one kernel launch, but +// due to blocksize constraints, we may need to launch multiple kernels. +// Each kernel launch is defined by one tensorListMeta construct, which we +// use to track and reset the necessary metadata for each launch. +template +void multi_tensor_apply( + std::vector>& tensor_lists, + at::ArrayRef scalars, + T callable, + ArgTypes... args) { + TORCH_CHECK( + tensor_lists.size() == depth, + "Number of tensor lists has to match the depth."); + const size_t n_tensors = tensor_lists[0].size(); + using scalar_vals_t = typename T::opmath_t; + TensorListScalarListMetadata tensorListMeta; + + int loc_block_info = 0; + int loc_tensor_info = 0; + for (size_t t = 0; t < n_tensors; t++) { + // short-circuit to avoid adding empty tensors to tensorListMeta + if (tensor_lists[0][t].numel() == 0) { + continue; + } + tensorListMeta.scalar_vals[loc_tensor_info] = scalars[t].to(); + tensorListMeta.numel_for_tensor[loc_tensor_info] = + tensor_lists[0][t].numel(); + for (int d = 0; d < depth; d++) { + tensorListMeta.addresses[d][loc_tensor_info] = + tensor_lists[d][t].const_data_ptr(); + } + loc_tensor_info++; + + // now we enter [chunking territory]. + // we will launch a kernel when EITHER the blocks get filled up OR + // the tensors get filled up. There will always be at least one block + // per tensor since the zero-sized ones will not enter the loop, so + // the nested forloop within represents iterating through the chunks + // of a single tensor. + const auto numel = tensor_lists[0][t].numel(); + const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0); + for (auto chunk = 0; chunk < chunks; chunk++) { + tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1; + tensorListMeta.block_to_chunk[loc_block_info] = chunk; + loc_block_info++; + + // a tensor is not considered full unless all its chunks have been + // processed + const bool tensors_full = + (loc_tensor_info == depth_to_max_tensors_scalarlist[depth - 1] && + chunk == chunks - 1); + const bool blocks_full = + (loc_block_info == depth_to_max_blocks[depth - 1]); + + if (tensors_full || blocks_full) { + multi_tensor_apply_kernel<<< + loc_block_info, + kBlockSize, + 0, + at::cuda::getCurrentCUDAStream()>>>( + tensorListMeta, callable, args...); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + // Reset. + loc_block_info = 0; + // all chunks have already been handled in the kernel + if (chunk == chunks - 1) { + loc_tensor_info = 0; + } else { // blocks were full and tensor chunks remain + tensorListMeta.numel_for_tensor[0] = + tensorListMeta.numel_for_tensor[loc_tensor_info - 1]; + tensorListMeta.scalar_vals[0] = + tensorListMeta.scalar_vals[loc_tensor_info - 1]; + for (int d = 0; d < depth; d++) { + tensorListMeta.addresses[d][0] = + tensorListMeta.addresses[d][loc_tensor_info - 1]; + } + loc_tensor_info = 1; + } + } + } + } + + // note: [finishing what we started] + // if there's remaining work to be done but the tensors/blocks aren't full + // yet we are at the end, submit the kernel to do the work! + if (loc_block_info != 0) { + multi_tensor_apply_kernel<<< + loc_block_info, + kBlockSize, + 0, + at::cuda::getCurrentCUDAStream()>>>(tensorListMeta, callable, args...); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } +} + +template +void multi_tensor_apply( + std::vector>& tensor_lists, + T callable, + ArgTypes... args) { + TORCH_CHECK( + tensor_lists.size() == depth, + "Number of tensor lists has to match the depth."); + const size_t n_tensors = tensor_lists[0].size(); + TensorListMetadata tensorListMeta; + tensorListMeta.start_tensor_this_launch = 0; + + int loc_block_info = 0; + int loc_tensor_info = 0; + int processed = 0; + + for (size_t t = 0; t < n_tensors; t++) { + // short-circuit to avoid adding empty tensors to tensorListMeta + if (tensor_lists[0][t].numel() == 0) { + continue; + } + processed++; + tensorListMeta.numel_for_tensor[loc_tensor_info] = + tensor_lists[0][t].numel(); + for (int d = 0; d < depth; d++) { + tensorListMeta.addresses[d][loc_tensor_info] = + tensor_lists[d][t].const_data_ptr(); + } + loc_tensor_info++; + + // see note: [chunking territory]. + const auto numel = tensor_lists[0][t].numel(); + const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0); + for (auto chunk = 0; chunk < chunks; chunk++) { + tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1; + tensorListMeta.block_to_chunk[loc_block_info] = chunk; + loc_block_info++; + + const bool tensors_full = + (loc_tensor_info == depth_to_max_tensors[depth - 1] && + chunk == chunks - 1); + const bool blocks_full = + (loc_block_info == depth_to_max_blocks[depth - 1]); + + if (tensors_full || blocks_full) { + multi_tensor_apply_kernel<<< + loc_block_info, + kBlockSize, + 0, + at::cuda::getCurrentCUDAStream()>>>( + tensorListMeta, callable, args...); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + // Reset. + loc_block_info = 0; + if (chunk == chunks - 1) { + loc_tensor_info = 0; + tensorListMeta.start_tensor_this_launch = processed; + } else { + tensorListMeta.numel_for_tensor[0] = + tensorListMeta.numel_for_tensor[loc_tensor_info - 1]; + for (int d = 0; d < depth; d++) { + tensorListMeta.addresses[d][0] = + tensorListMeta.addresses[d][loc_tensor_info - 1]; + } + loc_tensor_info = 1; + tensorListMeta.start_tensor_this_launch = processed - 1; + } + } + } + } + + // see note: [finishing what we started] + if (loc_block_info != 0) { + multi_tensor_apply_kernel<<< + loc_block_info, + kBlockSize, + 0, + at::cuda::getCurrentCUDAStream()>>>(tensorListMeta, callable, args...); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } +} + +template +void multi_tensor_apply_for_fused_optimizer( + std::vector>& tensor_lists, + at::TensorList state_steps, + T callable, + ArgTypes... args) { + TORCH_CHECK( + tensor_lists.size() == depth, + "Number of tensor lists has to match the depth"); + const auto num_tensors = tensor_lists[0].size(); + FusedOptimizerTensorListMetadata tensorListMeta; + + int loc_block_info = 0; + int loc_tensor_info = 0; + for (const auto& tensor_index : c10::irange(num_tensors)) { + // short-circuit to avoid adding empty tensors to tensorListMeta + if (tensor_lists[0][tensor_index].numel() == 0) { + continue; + } + tensorListMeta.state_steps_addresses[loc_tensor_info] = + state_steps[tensor_index].const_data_ptr(); + tensorListMeta.numel_for_tensor[loc_tensor_info] = + tensor_lists[0][tensor_index].numel(); + for (const auto& d : c10::irange(depth)) { + tensorListMeta.addresses[d][loc_tensor_info] = + tensor_lists[d][tensor_index].const_data_ptr(); + } + loc_tensor_info++; + + // see above note: [chunking territory] + const auto numel = tensor_lists[0][tensor_index].numel(); + const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0); + TORCH_CHECK(chunks > -1); + for (const auto& chunk : c10::irange(chunks)) { + tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1; + tensorListMeta.block_to_chunk[loc_block_info] = chunk; + loc_block_info++; + + const auto tensor_full = + (loc_tensor_info == depth_to_max_tensors[depth - 1] && + chunk == chunks - 1); + const auto blocks_full = loc_block_info == depth_to_max_blocks[depth - 1]; + + if (tensor_full || blocks_full) { + multi_tensor_apply_kernel<<< + loc_block_info, + kBlockSize, + 0, + at::cuda::getCurrentCUDAStream()>>>( + tensorListMeta, callable, args...); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + // Reset. + loc_block_info = 0; + if (chunk == chunks - 1) { + loc_tensor_info = 0; + } else { + tensorListMeta.numel_for_tensor[0] = + tensorListMeta.numel_for_tensor[loc_tensor_info - 1]; + tensorListMeta.state_steps_addresses[0] = + tensorListMeta.state_steps_addresses[loc_tensor_info - 1]; + for (const auto& d : c10::irange(depth)) { + tensorListMeta.addresses[d][0] = + tensorListMeta.addresses[d][loc_tensor_info - 1]; + } + loc_tensor_info = 1; + } + } + } + } + + // see above note: [finishing what we've started] + if (loc_block_info != 0) { + multi_tensor_apply_kernel<<< + loc_block_info, + kBlockSize, + 0, + at::cuda::getCurrentCUDAStream()>>>(tensorListMeta, callable, args...); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/Normalization.cuh b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/Normalization.cuh new file mode 100644 index 0000000000000000000000000000000000000000..c7555f114a8e7460cda4c80b36bfdedb2afee5f2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/Normalization.cuh @@ -0,0 +1,1742 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#include +#endif + +namespace at::native { + +// The maximum number of threads in a block +#if defined(USE_ROCM) +constexpr int MAX_BLOCK_SIZE = 256; +#else +constexpr int MAX_BLOCK_SIZE = 512; +#endif + +constexpr unsigned MAX_GRID_SIZE = 65535u; + +// Number of threads in a block given an input size up to MAX_BLOCK_SIZE +static int getNumThreads(int nElem) { +#if defined(USE_ROCM) + int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE }; +#else + int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE }; +#endif + for (int i = 0; i != 5; ++i) { + if (nElem <= threadSizes[i]) { + return threadSizes[i]; + } + } + return MAX_BLOCK_SIZE; +} + +// Returns the index of the most significant 1 bit in `val`. +__device__ __forceinline__ int getMSB(int val) { + return 31 - __clz(val); +} + +template +struct Float2 { + accscalar_t v1, v2; + __device__ Float2() {} + __device__ Float2(scalar_t v1, scalar_t v2) : v1(static_cast(v1)), v2(static_cast(v2)) {} + __device__ Float2(int v) : v1(static_cast(v)), v2(static_cast(v)) {} + __device__ Float2& operator+=(const Float2& a) { + v1 += a.v1; + v2 += a.v2; + return *this; + } + __device__ friend Float2 operator+(Float2 a, const Float2& b) { + a += b; + return a; + } +}; + +template +struct GradOp { + __device__ GradOp(accscalar_t m, const PTA& i, const PTA& g) + : mean(m), input(i), grad_output(g) {} + __device__ __forceinline__ Float2 operator()(int batch, int plane, int n) { + accscalar_t g = grad_output[batch][plane][n]; + accscalar_t c = static_cast(input[batch][plane][n]) - mean; + return Float2(g, g * c); + } + const accscalar_t mean; + const PTA& input; + const PTA& grad_output; +}; + +template +struct SumReduceOp { + __device__ __forceinline__ acc_t combine(acc_t a, acc_t b) const { return a + b; } + + __device__ __forceinline__ acc_t warp_shfl_down(acc_t data, int offset) const { + return WARP_SHFL_DOWN(data, offset); + } +}; + +template +struct SumReduceOp> { + using acc_t = Float2; + + __device__ __forceinline__ acc_t combine(acc_t a, acc_t b) const { return a + b; } + + __device__ __forceinline__ acc_t warp_shfl_down(acc_t data, int offset) const { + return {WARP_SHFL_DOWN(data.v1, offset), WARP_SHFL_DOWN(data.v2, offset)}; + } +}; + +// Sum across (batch, x/y/z) applying Op() pointwise +// this works by first having each thread sum it's part +// of the data. Then there is a double-shuffling reduction. +// First each warp (of C10_WARP_SIZE threads) uses warpSum to reduce its +// data to the "warp leader", who writes its value into shared memory. +// Then a single warp reads the remaining (at most C10_WARP_SIZE) items +// and reduces them using another warpSum. +// The implicit assumption is that there are no more +// than C10_WARP_SIZE**2 threads. +template +__device__ scalar_t reduce(Op op, PTA tensor, int plane) { + // first the reductions each thread does separately + scalar_t sum = static_cast(0); + for (int batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) { + for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x) { + sum += op(batch, plane, x); + } + } + __shared__ scalar_t shared[C10_WARP_SIZE]; + SumReduceOp reduce_op; + sum = cuda_utils::BlockReduce, cuda_utils::Block2D>(sum, reduce_op, 0, shared); + if (threadIdx.x == 0 && threadIdx.y == 0) { + shared[0] = sum; + } + __syncthreads(); + // Everyone picks it up, should be broadcast into the whole grad_input + return shared[0]; +} + +constexpr int ELEMENTS_PER_ITER = 4; // enables concurrency within each thread to hide latency +constexpr int ELEMENTS_PER_THREAD = 16; +constexpr int OPTIMAL_TILE_W = 32; +constexpr int MAX_H_BLOCK = 128; + +__host__ void flexible_launch_configs( + const int reduction, + const int stride, + dim3 &block, + dim3 &grid, + const bool coop_flag = false) { + int block_x = std::min(lastPow2(stride), OPTIMAL_TILE_W); + int block_y = std::min(lastPow2(at::ceil_div(reduction , ELEMENTS_PER_THREAD)), + MAX_BLOCK_SIZE / block_x); + if (block_x * block_y != MAX_BLOCK_SIZE) { + block_x = std::min(lastPow2(stride), MAX_BLOCK_SIZE / block_y); + } + + int grid_x = at::ceil_div(stride, block_x); + int grid_y = std::min(at::ceil_div(reduction, block_y * ELEMENTS_PER_THREAD), MAX_H_BLOCK); + if (coop_flag) { + // it's not worth having a grid reduction if the reduction dimension is not big enough + grid_y = grid_y < 8 ? 1 : grid_y; + } + + block.x = block_x; + block.y = block_y; + block.z = 1; + grid.x = grid_x; + grid.y = grid_y; + grid.z = 1; +} + +template +__device__ __forceinline__ void welford_merge_element(C& count, + T& mean, + T& m2n, + const C& count_new, + const T& mean_new, + const T& m2n_new) { + T factor = T(1.0) / ::max(1, (count + count_new)); + T delta0 = mean - mean_new; + mean = (mean_new * count_new + mean * count) * factor; + m2n += m2n_new + delta0 * delta0 * count_new * count * factor; + count += count_new; +} + +// merge mean/m2n among threadIdx.y within block +template +__device__ __forceinline__ void welford_merge_block_vertical(C& count, + T& mean, + T& m2n, + C* shmem_count, + T* shmem_mean, + T* shmem_m2n) { + // write to shared memory + auto address_base = threadIdx.x + threadIdx.y * blockDim.x; + +#pragma unroll + for (int offset = blockDim.y/2; offset > 0; offset >>= 1) { + if (threadIdx.y < offset*2) { + shmem_mean[address_base] = mean; + shmem_m2n[address_base] = m2n; + shmem_count[address_base] = count; + } + __syncthreads(); + if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) { + auto address = address_base + offset * blockDim.x; + // read shared memory back to register for reduction + auto count_new = shmem_count[address]; + auto mean_new = shmem_mean[address]; + auto m2n_new = shmem_m2n[address]; + + welford_merge_element(count, mean, m2n, count_new, mean_new, m2n_new); + } + } +} + +template +__global__ void batch_norm_transform_input_kernel( + const GenericPackedTensorAccessor input, + GenericPackedTensorAccessor output, + const GenericPackedTensorAccessor, 1, RestrictPtrTraits, index_t> mean_, + const GenericPackedTensorAccessor, 1, RestrictPtrTraits, index_t> var_or_invstd, + const GenericPackedTensorAccessor weight, + const GenericPackedTensorAccessor bias, + stat_accscalar_t epsilon) { + + index_t plane = blockIdx.x; + + if (plane >= input.size(1)) { + return; + } + + stat_accscalar_t gamma = weight.size(0) > 0 ? static_cast(weight[plane]) : static_cast(1); + stat_accscalar_t beta = bias.size(0) > 0 ? static_cast(bias[plane]) : static_cast(0); + stat_accscalar_t mean = static_cast(mean_[plane]); + stat_accscalar_t invstd; + if (train) { + invstd = var_or_invstd[plane]; + } else { + invstd = static_cast(1) / device_sqrt(static_cast(var_or_invstd[plane]) + epsilon); + } + + index_t bs = input.size(0); + index_t fs = input.size(2); + + index_t bstep = blockDim.y * gridDim.y; + for (index_t batch = threadIdx.y + blockIdx.y * blockDim.y; batch < bs; batch += bstep) { + auto o = output[batch][plane]; + auto i = input[batch][plane]; + for (index_t feature = threadIdx.x; feature < fs; feature += blockDim.x) { + o[feature] = static_cast(gamma * (i[feature] - mean) * invstd + beta); + } + } +} + +struct InvStd { + template + __device__ __forceinline__ T operator()(T var, double epsilon) const { + T invstd = 0; + if (var != static_cast(0) || epsilon != static_cast(0)) { + invstd = static_cast(1) / device_sqrt(var + epsilon); + } + return invstd; + } +}; + +struct Var { + template + __device__ __forceinline__ T operator()(T var, double epsilon) const { + return var; + } +}; + +template +__global__ void batch_norm_collect_statistics_kernel( + const GenericPackedTensorAccessor input, + const stat_accscalar_t epsilon, + const stat_accscalar_t momentum, + GenericPackedTensorAccessor save_mean, + GenericPackedTensorAccessor save_transformed_var) { + + __shared__ int shared_n[2 * 2 * C10_WARP_SIZE + C10_WARP_SIZE]; + + int plane = blockIdx.x; + int N = input.size(0) * input.size(2); + int tid = threadIdx.x + threadIdx.y * blockDim.x; + + // Compute the mean and variance across (batch, x/y/z) + // this uses the Welford (in the for loop)/parallel algorithm (to sum across the block) + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm + // and the parallel algorithm on the same page. + // We use two shuffles to reduce across the entire block. + // https://devblogs.nvidia.com/faster-parallel-reductions-kepler/ has a description. + stat_accscalar_t* shared_avg_var = (stat_accscalar_t*) &shared_n[C10_WARP_SIZE]; + + // first the reductions each thread does separately + stat_accscalar_t avg = 0; + stat_accscalar_t var_n = 0; + int n = 0; + for (int batch = threadIdx.y; batch < input.size(0); batch += blockDim.y) { + for (int x = threadIdx.x; x < input.size(2); x += blockDim.x) { + stat_accscalar_t v = input[batch][plane][x]; + stat_accscalar_t d1 = v - avg; + n++; + avg += d1 / n; + var_n += d1 * (v - avg); + } + } + + // first warpSum to get one value per thread to + // one value per warp + for (int i = 0; i < getMSB(C10_WARP_SIZE); ++i) { + stat_accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, C10_WARP_SIZE); + int o_n = WARP_SHFL_XOR(n, 1 << i, C10_WARP_SIZE); + stat_accscalar_t factor = 1.0 / fmaxf(1.0, n+o_n); + var_n += WARP_SHFL_XOR(var_n, 1 << i, C10_WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor; + avg = (n * avg + o_n * o_avg) * factor; + n += o_n; + } + + // this writes each warps item into shared memory + // there are at most C10_WARP_SIZE items left because + // there are at most C10_WARP_SIZE**2 threads at the beginning + __syncthreads(); + if (tid % C10_WARP_SIZE == 0) { + shared_n[tid / C10_WARP_SIZE] = n; + shared_avg_var[tid / C10_WARP_SIZE * 2] = avg; + shared_avg_var[tid / C10_WARP_SIZE * 2 + 1] = var_n; + } + __syncthreads(); + // now have a second warpSum to reduce the intermediate values + // from shared memory to a single number. The very first + // thread writes it to shared memory. + + if (tid < C10_WARP_SIZE) { + n = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_n[tid] : 0); + avg = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_avg_var[2 * tid] : stat_accscalar_t(0)); + var_n = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_avg_var[2 * tid + 1] : stat_accscalar_t(0)); + } + for (int i = 0; i < getMSB(C10_WARP_SIZE); ++i) { + stat_accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, C10_WARP_SIZE); + int o_n = WARP_SHFL_XOR(n, 1 << i, C10_WARP_SIZE); + stat_accscalar_t factor = 1.0 / fmaxf(1.0, n+o_n); + var_n += WARP_SHFL_XOR(var_n, 1 << i, C10_WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor; + avg = (n * avg + o_n * o_avg) * factor; + n += o_n; + } + + // Save the mean, variance, and moving averages + if (tid == 0) { + if (save_mean.data() != NULL) { + save_mean[plane] = avg; + } + if (save_transformed_var.data() != NULL) { + save_transformed_var[plane] = VarTransform{}(var_n / N, epsilon); + } + } + +} + +template +__global__ void batch_norm_backward_kernel( + const GenericPackedTensorAccessor input, + const GenericPackedTensorAccessor grad_output, + GenericPackedTensorAccessor grad_input, + GenericPackedTensorAccessor grad_weight, + GenericPackedTensorAccessor grad_bias, + const GenericPackedTensorAccessor weight, + const GenericPackedTensorAccessor running_mean, + const GenericPackedTensorAccessor running_var, + const GenericPackedTensorAccessor save_mean, + const GenericPackedTensorAccessor save_invstd, + bool train, + stat_accscalar_t epsilon) { + + index_t plane = blockIdx.x; + index_t N = grad_output.size(0) * grad_output.size(2); + + stat_accscalar_t mean, invstd; + if (train) { + mean = save_mean[plane]; + invstd = save_invstd[plane]; + } else { + mean = static_cast(running_mean[plane]); + invstd = static_cast(1) / device_sqrt(static_cast(running_var[plane]) + epsilon); + } + + stat_accscalar_t weight_val = weight.size(0) > 0 ? static_cast(weight[plane]) : stat_accscalar_t(1); + stat_accscalar_t norm = stat_accscalar_t(1) / N; + + // Compute two values across (batch, x/y/z) in one pass: + // 1. Sum(grad_output) + // 2. DotProduct(input - mean, grad_output) + GradOp> g(mean, input, grad_output); + auto res = reduce>(g, grad_output, plane); + + stat_accscalar_t grad_output_sum = res.v1; + stat_accscalar_t dot_p = res.v2; + + stat_accscalar_t grad_mean = grad_output_sum * norm; + stat_accscalar_t proj_scale = dot_p * norm * invstd * invstd; + stat_accscalar_t grad_scale = invstd * weight_val; + + if (grad_input.data() != NULL) { + for (int batch = threadIdx.y; batch < grad_output.size(0); batch += blockDim.y) { + for (int x = threadIdx.x; x < grad_output.size(2); x += blockDim.x) { + input_scalar_t go = grad_output[batch][plane][x]; + if (train) { + stat_accscalar_t inp = input[batch][plane][x]; + stat_accscalar_t proj = (inp - mean) * proj_scale; + grad_input[batch][plane][x] = static_cast((go - proj - grad_mean) * grad_scale); + } else { + grad_input[batch][plane][x] = static_cast(go * grad_scale); + } + } + } + } + + if (grad_weight.size(0) > 0) { + if (threadIdx.x == 0) { + grad_weight[plane] = static_cast(dot_p * invstd); + } + } + + if (grad_bias.size(0) > 0) { + if (threadIdx.x == 0) { + grad_bias[plane] = static_cast(grad_output_sum); + } + } +} + +template +__global__ void batch_norm_reduce_statistics_kernel( + const GenericPackedTensorAccessor vec_mean, + const GenericPackedTensorAccessor vec_invstd, + GenericPackedTensorAccessor mean, + GenericPackedTensorAccessor invstd, + GenericPackedTensorAccessor running_mean, + GenericPackedTensorAccessor running_var, + const accscalar_t epsilon, + const accscalar_t momentum, + const GenericPackedTensorAccessor counts) { + + int feature_size = vec_mean.size(1); + int world_size = vec_mean.size(0); + + int bid = blockIdx.x; + int tid = threadIdx.x; + + // first the reductions each thread does separately + for (int i = bid*blockDim.x+tid; i < feature_size; i += gridDim.x*blockDim.x) { + accscalar_t avg = 0; + accscalar_t var_n = 0; + index_t n = 0; + for (int j = 0; j < world_size; j++) { + scalar_t count = counts[j]; + accscalar_t m = vec_mean[j][i]; + accscalar_t v = accscalar_t(1.0) / (vec_invstd[j][i]); + v = (v * v - epsilon) * count; + accscalar_t factor = 1.0 / (n + count); + var_n += v + (avg - m) * (avg - m) * n * count * factor; + avg = n * factor * avg + count * factor * m; + n += count; + } + mean[i] = avg; + invstd[i] = static_cast(1) / device_sqrt(var_n / n + epsilon); + if (running_mean.data() != NULL) { + running_mean[i] = static_cast((1 - momentum) * running_mean[i] + momentum * avg); + } + accscalar_t unbiasedVar = var_n / (n - 1); + if (running_var.data() != NULL) { + running_var[i] = static_cast((1 - momentum) * running_var[i] + momentum * unbiasedVar); + } + } + +} + +template +__global__ void batch_norm_backward_reduce_kernel( + const GenericPackedTensorAccessor input, + const GenericPackedTensorAccessor grad_output, + GenericPackedTensorAccessor mean, + GenericPackedTensorAccessor invstd, + GenericPackedTensorAccessor sum_dy, + GenericPackedTensorAccessor sum_dy_xmu, + GenericPackedTensorAccessor grad_weight, + GenericPackedTensorAccessor grad_bias) { + + index_t plane = blockIdx.x; + + stat_accscalar_t r_mean = mean[plane]; + stat_accscalar_t factor = invstd[plane]; + + GradOp> g(r_mean, input, grad_output); + auto res = reduce>(g, grad_output, plane); + + if (threadIdx.x == 0) { + if (grad_weight.size(0) > 0) { + grad_weight[plane] = static_cast(res.v2 * factor); + } + if (grad_bias.size(0) > 0) { + grad_bias[plane] = static_cast(res.v1); + } + if (sum_dy.size(0) > 0) { + sum_dy[plane] = static_cast(res.v1); + } + if (sum_dy_xmu.size(0) > 0) { + sum_dy_xmu[plane] = static_cast(res.v2); + } + } +} + +template +__device__ __forceinline__ void batch_norm_backward_elemt_kernel_impl( + const GenericPackedTensorAccessor input, + const GenericPackedTensorAccessor grad_output, + const GenericPackedTensorAccessor mean, + const GenericPackedTensorAccessor invstd, + const GenericPackedTensorAccessor weight, + const GenericPackedTensorAccessor sum_dy, + const GenericPackedTensorAccessor sum_dy_xmu, + GenericPackedTensorAccessor grad_input, + const stat_accscalar_t norm_fct) { + index_t plane = blockIdx.x; + + if (plane >= input.size(1)) { + return; + } + + stat_accscalar_t m_c = mean[plane]; + stat_accscalar_t m_dy_c = sum_dy[plane] * norm_fct; + stat_accscalar_t factor_1_c = invstd[plane]; + stat_accscalar_t factor_2_c = weight.size(0) > 0 ? static_cast(weight[plane]) : stat_accscalar_t(1); + factor_2_c *= factor_1_c; + factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[plane] * norm_fct; + + index_t bs = input.size(0); + index_t fs = input.size(2); + + index_t bstep = blockDim.y * gridDim.y; + for (index_t batch = threadIdx.y + blockIdx.y * blockDim.y; batch < bs; batch += bstep) { + auto g_i = grad_input[batch][plane]; + auto g_o = grad_output[batch][plane]; + auto i = input[batch][plane]; + for (index_t feature = threadIdx.x; feature < fs; feature += blockDim.x) { + g_i[feature] = static_cast((g_o[feature] - m_dy_c - (i[feature] - m_c) * factor_1_c) * factor_2_c); + } + } +} + +template +__global__ void batch_norm_backward_elemt_kernel( + const GenericPackedTensorAccessor input, + const GenericPackedTensorAccessor grad_output, + const GenericPackedTensorAccessor mean, + const GenericPackedTensorAccessor invstd, + const GenericPackedTensorAccessor weight, + const GenericPackedTensorAccessor sum_dy, + const GenericPackedTensorAccessor sum_dy_xmu, + GenericPackedTensorAccessor grad_input, + const int* __restrict__ numel, const int world_size) { + int64_t total_numel = 0; + for (int i = 0; i < world_size; i ++) { + total_numel += numel[i]; + } + + const stat_accscalar_t norm_fct = + static_cast(1) / static_cast(total_numel); + batch_norm_backward_elemt_kernel_impl( + input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct); +} + +template +__global__ void batch_norm_backward_elemt_kernel( + const GenericPackedTensorAccessor input, + const GenericPackedTensorAccessor grad_output, + const GenericPackedTensorAccessor mean, + const GenericPackedTensorAccessor invstd, + const GenericPackedTensorAccessor weight, + const GenericPackedTensorAccessor sum_dy, + const GenericPackedTensorAccessor sum_dy_xmu, + GenericPackedTensorAccessor grad_input, + const stat_accscalar_t norm_fct) { + batch_norm_backward_elemt_kernel_impl( + input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct); +} + +template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> +static GenericPackedTensorAccessor get_packed_accessor( + const Tensor& t, std::string_view var_name) { + constexpr auto expect_type = c10::CppTypeToScalarType>::value; + const auto actual_type = t.scalar_type(); + TORCH_CHECK(actual_type == expect_type, "Expected ", var_name, + " to have type ", expect_type, " but got ", actual_type); + return t.generic_packed_accessor(); +} + +template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> +static GenericPackedTensorAccessor packed_accessor_or_dummy( + const Tensor& t, std::string_view var_name) { + if (!t.defined()) { + const std::array zeros{{0}}; + return GenericPackedTensorAccessor(nullptr, zeros.data(), zeros.data()); + } + return get_packed_accessor(t, var_name); +} + +template +std::tuple batch_norm_backward_cuda_template(const Tensor& grad_out_, const Tensor& input_, const Tensor& weight_, + const Tensor& running_mean_, const Tensor& running_var_, const Tensor& save_mean_, const Tensor& save_invstd_, + bool train, double epsilon, std::array grad_input_mask) { + + using accscalar_t = at::acc_type; + Tensor grad_input_; + Tensor grad_input_reshaped; + Tensor grad_weight_; + Tensor grad_bias_; + auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); + auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes()); + + if (grad_input_mask[0]) { + grad_input_ = at::empty_like(input_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + grad_input_reshaped = grad_input_.view(input_reshaped.sizes()); + } + if (grad_input_mask[1]) { + grad_weight_ = at::empty_like(weight_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + if (grad_input_mask[2]) { + grad_bias_ = at::empty_like(weight_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + + auto input = get_packed_accessor< + const input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input"); + auto grad_output = get_packed_accessor< + const input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output"); + auto grad_input = packed_accessor_or_dummy< + input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input"); + auto weight = packed_accessor_or_dummy< + const stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight"); + auto grad_weight = packed_accessor_or_dummy< + stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_weight_, "grad_weight"); + auto grad_bias = packed_accessor_or_dummy< + stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_bias_, "grad_bias"); + auto running_mean = packed_accessor_or_dummy< + const stat_scalar_t, 1, DefaultPtrTraits, index_t>(running_mean_, "running_mean"); + auto running_var = packed_accessor_or_dummy< + const stat_scalar_t, 1, DefaultPtrTraits, index_t>(running_var_, "running_var"); + auto save_mean = packed_accessor_or_dummy< + const accscalar_t, 1, DefaultPtrTraits, index_t>(save_mean_, "save_mean"); + auto save_invstd = packed_accessor_or_dummy< + const accscalar_t, 1, DefaultPtrTraits, index_t>(save_invstd_, "save_invstd"); + + auto stream = at::cuda::getCurrentCUDAStream(); + dim3 blocks(input.size(1)); + int tf = getNumThreads(input.size(2)); + dim3 threads(tf, std::max(1, MAX_BLOCK_SIZE/tf)); + + batch_norm_backward_kernel <<>> + (input, grad_output, grad_input, grad_weight, grad_bias, weight, running_mean, running_var, + save_mean, save_invstd, train, epsilon); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return std::make_tuple(grad_input_, grad_weight_, grad_bias_); +} + +template +void batch_norm_stats_cuda_template( + const Tensor& out_mean, const Tensor& out_invstd, const Tensor& input_, double epsilon) { + + using accscalar_t = at::acc_type; + int64_t n_input = input_.size(1); + Tensor dummy_mean_; + Tensor dummy_var_; + auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions + + resize_output(out_mean, {n_input}); + resize_output(out_invstd, {n_input}); + auto input = get_packed_accessor< + const scalar_t, 3, RestrictPtrTraits, index_t>(input_reshaped, "input"); + TORCH_INTERNAL_ASSERT(out_invstd.dim() == 1 && out_invstd.is_contiguous() && + out_invstd.sizes()[0]); + TORCH_INTERNAL_ASSERT(out_mean.dim() == 1 && out_mean.is_contiguous() && + out_mean.sizes()[0]); + + auto mean = packed_accessor_or_dummy< + accscalar_t, 1, RestrictPtrTraits, index_t>(out_mean, "out_mean"); + auto invstd = packed_accessor_or_dummy< + accscalar_t, 1, RestrictPtrTraits, index_t>(out_invstd, "out_invstd"); + auto stream = at::cuda::getCurrentCUDAStream(); + + dim3 blocks(input.size(1)); + int tf = getNumThreads(input.size(2)); + dim3 threads(tf, std::max(1, MAX_BLOCK_SIZE/tf)); + batch_norm_collect_statistics_kernel <<>> + (input, epsilon, 0.0, mean, invstd); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void batch_norm_elemt_cuda_template(const Tensor& output_, const Tensor& input_, const Tensor& weight_, + const Tensor& bias_, const Tensor& mean_, const Tensor& invstd_) { + + using stat_accscalar_t = at::acc_type; + int64_t n_input = input_.size(1); + auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions + auto output_reshaped = output_.view({input_.size(0), input_.size(1), -1}); + + auto input = get_packed_accessor< + const input_scalar_t, 3, RestrictPtrTraits, index_t>(input_reshaped, "input"); + auto output = get_packed_accessor< + input_scalar_t, 3, RestrictPtrTraits, index_t>(output_reshaped, "output"); + auto weight = packed_accessor_or_dummy< + const stat_scalar_t, 1, RestrictPtrTraits, index_t>(weight_, "weight"); + auto bias = packed_accessor_or_dummy< + const stat_scalar_t, 1, RestrictPtrTraits, index_t>(bias_, "bias"); + auto mean = packed_accessor_or_dummy< + stat_accscalar_t, 1, RestrictPtrTraits, index_t>(mean_, "mean"); + auto invstd = packed_accessor_or_dummy< + stat_accscalar_t, 1, RestrictPtrTraits, index_t>(invstd_, "invstd"); + auto stream = at::cuda::getCurrentCUDAStream(); + + // NOTE: We use transform_input_kernel in training mode, which ignores epsilon + const double dummy_epsilon = 1e-5; + + // The input_transform kernel is pointwise, but we need to balance reading parameters (save_var/mean, + // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks + // and good occupancy. Quiet likely, we could go with even more blocks than 1024. + // The various planes are independent, so we use blocks for them. + int tf = std::max(getNumThreads(input.size(2)/4), + std::min(getNumThreads(input.size(2)), 64)); + int tb = std::max(64/tf, 1); + dim3 blocks_trans(input.size(1), std::max(1, std::min((256*1024)/input.size(1), + (input.size(0)+tb-1)/tb))); + blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE); + dim3 threads_trans(tf, tb); + batch_norm_transform_input_kernel <<>> + (input, output, mean, invstd, weight, bias, dummy_epsilon); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +std::tuple batch_norm_gather_stats_cuda_template(const Tensor& mean_, const Tensor& invstd_, + const Tensor& running_mean_, const Tensor& running_var_, + double momentum, double epsilon, const Tensor& counts_) { + + Tensor save_mean_; + Tensor save_invstd_; + + auto features = mean_.size(1); + auto input_options = mean_.options(); + if (mean_.scalar_type() == at::ScalarType::Half || mean_.scalar_type() == at::ScalarType::BFloat16) { + input_options = input_options.dtype(ScalarType::Float); + } + save_mean_ = at::empty({features}, input_options); + save_invstd_ = at::empty({features}, input_options); + + auto mean = packed_accessor_or_dummy< + accscalar_t, 2, RestrictPtrTraits, index_t>(mean_, "mean"); + auto invstd = packed_accessor_or_dummy< + accscalar_t, 2, RestrictPtrTraits, index_t>(invstd_, "invstd"); + auto running_mean = packed_accessor_or_dummy< + scalar_t, 1, RestrictPtrTraits, index_t>(running_mean_, "running_mean"); + auto running_var = packed_accessor_or_dummy< + scalar_t, 1, RestrictPtrTraits, index_t>(running_var_, "running_mean"); + auto counts = packed_accessor_or_dummy< + scalar_t, 1, RestrictPtrTraits, index_t>(counts_, "counts"); + + auto save_mean = get_packed_accessor< + accscalar_t, 1, RestrictPtrTraits, index_t>(save_mean_, "save_mean"); + auto save_invstd = get_packed_accessor< + accscalar_t, 1, RestrictPtrTraits, index_t>(save_invstd_, "save_invstd"); + auto stream = at::cuda::getCurrentCUDAStream(); + + int block = getNumThreads(features); + int grid = std::max(1, features/block); + batch_norm_reduce_statistics_kernel <<>> + (mean, invstd, save_mean, save_invstd, running_mean, running_var, epsilon, momentum, counts); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return std::make_tuple(save_mean_, save_invstd_); +} + +template +std::tuple batch_norm_backward_reduce_cuda_template(const Tensor& grad_out_, const Tensor& input_, + const Tensor& mean_, const Tensor& invstd_, const Tensor& weight_, + const bool input_g, const bool weight_g, const bool bias_g) { + + using stat_accscalar_t = at::acc_type; + int64_t n_input = input_.size(1); + Tensor sum_dy_; + Tensor sum_dy_xmu_; + Tensor grad_weight_; + Tensor grad_bias_; + auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions + auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes()); + + if (input_g) { + sum_dy_ = at::empty_like(mean_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + sum_dy_xmu_ = at::empty_like(mean_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + if (weight_g) { + grad_weight_ = at::empty({n_input}, weight_.options()); + } + if (bias_g) { + grad_bias_ = at::empty({n_input}, weight_.options()); + } + + auto input = get_packed_accessor< + input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input"); + auto grad_output = get_packed_accessor< + input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output"); + auto grad_weight = packed_accessor_or_dummy< + stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_weight_, "grad_weight"); + auto grad_bias = packed_accessor_or_dummy< + stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_bias_, "grad_bias"); + auto mean = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean"); + auto invstd = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd"); + auto sum_dy = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy"); + auto sum_dy_xmu = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu"); + + auto batch_size = input_reshaped.size(0); + auto feature_size = input_reshaped.size(2); + auto stream = at::cuda::getCurrentCUDAStream(); + + int warp_size = at::cuda::warp_size(); + int block_y = std::min(lastPow2(batch_size), MAX_BLOCK_SIZE/warp_size); + // We want block_x to be at least a warp width + int block_x = std::min(std::max(getNumThreads(feature_size), warp_size), MAX_BLOCK_SIZE/block_y); + const dim3 block(block_x, block_y); + const dim3 grid(n_input); + + batch_norm_backward_reduce_kernel <<>> + (input, grad_output, mean, invstd, sum_dy, sum_dy_xmu, grad_weight, grad_bias); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return std::make_tuple(sum_dy_, sum_dy_xmu_, grad_weight_, grad_bias_); +} + +template +Tensor batch_norm_backward_elemt_cuda_template(const Tensor& grad_out_, const Tensor& input_, + const Tensor& mean_, const Tensor& invstd_, + const Tensor& weight_, const Tensor& sum_dy_, const Tensor& sum_dy_xmu_) { + + using stat_accscalar_t = at::acc_type; + int64_t n_input = input_.size(1); + auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions + auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes()); + auto grad_input_reshaped = at::empty_like(input_reshaped, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + + auto input = get_packed_accessor< + input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input"); + auto grad_input = get_packed_accessor< + input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input"); + auto grad_output = get_packed_accessor< + input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output"); + auto mean = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean"); + auto invstd = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd"); + auto weight = packed_accessor_or_dummy< + stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight"); + auto sum_dy = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy"); + auto sum_dy_xmu = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu"); + + auto stream = at::cuda::getCurrentCUDAStream(); + + // The kernel is pointwise, but we need to balance reading parameters (save_var/mean, + // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks + // and good occupancy. Quiet likely, we could go with even more blocks than 1024. + // The various planes are independent, so we use blocks for them. + int tf = std::max(getNumThreads(input.size(2)/4), + std::min(getNumThreads(input.size(2)), 64)); + int tb = std::max(64/tf, 1); + dim3 blocks_trans(input.size(1), std::max(1, std::min((256*1024)/input.size(1), + (input.size(0)+tb-1)/tb))); + blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE); + dim3 threads_trans(tf, tb); + auto reduction_size = input_.numel() / n_input; + auto norm_fct = static_cast(1.0 / reduction_size); + batch_norm_backward_elemt_kernel + <<>> + (input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return grad_input_reshaped.view(input_.sizes()); +} + +template +Tensor batch_norm_backward_elemt_cuda_template(const Tensor& grad_out_, const Tensor& input_, + const Tensor& mean_, const Tensor& invstd_, + const Tensor& weight_, const Tensor& sum_dy_, const Tensor& sum_dy_xmu_, const Tensor& count) { + + using stat_accscalar_t = at::acc_type; + int64_t n_input = input_.size(1); + auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions + auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes()); + auto grad_input_reshaped = at::empty_like(input_reshaped, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + + auto input = get_packed_accessor< + input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input"); + auto grad_input = get_packed_accessor< + input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input"); + auto grad_output = get_packed_accessor< + input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output"); + auto mean = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean"); + auto invstd = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd"); + auto weight = packed_accessor_or_dummy< + stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight"); + auto sum_dy = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy"); + auto sum_dy_xmu = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu"); + + auto stream = at::cuda::getCurrentCUDAStream(); + + // The kernel is pointwise, but we need to balance reading parameters (save_var/mean, + // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks + // and good occupancy. Quiet likely, we could go with even more blocks than 1024. + // The various planes are independent, so we use blocks for them. + int tf = std::max(getNumThreads(input.size(2)/4), + std::min(getNumThreads(input.size(2)), 64)); + int tb = std::max(64/tf, 1); + dim3 blocks_trans(input.size(1), std::max(1, std::min((256*1024)/input.size(1), + (input.size(0)+tb-1)/tb))); + blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE); + dim3 threads_trans(tf, tb); + batch_norm_backward_elemt_kernel <<>> + (input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, count.const_data_ptr(), count.numel()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return grad_input_reshaped.view(input_.sizes()); +} + +// welford kernel for c last tensor calculating mean/biased_variance/unbiased_variance +// original apex name: welford_kernel_c_last +template + +__global__ void +batch_norm_collect_statistics_channels_last_kernel( + const scalar_t* __restrict__ input, + accscalar_t* __restrict__ out_mean, + accscalar_t* __restrict__ out_invstd, + volatile accscalar_t* staging_data, + int* semaphores, + const int reduction_size, + const int stride, + accscalar_t epsilon) { + // hide latency with concurrency + accscalar_t x_mean[PARALLEL_LOADS]; + accscalar_t m_2_n[PARALLEL_LOADS]; + int count[PARALLEL_LOADS]; + +#pragma unroll + for (int i = 0; i < PARALLEL_LOADS; i++) { + x_mean[i] = accscalar_t(0); + m_2_n[i] = accscalar_t(0); + count[i] = accscalar_t(0); + } + // tensor dimension (m,c) + + // loop along m dimension + int inner_loop_stride = blockDim.y * gridDim.y; + + // offset along m dimension + int m_offset = blockIdx.y * blockDim.y + threadIdx.y; + int c_offset = blockIdx.x * blockDim.x + threadIdx.x; + + int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); + int address_base = m_offset * stride + c_offset; + int address_increment = inner_loop_stride * stride; + + for (int i = 0; i < loop_count; i++) { + accscalar_t x_math[PARALLEL_LOADS]; + accscalar_t x_count_inv[PARALLEL_LOADS]; + accscalar_t is_valid[PARALLEL_LOADS]; + + // load multiple data in +#pragma unroll + for (int j = 0; j < PARALLEL_LOADS; j++) { + if (c_offset < stride && m_offset < reduction_size) { + x_math[j] = input[address_base]; + count[j]++; + x_count_inv[j] = accscalar_t(1) / count[j]; + is_valid[j] = accscalar_t(1); + } else { + x_math[j] = accscalar_t(0); + x_count_inv[j] = accscalar_t(0); + is_valid[j] = accscalar_t(0); + } + m_offset += inner_loop_stride; + address_base += address_increment; + } + + // calculate mean/m2n with welford +#pragma unroll + for (int j = 0; j < PARALLEL_LOADS; j++) { + accscalar_t delta0 = x_math[j] - x_mean[j]; + x_mean[j] += delta0 * x_count_inv[j]; + accscalar_t delta1 = x_math[j] - x_mean[j]; + m_2_n[j] += delta0 * delta1 * is_valid[j]; + } + } + + // thread reduction to accumulate mean/m_2_n/count between PARALLEL_LOADS +#pragma unroll + for (int j = 1; j < PARALLEL_LOADS; j++) { + welford_merge_element(count[0], x_mean[0], m_2_n[0], count[j], x_mean[j], m_2_n[j]); + } + + // release x_mean / m_2_n + auto mean_th = x_mean[0]; + auto m2_th = m_2_n[0]; + auto count_th = count[0]; + + // block-wise reduction with shared memory (since reduction cannot be done within a warp) + static __shared__ accscalar_t shmem_mean[MAX_BLOCK_SIZE]; + static __shared__ accscalar_t shmem_m2n[MAX_BLOCK_SIZE]; + static __shared__ int shmem_count[MAX_BLOCK_SIZE]; + + welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n); + + if (gridDim.y > 1) { + volatile accscalar_t* staging_mean = staging_data; + volatile accscalar_t* staging_m2n = &staging_data[stride*gridDim.y]; + volatile int* staging_count = reinterpret_cast(&staging_m2n[stride*gridDim.y]); + + address_base = c_offset + blockIdx.y * stride; + // write data to staging_data; + if (threadIdx.y == 0 && c_offset < stride) { + staging_mean[address_base] = mean_th; + staging_m2n[address_base] = m2_th; + staging_count[address_base] = count_th; + } + + __threadfence(); + __syncthreads(); // ensuring writes to staging_ is visible to all blocks + + __shared__ bool is_last_block_done; + // mark block done + if (threadIdx.x == 0 && threadIdx.y == 0) { + int old = atomicAdd(&semaphores[blockIdx.x], 1); + is_last_block_done = (old == (gridDim.y-1)); + } + + __syncthreads(); + + // check that all data is now available in global memory + if (is_last_block_done) { + count_th = 0; + mean_th = accscalar_t(0.0); + m2_th = accscalar_t(0.0); + + for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) { + address_base = c_offset + y * stride; + int count_new = c_offset < stride ? staging_count[address_base] : 0; + accscalar_t mean_new = c_offset < stride ? staging_mean[address_base] : accscalar_t(0.0); + accscalar_t m2n_new = c_offset < stride ? staging_m2n[address_base] : accscalar_t(0.0); + + welford_merge_element(count_th, mean_th, m2_th, count_new, mean_new, m2n_new); + } + + welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n); + if (threadIdx.y == 0 && c_offset < stride) { + out_mean[c_offset] = static_cast(mean_th); + out_invstd[c_offset] = VarTransform{}(m2_th/count_th, epsilon); + } + } + } else { + if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) { + out_mean[c_offset] = static_cast(mean_th); + out_invstd[c_offset] = VarTransform{}(m2_th/count_th, epsilon); + } + } +} + +// elementwise BN kernel +// original apex name: batchnorm_forward_c_last_kernel +template < + typename scalar_t, + typename accscalar_t, + typename layerscalar_t, + int PARALLEL_LOADS> +__global__ void batch_norm_transform_input_channels_last_kernel( + const scalar_t* __restrict__ input, + const scalar_t* __restrict__ z, + const accscalar_t* __restrict__ mean, + const accscalar_t* __restrict__ inv_std, + const layerscalar_t* __restrict__ weight, + const layerscalar_t* __restrict__ shift, + scalar_t* __restrict__ out, + const int reduction_size, + const int stride, + const bool fuse_relu) { + // tensor dimension (m,c) + // loop along m dimension + int inner_loop_stride = blockDim.y * gridDim.y; + + // offset along m dimension + int m_offset = blockIdx.y * blockDim.y + threadIdx.y; + int c_offset = blockIdx.x * blockDim.x + threadIdx.x; + + if (c_offset >= stride || m_offset >= reduction_size) { + return; + } + + auto m_c = mean[c_offset]; + auto inv_std_c = static_cast(inv_std[c_offset]); + auto w_c = weight == nullptr ? accscalar_t(1.0) : static_cast(weight[c_offset]); + auto s_c = shift == nullptr ? accscalar_t(0.0) : static_cast(shift[c_offset]); + + int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); + int address_base = m_offset * stride + c_offset; + int address_increment = inner_loop_stride * stride; + + for (int i = 0; i < loop_count; i++) { +#pragma unroll + for (int j = 0; j < PARALLEL_LOADS; j++) { + if (c_offset < stride && m_offset < reduction_size) { + auto tmp = w_c * (static_cast(input[address_base]) - m_c ) * inv_std_c + s_c; + if (z != nullptr) { + tmp += z[address_base]; + } + out[address_base] = (fuse_relu && tmp <= accscalar_t(0.0) ? scalar_t(0.0) : static_cast(tmp)); + } + m_offset += inner_loop_stride; + address_base += address_increment; + } + } +} + +template +__device__ __forceinline__ void merge_block_vertical_backward(T& sum_dy, + T& sum_dy_xmu, + T* shmem_sum_dy, + T* shmem_sum_dy_xmu) { + // write to shared memory + auto address_base = threadIdx.x + threadIdx.y * blockDim.x; + +#pragma unroll + for (int offset = blockDim.y/2; offset > 0; offset >>= 1) { + if (threadIdx.y < offset*2) { + shmem_sum_dy[address_base] = sum_dy; + shmem_sum_dy_xmu[address_base] = sum_dy_xmu; + } + __syncthreads(); + if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) { + auto address = address_base + offset * blockDim.x; + + sum_dy += shmem_sum_dy[address]; + sum_dy_xmu += shmem_sum_dy_xmu[address]; + } + } +} + +// batchnorm backward kernel for c last tensor +// original apex name: reduce_bn_c_last_kernel +template < + int PARALLEL_LOADS, + typename scalar_t, + typename accscalar_t, + typename layerscalar_t> +__global__ void batch_norm_backward_reduce_channels_last_kernel( + const scalar_t* __restrict__ input, + const scalar_t* __restrict__ grad_output, + const accscalar_t* __restrict__ mean, + const accscalar_t* __restrict__ inv_std, + accscalar_t* __restrict__ sum_dy_o, + accscalar_t* __restrict__ sum_dy_xmu_o, + layerscalar_t* __restrict__ grad_weight, + layerscalar_t* __restrict__ grad_bias, + volatile accscalar_t* staging_data, + int* semaphores, + const int reduction_size, + const int stride) { + + // hide latency with concurrency + accscalar_t sum_dy[PARALLEL_LOADS]; + accscalar_t sum_dy_xmu[PARALLEL_LOADS]; + +#pragma unroll + for (int i = 0; i < PARALLEL_LOADS; i++) { + sum_dy[i] = accscalar_t(0); + sum_dy_xmu[i] = accscalar_t(0); + } + // tensor dimension (m,c) + + // loop along m dimension + int inner_loop_stride = blockDim.y * gridDim.y; + + // offset along m dimension + int m_offset = blockIdx.y * blockDim.y + threadIdx.y; + int c_offset = blockIdx.x * blockDim.x + threadIdx.x; + + if (c_offset >= stride || m_offset >= reduction_size) { + return; + } + + int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); + int address_base = m_offset * stride + c_offset; + int address_increment = inner_loop_stride * stride; + + auto r_mean = mean[c_offset]; + auto factor = inv_std[c_offset]; + + for (int i = 0; i < loop_count; i++) { + accscalar_t x_input[PARALLEL_LOADS]; + accscalar_t x_grad_output[PARALLEL_LOADS]; + + // load multiple data in +#pragma unroll + for (int j = 0; j < PARALLEL_LOADS; j++) { + if (c_offset < stride && m_offset < reduction_size) { + x_input[j] = input[address_base]; + x_grad_output[j] = grad_output[address_base]; + } else { + x_input[j] = accscalar_t(0); + x_grad_output[j] = accscalar_t(0); + } + m_offset += inner_loop_stride; + address_base += address_increment; + } + + // calculate sum_dy / sum_dy_xmu +#pragma unroll + for (int j = 0; j < PARALLEL_LOADS; j++) { + sum_dy[j] += x_grad_output[j]; + sum_dy_xmu[j] += x_grad_output[j] * (x_input[j] - r_mean); + } + } + + // thread reduction to accumulate sum_dy / sum_dy_xmu between PARALLEL_LOADS +#pragma unroll + for (int j = 1; j < PARALLEL_LOADS; j++) { + sum_dy[0] += sum_dy[j]; + sum_dy_xmu[0] += sum_dy_xmu[j]; + } + + // release array of registers + auto sum_dy_th = sum_dy[0]; + auto sum_dy_xmu_th = sum_dy_xmu[0]; + + // block-wise reduction with shared memory (since reduction cannot be done within a warp) + static __shared__ accscalar_t shmem_sum_dy[MAX_BLOCK_SIZE]; + static __shared__ accscalar_t shmem_sum_dy_xmu[MAX_BLOCK_SIZE]; + + merge_block_vertical_backward(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu); + + if (gridDim.y > 1) { + volatile accscalar_t* staging_sum_dy = staging_data; + volatile accscalar_t* staging_sum_dy_xmu = &staging_data[stride*gridDim.y]; + + address_base = c_offset + blockIdx.y * stride; + // write data to staging_data; + if (threadIdx.y == 0 && c_offset < stride) { + staging_sum_dy[address_base] = sum_dy_th; + staging_sum_dy_xmu[address_base] = sum_dy_xmu_th; + } + + __threadfence(); + __syncthreads(); // ensuring writes to staging_ is visible to all blocks + + __shared__ bool is_last_block_done; + // mark block done + if (threadIdx.x == 0 && threadIdx.y == 0) { + int old = atomicAdd(&semaphores[blockIdx.x], 1); + is_last_block_done = (old == (gridDim.y-1)); + } + + __syncthreads(); + + // check that all data is now available in global memory + if (is_last_block_done) { + sum_dy_th = accscalar_t(0.0); + sum_dy_xmu_th = accscalar_t(0.0); + + for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) { + address_base = c_offset + y * stride; + sum_dy_th += (c_offset < stride ? staging_sum_dy[address_base] : accscalar_t(0.0)); + sum_dy_xmu_th += (c_offset < stride ? staging_sum_dy_xmu[address_base] : accscalar_t(0.0)); + } + + merge_block_vertical_backward(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu); + if (threadIdx.y == 0 && c_offset < stride) { + if (grad_bias != nullptr) { + grad_bias[c_offset] = static_cast(sum_dy_th); + } + if (grad_weight != nullptr) { + grad_weight[c_offset] = static_cast(sum_dy_xmu_th * factor); + } + //mean_dy[c_offset] = sum_dy_th / reduction_size; + //mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size; + sum_dy_o[c_offset] = sum_dy_th; + sum_dy_xmu_o[c_offset] = sum_dy_xmu_th; + } + } + } else { + if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) { + if (grad_bias != nullptr) { + grad_bias[c_offset] = static_cast(sum_dy_th); + } + if (grad_weight != nullptr) { + grad_weight[c_offset] = static_cast(sum_dy_xmu_th * factor); + } + //mean_dy[c_offset] = sum_dy_th / reduction_size; + //mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size; + sum_dy_o[c_offset] = sum_dy_th; + sum_dy_xmu_o[c_offset] = sum_dy_xmu_th; + } + } +} + +// elementwise BN kernel +// original apex name: batchnorm_backward_c_last_kernel +template < + int PARALLEL_LOADS, + typename scalar_t, + typename accscalar_t, + typename layerscalar_t> +__device__ __forceinline__ void batch_norm_backward_elemt_channels_last_kernel_impl( + const scalar_t* __restrict__ grad_output, + const scalar_t* __restrict__ input, + const accscalar_t* __restrict__ mean, + const accscalar_t* __restrict__ inv_std, + const layerscalar_t* __restrict__ weight, + const accscalar_t* __restrict__ sum_dy, + const accscalar_t* __restrict__ sum_dy_xmu, + scalar_t* __restrict__ grad_input, + const accscalar_t norm_fct, + const int reduction_size, + const int stride) { + // tensor dimension (m,c) + // loop along m dimension + int inner_loop_stride = blockDim.y * gridDim.y; + + // offset along m dimension + int m_offset = blockIdx.y * blockDim.y + threadIdx.y; + int c_offset = blockIdx.x * blockDim.x + threadIdx.x; + + if (c_offset >= stride || m_offset >= reduction_size) { + return; + } + + auto m_c = mean[c_offset]; + auto m_dy_c = sum_dy[c_offset] * norm_fct; + auto factor_1_c = inv_std[c_offset]; + auto factor_2_c = (weight == nullptr? accscalar_t(1.0) : static_cast(weight[c_offset])) * factor_1_c; + factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[c_offset] * norm_fct; + + int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); + int address_base = m_offset * stride + c_offset; + int address_increment = inner_loop_stride * stride; + + for (int i = 0; i < loop_count; i++) { +#pragma unroll + for (int j = 0; j < PARALLEL_LOADS; j++) { + if (c_offset < stride && m_offset < reduction_size) { + grad_input[address_base] = static_cast( + (static_cast(grad_output[address_base]) - m_dy_c - + (static_cast(input[address_base]) - m_c) * factor_1_c) + * factor_2_c); + } + m_offset += inner_loop_stride; + address_base += address_increment; + } + } +} + +template < + int PARALLEL_LOADS, + typename scalar_t, + typename accscalar_t, + typename layerscalar_t> +__global__ void batch_norm_backward_elemt_channels_last_kernel( + const scalar_t* __restrict__ grad_output, + const scalar_t* __restrict__ input, + const accscalar_t* __restrict__ mean, + const accscalar_t* __restrict__ inv_std, + const layerscalar_t* __restrict__ weight, + const accscalar_t* __restrict__ sum_dy, + const accscalar_t* __restrict__ sum_dy_xmu, + const int* __restrict__ numel, + scalar_t* __restrict__ grad_input, + const int64_t world_size, + const int reduction_size, + const int stride) { + + int64_t total_numel = 0; + for (int i = 0; i < world_size; i++) { + total_numel += numel[i]; + } + + auto norm_fct = static_cast(1) / static_cast(total_numel); + batch_norm_backward_elemt_channels_last_kernel_impl( + grad_output, input, mean, inv_std, weight, sum_dy, sum_dy_xmu, + grad_input, norm_fct, reduction_size, stride); +} + +template < + int PARALLEL_LOADS, + typename scalar_t, + typename accscalar_t, + typename layerscalar_t> +__global__ void batch_norm_backward_elemt_channels_last_kernel( + const scalar_t* __restrict__ grad_output, + const scalar_t* __restrict__ input, + const accscalar_t* __restrict__ mean, + const accscalar_t* __restrict__ inv_std, + const layerscalar_t* __restrict__ weight, + const accscalar_t* __restrict__ sum_dy, + const accscalar_t* __restrict__ sum_dy_xmu, + scalar_t* __restrict__ grad_input, + const accscalar_t norm_fct, + const int reduction_size, + const int stride) { + batch_norm_backward_elemt_channels_last_kernel_impl( + grad_output, input, mean, inv_std, weight, sum_dy, sum_dy_xmu, + grad_input, norm_fct, reduction_size, stride); +} + +template +void batch_norm_stats_channels_last_cuda_template( + const Tensor& out_mean, const Tensor& out_invstd, const Tensor& input, double epsilon) { + using accscalar_t = at::acc_type; + + const auto stride = input.sizes()[1]; + const auto reduction_size = input.numel() / stride; + + resize_output(out_mean, {stride}); + resize_output(out_invstd, {stride}); + TORCH_INTERNAL_ASSERT(out_invstd.dim() == 1 && out_invstd.is_contiguous() && + out_invstd.sizes()[0]); + TORCH_INTERNAL_ASSERT(out_mean.dim() == 1 && out_mean.is_contiguous() && + out_mean.sizes()[0]); + + dim3 block; + dim3 grid; + flexible_launch_configs(reduction_size, stride, block, grid, true); + + at::Tensor staging_data; + at::Tensor semaphores; + if (grid.y > 1) { + staging_data = at::empty({4*stride*grid.y}, out_mean.options()); + semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt)); + } + + accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.mutable_data_ptr() : nullptr; + int* semaphores_ptr = grid.y > 1 ? semaphores.mutable_data_ptr() : nullptr; + batch_norm_collect_statistics_channels_last_kernel + <<>>( + input.const_data_ptr(), + out_mean.mutable_data_ptr(), + out_invstd.mutable_data_ptr(), + staging_data_ptr, + semaphores_ptr, + reduction_size, + stride, + epsilon); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +void batch_norm_elemt_channels_last_cuda_template( + const at::Tensor& output, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& shift, // bias of BN + const at::Tensor& mean, + const at::Tensor& inv_std, + const std::optional& z = std::nullopt, // bias after BN + const bool fuse_relu = false) { + const auto stride = input.sizes()[1]; + const auto reduction_size = input.numel() / stride; + + dim3 block; + dim3 grid; + flexible_launch_configs(reduction_size, stride, block, grid); + + auto stream = at::cuda::getCurrentCUDAStream(); + const auto second_dtype = weight.defined() ? weight.scalar_type() : + (shift.defined() ? shift.scalar_type() : input.scalar_type()); + + if (input.scalar_type() != second_dtype) { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_forward", [&] { + using accscalar_t = at::acc_type; + batch_norm_transform_input_channels_last_kernel + <<>>( + input.const_data_ptr(), + z.has_value() ? z.value().const_data_ptr() : nullptr, + mean.const_data_ptr(), + inv_std.const_data_ptr(), + weight.defined() ? weight.const_data_ptr() : nullptr, + shift.defined() ? shift.const_data_ptr() : nullptr, + output.mutable_data_ptr(), + reduction_size, + stride, + fuse_relu); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + } else { + if (weight.defined()){ + TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_forward: input.scalar_type() ", input.scalar_type(), + " is not supported with weight.scalar_type() ", weight.scalar_type()); + } + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_forward", [&] { + using accscalar_t = at::acc_type; + batch_norm_transform_input_channels_last_kernel + <<>>( + input.const_data_ptr(), + z.has_value() ? z.value().const_data_ptr() : nullptr, + mean.const_data_ptr(), + inv_std.const_data_ptr(), + weight.defined() ? weight.const_data_ptr() : nullptr, + shift.defined() ? shift.const_data_ptr(): nullptr, + output.mutable_data_ptr(), + reduction_size, + stride, + fuse_relu); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + } +} + +std::tuple +batch_norm_backward_reduce_cuda_channels_last_template(const at::Tensor& grad_output, + const at::Tensor& input, + const at::Tensor& mean, + const at::Tensor& inv_std, + const at::Tensor& weight, + const bool input_g, const bool weight_g, const bool bias_g) { + const auto stride = input.sizes()[1]; + const auto reduction_size = input.numel() / stride; + + at::Tensor sumn_dy = at::empty({stride}, mean.options()); + at::Tensor sum_dy_xmu = at::empty({stride}, mean.options()); + + at::Tensor grad_weight; + at::Tensor grad_bias; + if (weight.defined()) { + grad_weight = at::empty({stride}, weight.options()); + grad_bias = at::empty({stride}, weight.options()); + } else { + // because I cannot return an uninitialized at::Tensor + grad_weight = at::empty({0}, mean.options()); + grad_bias = at::empty({0}, mean.options()); + } + + dim3 block; + dim3 grid; + flexible_launch_configs(reduction_size, stride, block, grid, true); + + at::Tensor staging_data; + at::Tensor semaphores; + if (grid.y > 1) { + staging_data = at::empty({2*stride*grid.y}, mean.options()); + semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt)); + } + auto stream = at::cuda::getCurrentCUDAStream(); + + if (weight.defined() && input.scalar_type() != weight.scalar_type()) { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_reduce", [&] { + using accscalar_t = at::acc_type; + accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.mutable_data_ptr() : nullptr; + int* semaphores_ptr = grid.y > 1 ? semaphores.mutable_data_ptr() : nullptr; + batch_norm_backward_reduce_channels_last_kernel + <<>>( + input.const_data_ptr(), + grad_output.const_data_ptr(), + mean.const_data_ptr(), + inv_std.const_data_ptr(), + sumn_dy.mutable_data_ptr(), + sum_dy_xmu.mutable_data_ptr(), + grad_weight.mutable_data_ptr(), + grad_bias.mutable_data_ptr(), + staging_data_ptr, + semaphores_ptr, + reduction_size, + stride); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + } else { + if (weight.defined()) { + TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_backward_reduce: input.scalar_type() ", input.scalar_type(), + " is not supported with weight.scalar_type() ", weight.scalar_type()); + } + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_reduce", [&] { + using accscalar_t = at::acc_type; + accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.mutable_data_ptr() : nullptr; + int* semaphores_ptr = grid.y > 1 ? semaphores.mutable_data_ptr() : nullptr; + batch_norm_backward_reduce_channels_last_kernel + <<>>( + input.const_data_ptr(), + grad_output.const_data_ptr(), + mean.const_data_ptr(), + inv_std.const_data_ptr(), + sumn_dy.mutable_data_ptr(), + sum_dy_xmu.mutable_data_ptr(), + weight.defined() ? grad_weight.mutable_data_ptr() : nullptr, + weight.defined() ? grad_bias.mutable_data_ptr() : nullptr, + staging_data_ptr, + semaphores_ptr, + reduction_size, + stride); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + } + + return std::make_tuple(sumn_dy, sum_dy_xmu, grad_weight, grad_bias); +} + +at::Tensor batch_norm_backward_elemt_channels_last_cuda_template( + const at::Tensor& grad_output, + const at::Tensor& input, + const at::Tensor& mean, + const at::Tensor& inv_std, + const at::Tensor& weight, + const at::Tensor& sum_dy, + const at::Tensor& sum_dy_xmu, + const at::Tensor& count) { + const auto stride = input.sizes()[1]; + const auto reduction_size = input.numel() / stride; + + // Input is guarunteed to be channels-last compatible + at::Tensor grad_input = at::empty_like(input); + + dim3 block; + dim3 grid; + flexible_launch_configs(reduction_size, stride, block, grid); + + auto stream = at::cuda::getCurrentCUDAStream(); + + if (weight.defined() && weight.scalar_type() != input.scalar_type()) { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_element", [&] { + using accscalar_t = at::acc_type; + batch_norm_backward_elemt_channels_last_kernel + <<>>( + grad_output.const_data_ptr(), + input.const_data_ptr(), + mean.const_data_ptr(), + inv_std.const_data_ptr(), + weight.const_data_ptr(), + sum_dy.const_data_ptr(), + sum_dy_xmu.const_data_ptr(), + count.const_data_ptr(), + grad_input.mutable_data_ptr(), + count.numel(), + reduction_size, + stride); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + } else { + if (weight.defined()) { + TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_backward_element: input.scalar_type() ", input.scalar_type(), + " is not supported with weight.scalar_type() ", weight.scalar_type()); + } + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "batchnorm_backward_element", [&] { + using accscalar_t = at::acc_type; + batch_norm_backward_elemt_channels_last_kernel + <<>>( + grad_output.const_data_ptr(), + input.const_data_ptr(), + mean.const_data_ptr(), + inv_std.const_data_ptr(), + weight.defined() ? weight.const_data_ptr() : nullptr, + sum_dy.const_data_ptr(), + sum_dy_xmu.const_data_ptr(), + count.const_data_ptr(), + grad_input.mutable_data_ptr(), + count.numel(), + reduction_size, + stride); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + } + + return grad_input; +} + +at::Tensor batch_norm_backward_elemt_channels_last_cuda_template( + const at::Tensor& grad_output, + const at::Tensor& input, + const at::Tensor& mean, + const at::Tensor& inv_std, + const at::Tensor& weight, + const at::Tensor& sum_dy, + const at::Tensor& sum_dy_xmu) { + const auto stride = input.sizes()[1]; + const auto reduction_size = input.numel() / stride; + auto norm_fct = 1.0 / reduction_size; + + // Input is guarunteed to be channels-last compatible + at::Tensor grad_input = at::empty_like(input); + + dim3 block; + dim3 grid; + flexible_launch_configs(reduction_size, stride, block, grid); + + auto stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_element", [&] { + using accscalar_t = at::acc_type; + + if (weight.defined() && weight.scalar_type() != input.scalar_type()) { + batch_norm_backward_elemt_channels_last_kernel + <<>>( + grad_output.const_data_ptr(), + input.const_data_ptr(), + mean.const_data_ptr(), + inv_std.const_data_ptr(), + weight.const_data_ptr(), + sum_dy.const_data_ptr(), + sum_dy_xmu.const_data_ptr(), + grad_input.mutable_data_ptr(), + static_cast(norm_fct), + reduction_size, + stride); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + batch_norm_backward_elemt_channels_last_kernel + <<>>( + grad_output.const_data_ptr(), + input.const_data_ptr(), + mean.const_data_ptr(), + inv_std.const_data_ptr(), + weight.defined() ? weight.const_data_ptr() : nullptr, + sum_dy.const_data_ptr(), + sum_dy_xmu.const_data_ptr(), + grad_input.mutable_data_ptr(), + static_cast(norm_fct), + reduction_size, + stride); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + }); + + return grad_input; +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/PersistentSoftmax.cuh b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/PersistentSoftmax.cuh new file mode 100644 index 0000000000000000000000000000000000000000..69895a3f22f69c8e9867e18ad397137c49ae6f8d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/PersistentSoftmax.cuh @@ -0,0 +1,402 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace { + +int log2_ceil(int value) { + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; +} + +template +struct Add { + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } +}; + +template +struct Max { + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? b : a; + } +}; + +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t* sum) { + ReduceOp r; + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); + } + } +} + +// The softmax_warp_* methods perform softmax forward and backward propagation on samples spanning the fast dimension. +// Each sample contains element_count scalar elements. element_count can be any integer value <= 1024. +// The template arguments have the following meaning: +// One "WARP" works on one "BATCH". One "BATCH" contains "WARP_BATCH" samples. +// WARP_BATCH is equal to 1 when element_count is large, and > 1 when element_count is small. +// A "WARP" contains "C10_WARPS_SIZE" threads, these treads are guaranteed to belong to the same warp. +// This is important because it means only __shfl_ instructions are required for reductions. +// Note that this means WARP_SIZE must be a power of two and <= architecture warp size. +// CUDA warp size is 32 for all existing GPU architectures, but there is no guarantee this will not change for future arch. +// ROCm warp size is 64 for all currently ROCm-supported GPU architectures, but this may change for future archs. +// is_log_softmax is a flag indicating whether SoftMax or LogSoftMax should be computed. +// is_masked is a flag indicating whether SoftMax or MaskedSoftMax should be computed. +// The template can be instantiated with any floating point type for the type arguments input_t, output_t and acc_t. +// This allows SoftMax to be fused with a cast immediately following the SoftMax. +// The mask should have the same shape as input, with a boolean indicate if the value is masked. +// The head_chunk_size is only used for transformer mask softmax, equals to H * D * D. +// For instance: +// input_t=half, acc_t=float, output_t=half => read half tensor, float accumulators, write half tensor. +// input_t=half, acc_t=float, output_t=float => read half tensor, float accumulators, write float tensor. +// input_t_float, acc_t=float, output_t=half => read float tensor, float accumulators, write half tensor. + +template +__global__ void softmax_warp_forward(output_t *dst, const input_t *src, int batch_size, int stride, int element_count, const bool *mask = nullptr, const int head_chunk_size = -1, bool is_transformer_mask = false) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + + int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + + // batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + int idx_offset = first_batch * stride + local_idx; + + src += idx_offset; + dst += idx_offset; + + if (is_transformer_mask) { + mask += ((first_batch * stride) / head_chunk_size) * stride + local_idx; + } else { + mask += idx_offset; + } + // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop, + // but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep + // the nested loops. + // This should have no impact on performance because the loops are unrolled anyway. + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + for (int it = 0; it < WARP_ITERATIONS; ++it) { + int element_index = local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + elements[i][it] = src[i*element_count+it*WARP_SIZE]; + } else { + elements[i][it] = -std::numeric_limits::infinity(); + } + } + } + + // compute max_value + acc_t max_value[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + bool is_meaningful_max = false; + max_value[i] = elements[i][0]; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + if (is_masked) { + int idx = it*WARP_SIZE; + if ((idx + local_idx) < batch_element_count) { + if (!is_transformer_mask) { + idx += i*element_count; + } + if (!mask[idx]) { + max_value[i] = (is_meaningful_max && max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + is_meaningful_max = true; + } + } + } else { + max_value[i] = max_value[i] > elements[i][it] ? max_value[i] : elements[i][it]; + } + } + if (is_masked) { + if (!is_meaningful_max) { + max_value[i] = -std::numeric_limits::infinity(); + } + } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH] { 0.0f }; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + if (!is_masked) { + if (is_log_softmax) { + sum[i] += std::exp(elements[i][it] - max_value[i]); + } else { + elements[i][it] = std::exp(elements[i][it] - max_value[i]); + sum[i] += elements[i][it]; + } + } else { + int idx = it*WARP_SIZE; + bool valid = (idx + local_idx) < batch_element_count; + if (!is_transformer_mask) { + idx += i*element_count; + } + if (valid) { + if (!mask[idx]) { + if (is_log_softmax) { + sum[i] += std::exp(elements[i][it] - max_value[i]); + } else { + elements[i][it] = std::exp(elements[i][it] - max_value[i]); + sum[i] += elements[i][it]; + } + } else { + if (!is_log_softmax) { + // Masked values are treated as -infinity, and std::exp(-infinity) is 0. + elements[i][it] = 0; + } + } + } else { + if (!is_log_softmax) { + elements[i][it] = 0.; + } + } + } + } + } + warp_reduce(sum); + + // store result + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + if (is_log_softmax) sum[i] = std::log(sum[i]); + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + int element_index = local_idx + it * WARP_SIZE; + if (element_index < element_count) { + if (is_log_softmax) { + dst[i*element_count+it*WARP_SIZE] = elements[i][it] - max_value[i] - sum[i]; + } else if (sum[i] == 0) { + dst[i*element_count+it*WARP_SIZE] = std::numeric_limits::quiet_NaN(); + } else { + dst[i*element_count+it*WARP_SIZE] = elements[i][it] / sum[i]; + } + } else { + break; + } + } + } +} + +template +__global__ void softmax_warp_backward(output_t *gradInput, const input_t *grad, const input_t *output, int batch_size, int stride, int element_count, const bool *mask = nullptr) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + + int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + + // batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x % WARP_SIZE; + + // the first element to process by the current thread + int thread_offset = first_batch * stride + local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + if (is_masked) { + mask += thread_offset; + } + + // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop, + // but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep + // the nested loops. + // This should have no impact on performance because the loops are unrolled anyway. + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]; + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + for (int it = 0; it < WARP_ITERATIONS; ++it) { + int element_index = local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + grad_reg[i][it] = grad[i*element_count+it*WARP_SIZE]; + output_reg[i][it] = output[i*element_count+it*WARP_SIZE]; + } else { + grad_reg[i][it] = acc_t(0); + output_reg[i][it] = acc_t(0); + } + } + } + + acc_t sum[WARP_BATCH] { 0.0f }; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + if (!is_masked || !mask[i*element_count+it*WARP_SIZE]) { + sum[i] += grad_reg[i][it]; + } + } + } + warp_reduce(sum); + + // store result + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + int element_index = local_idx + it * WARP_SIZE; + if (element_index < element_count) { + if (is_masked && mask[i*element_count+it*WARP_SIZE]) { + gradInput[i*element_count+it*WARP_SIZE] = 0; + } + // compute gradients + else if (is_log_softmax) { + gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]); + } else { + gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - output_reg[i][it] * sum[i]); + } + } + } + } +} + +} // end of anonymous namespace + +template +void dispatch_softmax_forward(output_t *dst, const input_t *src, int softmax_elements, int softmax_elements_stride, int batch_count, const bool *mask = nullptr, int chunk_size = -1, bool is_transformer_mask = false) +{ + TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 ); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. + int warp_size = at::cuda::warp_size(); + warp_size = (next_power_of_two < warp_size) ? next_power_of_two : warp_size; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = (batch_count + batches_per_block - 1) / batches_per_block; + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + #define LAUNCH_SOFTMAX_WARP_FORWARD(L2E) case L2E: \ + softmax_warp_forward \ + <<>>(dst, \ + src, batch_count, softmax_elements_stride, softmax_elements, mask, chunk_size, is_transformer_mask); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ + break; + + LAUNCH_SOFTMAX_WARP_FORWARD(0); // 1 + LAUNCH_SOFTMAX_WARP_FORWARD(1); // 2 + LAUNCH_SOFTMAX_WARP_FORWARD(2); // 4 + LAUNCH_SOFTMAX_WARP_FORWARD(3); // 8 + LAUNCH_SOFTMAX_WARP_FORWARD(4); // 16 + LAUNCH_SOFTMAX_WARP_FORWARD(5); // 32 + LAUNCH_SOFTMAX_WARP_FORWARD(6); // 64 + LAUNCH_SOFTMAX_WARP_FORWARD(7); // 128 + LAUNCH_SOFTMAX_WARP_FORWARD(8); // 256 + LAUNCH_SOFTMAX_WARP_FORWARD(9); // 512 + LAUNCH_SOFTMAX_WARP_FORWARD(10); // 1024 + LAUNCH_SOFTMAX_WARP_FORWARD(11); // 2048 + default: + break; + } + } +} + +template +void dispatch_softmax_backward(output_t *grad_input, const input_t *grad, const input_t *output, int softmax_elements, int softmax_elements_stride, int batch_count, const bool *mask = nullptr) +{ + TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 ); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. + int warp_size = at::cuda::warp_size(); + warp_size = (next_power_of_two < warp_size) ? next_power_of_two : warp_size; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = (batch_count + batches_per_block - 1) / batches_per_block; + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + #define LAUNCH_SOFTMAX_WARP_BACKWARD(L2E) case L2E: \ + softmax_warp_backward \ + <<>> \ + (grad_input, grad, output, batch_count, softmax_elements_stride, \ + softmax_elements, mask); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ + break; + + LAUNCH_SOFTMAX_WARP_BACKWARD(0); // 1 + LAUNCH_SOFTMAX_WARP_BACKWARD(1); // 2 + LAUNCH_SOFTMAX_WARP_BACKWARD(2); // 4 + LAUNCH_SOFTMAX_WARP_BACKWARD(3); // 8 + LAUNCH_SOFTMAX_WARP_BACKWARD(4); // 16 + LAUNCH_SOFTMAX_WARP_BACKWARD(5); // 32 + LAUNCH_SOFTMAX_WARP_BACKWARD(6); // 64 + LAUNCH_SOFTMAX_WARP_BACKWARD(7); // 128 + LAUNCH_SOFTMAX_WARP_BACKWARD(8); // 256 + LAUNCH_SOFTMAX_WARP_BACKWARD(9); // 512 + LAUNCH_SOFTMAX_WARP_BACKWARD(10); // 1024 + default: + break; + } + } +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/Pow.cuh b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/Pow.cuh new file mode 100644 index 0000000000000000000000000000000000000000..8573a373a3636fff430f066b00f593216b06e0eb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/Pow.cuh @@ -0,0 +1,58 @@ +#pragma once +#include +#include + +namespace at::native { + +namespace { + + +// SFINAE doesn't work well with NVCC under Windows for math functions like pow and sqrt. +// So we need to define the functions with the explicit function signatures. +// As for pow, the following signatures are defined as the device function: +// pow(float, int) +// pow(double, int) +// pow(float, float) +// pow(double, double) +#ifdef _MSC_VER +// Functions for pow +// pow for at::Half +static inline __host__ __device__ at::Half pow_(at::Half base, at::Half exp) { + return static_cast(std::pow(static_cast(base), static_cast(exp))); +} +// pow for at::BFloat16 +static inline __host__ __device__ at::BFloat16 pow_(at::BFloat16 base, at::BFloat16 exp) { + return static_cast(std::pow(static_cast(base), static_cast(exp))); +} +// pow (floating, floating/int) +template +static inline __host__ __device__ typename std::enable_if_t && (std::is_same_v || std::is_same_v), Base_type> + pow_(Base_type base, Exp_type exp) { + return std::pow(base, exp); +} +// pow (Otherwise) +template +static inline __host__ __device__ typename std::enable_if_t && !std::is_same_v, Base_type> + pow_(Base_type base, Exp_type exp) { + return static_cast(std::pow(static_cast(base), static_cast(exp))); +} +#else +template +static inline __host__ __device__ Base_type pow_(Base_type base, Exp_type exp) { + return ::pow(base, exp); +} +#endif + +template +static inline __host__ __device__ std::enable_if_t, T> pow_( + T base, T exp) { + return at::native::powi(base, exp); +} + +template +static inline __host__ __device__ c10::complex pow_(c10::complex base, c10::complex exp) { + return c10_complex_math::pow(base, exp); +} + +} // namespace +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/Randperm.cuh b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/Randperm.cuh new file mode 100644 index 0000000000000000000000000000000000000000..cd0f61a73277c2cec7959c04151c3f630699aebc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/Randperm.cuh @@ -0,0 +1,58 @@ +#include +#include +#include + +#include +#include +#include + +namespace { + +// See note [Algorithm of randperm] +template +__global__ void randperm_handle_duplicate_keys_kernel(T *keys, scalar_t *data, T mask, int n, at::PhiloxCudaState philox_args) { + int tid = threadIdx.x + blockDim.x * blockIdx.x; + + // find the beginning of islands + if (tid >= n - 1) return; // out of range + if ((keys[tid] & mask) != (keys[tid + 1] & mask)) return; // not in an island + if (tid != 0 && (keys[tid] & mask) == (keys[tid - 1] & mask)) return; // not the beginning of an island + + // find the size of islands + int island_size = 0; + do { island_size++; } + while ((tid + island_size < n) && (keys[tid + island_size] & mask) == (keys[tid] & mask)); + + // do random permutation inside each island. + data += tid; + const auto [seed, offset] = at::cuda::philox::unpack(philox_args); + curandStatePhilox4_32_10_t state; + curand_init(seed, tid, offset, &state); + for (int i = island_size - 1; i > 0; i--) { + unsigned int r = curand(&state) % (i + 1); + if (i != r) { + scalar_t tmp = data[i]; + data[i] = data[r]; + data[r] = tmp; + } + } +} + +// See note [Algorithm of randperm] +template +void randperm_handle_duplicate_keys(T *keys, scalar_t *data, int bits, int64_t n, std::optional &gen_) { + auto gen = at::get_generator_or_default(gen_, at::cuda::detail::getDefaultCUDAGenerator()); + int64_t counter_offset = n; + at::PhiloxCudaState rng_engine_inputs; + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + rng_engine_inputs = gen->philox_cuda_state(counter_offset); + } + T mask = static_cast((1UL << bits) - 1); + randperm_handle_duplicate_keys_kernel<<<(n + 511) / 512, 512, 0, at::cuda::getCurrentCUDAStream()>>>( + keys, data, mask, n, rng_engine_inputs); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/Reduce.cuh b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/Reduce.cuh new file mode 100644 index 0000000000000000000000000000000000000000..ceafc409611400afb89991cab3dd85e7a77819f2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/Reduce.cuh @@ -0,0 +1,1402 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace at::native { + +static inline int64_t div_up(int64_t a, int64_t b) { + return (a + b - 1) / b; +} + +// returns floor(log2(n)) +static inline int last_pow2(int n) { + n |= (n >> 1); + n |= (n >> 2); + n |= (n >> 4); + n |= (n >> 8); + n |= (n >> 16); + return std::max(1, n - (n >> 1)); +} + +// returns reduced fraction numerator & denominator +C10_HOST_DEVICE static void reduce_fraction(size_t &numerator, size_t &denominator) { + // get GCD of num and denom using Euclid's algorithm. + // Can replace this with std::gcd if we ever support c++17. + size_t a = denominator; + size_t b = numerator; + while (b != 0) { + a %= b; + // swap(a,b) + size_t tmp = a; + a = b; + b = tmp; + } + + // a is now the GCD + numerator /= a; + denominator /= a; +} + +//template for changing MAX_NUM_THREADS based on op dtype +template +struct mnt_wrapper { + static constexpr int MAX_NUM_THREADS = 512; +}; + +template <> +struct mnt_wrapper >{ + static constexpr int MAX_NUM_THREADS = 256; +}; + +constexpr int max_reduce_threads(c10::ScalarType type) { + return type == kComplexDouble ? 256 : 512; +} + +struct ReduceConfig { + static constexpr int BLOCK_X = 0; + static constexpr int BLOCK_Y = 1; + static constexpr int CTA = 2; + + ReduceConfig(int element_size_bytes, int num_outputs, int num_inputs) + : element_size_bytes(element_size_bytes) + , num_inputs(num_inputs) + , num_outputs(num_outputs) {} + int element_size_bytes; + int num_inputs; + int num_outputs; + int step_input = 1; + int step_output = 1; + int ctas_per_output = 1; + int input_mult[3] = {0, 0, 0}; + int output_mult[2] = {0, 0}; + + int block_width; + int block_height; + int num_threads; + + bool vectorize_input = false; + int output_vec_size = 1; + + template + void set_block_dimension(int64_t dim0, int64_t dim1) { + const int max_num_threads = mnt_wrapper::MAX_NUM_THREADS / output_vec_size; + int dim0_pow2 = dim0 < max_num_threads ? static_cast(last_pow2(dim0)) : max_num_threads; + int dim1_pow2 = dim1 < max_num_threads ? static_cast(last_pow2(dim1)) : max_num_threads; + block_width = std::min(dim0_pow2, int(at::cuda::warp_size())); + block_height = std::min(dim1_pow2, int(max_num_threads / block_width)); + block_width = std::min(dim0_pow2, int(max_num_threads / block_height)); + num_threads = block_width * block_height; + } + + int split_input(int parallelism) { + int step = step_input; + step_input *= parallelism; + return step; + } + + int split_output(int parallelism) { + int step = step_output; + step_output *= parallelism; + return step; + } + + dim3 block() const { + return dim3(block_width, block_height); + } + + dim3 grid() const { + return dim3(div_up(num_outputs / output_vec_size, step_output), ctas_per_output); + } + + C10_HOST_DEVICE bool should_block_x_reduce() const { + return input_mult[BLOCK_X] != 0; + } + + C10_HOST_DEVICE bool should_block_y_reduce() const { + return input_mult[BLOCK_Y] != 0; + } + + C10_HOST_DEVICE bool should_global_reduce() const { + return input_mult[CTA] != 0; + } + + C10_DEVICE bool should_store(int output_idx) const { + return output_idx < num_outputs && + (!should_block_x_reduce() || threadIdx.x == 0) && + (!should_block_y_reduce() || threadIdx.y == 0); + } + + C10_DEVICE bool should_reduce_tail() const { + return (!should_block_y_reduce() || threadIdx.y == 0) && + (!should_global_reduce() || blockIdx.y == 0); + } + + C10_HOST_DEVICE int input_idx() const { + int lane = threadIdx.x; + int warp = threadIdx.y; + int cta2 = blockIdx.y; + return (lane * input_mult[BLOCK_X] + + warp * input_mult[BLOCK_Y] + + cta2 * input_mult[CTA]); + } + + template + C10_HOST_DEVICE int output_idx() const { + int lane = threadIdx.x; + int warp = threadIdx.y; + int cta1 = blockIdx.x; + return (lane * output_mult[BLOCK_X] + + warp * output_mult[BLOCK_Y] + + cta1 * step_output) * output_vec_size; + } + + C10_DEVICE int shared_memory_offset(int offset) const { + return threadIdx.x + (threadIdx.y + offset) * blockDim.x; + } + + C10_DEVICE int staging_memory_offset(int cta2) const { + int offset = cta2 + blockIdx.x * gridDim.y; + if (!should_block_x_reduce()) { + offset = threadIdx.x + offset * blockDim.x; + } + return offset; + } + + int shared_memory_size() const { + if (!should_block_y_reduce() && + (!should_block_x_reduce() || + block_width <= at::cuda::warp_size())) { + return 0; + } + return element_size_bytes * num_threads * output_vec_size; + } + + int64_t global_memory_size() const { + if (!should_global_reduce()) { + return 0; + } + auto size = (int64_t)element_size_bytes * num_outputs * ctas_per_output; + if (!should_block_x_reduce()) { + size *= block().x * output_vec_size; + } + return size; + } + + int semaphore_size() const { + if (!should_global_reduce()) { + return 0; + } + return sizeof(int) * grid().x; + } + + int values_per_thread() const { + return div_up(num_inputs, step_input); + } +}; + +std::ostream& operator<<(std::ostream& out, const ReduceConfig& config); + +template +C10_LAUNCH_BOUNDS_2(nt, 4) +__global__ void reduce_kernel(R reduction) { + reduction.template run(); +} + +template +static OffsetCalculator<2, index_t> make_output_calculator(const TensorIterator& iter) { + int num_reduce_dims = iter.num_reduce_dims(); + int num_output_dims = iter.ndim() - num_reduce_dims; + int input_index = iter.ntensors() - 1; + int output_index = 0; + std::array strides = { + iter.strides(output_index).data() + num_reduce_dims, + iter.strides(input_index).data() + num_reduce_dims, + }; + auto shape = iter.shape().data() + num_reduce_dims; + return OffsetCalculator<2, index_t>(num_output_dims, shape, strides.data()); +} + +template +static OffsetCalculator<1, index_t> make_input_calculator(const TensorIterator& iter) { + int num_reduce_dims = iter.num_reduce_dims(); + int input_index = iter.ntensors() - 1; + std::array strides = { + iter.strides(input_index).data(), + }; + return OffsetCalculator<1, index_t>(num_reduce_dims, iter.shape().data(), strides.data()); +} + +template +struct func_wrapper_t { + using arg_t = typename binary_function_traits::arg1_t; + using scalar_t = typename binary_function_traits::arg2_t; + + func_t combine; + static inline __device__ out_scalar_t project(arg_t arg) { + return (out_scalar_t) arg; + } + static inline __device__ arg_t warp_shfl_down(arg_t arg, int offset) { + return WARP_SHFL_DOWN(arg, offset); + } + + static __device__ arg_t translate_idx(arg_t acc, int64_t /*idx*/) { + return acc; + } + + func_wrapper_t(const func_t& op) : combine(op) { + } + + // wrap a normal reduction that ignores the index + __device__ arg_t reduce(arg_t acc, scalar_t val, int64_t idx) const { + return combine(acc, val); + } +}; + +template +func_wrapper_t func_wrapper(const func_t& op) { + return func_wrapper_t { op }; +} + +template +struct ReduceJitOp { +//ReduceJitOp is almost like ReduceOp, but it doesn't have ops functor that specifies reduction operations +//Maybe we can find a way to unify ReduceOp and ReduceJitOp + using InputCalculator = OffsetCalculator<1, uint32_t>; + using OutputCalculator = OffsetCalculator<2, uint32_t>; + //TODO for now arg_t is always opmath_t of the input, later we'll need to change it + using arg_t = at::opmath_type; + + //TODO - ReduceJitOp will probably need to be changed for reductions that need full functor, + //not just wrapper + arg_t ident; + ReduceConfig config; + InputCalculator input_calc; + OutputCalculator output_calc; + const void* src; + const char* dst[2]; //it accepts at most two destinations + // acc_buf used for accumulation among sub Tensor Iterator when accumulation on + // output is not permissible + void* acc_buf; + // cta_buf used for accumulation between blocks during global reduction + void* cta_buf; + int* semaphores; + int64_t base_idx; + bool accumulate; + bool final_output; + int noutputs; + + ReduceJitOp( + ReduceConfig config, + InputCalculator input_calc, + OutputCalculator output_calc, + const void* src, + char* dst0, + std::optional dst1, + void* acc_buf, + void* cta_buf, + int* semaphores, + arg_t ident, + int noutputs, + int64_t base_idx) + : ident(ident), + config(config), + input_calc(input_calc), + output_calc(output_calc), + src(src), + acc_buf(acc_buf), + cta_buf(cta_buf), + semaphores(semaphores), + base_idx(base_idx), + noutputs(noutputs) { + dst[0] = dst0; + if (dst1.has_value()) { + dst[1] = dst1.value(); + } + } +}; + +template +struct ReduceOp { + using traits = function_traits; + using arg_t = typename std::decay::type>::type; + + using InputCalculator = OffsetCalculator<1, index_t>; + using OutputCalculator = OffsetCalculator<2, index_t>; + + static constexpr bool can_accumulate_in_output = + std::is_convertible_v + && std::is_convertible_v; + + ops_t ops; + arg_t ident; + ReduceConfig config; + InputCalculator input_calc; + OutputCalculator output_calc; + const void* src; + const char* dst[2]; //it accepts at most two destinations + // acc_buf used for accumulation among sub Tensor Iterator when accumulation on + // output is not permissible + void* acc_buf; + // cta_buf used for accumulation between blocks during global reduction + void* cta_buf; + int* semaphores; + int64_t base_idx; + bool accumulate; + bool final_output; + int noutputs; + + ReduceOp( + ops_t ops, + ReduceConfig config, + InputCalculator input_calc, + OutputCalculator output_calc, + const void* src, + char* dst0, + std::optional dst1, + void* acc_buf, + void* cta_buf, + int* semaphores, + arg_t ident, + int noutputs, + int64_t base_idx) + : ops(ops), + ident(ident), + config(config), + input_calc(input_calc), + output_calc(output_calc), + src(src), + acc_buf(acc_buf), + cta_buf(cta_buf), + semaphores(semaphores), + base_idx(base_idx), + noutputs(noutputs) { + dst[0] = dst0; + if (dst1.has_value()) { + dst[1] = dst1.value(); + } + } + + template + C10_DEVICE void run() const { + extern __shared__ char shared_memory[]; + index_t output_idx = config.output_idx(); + index_t input_idx = config.input_idx(); + auto base_offsets1 = output_calc.get(output_idx)[1]; + + using arg_vec_t = std::array; + arg_vec_t value; + + if (output_idx < config.num_outputs && input_idx < config.num_inputs) { + const scalar_t* input_slice = (const scalar_t*)((const char*)src + base_offsets1); + value = thread_reduce(input_slice); + } + + if (config.should_block_y_reduce()) { + value = block_y_reduce(value, shared_memory); + } + if (config.should_block_x_reduce()) { + value = block_x_reduce(value, shared_memory); + } + + using out_ptr_vec_t = std::array; + using offset_vec_t = std::array; + offset_vec_t base_offsets; + out_ptr_vec_t out; + + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + base_offsets[i] = output_calc.get(output_idx + i)[0]; + out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]); + } + + arg_vec_t* acc = nullptr; + if (acc_buf != nullptr) { + size_t numerator = sizeof(arg_t); + size_t denominator = sizeof(out_scalar_t); + reduce_fraction(numerator, denominator); + acc = (arg_vec_t*)((char*)acc_buf + (base_offsets[0] * numerator / denominator)); + } + + if (config.should_global_reduce()) { + value = global_reduce(value, acc, shared_memory); + } else if (config.should_store(output_idx)) { + if (accumulate) { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = ops.translate_idx(value[i], base_idx); + } + } + + if (acc == nullptr) { + if (accumulate) { + value = accumulate_in_output(out, value); + } + if (final_output) { + set_results_to_output(value, base_offsets); + } else { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + *(out[i]) = get_accumulated_output(out[i], value[i]); + } + } + } else { + if (accumulate) { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = ops.combine((*acc)[i], value[i]); + } + } + if (final_output) { + set_results_to_output(value, base_offsets); + } else { + *acc = value; + } + } + } + } + + template + C10_DEVICE std::array thread_reduce(const scalar_t* data) const { + if (config.vectorize_input) { + CUDA_KERNEL_ASSERT(output_vec_size == 1); + // reduce at the header of input_slice where memory is not aligned, + // so that thread_reduce will have an aligned memory to work on. + return {input_vectorized_thread_reduce_impl(data)}; + } else { + index_t element_stride = input_calc.strides_[0][0] / sizeof(scalar_t); + bool is_contiguous = (input_calc.dims == 1 && element_stride == 1); + if (is_contiguous) { + return thread_reduce_impl(data, [](index_t idx) { return idx; }); + } else if (input_calc.dims == 1) { + return thread_reduce_impl(data, [&](index_t idx) { return idx * element_stride; }); + } else { + return thread_reduce_impl(data, [&](index_t idx) { return input_calc.get(idx)[0] / sizeof(scalar_t); }); + } + } + } + + C10_DEVICE arg_t input_vectorized_thread_reduce_impl(const scalar_t* data) const { + index_t end = config.num_inputs; + + // Handle the head of input slice where data is not aligned + arg_t value = ident; + constexpr int align_bytes = alignof(at::native::memory::aligned_vector); + constexpr int align_elements = align_bytes / sizeof(scalar_t); + int shift = ((uint64_t)data) % align_bytes / sizeof(scalar_t); + if (shift > 0) { + data -= shift; + end += shift; + if(threadIdx.x >= shift && threadIdx.x < align_elements && config.should_reduce_tail()){ + value = ops.reduce(value, c10::load(data + threadIdx.x), threadIdx.x - shift); + } + end -= align_elements; + data += align_elements; + shift = align_elements - shift; + } + + // Do the vectorized reduction + using load_t = at::native::memory::aligned_vector; + + index_t idx = config.input_idx(); + const index_t stride = config.step_input; + + // Multiple accumulators to remove dependency between unrolled loops. + arg_t value_list[input_vec_size]; + value_list[0] = value; + + #pragma unroll + for (int i = 1; i < input_vec_size; i++) { + value_list[i] = ident; + } + + while (idx * input_vec_size + input_vec_size - 1 < end) { + const auto values_vec = memory::load_vector(data, idx); + #pragma unroll + for (index_t i = 0; i < input_vec_size; i++) { + value_list[i] = ops.reduce(value_list[i], values_vec.val[i], shift + idx * input_vec_size + i); + } + idx += stride; + } + + // tail + index_t tail_start = end - end % input_vec_size; + if (config.should_reduce_tail()) { + int idx = tail_start + threadIdx.x; + if (idx < end) { + const auto value = c10::load(data + idx); + value_list[0] = ops.reduce(value_list[0], value, idx + shift); + } + } + + // combine accumulators + #pragma unroll + for (int i = 1; i < input_vec_size; i++) { + value_list[0] = ops.combine(value_list[0], value_list[i]); + } + return value_list[0]; + } + + template + C10_DEVICE std::array thread_reduce_impl(const scalar_t* data_, offset_calc_t calc) const { + index_t idx = config.input_idx(); + const index_t end = config.num_inputs; + const index_t stride = config.step_input; + + using arg_vec_t = std::array; + using load_t = at::native::memory::aligned_vector; + + // Multiple accumulators to remove dependency between unrolled loops. + arg_vec_t value_list[vt0]; + + #pragma unroll + for (int i = 0; i < vt0; i++) { + #pragma unroll + for (int j = 0; j < output_vec_size; j++) { + value_list[i][j] = ident; + } + } + + load_t values[vt0]; + + while (idx + (vt0 - 1) * stride < end) { + #pragma unroll + for (index_t i = 0; i < vt0; i++) { + const auto offset = calc(idx + i * stride) / output_vec_size; + values[i] = memory::load_vector(data_, offset); + } + #pragma unroll + for (index_t i = 0; i < vt0; i++) { + #pragma unroll + for (index_t j = 0; j < output_vec_size; j++) { + value_list[i][j] = ops.reduce(value_list[i][j], values[i].val[j], idx + i * stride); + } + } + idx += stride * vt0; + } + + // tail + int idx_ = idx; + #pragma unroll + for (index_t i = 0; i < vt0; i++) { + if (idx >= end) { + break; + } + const auto offset = calc(idx) / output_vec_size; + values[i] = memory::load_vector(data_, offset); + idx += stride; + } + idx = idx_; + #pragma unroll + for (index_t i = 0; i < vt0; i++) { + if (idx >= end) { + break; + } + #pragma unroll + for (index_t j = 0; j < output_vec_size; j++) { + value_list[i][j] = ops.reduce(value_list[i][j], values[i].val[j], idx); + } + idx += stride; + } + + // combine accumulators + #pragma unroll + for (int i = 1; i < vt0; i++) { + #pragma unroll + for (index_t j = 0; j < output_vec_size; j++) { + value_list[0][j] = ops.combine(value_list[0][j], value_list[i][j]); + } + } + return value_list[0]; + } + + template + C10_DEVICE std::array block_x_reduce(std::array value, char* shared_memory) const { + using args_vec_t = std::array; + int dim_x = blockDim.x; + args_vec_t* shared = (args_vec_t*)shared_memory; + if (dim_x > warpSize) { + int address_base = threadIdx.x + threadIdx.y*blockDim.x; + shared[address_base] = value; + for (int offset = dim_x/2; offset >= warpSize; offset >>= 1) { + __syncthreads(); + if (threadIdx.x < offset && threadIdx.x + offset < blockDim.x) { + args_vec_t other = shared[address_base + offset]; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = ops.combine(value[i], other[i]); + } + shared[address_base] = value; + } + } + dim_x = warpSize; + } + + __syncthreads(); + + for (int offset = 1; offset < dim_x; offset <<= 1) { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + arg_t other = ops.warp_shfl_down(value[i], offset); + value[i] = ops.combine(value[i], other); + } + } + return value; + } + + template + C10_DEVICE std::array block_y_reduce(std::array value, char* shared_memory) const { + using args_vec_t = std::array; + args_vec_t* shared = (args_vec_t*)shared_memory; + shared[config.shared_memory_offset(0)] = value; + for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) { + __syncthreads(); + if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) { + args_vec_t other = shared[config.shared_memory_offset(offset)]; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = ops.combine(value[i], other[i]); + } + shared[config.shared_memory_offset(0)] = value; + } + } + return value; + } + + C10_DEVICE bool mark_block_finished() const { + __shared__ bool is_last_block_done_shared; + + __syncthreads(); + if (threadIdx.x == 0 && threadIdx.y == 0) { + int prev_blocks_finished = atomicAdd(&semaphores[blockIdx.x], 1); + is_last_block_done_shared = (prev_blocks_finished == gridDim.y - 1); + } + + __syncthreads(); + + return is_last_block_done_shared; + } + + template + C10_DEVICE std::array accumulate_in_output( + std::array out, + std::array value, + typename std::enable_if_t* = nullptr + ) const { + std::array ret; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + ret[i] = ops.combine(*(out[i]), value[i]); + } + return ret; + } + + template + C10_DEVICE out_scalar_t get_accumulated_output( + out_scalar_t* out, arg_t value, + typename std::enable_if_t* = nullptr + ) const { + CUDA_KERNEL_ASSERT(!final_output); + return (out_scalar_t)value; + } + + // This function should never be called -- + // it's the version of `accumulate_in_output` + // when accumulation in the output is not possible. + template + C10_DEVICE std::array accumulate_in_output( + std::array, + std::array, + typename std::enable_if_t* = nullptr + ) const { + CUDA_KERNEL_ASSERT(false); + return {arg_t{}}; + } + + // This function should never be called -- + // it's the version of `get_accumulated_output` + // when accumulation in the output is not possible. + template + C10_DEVICE out_scalar_t get_accumulated_output( + out_scalar_t* out, arg_t value, + typename std::enable_if_t* = nullptr + ) const { + CUDA_KERNEL_ASSERT(false); + return *out; + } + + template + C10_DEVICE void set_results(const T x, const index_t base_offset) const { + CUDA_KERNEL_ASSERT(noutputs == 1); + auto res = (out_scalar_t*)((char*)dst[0] + base_offset); + *res = x; + } + + //Currently implemented for max of two outputs + template + C10_DEVICE void set_results(const thrust::pair x, const index_t base_offset) const { + if (noutputs >= 1) { + auto res0 = (T1*)((char*)dst[0] + base_offset); + *res0 = x.first; + } + if (noutputs >= 2) { + // base offset is computed assuming element size being sizeof(T1), so we need to make a + // correction to obtain the correct base offset + auto res1 = (T2*) ((char *) dst[1] + base_offset / sizeof(T1) * sizeof(T2)); + *res1 = x.second; + } + } + + template + C10_DEVICE void set_results_to_output(std::array value, std::array base_offset) const { + CUDA_KERNEL_ASSERT(final_output); + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + set_results(ops.project(value[i]), base_offset[i]); + } + } + + template + C10_DEVICE std::array global_reduce(std::array value, std::array *acc, char* shared_memory) const { + using arg_vec_t = std::array; + using out_ptr_vec_t = std::array; + using offset_vec_t = std::array; + + arg_vec_t* reduce_buffer = (arg_vec_t*)cta_buf; + index_t output_idx = config.output_idx(); + offset_vec_t base_offsets; + out_ptr_vec_t out; + + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + base_offsets[i] = output_calc.get(output_idx + i)[0]; + out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]); + } + + bool should_store = config.should_store(output_idx); + if (should_store) { + index_t offset = config.staging_memory_offset(blockIdx.y); + reduce_buffer[offset] = value; + } + + __threadfence(); // make sure writes are globally visible + __syncthreads(); // if multiple warps in this block wrote to staging, make sure they're all done + bool is_last_block_done = mark_block_finished(); + + if (is_last_block_done) { + __threadfence(); // complete the acquire pattern after atomic + for (auto &v : value) { + v = ident; + } + if (config.should_block_x_reduce()) { + index_t input_offset = threadIdx.x + threadIdx.y * blockDim.x; + index_t step = blockDim.x * blockDim.y; + for (; input_offset < config.ctas_per_output; input_offset += step) { + index_t idx = config.staging_memory_offset(input_offset); + arg_vec_t next = reduce_buffer[idx]; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = ops.combine(value[i], next[i]); + } + } + } else { + index_t input_offset = threadIdx.y; + index_t step = blockDim.y; + for (; input_offset < config.ctas_per_output; input_offset += step) { + index_t idx = config.staging_memory_offset(input_offset); + arg_vec_t next = reduce_buffer[idx]; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = ops.combine(value[i], next[i]); + } + } + } + value = block_y_reduce(value, shared_memory); + if (config.should_block_x_reduce()) { + value = block_x_reduce(value, shared_memory); + } + if (should_store) { + if (accumulate) { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = ops.translate_idx(value[i], base_idx); + } + } + + if (acc == nullptr) { + if (accumulate) { + value = accumulate_in_output(out, value); + } + if (final_output) { + set_results_to_output(value, base_offsets); + } else { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + *(out[i]) = get_accumulated_output(out[i], value[i]); + } + } + } else { + if (accumulate) { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = ops.combine((*acc)[i], value[i]); + } + } + if (final_output) { + set_results_to_output(value, base_offsets); + } else { + *acc = value; + } + } + } + } + + return value; + } +}; + +template +static void launch_reduce_kernel(const ReduceConfig& config, const R& reduction) { + dim3 block = config.block(); + dim3 grid = config.grid(); + + auto stream = at::cuda::getCurrentCUDAStream(); + int shared_memory = config.shared_memory_size(); + + switch(config.output_vec_size) { + case 4: + reduce_kernel<<>>(reduction); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + case 2: + reduce_kernel<<>>(reduction); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + default: + reduce_kernel<<>>(reduction); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } +} + +inline void launch_jitted_reduce_kernel( + std::mutex &jiterator_mutex, + std::array &fn_cache, + const at::cuda::jit::KernelDescriptor &desc, + int vt0, const ReduceConfig& config, const void *reduction) { + dim3 block = config.block(); + dim3 grid = config.grid(); + + int shared_memory = config.shared_memory_size(); + at::cuda::jit::NvrtcFunction* fn_ptr; + switch(config.output_vec_size) { + case 4: + fn_ptr = &fn_cache[0]; + break; + case 2: + fn_ptr = &fn_cache[1]; + break; + default: + fn_ptr = &fn_cache[2]; + } + if (!fn_ptr->function) { + int max_threads_codegen = + max_reduce_threads(desc.f_inputs_type) / config.output_vec_size; + auto code = at::cuda::jit::generate_reduction_code( + desc, vt0, true, false, config.output_vec_size, max_threads_codegen); + + *fn_ptr = at::cuda::jit::jit_pwise_function(code, "reduction_" + desc.name); + } + constexpr int kernel_args = 1; + const void* args[kernel_args]; + args[0] = reduction; + at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, grid, block, shared_memory); +} + + +class AccumulationBuffer { + public: + AccumulationBuffer() {} + + AccumulationBuffer(size_t acc_t_size, size_t out_t_size, char* out_ptr, int64_t size) { + out_ptr_ = (char*)out_ptr; + if (out_t_size >= acc_t_size) { + // reusing output buffer for accumulation. + acc_ptr_ = (char*)out_ptr; + numerator_ = 1; + denominator_ = 1; + } else { + auto& allocator = *c10::cuda::CUDACachingAllocator::get(); + buffer_ = allocator.allocate(size); + acc_ptr_ = (char*)buffer_.get(); + numerator_ = acc_t_size; + denominator_ = out_t_size; + reduce_fraction(numerator_, denominator_); + } + } + + char* get_acc_slice(char* out_ptr) { + if (acc_ptr_ == nullptr) { + return nullptr; + } + return acc_ptr_ + ((out_ptr - out_ptr_) * numerator_ / denominator_); + } + + private: + char* acc_ptr_ = nullptr; + char* out_ptr_ = nullptr; + size_t numerator_; + size_t denominator_; + at::DataPtr buffer_; +}; + +template +int get_output_vec_size(const TensorIterator &iter) { + int vec_size = 4; + auto update_vec_size = [&vec_size](uint64_t n) { + while(n % vec_size != 0) { + vec_size /= 2; + } + }; + + uint64_t base_address = reinterpret_cast(iter.data_ptr(iter.noutputs())) / sizeof(scalar_t); + update_vec_size(base_address); + + const int output_index = iter.num_reduce_dims(); + update_vec_size(iter.shape()[output_index]); + + int j = 0; + for(auto i : iter.strides(iter.noutputs())) { + if (j != output_index) { + update_vec_size(i / sizeof(scalar_t)); + } + j++; + } + return vec_size; +} + +template +ReduceConfig setReduceConfig(const TensorIterator& iter){ + // Start by assuming that each thread handles a single output and all + // the inputs for that output. + int64_t num_outputs = iter.num_output_elements(); + int64_t inputs_per_output = iter.numel() / num_outputs; + int input_index = iter.ntensors() - 1; + + auto config = ReduceConfig(sizeof(arg_t), num_outputs, inputs_per_output); + + int64_t dim0; + int64_t dim1; + int64_t fastest_moving_stride; + bool reduction_on_fastest_striding_dimension; + + if (iter.ndim() > 0) { + // Adjust block size to map block width to fastest changing dimension of input + // tensor. This grants the best possible memory accessing pattern, given that + // for non-contiguous tensor with space in between, we cannot have perfect + // memory coalescing. + reduction_on_fastest_striding_dimension = + (iter.num_reduce_dims() == iter.ndim()) || + (iter.strides(/*arg=*/input_index)[0] < + iter.strides(/*arg=*/input_index)[iter.num_reduce_dims()]); + // Notice that dim0 & dim1 does NOT guarantee any launch configuration here! + // dim0 & dim1 are more like the upper bound of the block dimension. The + // actual launch config and reduction scheme is determined by setting values + // to `config.input_mult` and `config.output_mult`. + // We try to max out dim1 so that we have enough threads per CTA to deliver + // performance for larger problem size. + if (reduction_on_fastest_striding_dimension) { + // Map block.x to the fastest reducing dimension. It implies: + // 1. block_x_reduce is required. + // 2. block.y now max out to num_outputs. + dim0 = inputs_per_output; + dim1 = num_outputs; + fastest_moving_stride = iter.strides(/*arg=*/input_index)[0]; + } else { + // Map block.x to the fastest non reducing dimension. It implies: + // 1. block_x_reduce is turned off. + // 2. block.y now max out to inputs_per_output. + dim0 = num_outputs; + dim1 = inputs_per_output; + fastest_moving_stride = iter.strides(/*arg=*/input_index)[iter.num_reduce_dims()]; + } + } else { + reduction_on_fastest_striding_dimension = true; + fastest_moving_stride = sizeof(scalar_t); + dim0 = 1; + dim1 = 1; + } + + // We do vectorization to gain better memory access, there are two cases which we call + // "vectorize along input" and "vectorize along output". Note that the "input/output" + // here does not mean we are vectorizing load/store instructions. We always only vectorize + // load instructions. + // + // Case 1: "vectorize along input" + // This case happens when we are reducing along fastest moving dimesion. In such case, threads + // with the same threadIdx.y works on the same reduction cooperatively and will produce results + // for the same output. In such case, values in each loaded vector always correspond to the same output. + // + // Case 2: "vectorize along output" + // This case happens when the fastest moving dimesion is not the dimension of reduction. In such case, + // threads with different threadIdx.x are independent and will produce results for different outputs. + // In such case, values in each loaded vector always correspond to different outputs. + if (fastest_moving_stride == sizeof(scalar_t)) { +#ifdef USE_ROCM + if (reduction_on_fastest_striding_dimension && dim0 > 128 && iter.num_reduce_dims() == 1) { +#else + if (reduction_on_fastest_striding_dimension && dim0 > 128 && iter.num_reduce_dims() == 1 && vt0 >= input_vec_size) { +#endif + // Case 1: "vectorize along input" + // Note that if vt0 < ReduceConfig::vec_size, then this means the register pressure could be high, in such case, + // we should avoid vectorization. + config.vectorize_input = true; + dim0 /= input_vec_size; + } else if (!reduction_on_fastest_striding_dimension) { + // Case 2: "vectorize along output" + config.output_vec_size = get_output_vec_size(iter); + dim0 /= config.output_vec_size; + } + } + + // Adjust block_width and block_height + config.set_block_dimension(dim0, dim1); + + int block_width = config.block_width; + int block_height = config.block_height; + + if (iter.ndim() == 0 || reduction_on_fastest_striding_dimension) { + // Split the input across lanes if the input is contiguous in the reduced + // dimension. This will require reduction between threads using warp + // shuffle instructions and shared memory (if block_width > warpSize). + config.input_mult[0] = config.split_input(block_width); + } else { + // Otherwise split the output across lanes in a warp. + config.output_mult[0] = config.split_output(block_width); + } + + constexpr int min_values_per_thread = 16; + constexpr int max_values_per_thread = 256; + + const int warp_split_threshold = + std::min(block_height * 16, max_values_per_thread); + bool split_across_warps = config.values_per_thread() >= warp_split_threshold; + const int num_mp = + at::cuda::getCurrentDeviceProperties()->multiProcessorCount; +#ifdef USE_ROCM + bool force_splitting_output = iter.ndim() == 2 && + reduction_on_fastest_striding_dimension && + config.values_per_thread() < 1024 && num_mp < 100; + split_across_warps = !force_splitting_output && split_across_warps; +#endif + + if (split_across_warps) { + // Divide the input across warps in a thread-block, if that leaves at least + // 16 elements to be summed by each thread. This will require inter-warp + // reduction using shared memory. + config.input_mult[1] = config.split_input(block_height); + } else { + // Otherwise, each warp handles a separate output. + config.output_mult[1] = config.split_output(block_height); + } + + int max_threads_per_mp = + at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor; +#ifdef USE_ROCM + // If the grid consists of a single threadblock, do not change the max threads per + // MP value. This will increase the parallelism across the y dimension of the grid. + bool uses_a_single_block = config.grid().x == config.grid().y == config.grid().z == 1; + + if (!uses_a_single_block) { + // Control the number of threadblocks by adjusting the maximum number of + // threads per multi-processor. These numbers better reflect the maximum + // theoretical achievable threads per MP for the reduction operation. + if (iter.ndim() == 1 || iter.ndim() == 3) + max_threads_per_mp = 512; + else if (iter.ndim() == 2) + max_threads_per_mp = 256; + } +#endif + const int blocks_per_sm = max_threads_per_mp / config.num_threads; + const int target_grid_size = num_mp * blocks_per_sm; + int grid = config.grid().x; + if (config.input_mult[1] != 0 && config.values_per_thread() >= max_values_per_thread && grid <= target_grid_size) { + // Divide the input across thread-blocks if the amount of work per-thread + // is large enough and the size of the output is small enough. This will + // require a reduction using global memory. + // If we decide to split input across blocks, as long as we can get enough + // number of blocks (`target_grid_size`) to balance SM, we should still + // make the number of values per thread large for best performance. + int ctas_per_output1 = div_up(target_grid_size, grid); + int ctas_per_output2 = div_up(config.values_per_thread(), min_values_per_thread); + int ctas_per_output3 = div_up(config.values_per_thread(), max_values_per_thread); + // We want the minimum of ctas_per_output1 and ctas_per_output2, so that each thread can have + // a large number of values to deal with. But we don't want values_per_thread to be larger than + // max_values_per_thread + config.ctas_per_output = std::max(std::min(ctas_per_output1, ctas_per_output2), ctas_per_output3); +#ifdef USE_ROCM + // In cases where a number of threadblocks along the y direction of the grid + // is needed then make sure they are reduced to the number of MPs. For + // smaller sizes, use half the number of MPs. For smaller sizes than half + // the number of MPs use the original value unless the value is less than 16 + // blocks in which case it is more profitable to use just 1 block. + if (config.ctas_per_output > num_mp) + if (num_mp < 128) + config.ctas_per_output = + num_mp * (config.ctas_per_output > 512 ? 4 : 2); + else + config.ctas_per_output = num_mp; + else if (config.ctas_per_output > div_up(num_mp, 2)) + config.ctas_per_output = div_up(num_mp, 2); + else if (config.ctas_per_output < 16) + config.ctas_per_output = 1; + bool is_channel_last = iter.tensor_base(1).is_contiguous(at::MemoryFormat::ChannelsLast); + if (iter.ndim() == 3 && !reduction_on_fastest_striding_dimension && !is_channel_last) + config.ctas_per_output = 4; +#endif + if (config.ctas_per_output > 1) { + config.input_mult[2] = config.split_input(config.ctas_per_output); + } + } + return config; +}; + +template +inline void gpu_reduce_kernel(TensorIterator& iter, const ops_t& ops, ident_t ident=0, + AccumulationBuffer* acc_buf_ptr=nullptr, int64_t base_idx=0) { + AT_ASSERT(iter.numel() > 0 && iter.ntensors() - iter.noutputs() == 1 && iter.noutputs() >= 1); + + using traits = function_traits; + using arg_t = typename traits::template arg<0>::type; + // at::Half/at::ComplexHalf overflows easily as it's range is very small. + // So when scalar_t and out_scalar_t are at::Half/at::ComplexHalf, we + // set can_accumulate_in_output to False. + static constexpr bool is_inp_out_type_half_or_chalf = + (std::is_same_v && + std::is_same_v) || + (std::is_same_v, scalar_t> && + std::is_same_v, out_scalar_t>); + // at::BFloat16 has lower precision and can lead to rounding errors. + // So when scalar_t and out_scalar_t are at::BFloat16, we + // set can_accumulate_in_output to False. + static constexpr bool is_inp_out_type_bfloat16 = + (std::is_same_v && + std::is_same_v); + static constexpr bool can_accumulate_in_output = + std::is_convertible_v && + !(is_inp_out_type_half_or_chalf || is_inp_out_type_bfloat16); + + bool can_use_32bit_indexing = iter.can_use_32bit_indexing(); + std::unique_ptr owned_buf_ptr; + // The acc_buf_ptr is a shared pointer. It is create at the first entrance and + // reused by all recursive function calls. + if (acc_buf_ptr == NULL) { + // acc_buf_ptr holds buffer used for accumulation among multiple sub_iter + // when accumulation in output is not possible. + if (!can_accumulate_in_output && !can_use_32bit_indexing) { + int64_t output_memory_size = iter.element_size(0); + for (int dim = 0; dim < iter.ndim(); dim++) { + output_memory_size = std::max(output_memory_size, iter.shape()[dim] * iter.strides(0)[dim]); + } + output_memory_size /= iter.element_size(0); //iter.strides is in bytes + owned_buf_ptr.reset(new AccumulationBuffer(sizeof(arg_t), + sizeof(out_scalar_t), + (char*) iter.data_ptr(0), + output_memory_size * sizeof(arg_t))); + } else { + owned_buf_ptr.reset(new AccumulationBuffer()); + } + acc_buf_ptr = owned_buf_ptr.get(); + } + + if (!can_use_32bit_indexing) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + int64_t sub_iter_base_idx = sub_iter.view_offsets()[0]; + + gpu_reduce_kernel(sub_iter, ops, ident, + acc_buf_ptr, sub_iter_base_idx); + } + return; + } + + const char* in_data = (char*)iter.data_ptr(iter.ntensors() - 1); + char* out_data = (char*)iter.data_ptr(0); + const auto noutputs = iter.noutputs(); + std::optional out_data_extra; + if (noutputs > 1) { + out_data_extra = (char*)iter.data_ptr(1); + } else { + out_data_extra = std::nullopt; + } + char* acc_data = acc_buf_ptr->get_acc_slice(out_data); + + ReduceConfig config = setReduceConfig(iter); + at::DataPtr buffer; + at::DataPtr semaphores; + if (config.should_global_reduce()) { + auto& allocator = *c10::cuda::CUDACachingAllocator::get(); + buffer = allocator.allocate(config.global_memory_size()); + semaphores = allocator.allocate(config.semaphore_size()); + + auto stream = at::cuda::getCurrentCUDAStream(); + AT_CUDA_CHECK(cudaMemsetAsync(semaphores.get(), 0, config.semaphore_size(), stream)); + } + + AT_ASSERT(can_use_32bit_indexing); + auto output_calc = make_output_calculator(iter); + auto input_calc = make_input_calculator(iter); + auto reduce = ReduceOp( + ops, + config, + input_calc, + output_calc, + in_data, + out_data, + out_data_extra, + acc_data, + buffer.get(), + (int*)semaphores.get(), + ident, + noutputs, + base_idx); + reduce.accumulate = iter.should_accumulate(); + reduce.final_output = iter.is_final_output(); + + launch_reduce_kernel::MAX_NUM_THREADS>(config, reduce); +} + +//TODO this is 100 lines of almost-copy-paste, because we have to have different template args for this function +//try unifying with gpu_reduce_kernel +template +inline void jitted_gpu_reduce_kernel(TensorIterator& iter, const std::string& func, ident_t ident=0, + AccumulationBuffer* acc_buf_ptr=nullptr, int64_t base_idx=0) { + AT_ASSERT(iter.numel() > 0 && iter.ntensors() - iter.noutputs() == 1 && iter.noutputs() >= 1); + + //TODO - this will be different for more complicated reductions, but for now reductions using + //func_wrapper all have arg_t = opmath + using arg_t = at::opmath_type; + // at::Half/at::ComplexHalf overflows easily as it's range is very small. + // So when scalar_t and out_scalar_t are at::Half/at::ComplexHalf, we + // set can_accumulate_in_output to False. + static constexpr bool is_inp_out_type_half_or_chalf = + (std::is_same_v && + std::is_same_v ) || + (std::is_same_v, scalar_t> && + std::is_same_v, out_scalar_t>); + // at::BFloat16 has lower precision and can lead to rounding errors. + // So when scalar_t and out_scalar_t are at::BFloat16, we + // set can_accumulate_in_output to False. + static constexpr bool is_inp_out_type_bfloat16 = + (std::is_same_v && + std::is_same_v); + static constexpr bool can_accumulate_in_output = + std::is_convertible_v && + !(is_inp_out_type_half_or_chalf || is_inp_out_type_bfloat16); + + bool can_use_32bit_indexing = iter.can_use_32bit_indexing(); + std::unique_ptr owned_buf_ptr; + + // The acc_buf_ptr is a shared pointer. It is create at the first entrance and + // reused by all recursive function calls. + if (acc_buf_ptr == NULL) { + // acc_buf_ptr holds buffer used for accumulation among multiple sub_iter + // when accumulation in output is not possible. + if (!can_accumulate_in_output && !can_use_32bit_indexing) { + int64_t output_memory_size = iter.element_size(0); + for (int dim = 0; dim < iter.ndim(); dim++) { + output_memory_size = std::max(output_memory_size, iter.shape()[dim] * iter.strides(0)[dim]); + } + output_memory_size /= iter.element_size(0); //iter.strides is in bytes + owned_buf_ptr.reset(new AccumulationBuffer(sizeof(out_scalar_t), //TODO + sizeof(out_scalar_t), + (char*) iter.data_ptr(0), + output_memory_size * sizeof(out_scalar_t))); //TODO + } else { + owned_buf_ptr.reset(new AccumulationBuffer()); + } + acc_buf_ptr = owned_buf_ptr.get(); + } + + if (!can_use_32bit_indexing) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + int64_t sub_iter_base_idx = sub_iter.view_offsets()[0]; + + jitted_gpu_reduce_kernel(sub_iter, func, ident, + acc_buf_ptr, sub_iter_base_idx); + } + return; + } + + //TODO - for now we support a single input, we may be able to relax this constraint + const char* in_data = (char*)iter.data_ptr(iter.ntensors() - 1); + char* out_data = (char*)iter.data_ptr(0); + const auto noutputs = iter.noutputs(); + std::optional out_data_extra; + if (noutputs > 1) { + out_data_extra = (char*)iter.data_ptr(1); + } else { + out_data_extra = std::nullopt; + } + char* acc_data = acc_buf_ptr->get_acc_slice(out_data); + + ReduceConfig config = setReduceConfig(iter); + + at::DataPtr buffer; + at::DataPtr semaphores; + if (config.should_global_reduce()) { + auto& allocator = *c10::cuda::CUDACachingAllocator::get(); + buffer = allocator.allocate(config.global_memory_size()); + semaphores = allocator.allocate(config.semaphore_size()); + + auto stream = at::cuda::getCurrentCUDAStream(); + AT_CUDA_CHECK(cudaMemsetAsync(semaphores.get(), 0, config.semaphore_size(), stream)); + } + + AT_ASSERT(can_use_32bit_indexing); + auto output_calc = make_output_calculator(iter); + auto input_calc = make_input_calculator(iter); + auto reduce = ReduceJitOp( + config, + input_calc, + output_calc, + in_data, + out_data, + out_data_extra, + acc_data, + buffer.get(), + (int*)semaphores.get(), + ident, + noutputs, + base_idx); + reduce.accumulate = iter.should_accumulate(); + reduce.final_output = iter.is_final_output(); + + constexpr int nInputs = 1; + constexpr int nOutputs = 1; + static auto desc = at::cuda::jit::make_kernel_descriptor< + out_scalar_t, scalar_t>(name, func, nInputs, nOutputs); + + static std::mutex jiterator_mutex; + static std::vector> fn_cache(c10::cuda::device_count()); + auto &cache = fn_cache[iter.device().index()]; + + launch_jitted_reduce_kernel( + jiterator_mutex, cache, desc, vt0, config, &reduce); +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/ReduceOps.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/ReduceOps.h new file mode 100644 index 0000000000000000000000000000000000000000..7deae24af5576040634c2ae29ccba239f5c992d4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/ReduceOps.h @@ -0,0 +1,20 @@ + +namespace at { +struct TensorIterator; +} + +namespace c10 { +class Scalar; +} + +namespace at::native { + +void norm_launch_kernel(TensorIterator &iter, double val); +void min_launch_kernel(TensorIterator &iter); +void max_launch_kernel(TensorIterator &iter); +void aminmax_launch_kernel(TensorIterator &iter); +void min_all_launch_kernel(TensorIterator &iter); +void max_all_launch_kernel(TensorIterator &iter); +void aminmax_allreduce_launch_kernel(TensorIterator &iter); + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/Resize.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/Resize.h new file mode 100644 index 0000000000000000000000000000000000000000..95b3014d4ed82eab44a2c955a9a0e7af0e579ba9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/Resize.h @@ -0,0 +1,53 @@ +#pragma once + +#include +#include + +#include + +namespace at::native { + +TORCH_CUDA_CPP_API void resize_bytes_cuda(StorageImpl* storage, size_t size_bytes); + +static inline void maybe_resize_storage_cuda(TensorImpl* self, size_t new_size_bytes) { + // It does not make sense to try to resize a storage + // to hold 0 elements, and this can break + // if storage_offset is positive but + // new_size is 0, so just bail in that case + // (same comment is in Resize.h) + if (self->numel() == 0) { + return; + } + + const Storage &storage = self->unsafe_storage(); + TORCH_CHECK(storage, "Tensor: invalid null storage"); + if (new_size_bytes > storage.nbytes()) { + resize_bytes_cuda(storage.unsafeGetStorageImpl(), new_size_bytes); + } +} + +inline TensorImpl* resize_impl_cuda_( + TensorImpl* self, + IntArrayRef size, + at::OptionalIntArrayRef stride) { + if (self->sizes() == size && (!stride || self->strides() == stride)) { + return self; + } + const auto itemsize = self->dtype().itemsize(); + const auto storage_offset = self->storage_offset(); + size_t storage_size = 1; + if (stride) { + self->set_sizes_and_strides(size, *stride); + storage_size = at::detail::computeStorageNbytes( + size, *stride, itemsize, storage_offset); + } else { + self->set_sizes_contiguous(size); + storage_size = at::detail::computeStorageNbytesContiguous( + size, itemsize, storage_offset); + } + maybe_resize_storage_cuda(self, storage_size); + + return self; +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/RowwiseScaledMM.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/RowwiseScaledMM.h new file mode 100644 index 0000000000000000000000000000000000000000..3c9836bc8aa6d1240c27afbf811ad9a9a34bcf70 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/RowwiseScaledMM.h @@ -0,0 +1,14 @@ +#pragma once +#include +#include + +namespace at::cuda::detail { +TORCH_API void f8f8bf16_rowwise( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, // FP32 + at::Tensor w_scale, // FP32 + std::optional bias, // BF16 + bool use_fast_accum, + at::Tensor& out); +} // namespace at::cuda::detail diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/ScaledGroupMM.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/ScaledGroupMM.h new file mode 100644 index 0000000000000000000000000000000000000000..a2dafda0aaeccbaa6b8391d801bcd2853e89a516 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/ScaledGroupMM.h @@ -0,0 +1,15 @@ +#pragma once +#include +#include + +namespace at::cuda::detail { +TORCH_API void f8f8bf16_grouped_mm( + at::Tensor mat_a, // FP8 + at::Tensor mat_b, // FP8 + at::Tensor scale_a, // FP32 + at::Tensor scale_b, // FP32 + std::optional offs, + std::optional bias, // BF16 + bool use_fast_accum, + at::Tensor& out); +} // namespace at::cuda::detail diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/ScanKernels.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/ScanKernels.h new file mode 100644 index 0000000000000000000000000000000000000000..fbc3d974cf9684205036fa4d71f4fbebbfd1ed52 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/ScanKernels.h @@ -0,0 +1,18 @@ +#pragma once +#include + +namespace at { +class TensorBase; + +namespace native { + +// NOTE: these functions require output tensors to be contiguous +void launch_cummax_cuda_kernel(const TensorBase& self, const TensorBase& values, + const TensorBase& indices, int64_t dim); +void launch_cummin_cuda_kernel(const TensorBase& self, const TensorBase& values, + const TensorBase& indices, int64_t dim); +void launch_logcumsumexp_cuda_kernel(const TensorBase& result, const TensorBase& self, int64_t dim); +void launch_cumsum_cuda_kernel(const TensorBase& result, const TensorBase& self, int64_t dim); +void launch_cumprod_cuda_kernel(const TensorBase& result, const TensorBase& self, int64_t dim); + +}} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/ScanUtils.cuh b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/ScanUtils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..210b0ceaceb900f67f11a052bbed15cfca675309 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/ScanUtils.cuh @@ -0,0 +1,475 @@ +#pragma once +#include +#include +#include +#include + +#include +#include +#include + +namespace at::native { + +template +constexpr inline integer ceil_div(integer n, integer m) { + return (n + m - 1) / m; +} + +template +constexpr inline integer get_log_num_threads_x_inner_scan(integer num_rows, integer row_size) { + integer log_num_threads_x = 0; + integer log_num_threads_y = 0; + while (((integer)1 << log_num_threads_x) < row_size) { + ++log_num_threads_x; + } + while (((integer)1 << log_num_threads_y) < num_rows) { + ++log_num_threads_y; + } + // we want to keep the ratio between the x-threads and y-threads about the same as + // the ratio between the row_size and num_rows, but the total number of threads in + // a block should be about 512 + integer diff = log_num_threads_x - log_num_threads_y; + // 9 is from log2(512) + log_num_threads_x = ((integer)9 + diff) / (integer)2; + // I found that in having larger log_num_threads_x can give significant speed up in some cases, + // but detrimental in another case, so just keep the lower bound to be log2(16) == 4 to make it + // similar to the previous implementation + // Keeping the upper bound to be log2(512) == 9 as the maximum number of threads in a block. + log_num_threads_x = std::min(std::max((integer)4, log_num_threads_x), (integer)9); + return log_num_threads_x; +} + +template +__device__ void binary_op_update(const scalar_t lhs, scalar_t& rhs, const idx_t lhs_idx, idx_t& rhs_idx, BinaryOperation binary_op) { + if(!at::_isnan(rhs) && (at::_isnan(lhs) || !binary_op(rhs, lhs))) { + rhs = lhs; + rhs_idx = lhs_idx; + } +} +/* Perform an inclusive scan along the innermost dimension of a tensor. + * + * - num_rows is the size of the flattened outer dimensions; + * - row_size is the size of the innermost dimension; + * + * The outer dimensions of the tensor are considered as a single dimension, i.e. the tensor is + * considered as having 'num_rows' rows of size 'row_size'. + * Each thread block processes one or more sets of contiguous rows (processing multiple rows + * per thread block is quicker than processing a single row, especially for short rows). + */ +template +__global__ void tensor_kernel_scan_innermost_dim_with_indices(const scalar_t *self_, scalar_t *values_, int64_t *indices_, + int num_rows, int row_size, + const uint32_t num_threads, const uint32_t log_num_threads_x, + scalar_t init, BinaryFunction binary_op) { + // dynamic memory allocation for vbuf and ibuf + alignas(sizeof(double)) extern __shared__ char buf[]; + scalar_t* vbuf = reinterpret_cast(buf); // the size is num_threads * 2 + int64_t* ibuf = reinterpret_cast(vbuf + num_threads * 2); + const uint32_t num_threads_x = 1 << log_num_threads_x; + scalar_t* row_buf = vbuf + 2 * num_threads_x * threadIdx.y; + int64_t* row_idx_buf = ibuf + 2 * num_threads_x * threadIdx.y; + + for (int block_row = blockIdx.x * blockDim.y; + block_row < num_rows; + block_row += blockDim.y * gridDim.x) { + int row = block_row + threadIdx.y; + const scalar_t *row_self = self_ + row * row_size; + scalar_t *row_values = values_ + row * row_size; + int64_t *row_indices = indices_ + row * row_size; + scalar_t block_total = init; + int64_t block_idx_final = 0; + const bool row_exists = row < num_rows; + // Perform scan on one block at a time, keeping track of the total value of + // all blocks processed so far. + for (int block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) { + // Load data into shared memory (two values per thread). + int col1 = block_col + threadIdx.x; + int col2 = block_col + num_threads_x + threadIdx.x; + if (row_exists) { + if (col1 < row_size) { + row_buf[threadIdx.x] = c10::load(&row_self[col1]); + row_idx_buf[threadIdx.x] = col1; + } else { + row_buf[threadIdx.x] = init; + // No need to set the index here as the value in init will never be selected + } + + if (col2 < row_size) { + row_buf[num_threads_x + threadIdx.x] = c10::load(&row_self[col2]); + row_idx_buf[num_threads_x + threadIdx.x] = col2; + } else { + row_buf[num_threads_x + threadIdx.x] = init; + // No need to set the index here as the value in init will never be selected + } + + // Add the total value of all previous blocks to the first value of this block. + if (threadIdx.x == 0) { + binary_op_update(block_total, row_buf[0], block_idx_final, row_idx_buf[0], binary_op); + } + } + __syncthreads(); + + // Parallel reduction with Sklansky method. The diagram can be seen on this paper: + // https://research.nvidia.com/publication/single-pass-parallel-prefix-scan-decoupled-look-back + for (uint32_t s = 1; s <= num_threads_x; s <<= 1) { + if (row_exists) { + uint32_t a = (threadIdx.x / s) * (2 * s) + s; + uint32_t ti = a + (threadIdx.x % s); + uint32_t si = a - 1; + binary_op_update(row_buf[si], row_buf[ti], row_idx_buf[si], row_idx_buf[ti], binary_op); + } + __syncthreads(); + } + + // Write back to output. + if (row_exists) { + if (col1 < row_size){ + row_values[col1] = row_buf[threadIdx.x]; + row_indices[col1] = row_idx_buf[threadIdx.x]; + } + if (col2 < row_size) { + row_values[col2] = row_buf[num_threads_x + threadIdx.x]; + row_indices[col2] = row_idx_buf[num_threads_x + threadIdx.x]; + } + } + block_total = row_buf[2 * num_threads_x - 1]; + block_idx_final = row_idx_buf[2 * num_threads_x - 1]; + __syncthreads(); + } + } +} + +/* Perform an inclusive scan along an outer dimension of a tensor. + * + * - num_orows is the size of the flattened outer dimensions; + * - num_irows is the size of the flattened inner dimensions; + * - row_size is the size of the dimension along which to compute the variance; + * + * The dimensions to the outside and inside of the specified dimension are considered as flattened. + * Thread blocks with the same blockIdx.y process an "outer row" (i.e. an element of the flattened + * outer dimensions, which contains several "inner rows"). + * Each thread processes a single inner row at a time. + */ +template +__global__ void tensor_kernel_scan_outer_dim_with_indices(const scalar_t *self_, scalar_t *values_, int64_t *indices_, + const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size, scalar_t init, BinaryFunction binary_op) { + for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) { + for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) { + const scalar_t *self = self_ + orow * row_size * num_irows + irow; + scalar_t *values = values_ + orow * row_size * num_irows + irow; + int64_t *indices = indices_ + orow * row_size * num_irows + irow; + scalar_t out = init; + int64_t out_idx = 0; + + for (auto col = decltype(row_size){0}; col < row_size; ++col) { + const auto val = c10::load(self); + if(at::_isnan(val) || (!at::_isnan(out) && binary_op(val, out))) { + out = val; + out_idx = col; + } + *values = out; + *indices = out_idx; + self += num_irows; + values += num_irows; + indices += num_irows; + } + } + } +} + +inline void check_fits_in_unsigned(int64_t val, const char* name) { + constexpr auto umax = std::numeric_limits::max(); + TORCH_CHECK( + val >= 0 && val <= umax, name, " must fit in a 32-bit uint32_t value"); +} + + +template +__host__ void scan_outer_dim_with_indices( + const TensorBase& self, const TensorBase& values, const TensorBase& indices, + int dim, scalar_t init, BinaryFunction binary_op) { + int64_t row_size = self.size(dim); + auto sizes = self.sizes(); + + // Treat all outer dimensions (i.e. dim_ < dim) as one. + const int64_t num_orows = c10::multiply_integers(sizes.begin(), sizes.begin() + dim); + + // Treat all inner dimensions (i.e. dim > dimension) as one. + const int64_t num_irows = c10::multiply_integers(sizes.begin() + dim + 1, sizes.end()); + //for performance reasons, cuda kernels use uint32_t for loops over irows, orows and row, + //make sure that input is not bigger than supported by uint32_t + check_fits_in_unsigned(num_irows, "num_irows"); + check_fits_in_unsigned(num_orows, "num_orows"); + check_fits_in_unsigned(row_size, "row_size"); + + + dim3 threads(std::min(512, int(num_irows))); + int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int64_t{threads.x}))); + tensor_kernel_scan_outer_dim_with_indices<<>>( + self.const_data_ptr(), values.mutable_data_ptr(), indices.mutable_data_ptr(), + num_orows, num_irows, row_size, init, binary_op); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +__host__ void scan_innermost_dim_with_indices( + const TensorBase& self, const TensorBase& values, const TensorBase& indices, + scalar_t init, BinaryFunction binary_op) { + int ndim = self.dim(); + // Treat all outer dimensions as a single dimension. + int row_size = self.size(ndim - 1); + int num_rows = self.numel() / row_size; + + // assuming max_num_threads per block is 512 + const uint32_t num_threads = 512; + const uint32_t log_num_threads_x = get_log_num_threads_x_inner_scan(num_rows, row_size); + const uint32_t num_threads_x = (1 << log_num_threads_x); + const uint32_t num_threads_y = num_threads / num_threads_x; + dim3 threads(num_threads_x, num_threads_y); + dim3 grid(std::min(at::cuda::getCurrentDeviceProperties()->maxGridSize[0], ceil_div(num_rows, int(threads.y)))); + + const uint32_t mem_size = 2 * num_threads * (sizeof(scalar_t) + sizeof(int64_t)); + tensor_kernel_scan_innermost_dim_with_indices<<>>( + self.const_data_ptr(), values.mutable_data_ptr(), indices.mutable_data_ptr(), + num_rows, row_size, num_threads, log_num_threads_x, init, binary_op); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void scan_dim_with_indices(const TensorBase& self, const TensorBase& values, const TensorBase& indices, //int64_t dim) { + int64_t dim, scalar_t init, BinaryFunction binary_op) { + int ndim = self.dim(); + auto self_ = self.expect_contiguous(); + TORCH_INTERNAL_ASSERT(values.is_contiguous() && indices.is_contiguous()); + if (dim == ndim - 1) { + scan_innermost_dim_with_indices(*self_, values, indices, init, binary_op); + } else { + scan_outer_dim_with_indices(*self_, values, indices, dim, init, binary_op); + } +} + +// TODO: The implementation of `tensor_kernel_scan_outer_dim` and +// `tensor_kernel_scan_innermost_dim` is similar to +// `tensor_kernel_scan_outer_dim_with_indices` +// `tensor_kernel_scan_outer_dim_with_indices` and should be refactored to +// remove the duplication. + +/* Perform an inclusive scan along an outer dimension of a tensor. + * + * - num_orows is the size of the flattened outer dimensions; + * - num_irows is the size of the flattened inner dimensions; + * - row_size is the size of the dimension along which to scan; + * + * The dimensions to the outside and inside of the specified dimension are considered as flattened. + * Thread blocks with the same blockIdx.y process an "outer row" (i.e. an element of the flattened + * outer dimensions, which contains several "inner rows"). + * Each thread processes a single inner row at a time. + */ +template +__global__ void tensor_kernel_scan_outer_dim(scalar_t *tgt_, const scalar_t *src_, + const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size, + const scalar_t init, BinaryOp binary_op) +{ + for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) { + for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) { + const scalar_t *src = src_ + orow * row_size * num_irows + irow; + scalar_t *tgt = tgt_ + orow * row_size * num_irows + irow; + scalar_t acc = init; + + for (uint32_t col = 0; col < row_size; ++col) { + acc = binary_op(acc, c10::load(src)); + *tgt = acc; + + src += num_irows; + tgt += num_irows; + } + } + } +} + +/* Perform an inclusive scan along the innermost dimension of a tensor. + * + * - num_rows is the size of the flattened outer dimensions; + * - row_size is the size of the innermost dimension; + * + * The outer dimensions of the tensor are considered as a single dimension, i.e. the tensor is + * considered as having 'num_rows' rows of size 'row_size'. + * Each thread block processes one or more sets of contiguous rows (processing multiple rows + * per thread block is quicker than processing a single row, especially for short rows). + */ +template +__device__ void tensor_kernel_scan_innermost_dim_impl(T* row_buf, T *tgt_, const T *src_, + const uint32_t num_rows, const uint32_t row_size, + const uint32_t log_num_threads_x, + T init, BinaryFunction binary_op){ + const index_t num_threads_x = 1 << log_num_threads_x; + for (index_t block_row = blockIdx.x * (index_t) blockDim.y; + block_row < num_rows; + block_row += blockDim.y * gridDim.x) { + index_t row = block_row + (index_t) threadIdx.y; + T block_total = init; + + const T *row_src = src_ + row * row_size; + T *row_tgt = tgt_ + row * row_size; + const bool row_exists = row < num_rows; + + // Perform scan on one block at a time, keeping track of the total value of + // all blocks processed so far. + for (index_t block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) { + // Load data into shared memory (two values per thread). + index_t col1 = block_col + (index_t) threadIdx.x; + index_t col2 = block_col + num_threads_x + (index_t) threadIdx.x; + if (row_exists) { + if (col1 < row_size) { + row_buf[threadIdx.x] = row_src[col1]; + } else { + row_buf[threadIdx.x] = init; + } + + if (col2 < row_size) { + row_buf[num_threads_x + threadIdx.x] = row_src[col2]; + } else { + row_buf[num_threads_x + threadIdx.x] = init; + } + + // Add the total value of all previous blocks to the first value of this block. + if (threadIdx.x == 0) { + row_buf[0] = binary_op(row_buf[0], block_total); + } + } + __syncthreads(); + + // Parallel reduction with Sklansky method. The diagram can be seen on this paper: + // https://research.nvidia.com/publication/single-pass-parallel-prefix-scan-decoupled-look-back + for (int m = 0; m <= log_num_threads_x; ++m) { + if (row_exists) { + index_t s = 1 << m; // s = 2 ^ m + auto a = static_cast((threadIdx.x >> m) << (m + 1)) | s; // a = (threadIdx.x / s) * (2 * s) + s + index_t ti = a + (threadIdx.x % s); + index_t si = a - 1; + row_buf[ti] = binary_op(row_buf[ti], row_buf[si]); + } + __syncthreads(); + } + + // Write back to output. + if (row_exists) { + if (col1 < row_size) row_tgt[col1] = row_buf[threadIdx.x]; + if (col2 < row_size) row_tgt[col2] = row_buf[num_threads_x + threadIdx.x]; + } + block_total = row_buf[2 * num_threads_x - 1]; + __syncthreads(); + } + } +} + +template < + typename T, + class BinaryFunction> +__global__ void tensor_kernel_scan_innermost_dim( + T* tgt_, + const T* src_, + const uint32_t num_rows, + const uint32_t row_size, + const uint32_t log_num_threads_x, + T init, + BinaryFunction binary_op) { + alignas(sizeof(double)) extern __shared__ char sbuf[]; + T* sbuf2 = reinterpret_cast(sbuf); + const uint32_t num_threads_x = 1 << log_num_threads_x; + T* row_buf = reinterpret_cast(sbuf2 + num_threads_x * 2 * threadIdx.y); + if (num_rows * (size_t) row_size <= UINT_MAX) { + tensor_kernel_scan_innermost_dim_impl( + row_buf, tgt_, src_, num_rows, row_size, log_num_threads_x, init, binary_op); + } else { + tensor_kernel_scan_innermost_dim_impl( + row_buf, tgt_, src_, num_rows, row_size, log_num_threads_x, init, binary_op); + } +} + + +template +__host__ void scan_outer_dim(const TensorBase& self, const TensorBase& result, + int dim, scalar_t init, BinaryFunction binary_op) { + const int64_t row_size = self.size(dim); + auto sizes = self.sizes(); + + // Treat all outer dimensions (i.e. dim_ < dim) as one. + const int64_t num_orows = c10::multiply_integers(sizes.begin(), sizes.begin() + dim); + + // Treat all inner dimensions (i.e. dim > dimension) as one. + const int64_t num_irows = c10::multiply_integers(sizes.begin() + dim + 1, sizes.end()); + + dim3 threads(std::min(512, int(num_irows))); + int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int64_t{threads.x}))); + + check_fits_in_unsigned(num_irows, "num_irows"); + check_fits_in_unsigned(num_orows, "num_orows"); + check_fits_in_unsigned(row_size, "row_size"); + + tensor_kernel_scan_outer_dim<<>>( + result.mutable_data_ptr(), self.const_data_ptr(), + num_orows, num_irows, row_size, init, binary_op); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void scan_innermost_dim(const TensorBase& self, const TensorBase& result, + scalar_t init, BinaryFunction binary_op) { + int64_t ndim = self.dim(); + // Treat all outer dimensions as a single dimension. + int64_t row_size = self.size(ndim - 1); + int64_t num_rows = self.numel() / row_size; + + // assuming max_num_threads per block is 512 + const uint32_t num_threads = 512; + const uint32_t log_num_threads_x = get_log_num_threads_x_inner_scan(num_rows, row_size); + const uint32_t num_threads_x = (1 << log_num_threads_x); + const uint32_t num_threads_y = num_threads / num_threads_x; + dim3 threads(num_threads_x, num_threads_y); + int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[0]; + dim3 grid(std::min(maxGridDim, ceil_div(num_rows, int64_t{threads.y}))); + + check_fits_in_unsigned(num_rows, "Number of rows (self.numel()/self.size(self.dim()-1))"); + check_fits_in_unsigned(row_size, "row_size"); + + tensor_kernel_scan_innermost_dim<<>>( + result.mutable_data_ptr(), self.const_data_ptr(), + num_rows, row_size, log_num_threads_x, init, binary_op); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void scan_dim(const TensorBase& self, const TensorBase& result, + int64_t dim, scalar_t init, BinaryFunction binary_op) { + int ndim = self.dim(); + auto self_ = self.expect_contiguous(); + TORCH_INTERNAL_ASSERT(result.is_contiguous()); + + if (self.numel() == self.size(dim)) { + if constexpr (std::is_same_v>) { + if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms()) && (self.is_floating_point() || self.is_complex())) { +#if defined(CUDA_VERSION) || defined(USE_ROCM) + cuda::cub::inclusive_deterministic_scan(self_->const_data_ptr(), result.mutable_data_ptr(), binary_op, self.numel()); +#else + globalContext().alertNotDeterministic("cumsum_cuda_kernel"); + cuda::cub::inclusive_scan(self_->const_data_ptr(), result.mutable_data_ptr(), binary_op, self.numel()); +#endif + } else { + cuda::cub::inclusive_scan(self_->const_data_ptr(), result.mutable_data_ptr(), binary_op, self.numel()); + } + } else { + cuda::cub::inclusive_scan(self_->const_data_ptr(), result.mutable_data_ptr(), binary_op, self.numel()); + } + } else if (dim == ndim - 1) { + scan_innermost_dim(*self_, result, init, binary_op); + } else { + scan_outer_dim(*self_, result, dim, init, binary_op); + } +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/Sort.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/Sort.h new file mode 100644 index 0000000000000000000000000000000000000000..7b6927f25d012d0ab4d34e49f6074ff9863a1aae --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/Sort.h @@ -0,0 +1,17 @@ +#pragma once +#include +#include +#include + + +namespace at::native { + +inline bool should_use_small_sort(const TensorBase &self, int64_t dim) { + return self.size(dim) <= 4096; +} + +void sortKeyValueInplace( + const TensorBase &key, const TensorBase &value, int64_t dim, + bool descending, bool stable=false); + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/SortStable.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/SortStable.h new file mode 100644 index 0000000000000000000000000000000000000000..b58c4f2463794ee441968abfbb763bc870aba788 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/SortStable.h @@ -0,0 +1,17 @@ +#pragma once +#include +#include + +namespace at::native { + +// Stable-sort self into values, and set indices to the +// inverse-permutation from values back to self. +// Output tensors must be pre-allocated and contiguous. +void launch_stable_sort_kernel( + const TensorBase& self, + int64_t dim, + bool descending, + const TensorBase& values, + const TensorBase& indices); + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/SortUtils.cuh b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/SortUtils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..02973ad7e52ad1ca61505cebcf6d5eeff37e4a9e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/SortUtils.cuh @@ -0,0 +1,343 @@ +#pragma once +#include + +#include +#include +#include +#include +#include +#include +#include + +#define HAS_WARP_MERGE_SORT() (CUDA_VERSION >= 110600) + + +namespace at::native { + +template +__device__ inline void swapVars(T& t1, T& t2) { + T tmp = t1; + t1 = t2; + t2 = tmp; +} + +template +__device__ inline void bitonicSwap(K& kA, V& vA, bool& validA, + K& kB, V& vB, bool& validB, + bool dir, + const Comparator& comp) { + // Invalid entries always sort to the end + bool swap = (comp(kA, kB) && validA) || !validB; + if (swap == dir) { + swapVars(kA, kB); + swapVars(vA, vB); + swapVars(validA, validB); + } +}; + +template +__device__ inline void bitonicSort(K *keys, + V *values, + bool *valid, + const Comparator& comp) { +#if !defined(USE_ROCM) +#pragma unroll +#endif + for (unsigned int size = 2; size < Power2SortSize; size *= 2) { + bool flag = ((threadIdx.x & (size / 2)) != 0); + +#if !defined(USE_ROCM) +#pragma unroll +#endif + for (unsigned int stride = size / 2; stride > 0; stride /= 2) { + + __syncthreads(); + + unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); + bitonicSwap( + keys[pos], values[pos], valid[pos], + keys[pos + stride], values[pos + stride], valid[pos + stride], + flag, comp); + } + } + +#if !defined(USE_ROCM) +#pragma unroll +#endif + for (unsigned int stride = Power2SortSize / 2; stride > 0; stride /= 2) { + + __syncthreads(); + + unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); + bitonicSwap( + keys[pos], values[pos], valid[pos], + keys[pos + stride], values[pos + stride], valid[pos + stride], + false, comp); + } + + __syncthreads(); + +} + +// at::cuda::detail::TensorInfo version +// Sorts (key, value) pairs (in different tensors) in-place; i.e., +// modifies the input `keys` and `values` +template +C10_LAUNCH_BOUNDS_1(block_dim_x * max_block_dim_y) +__global__ void +bitonicSortKVInPlace(at::cuda::detail::TensorInfo keys, + IndexType keySlices, + IndexType keySliceSize, + IndexType keySliceStride, + at::cuda::detail::TensorInfo values, + IndexType valueSliceStride, + Comparator comp) { + // Find the slice of the tensor that we are sorting + // NOTE: blockDim.y may be less max_block_dim_y + const IndexType blockIndex = getLinearBlockId(); + const IndexType linearIndex = blockIndex * blockDim.y + threadIdx.y; + + // If the entire block is out of bounds exit early + if (blockIndex * blockDim.y >= keySlices) { + return; + } + // It's also possible for some rows of a block to be out of bounds + // but all thread need to run for __syncthreads to work. + const bool row_valid = linearIndex < keySlices; + + constexpr int items_per_thread = 2; + constexpr int Power2SortSize = block_dim_x * items_per_thread; + + // Storage for max_block_dim_y sorts performed in parallel + __shared__ K blockSharedKeys[max_block_dim_y][Power2SortSize]; + __shared__ V blockSharedValues[max_block_dim_y][Power2SortSize]; + __shared__ bool blockSharedValid[max_block_dim_y][Power2SortSize]; + + auto sharedKeys = blockSharedKeys[threadIdx.y]; + auto sharedValues = blockSharedValues[threadIdx.y]; + auto sharedValid = blockSharedValid[threadIdx.y]; + + const IndexType keyStartOffset = + at::cuda::detail::IndexToOffset::get(linearIndex, keys); + const IndexType valueStartOffset = + at::cuda::detail::IndexToOffset::get(linearIndex, values); + + // Load 2 values per thread into the shared workspace + #pragma unroll + for (int k = 0; k < items_per_thread; ++k) { + auto idx = threadIdx.x + k * blockDim.x; + bool valid = row_valid && idx < keySliceSize; + + sharedKeys[idx] = valid ? + keys.data[idx * keySliceStride + keyStartOffset] : K{}; + sharedValues[idx] = valid ? + values.data[idx * valueSliceStride + valueStartOffset] : V{}; + sharedValid[idx] = valid; + } + + // Sort! + bitonicSort( + sharedKeys, sharedValues, sharedValid, comp); + + if (!row_valid) { + return; + } + + // Store outputs + #pragma unroll + for (int k = 0; k < items_per_thread; ++k) { + auto idx = threadIdx.x + k * blockDim.x; + if (idx < keySliceSize) { + keys.data[idx * keySliceStride + keyStartOffset] = sharedKeys[idx]; + values.data[idx * valueSliceStride + valueStartOffset] = sharedValues[idx]; + } + } +} + +#if HAS_WARP_MERGE_SORT() + +template +C10_LAUNCH_BOUNDS_1(C10_WARP_SIZE * max_block_dim_y) +__global__ void +warpMergeSortKVInPlace( + at::cuda::detail::TensorInfo keys, + IndexType keySlices, + IndexType keySliceSize, + IndexType keySliceStride, + at::cuda::detail::TensorInfo values, + IndexType valueSliceStride, + Comparator comp, + K invalid_key) { + // Find the slice of the tensor that we are sorting + // NOTE: blockDim.y may be less max_block_dim_y + const IndexType blockIndex = getLinearBlockId(); + const IndexType linearIndex = blockIndex * blockDim.y + threadIdx.y; + + // If this row is out of bounds exit early + if (linearIndex >= keySlices) { + return; + } + + const IndexType keyStartOffset = + at::cuda::detail::IndexToOffset::get(linearIndex, keys); + const IndexType valueStartOffset = + at::cuda::detail::IndexToOffset::get(linearIndex, values); + + K *keys_slice = &keys.data[keyStartOffset]; + V *values_slice = &values.data[valueStartOffset]; + + StridedRandomAccessor keys_iter(keys_slice, keySliceStride); + StridedRandomAccessor values_iter(values_slice, valueSliceStride); + + namespace cub = ROCM_HIPCUB(at_cuda_detail::cub); + + CUDA_KERNEL_ASSERT(blockDim.x == C10_WARP_SIZE); + CUDA_KERNEL_ASSERT(blockDim.y <= max_block_dim_y); + constexpr int items_per_thread = sort_size / C10_WARP_SIZE; + static_assert( + items_per_thread * C10_WARP_SIZE == sort_size, + "sort_size must be a multiple of C10_WARP_SIZE"); + + + using LoadKeys = cub::WarpLoad; + using LoadValues = cub::WarpLoad; + using Sort = cub::WarpMergeSort; + using StoreKeys = cub::WarpStore; + using StoreValues = cub::WarpStore; + + __shared__ union { + typename LoadKeys::TempStorage load_keys; + typename LoadValues::TempStorage load_values; + typename Sort::TempStorage sort; + typename StoreKeys::TempStorage store_keys; + typename StoreValues::TempStorage store_values; + } tmp_storage[max_block_dim_y]; + + auto& warp_storage = tmp_storage[threadIdx.y]; + + // Load inputs + K local_keys[items_per_thread]; + V local_values[items_per_thread]; + + const auto invalid_value = V{}; + LoadKeys(warp_storage.load_keys).Load(keys_iter, local_keys, keySliceSize, invalid_key); + WARP_SYNC(); + LoadValues(warp_storage.load_values).Load(values_iter, local_values, keySliceSize, invalid_value); + WARP_SYNC(); + + // Sort! We use stable sort to ensure that invalid values are never + // sorted before valid values. In testing it performed the same as + // .Sort, so there is no down-side. + Sort(warp_storage.sort).StableSort( + local_keys, local_values, comp, keySliceSize, invalid_key); + WARP_SYNC(); + + // Store outputs + StoreKeys(warp_storage.store_keys).Store(keys_iter, local_keys, keySliceSize); + WARP_SYNC(); + StoreValues(warp_storage.store_values).Store(values_iter, local_values, keySliceSize); +} + +#endif // HAS_WARP_MERGE_SORT() + +template +C10_LAUNCH_BOUNDS_1(block_size) +__global__ void +radixSortKVInPlace(at::cuda::detail::TensorInfo keys, + IndexType keySlices, + IndexType keySliceSize, + IndexType keySliceStride, + at::cuda::detail::TensorInfo values, + IndexType valueSliceStride, + bool descending) { + static_assert(block_size > 0, ""); + + // Find the slice of the tensor that we are sorting + const IndexType linearIndex = getLinearBlockId(); + // Tiling the slices could have us be out of bounds, if there are a + // lot of slices to sort + if (linearIndex >= keySlices) { + return; + } + + const IndexType keyStartOffset = + at::cuda::detail::IndexToOffset::get(linearIndex, keys); + const IndexType valueStartOffset = + at::cuda::detail::IndexToOffset::get(linearIndex, values); + + K *keys_slice = &keys.data[keyStartOffset]; + V *values_slice = &values.data[valueStartOffset]; + + StridedRandomAccessor keys_iter(keys_slice, keySliceStride); + StridedRandomAccessor values_iter(values_slice, valueSliceStride); + + namespace cub = ROCM_HIPCUB(at_cuda_detail::cub); + + using key_t = typename at::cuda::cub::detail::cuda_type::type; + using LoadKeys = cub::BlockLoad; + using LoadValues = cub::BlockLoad; + using Sort = cub::BlockRadixSort; + using StoreKeys = cub::BlockStore; + using StoreValues = cub::BlockStore; + + __shared__ union { + typename LoadKeys::TempStorage load_keys; + typename LoadValues::TempStorage load_values; + typename Sort::TempStorage sort; + typename StoreKeys::TempStorage store_keys; + typename StoreValues::TempStorage store_values; + } tmp_storage; + + // cub's Block operations operate on a fixed number of items, but the + // actual slice we are sorting might be smaller. So, we need to make + // up the difference with keys that will always sort higher. + const K invalid_key = [descending] { + using radix_t = typename cub::Traits::UnsignedBits; + union { + K key; + radix_t radix; + } tmp; + tmp.radix = descending ? + cub::Traits::LOWEST_KEY : + cub::Traits::MAX_KEY; + return tmp.key; + }(); + const V invalid_value = static_cast(0); + + // Load inputs + K local_keys[items_per_thread]; + V local_values[items_per_thread]; + + LoadKeys(tmp_storage.load_keys).Load(keys_iter, local_keys, keySliceSize, invalid_key); + __syncthreads(); + LoadValues(tmp_storage.load_values).Load(values_iter, local_values, keySliceSize, invalid_value); + __syncthreads(); + + // Sort! + if (descending) { + Sort(tmp_storage.sort).SortDescending( + reinterpret_cast(local_keys), + local_values); + } else { + Sort(tmp_storage.sort).Sort( + reinterpret_cast(local_keys), + local_values); + } + __syncthreads(); + + // Store outputs + StoreKeys(tmp_storage.store_keys).Store(keys_iter, local_keys, keySliceSize); + __syncthreads(); + StoreValues(tmp_storage.store_values).Store(values_iter, local_values, keySliceSize); +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/Sorting.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/Sorting.h new file mode 100644 index 0000000000000000000000000000000000000000..1bb95a87b5a6279b3041ee47d578a25eb7df3d86 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/Sorting.h @@ -0,0 +1,17 @@ +#pragma once +#include + +namespace at { +class TensorBase; +} + +namespace at::native { + +void launch_kthvalue_kernel( + const TensorBase &values, const TensorBase &indices, + const TensorBase &self, int64_t dim, int64_t k); +void launch_median_kernel( + const TensorBase &vals, const TensorBase &inds, + const TensorBase &in, int64_t dim, bool ignore_nan); + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/SortingCommon.cuh b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/SortingCommon.cuh new file mode 100644 index 0000000000000000000000000000000000000000..d053d0eb80ead941175a95a9a5835b0cc204745a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/SortingCommon.cuh @@ -0,0 +1,193 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { + +// Is this questionable namespace pollution? +#if defined(USE_ROCM) +constexpr int MAX_BLOCK_SIZE = 256; + +#else +constexpr int MAX_BLOCK_SIZE = 1024; +#endif + +// Maximum size per grid dimension that we assume (compute capability >= 2.0) +constexpr int64_t MAX_GRID_SIZE = 65535LL; + +inline bool getGridFromTiles(int64_t gridTiles, dim3& grid) { + if (gridTiles > MAX_GRID_SIZE * MAX_GRID_SIZE * MAX_GRID_SIZE) { + return false; + } + + int64_t gridX = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles; + int64_t gridY = 1; + int64_t gridZ = 1; + + if (gridTiles > MAX_GRID_SIZE) { + gridTiles = ceil_div(gridTiles, MAX_GRID_SIZE); + gridY = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles; + + if (gridTiles > MAX_GRID_SIZE) { + gridTiles = ceil_div(gridTiles, MAX_GRID_SIZE); + gridZ = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles; + } + } + + grid = dim3(gridX, gridY, gridZ); + return true; +} + +template +struct GTOp { + __device__ bool operator()(const scalar_t& lhs, const scalar_t& rhs) const { + return (handleNaN && at::_isnan(lhs) && !at::_isnan(rhs)) || + (static_cast(lhs) > static_cast(rhs)); + } +}; + +template +struct LTOp { + __device__ bool operator()(const scalar_t& lhs, const scalar_t& rhs) const { + return (handleNaN && at::_isnan(rhs) && !at::_isnan(lhs)) || + (static_cast(lhs) < static_cast(rhs)); + } +}; + +template +__device__ __forceinline__ index_t getLinearBlockId() { + return blockIdx.z * gridDim.y * gridDim.x + blockIdx.y * gridDim.x + + blockIdx.x; +} + +// For slice sorting in Thrust; extracts a slice index from a linear +// index and uses that for comparison +struct SliceComp { + SliceComp(int64_t size) : sliceSize(size) {} + + __device__ bool operator()(const int64_t& a, const int64_t& b) const { + // Since the slices are guaranteed to be innermost, + // the segment is just via int64_t division + int64_t segA = a / sliceSize; + int64_t segB = b / sliceSize; + return segA < segB; + } + + const int64_t sliceSize; +}; + +// For sorting in Thurst; extracts a within-slice index from a linear index +struct GlobalIndexToPerSliceIndex { + GlobalIndexToPerSliceIndex(int64_t size) : sliceSize(size) {} + + __device__ inline void operator()(int64_t& v) const { + v = v % sliceSize; + } + + const int64_t sliceSize; +}; + +// Returns 2^(ceil(lg(n)) from Stanford bit twiddling hacks +inline uint64_t nextHighestPowerOf2(uint64_t n) { + n--; + n |= n >> 1; + n |= n >> 2; + n |= n >> 4; + n |= n >> 8; + n |= n >> 16; +#ifndef _MSC_VER + n |= n >> 32; +#endif + n++; + + return n; +} + + +// WARNING: This function assumes input tensors are contiguous +template +void run_launcher( + const TensorBase &values, + const TensorBase &indices, + const TensorBase &self, + int64_t dim, + Launcher l) { + auto self_info = cuda::detail::getTensorInfo(self); + auto values_info = cuda::detail::getTensorInfo(values); + auto indices_info = cuda::detail::getTensorInfo(indices); + + int64_t slice_size = self.size(dim); + /* We use these structures solely to find the offset to */ + /* each slice we are operating on */ + self_info.reduceDim(dim); + values_info.reduceDim(dim); + indices_info.reduceDim(dim); + + /* Collapse all other dims */ + int collapse_self_dim = self_info.collapseDims(dim); + int collapse_values_dim = values_info.collapseDims(dim); + int collapse_indices_dim = indices_info.collapseDims(dim); + + int64_t num_slices = 1; + for (int i = 0; i < self_info.dims; ++i) { + num_slices *= self_info.sizes[i]; + } + + /* This is used as a template parameter to calculate indices. */ + /* We only specialize it if all collapsed dim sizes are the */ + /* same; otherwise, we use -1 which is the specialization */ + /* parameter for arbitrary dimensions */ + int all_dims = self_info.dims; + if (values_info.dims != all_dims || indices_info.dims != all_dims) { + all_dims = -1; + } + + if (all_dims == 1) { + l.template launch( + values_info, + collapse_values_dim, + indices_info, + collapse_indices_dim, + self_info, + collapse_self_dim, + num_slices, + slice_size); + } else if (all_dims == 2) { + l.template launch( + values_info, + collapse_values_dim, + indices_info, + collapse_indices_dim, + self_info, + collapse_self_dim, + num_slices, + slice_size); + } else if (all_dims == 3) { + l.template launch( + values_info, + collapse_values_dim, + indices_info, + collapse_indices_dim, + self_info, + collapse_self_dim, + num_slices, + slice_size); + } else { + l.template launch( + values_info, + collapse_values_dim, + indices_info, + collapse_indices_dim, + self_info, + collapse_self_dim, + num_slices, + slice_size); + } +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/SortingRadixSelect.cuh b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/SortingRadixSelect.cuh new file mode 100644 index 0000000000000000000000000000000000000000..503461c8b9fe6d14baddd73b381e7a015fda2304 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/SortingRadixSelect.cuh @@ -0,0 +1,427 @@ +#include +#include +#include +#include +#include + +namespace at::native { + +template +struct TopKTypeConfig {}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + // Converts a float to an integer representation with the same + // sorting; i.e., for floats f1, f2: + // if f1 < f2 then convert(f1) < convert(f2) + // We use this to enable radix selection of floating-point values. + // This also gives a relative order for NaNs, but that's ok, as they + // will all be adjacent + // neg inf: signbit=1 exp=ff fraction=0 --> radix = 0 00 ff.. + // pos inf: signbit=0 exp=ff fraction=0 --> radix = 1 ff 00.. + // pos nan: signbit=0 exp=ff fraction>0 --> radix = 1 ff x>0 + // neg nan: signbit=1 exp=ff fraction>0 --> radix = 0 00 x +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(uint8_t v) { + return v; + } + + static inline __device__ uint8_t deconvert(RadixType v) { + return v; + } +}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(int8_t v) { + return 128u + v; + } + + static inline __device__ int8_t deconvert(RadixType v) { + return v - 128; + } +}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(int16_t v) { + static_assert(sizeof(short) == 2, ""); + return 32768u + v; + } + + static inline __device__ int16_t deconvert(RadixType v) { + return v - 32768; + } +}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(int32_t v) { + static_assert(sizeof(int) == 4, ""); + return 2147483648u + v; + } + + static inline __device__ int32_t deconvert(RadixType v) { + return v - 2147483648u; + } +}; + +template <> +struct TopKTypeConfig { + typedef uint64_t RadixType; + + static inline __device__ RadixType convert(int64_t v) { + static_assert(sizeof(int64_t) == 8, ""); + return 9223372036854775808ull + v; + } + + static inline __device__ int64_t deconvert(RadixType v) { + return v - 9223372036854775808ull; + } +}; + +template <> +struct TopKTypeConfig { + typedef uint64_t RadixType; + + static inline __device__ RadixType convert(double v) { + RadixType x = __double_as_longlong(v); + RadixType mask = -((x >> 63)) | 0x8000000000000000; + return (v == v) ? (x ^ mask) : 0xffffffffffffffff; + } + + static inline __device__ double deconvert(RadixType v) { + RadixType mask = ((v >> 63) - 1) | 0x8000000000000000; + return __longlong_as_double(v ^ mask); + } +}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(at::Half v) { +#if defined(__CUDA_ARCH__) || defined(USE_ROCM) + RadixType x = __half_as_ushort(v); + RadixType mask = (x & 0x00008000) ? 0x0000ffff : 0x00008000; + return (v == v) ? (x ^ mask) : 0xffff; +#else + CUDA_KERNEL_ASSERT(false); + return 0u; +#endif + } + + static inline __device__ at::Half deconvert(RadixType v) { +#if defined(__CUDA_ARCH__) || defined(USE_ROCM) + RadixType mask = (v & 0x00008000) ? 0x00008000 : 0x0000ffff; + return __ushort_as_half(v ^ mask); +#else + CUDA_KERNEL_ASSERT(false); + return static_cast(0); +#endif + } +}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(at::BFloat16 v) { + RadixType x = v.x; + RadixType mask = (x & 0x00008000) ? 0x0000ffff : 0x00008000; + return (v == v) ? (x ^ mask) : 0xffff; + } + + static inline __device__ at::BFloat16 deconvert(RadixType v) { + RadixType mask = (v & 0x00008000) ? 0x00008000 : 0x0000ffff; + at::BFloat16 r; + r.x = (v ^ mask); + return r; + } +}; + +// This function counts the distribution of all input values in a +// slice we are selecting by radix digit at `radixDigitPos`, but only +// those that pass the filter `((v & desiredMask) == desired)`. +// This produces and broadcasts the seen counts for a single block only. +// `smem` must have at least `RadixSize` elements. +template < + typename scalar_t, + typename bitwise_t, + typename index_t, + typename CountType, + int RadixSize, + int RadixBits> +__device__ void countRadixUsingMask( + CountType counts[RadixSize], + CountType* smem, + bitwise_t desired, + bitwise_t desiredMask, + int radixDigitPos, + index_t sliceSize, + index_t withinSliceStride, + const scalar_t* data) { + // Clear out per-thread counts from a previous round +#pragma unroll + for (int i = 0; i < RadixSize; ++i) { + counts[i] = 0; + } + + if (threadIdx.x < RadixSize) { + smem[threadIdx.x] = 0; + } + __syncthreads(); + + // Scan over all the data. Upon a read, the warp will accumulate + // counts per each digit in the radix using warp voting. +#if !defined(USE_ROCM) + // Must be called outside of loop to ensure all threads participate + unsigned mask = WARP_BALLOT(threadIdx.x < sliceSize); +#endif + for (index_t i = threadIdx.x; i < sliceSize;) { + bitwise_t val = + TopKTypeConfig::convert(doLdg(&data[i * withinSliceStride])); + + bool hasVal = ((val & desiredMask) == desired); + bitwise_t digitInRadix = at::cuda::Bitfield::getBitfield( + val, radixDigitPos, RadixBits); + +#pragma unroll + for (uint32_t j = 0; j < RadixSize; ++j) { + bool vote = hasVal && (digitInRadix == j); +#if defined(USE_ROCM) + counts[j] += __popcll(WARP_BALLOT(vote)); +#else + counts[j] += __popc(WARP_BALLOT(vote, mask)); +#endif + } + i += blockDim.x; +#if !defined(USE_ROCM) + mask = WARP_BALLOT(i < sliceSize, mask); +#endif + } + + // Now, for each warp, sum values + if (at::cuda::getLaneId() == 0) { +#pragma unroll + for (uint32_t i = 0; i < RadixSize; ++i) { + gpuAtomicAddNoReturn(&smem[i], counts[i]); + } + } + + __syncthreads(); + + // For each thread, read in the total counts +#pragma unroll + for (uint32_t i = 0; i < RadixSize; ++i) { + counts[i] = smem[i]; + } + + __syncthreads(); +} + +// Over what radix we are selecting values +constexpr int RADIX_BITS = 2; // digits are base-(2 ^ RADIX_BITS) +constexpr int RADIX_SIZE = 4; // 2 ^ RADIX_BITS +constexpr int RADIX_MASK = (RADIX_SIZE - 1); + +// This finds the unique value `v` that matches the pattern +// ((v & desired) == desiredMask) in our sorted int format +template +__device__ scalar_t findPattern( + scalar_t* smem, + const scalar_t* data, + index_t sliceSize, + index_t withinSliceStride, + bitwise_t desired, + bitwise_t desiredMask) { + if (threadIdx.x < 2) { + smem[threadIdx.x] = static_cast(0); + } + __syncthreads(); + + // All threads participate in the loop, in order to sync on the flag + index_t numIterations = + round_up(sliceSize, static_cast(blockDim.x)); + for (index_t i = threadIdx.x; i < numIterations; i += blockDim.x) { + bool inRange = (i < sliceSize); + scalar_t v = inRange ? doLdg(&data[i * withinSliceStride]) + : static_cast(0); + + if (inRange && + ((TopKTypeConfig::convert(v) & desiredMask) == desired)) { + // There should not be conflicts if we are using findPattern, + // since the result is unique + smem[0] = static_cast(1); + smem[1] = v; // can't use val as the flag, since it could be 0 + } + + __syncthreads(); + + scalar_t found = smem[0]; + scalar_t val = smem[1]; + + __syncthreads(); + + // Check to see if a thread found the value + if (found != static_cast(0)) { + // all threads return this value + return val; + } + } + + // should not get here + CUDA_KERNEL_ASSERT(false); + return static_cast(0); +} + +// Returns the top-Kth element found in the data using radix selection +template +__device__ void radixSelect( + const scalar_t* data, + index_t k, + bool largest, + index_t sliceSize, + index_t withinSliceStride, + int* smem, + scalar_t* topK) { + // Per-thread buckets into which we accumulate digit counts in our + // radix + int counts[RADIX_SIZE]; + + // We only consider elements x such that (x & desiredMask) == desired + // Initially, we consider all elements of the array, so the above + // statement is true regardless of input. + bitwise_t desired = 0; + bitwise_t desiredMask = 0; + + // We are looking for the top kToFind-th element when iterating over + // digits; this count gets reduced by elimination when counting + // successive digits + int kToFind = k; + + // We start at the most significant digit in our radix, scanning + // through to the least significant digit + for (int digitPos = sizeof(scalar_t) * 8 - RADIX_BITS; digitPos >= 0; + digitPos -= RADIX_BITS) { + // Count radix distribution for the current position and reduce + // across all threads + countRadixUsingMask< + scalar_t, + bitwise_t, + index_t, + int, + RADIX_SIZE, + RADIX_BITS>( + counts, + smem, + desired, + desiredMask, + digitPos, + sliceSize, + withinSliceStride, + data); + + auto found_unique = [&](int i, int count) -> bool { + /* All threads have the same value in counts here, so all */ + /* threads will return from the function. */ + if (count == 1 && kToFind == 1) { + /* There is a unique answer. */ + desired = at::cuda::Bitfield::setBitfield( + desired, i, digitPos, RADIX_BITS); + desiredMask = at::cuda::Bitfield::setBitfield( + desiredMask, RADIX_MASK, digitPos, RADIX_BITS); + + /* The answer is now the unique element v such that: */ + /* (v & desiredMask) == desired */ + /* However, we do not yet know what the actual element is. We */ + /* need to perform a search through the data to find the */ + /* element that matches this pattern. */ + *topK = findPattern( + (scalar_t*)smem, + data, + sliceSize, + withinSliceStride, + desired, + desiredMask); + return true; + } + return false; + }; + auto found_non_unique = [&](int i, int count) -> bool { + if (count >= kToFind) { + desired = + at::cuda::Bitfield::setBitfield( + desired, i, digitPos, RADIX_BITS); + desiredMask = at::cuda::Bitfield::setBitfield( + desiredMask, RADIX_MASK, digitPos, RADIX_BITS); + + /* The top-Kth element v must now be one such that: */ + /* (v & desiredMask == desired) */ + /* but we haven't narrowed it down; we must check the next */ + /* least-significant digit */ + return true; + } + kToFind -= count; + return false; // continue the loop + }; + + // All threads participate in the comparisons below to know the + // final result + if (largest) { + // Process in descending order +#pragma unroll + for (int i = RADIX_SIZE - 1; i >= 0; --i) { + int count = counts[i]; + if (found_unique(i, count)) { + return; + } + if (found_non_unique(i, count)) { + break; + } + } + } else { + // Process in ascending order +#pragma unroll + for (int i = 0; i < RADIX_SIZE; ++i) { + int count = counts[i]; + if (found_unique(i, count)) { + return; + } + if (found_non_unique(i, count)) { + break; + } + } + } + } // end digitPos for + + // There is no unique result, but there is a non-unique result + // matching `desired` exactly + *topK = TopKTypeConfig::deconvert(desired); +} +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/TensorModeKernel.cuh b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/TensorModeKernel.cuh new file mode 100644 index 0000000000000000000000000000000000000000..ec9a87524910f0292b2f0a790f03877bc798ceed --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/TensorModeKernel.cuh @@ -0,0 +1,431 @@ +#pragma once + +#include +#include +#include +#include + +namespace at::native { + +// Used for a segmented reduction +struct ModeUnsignedBoolPair { + unsigned int val; + bool flag; +}; + +// In the kernel below, we have a common pattern of reducing (unsigned int, +// unsigned int) pairs of data +struct ModeUnsignedPair { + unsigned int val; + unsigned int index; +}; + +// Inclusive Scan via an upsweep/downsweep mechanism. Assumes: +// +// 1. Power2ScanSize is a power of 2. This code still works for collections that +// do not exactly contain a power of 2 number of elements, simply round up to +// the nearest power of 2 and then call. +// +// 2. That there are two-elements per thread, i.e. the size of the smem storage +// is 2 * blockDim.x * sizeof(T). +// +// Consider a (+)-Scan on the following elements: +// +// Upsweep: +// +// 0 1 2 3 4 5 6 7 +// 1 5 9 13 +// 6 22 +// 28 +// +// Downsweep: +// 15 +// 3 10 21 +template +__device__ void inclusivePrefixScan(T* smem, BinaryOp binop) { + // Reduce step ("upsweep") +#pragma unroll + for (int stride = 1; stride < Power2ScanSize; stride <<= 1) { + int index = (threadIdx.x + 1) * stride * 2 - 1; + if (index < Power2ScanSize) { + smem[index] = binop(smem[index], smem[index - stride]); + } + __syncthreads(); + } + + // Post-reduce step ("downsweep") +#pragma unroll + for (int stride = Power2ScanSize / 4; stride > 0; stride >>= 1) { + int index = (threadIdx.x + 1) * stride * 2 - 1; + if ((index + stride) < Power2ScanSize) { + smem[index + stride] = binop(smem[index + stride], smem[index]); + } + __syncthreads(); + } +} + +// Block-wide reduction where each thread locally reduces N +// values before letting a single warp take over - assumes +// threadVals is in registers, not shared memory +// +// If smem is not used again, there is no need to __syncthreads before this +// call. However, if smem will be used, e.g., this function is called in a loop, +// then __syncthreads is needed either before or afterwards to prevent non-0 +// threads overriding smem in the next loop before num-0 thread reads from it. +template +__device__ T reduceBlockWithNThreadLocalReductions( + T* smem, + T threadVals[N], + const unsigned int numVals, + ReduceOp reduceOp, + T init) { + int offset = threadIdx.x * N; + T local = offset < numVals ? threadVals[0] : init; + +#pragma unroll + for (int i = 1; i < N; ++i) { + ++offset; + T next = offset < numVals ? threadVals[i] : init; + local = reduceOp.combine(local, next); + } + + return cuda_utils::BlockReduce(local, reduceOp, init, smem); +} + +template +__device__ inline void swapVars(T& t1, T& t2) { + T tmp = t1; + t1 = t2; + t2 = tmp; +} + +template +__device__ inline void bitonicSwap( + K& kA, + V& vA, + bool& validA, + K& kB, + V& vB, + bool& validB, + bool dir, + const Comparator& comp) { + // Invalid entries always sort to the end + bool swap = (comp(kA, kB) && validA) || !validB; + if (swap == dir) { + swapVars(kA, kB); + swapVars(vA, vB); + swapVars(validA, validB); + } +}; + +template +__device__ inline void bitonicSwapKeys( + K& kA, + bool& validA, + K& kB, + bool& validB, + bool dir, + const Comparator& comp) { + bool swap = (comp(kA, kB) && validA) || !validB; + if (swap == dir) { + swapVars(kA, kB); + swapVars(validA, validB); + } +} + +template < + typename K, + typename IndexType, + int Power2SortSize, + typename Comparator> +__device__ inline void bitonicSortKeys( + K keys[Power2SortSize], + bool valid[Power2SortSize], + const Comparator& comp) { +#if !defined(USE_ROCM) +#pragma unroll +#endif + for (unsigned int size = 2; size < Power2SortSize; size *= 2) { + bool flag = ((threadIdx.x & (size / 2)) != 0); + +#if !defined(USE_ROCM) +#pragma unroll +#endif + for (unsigned int stride = size / 2; stride > 0; stride /= 2) { + __syncthreads(); + + unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); + bitonicSwapKeys( + keys[pos], + valid[pos], + keys[pos + stride], + valid[pos + stride], + flag, + comp); + } + } + +#if !defined(USE_ROCM) +#pragma unroll +#endif + for (unsigned int stride = Power2SortSize / 2; stride > 0; stride /= 2) { + __syncthreads(); + + unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); + bitonicSwapKeys( + keys[pos], + valid[pos], + keys[pos + stride], + valid[pos + stride], + false, + comp); + } + + __syncthreads(); +} + +// The mode kernel has the following characteristics: It uses internal shared +// memory buffers of Power2Size, which must be greater than the number of +// elements. Additionally, there is one block for every slice to calculate the +// mode for, and in each block there is one thread for every two elements. +// +// Both sorted and positions are assumed to be contiguous Tensors with the mode +// dimension as the innermost dim, such that we can get the particular slice for +// a Tensor via its linear block dimension * the slice size. +template +__launch_bounds__(1024, 1) +__global__ void compute_mode( + const T* input, + at::cuda::detail::TensorInfo values, + at::cuda::detail::TensorInfo indices, + int64_t sliceSize, + int64_t slices) { + int tidx = threadIdx.x; + int stidx = blockDim.x + threadIdx.x; // Second index this thread responsible for + + // First, we need to calculate the offset into the sorted Tensor that + // represents the start of the slice for this block to calculate the mode for. + // This offset is a combination of the gridIndices, and the number of elements + // in the slice. + unsigned int blockId = getLinearBlockId(); + unsigned int linearOffset = blockId * sliceSize; + + if (blockId >= slices) { + return; + } + + // shmem is a dynamically sized buffer we will use throughout the kernel to + // handle computation efficiently. The size of this shmem must be + // sizeof(T) * Power2Size + (2 * sizeof(unsigned int) * Power2Size) + // + // Initially, the buffer will be organized as follows: + // + // [smem (slice elements) | bmem (valid indices) | ] + extern __shared__ char shmem[]; + + // smem represents a proportion of the shared memory buffer that is used to + // store the elements from the slice: + T* smem = reinterpret_cast(shmem); + + // Each thread loads up to two elements from the Tensor into shared memory + if (tidx < sliceSize) { + smem[tidx] = c10::load(&input[linearOffset + tidx]); + } + if (stidx < sliceSize) { + smem[stidx] = c10::load(&input[linearOffset + stidx]); + } + + // Next, we initialize a boolean region of the buffer, offset by the loaded + // element smem region + bool* bmem = reinterpret_cast(&smem[Power2Size]); + + // The first use of this region stores bmem[i] = i < sliceSize to mark the + // valid components in the smem buffer + bmem[tidx] = tidx < sliceSize; + bmem[stidx] = stidx < sliceSize; + __syncthreads(); // barrier for smem, bmem initialization + + // First, sort the input slice in ascending order. smem contains the input + // elements, and bmem marks the valid indices + bitonicSortKeys( + smem, bmem, [&] GPU_LAMBDA(const auto& a, const auto& b) { + return a < b; + }); + __syncthreads(); // make no assumptions that the sort syncs at end + + // The next step of our algorithm is performing a block-wide comparison of + // neighboring elements. In particular, given an sorted input slice A, we + // produce an output slice B, such that B[i] = 1 if A[i-i] != A[i], otherwise + // 0. + // + // Given the input A = [0, 0, 1, 1, 2, 2, 2, 4, 5, 6, 6, 7, 8] + // B = [1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1] + // + // In particular, we can think of B[i] true indicating the start of a sequence + // of equal values in the sorted list. Similarly, we will also store the + // negation of B, which we'll call C. In particular, we can think of C[i] = + // true iff A[i-1] == A[i] in our original sorted slice. + // + // C = [0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0] + + // We overwrite bmem, and treat the rest of shared memory as a buffer of + // (index, flag) pairs where the index represents values from C, and the flag + // represents values from B. + // + // [smem (sorted slice) | ubpmem (index, flag pairs)] + + struct ModeUnsignedBoolPair* ubpmem = + reinterpret_cast(&smem[Power2Size]); + + if (tidx == 0) { + ubpmem[0].flag = true; + ubpmem[0].val = 0; + } + + // Compares elements (0, 1), (2, 3), ... and sets 1, 3, ... + ubpmem[tidx * 2 + 1].flag = + smem[tidx * 2] != smem[tidx * 2 + 1]; // (0, 1), (1, 2), etc. + ubpmem[tidx * 2 + 1].val = !ubpmem[tidx * 2 + 1].flag; + + // Compares elements (1, 2), (3, 4), ... and sets 2, 4, ... + if (((tidx + 1) * 2) < Power2Size) { + ubpmem[(tidx + 1) * 2].flag = + smem[((tidx + 1) * 2) - 1] != smem[(tidx + 1) * 2]; + ubpmem[(tidx + 1) * 2].val = !ubpmem[(tidx + 1) * 2].flag; + } + __syncthreads(); // barrier for ubpmem initialization + + // Next, we perform a segmented prefix sum on the neighboring elements, where + // the presence of a one indicates the start of a segment. In this case B acts + // as the segment start flags, and C is the buffer to be summed: + // + // Input (C) = [0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0] + // Flag (B) = [1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1] + // Output (C) = [0, 1, 0, 1, 0, 1, 2, 0, 0, 0, 1, 0, 0] + // + // Afterwards, the (index) components of the ubpmem buffer contain the lengths + // of the segments (minus 1), i.e. the counts of each element in the original + // input. + inclusivePrefixScan( + ubpmem, [=] GPU_LAMBDA(const auto& a, const auto& b) { + ModeUnsignedBoolPair c; + c.val = a.flag ? a.val : a.val + b.val; + c.flag = a.flag | b.flag; + return c; + }); + // assumes scan syncs at the end + + // Next, we reinterpret the ubpmem buffer as pairs of unsigned integers (i.e. + // we treat the boolean flag regions as integers). We initialize these to + // represent indices, and we'll call this buffer I + struct ModeUnsignedPair* uupmem = + reinterpret_cast(ubpmem); + + // At this point, we need to find the maximum element in lengths buffer C. + // This element will represent the count (-1) of the mode. Because of the + // way we have set up the problem, the index where this mode occurs will + // also be the location of the mode value in the sorted array, e.g. + // + // smem = [0, 0, 1, 1, 1, 2] + // C = [0, 1, 0, 1, 2, 0] + // I = [0, 1, 2, 3, 4, 5] + // ^ + // maximum value, also aligned with mode = 1 + // + // We perform a block wide max-reduction of the C buffer, but we also need the + // indices to come along with it, so we utilize the uupmem construction. + // + // At the end we need to return the ModeUnsignedPair containing index = 4, val + // = 2, which represents the max + + // In practice, we will make each thread locally reduce 2 values in its + // registers prior to the global block-wide reduction. Note that instead of + // tidx/stidx, we utilize tidx * 2, tidx * 2 + 1, so each thread deals with + // adjacent elements. This is because the reduce code below relies on thread + // elements to be adjacent. + struct ModeUnsignedPair uup[2]; + uup[0].index = tidx * 2; + uup[0].val = ubpmem[tidx * 2].val; + uup[1].index = tidx * 2 + 1; + uup[1].val = ubpmem[tidx * 2 + 1].val; + __syncthreads(); + + struct ModeUnsignedPair max = {0, 0}; + + struct MaxOp { + inline __device__ ModeUnsignedPair combine(ModeUnsignedPair a, ModeUnsignedPair b) const { + return b.val > a.val ? b : a; + } + + inline __device__ ModeUnsignedPair warp_shfl_down(ModeUnsignedPair acc, int offset) const { + ModeUnsignedPair ret; + ret.index = WARP_SHFL_DOWN(acc.index, offset); + ret.val = WARP_SHFL_DOWN(acc.val, offset); + return ret; + } + } max_op; + + max = reduceBlockWithNThreadLocalReductions<2>( + uupmem, + uup, + sliceSize, + max_op, + max); + + // Store the mode in shared memory for use in finding the mode in the input + // slice + __shared__ T mode; + + // Given the above constraints, the mode is the value at the reduced index in + // the original sorted element buffer + if (tidx == 0) { + mode = smem[max.index]; + } + __syncthreads(); // broadcast mode + + // Finally, we need to find "an" index of the mode in the input + // Tensor. The API does not constrain which index we pick, but here + // we always pick the largest index. We store the index if the value + // is the mode, or 0 otherwise. Then find the maximum value. + // + // Again we reduce 2 elements in the thread's registers prior to the + // block-wide reduction + unsigned mode_index[2] = {0u, 0u}; + if (tidx * 2 < sliceSize) { + const unsigned idx = tidx * 2; + mode_index[0] = c10::load(&input[linearOffset + idx]) == mode ? idx : 0u; + } + if (tidx * 2 + 1 < sliceSize) { + const unsigned idx = tidx * 2 + 1; + mode_index[1] = c10::load(&input[linearOffset + idx]) == mode ? idx : 0u; + } + + struct MaxIndexOp { + inline __device__ unsigned combine(unsigned a, unsigned b) const { + return b > a ? b : a; + } + + inline __device__ unsigned warp_shfl_down(unsigned acc, int offset) const { + return WARP_SHFL_DOWN(acc, offset); + } + } max_index_op; + + int64_t index = reduceBlockWithNThreadLocalReductions<2>( + reinterpret_cast(&shmem[0]), + mode_index, + sliceSize, + max_index_op, + 0u); + + // Finally, we have the mode, and an index where it occurs. We use a single + // thread to place this in the appropriate output position + if (tidx == 0) { + unsigned int outputOffset = + at::cuda::detail::IndexToOffset::get( + blockId, values); + values.data[outputOffset] = mode; + indices.data[outputOffset] = index; + } +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/TensorModeKernel.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/TensorModeKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..3096850f793d3c3dd07033c3e45fc812e37e0c83 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/TensorModeKernel.h @@ -0,0 +1,18 @@ +#pragma once +#include + +namespace at { +class TensorBase; +} + +namespace at::native { + +void launch_fused_mode_kernel( + const TensorBase &values, const TensorBase &indices, + const TensorBase &self, int64_t slice_size, int64_t slices); + +void launch_apply_mode_kernel( + const TensorBase &values, const TensorBase &indices, + const TensorBase &self, int64_t dim, int64_t ndim); + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/TensorTopK.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/TensorTopK.h new file mode 100644 index 0000000000000000000000000000000000000000..6cf08cd73642e8ba12bf823df65a753610f70400 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/TensorTopK.h @@ -0,0 +1,13 @@ +#pragma once +#include + +namespace at { +class TensorBase; +} + +namespace at::native { +void launch_gather_topk_kernel( + const TensorBase& self, + int64_t k, int64_t dim, bool largest, + const TensorBase& values, const TensorBase& indices); +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/UniqueCub.cuh b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/UniqueCub.cuh new file mode 100644 index 0000000000000000000000000000000000000000..3256fb076d48ca68c00151845fa1a80d879ecd4c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/UniqueCub.cuh @@ -0,0 +1,12 @@ +#include + +namespace at::native::internal { + +template +std::tuple unique_cuda_template( + const Tensor& self, + const bool consecutive, + const bool return_inverse, + const bool return_counts); + +} // namespace at::native::internal diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/UpSample.cuh b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/UpSample.cuh new file mode 100644 index 0000000000000000000000000000000000000000..9e5716dded2194dd544ca9c6f276d0264bc97c37 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/UpSample.cuh @@ -0,0 +1,368 @@ +#pragma once +#include +#include + +#include +#include +#include + +#include +#include + +namespace at::native { + +namespace upsample { +// TODO: Remove duplicate declaration. +TORCH_API c10::SmallVector compute_output_size( + c10::IntArrayRef input_size, // Full input tensor size. + at::OptionalIntArrayRef output_size, + std::optional> scale_factors); +} // namespace upsample + +namespace upsample_cuda { + +// TODO: Remove duplication with Upsample.h (CPU). +inline std::optional get_scale_value(std::optional> scales, int idx) { + if (!scales) { + return std::nullopt; + } + return scales->at(idx); +} + +} // namespace upsample_cuda + + +/* TODO: move this to a common place */ +template +__device__ inline scalar_t min(scalar_t a, scalar_t b) { + return a < b ? a : b; +} + +template +__device__ inline scalar_t max(scalar_t a, scalar_t b) { + return a > b ? a : b; +} + +// NOTE [ Nearest neighbor upsampling kernel implementation ] +// +// The nearest neighbor upsampling kernel implementation is symmetrical as +// expected. We launch kernels with threads mapping to destination tensors where +// kernels write data to, each thread reads data from the source tensor, this +// means: +// 1. In the forward kernel, +// src_xxx refers to properties of input tensors; +// dst_xxx refers to properties of output tensors; +// scale_factor is the ratio of src_size to dst_size; +// 2. In the backward kernel, +// src_xxx refers to properties of grad_output tensors; +// dst_xxx refers to properties of grad_input tensors; +// scale_factor is the ratio of src_size to dst_size; +// +// Because of this, we need to take the reciprocal of the scale defined by +// upsample layer during forward path. The motivation is to avoid slow +// division in the kernel code, so we can use faster multiplication instead. +// This is not necessary during backward path, since the scale_factor is already +// the reciprocal of corresponding scale_factor used in the forward path due to +// the swap of source and destination tensor. +// +// Similarly, since the mapping from grad_input to grad_output during backward +// is the reverse of the mapping of output to input, we need to have opposite +// mapping functions to compute the source index. + +// see NOTE [ Nearest neighbor upsampling kernel implementation ] +template +__host__ __forceinline__ accscalar_t compute_scales_value( + const std::optional scale, + int64_t src_size, + int64_t dst_size) { + // FIXME: remove magic > 0 after we ensure no models were serialized with -1 defaults. + return (scale.has_value() && scale.value() > 0.) ? (accscalar_t)(1.0 / scale.value()) + : (accscalar_t)src_size / dst_size; +} + +// see NOTE [ Nearest neighbor upsampling kernel implementation ] +template +__host__ __forceinline__ accscalar_t compute_scales_value_backwards( + const std::optional scale, + int64_t src_size, + int64_t dst_size) { + // FIXME: remove magic > 0 after we ensure no models were serialized with -1 defaults. + return (scale.has_value() && scale.value() > 0.) ? (accscalar_t)scale.value() + : (accscalar_t)src_size / dst_size; +} + +template +__host__ __forceinline__ accscalar_t area_pixel_compute_scale( + int input_size, + int output_size, + bool align_corners, + const std::optional scale) { + if(align_corners) { + if(output_size > 1) { + return (accscalar_t)(input_size - 1) / (output_size - 1); + } + else { + return static_cast(0); + } + } + else{ + return compute_scales_value(scale, input_size, output_size); + } +} + +template +__device__ __forceinline__ accscalar_t area_pixel_compute_source_index( + accscalar_t scale, + int dst_index, + bool align_corners, + bool cubic) { + if (align_corners) { + return scale * dst_index; + } else { + accscalar_t src_idx = scale * (dst_index + static_cast(0.5)) - + static_cast(0.5); + // See Note[Follow Opencv resize logic] + return (!cubic && src_idx < static_cast(0)) + ? static_cast(0) + : src_idx; + } +} + +// see NOTE [ Nearest neighbor upsampling kernel implementation ] +__device__ __forceinline__ int nearest_neighbor_compute_source_index( + const float scale, + int dst_index, + int input_size) { + // index_f32 = (output_index) * scale + // input_index = round(index_f32) + // Same as a buggy OpenCV INTER_NEAREST + // We keep this method for BC and consider as deprecated. + // See nearest_neighbor_exact_compute_source_index as replacement + const int src_index = + min(static_cast(floorf((dst_index) * scale)), input_size - 1); + return src_index; +} + +__device__ __forceinline__ int nearest_neighbor_exact_compute_source_index( + const float scale, + int dst_index, + int input_size) { + // index_f32 = (output_index + 0.5) * scale - 0.5 + // input_index = round(index_f32) + // Same as Pillow and Scikit-Image/Scipy ndi.zoom + const int src_index = + min(static_cast(floorf((dst_index + static_cast(0.5)) * scale)), input_size - 1); + return src_index; +} + +// see NOTE [ Nearest neighbor upsampling kernel implementation ] +__device__ __forceinline__ int nearest_neighbor_bw_compute_source_index( + const float scale, + int dst_index, + int output_size) { + // Equivalent to buggy OpenCV INTER_NEAREST + // We keep this method for BC and consider as deprecated. + // See nearest_neighbor_exact_bw_compute_source_index as replacement + const int src_index = + min(static_cast(ceilf(dst_index * scale)), output_size); + return src_index; +} + +// see NOTE [ Nearest neighbor upsampling kernel implementation ] +__device__ __forceinline__ int nearest_neighbor_exact_bw_compute_source_index( + const float scale, + int dst_index, + int output_size) { + // Equivalent to Pillow and Scikit-Image/Scipy ndi.zoom + const int src_index = + min(static_cast(ceilf(dst_index * scale - static_cast(0.5))), output_size); + return src_index; +} + +/* Used by UpSampleBicubic2d.cu */ +template +__device__ __forceinline__ scalar_t upsample_get_value_bounded( + const PackedTensorAccessor64& data, + int batch, + int channel, + int height, + int width, + int y, + int x) { + int access_y = max(min(y, height - 1), 0); + int access_x = max(min(x, width - 1), 0); + return data[batch][channel][access_y][access_x]; +} + +/* Used by UpSampleBicubic2d.cu */ +template +__device__ __forceinline__ void upsample_increment_value_bounded( + PackedTensorAccessor64& data, + int batch, + int channel, + int height, + int width, + int y, + int x, + accscalar_t value) { + int access_y = max(min(y, height - 1), 0); + int access_x = max(min(x, width - 1), 0); + /* TODO: result here is truncated to scalar_t, + check: https://github.com/pytorch/pytorch/pull/19630#discussion_r281426912 + */ + gpuAtomicAddNoReturn( + &data[batch][channel][access_y][access_x], static_cast(value)); +} + +// Based on +// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm +template +__device__ __forceinline__ accscalar_t cubic_convolution1( + accscalar_t x, + accscalar_t A) { + return ((A + 2) * x - (A + 3)) * x * x + 1; +} + +template +__device__ __forceinline__ accscalar_t cubic_convolution2( + accscalar_t x, + accscalar_t A) { + return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; +} + +template +__device__ __forceinline__ void get_cubic_upsampling_coefficients( + accscalar_t coeffs[4], + accscalar_t t) { + accscalar_t A = -0.75; + + accscalar_t x1 = t; + coeffs[0] = cubic_convolution2(x1 + 1.0, A); + coeffs[1] = cubic_convolution1(x1, A); + + // opposite coefficients + accscalar_t x2 = 1.0 - t; + coeffs[2] = cubic_convolution1(x2, A); + coeffs[3] = cubic_convolution2(x2 + 1.0, A); +} + +template +__device__ __forceinline__ accscalar_t cubic_interp1d( + scalar_t x0, + scalar_t x1, + scalar_t x2, + scalar_t x3, + accscalar_t t) { + accscalar_t coeffs[4]; + get_cubic_upsampling_coefficients(coeffs, t); + + return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3]; +} + +namespace upsample_antialias { + +// taken from +// https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/ +// src/libImaging/Resample.c#L20-L29 +struct BilinearFilterFunctor { + + template + __device__ accscalar_t operator()(accscalar_t x) const { + if (x < 0) { + x = -x; + } + if (x < 1) { + return 1 - x; + } + return 0; + } + + static const int size = 2; +}; + +// taken from +// https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/ +// src/libImaging/Resample.c#L46-L62 +struct BicubicFilterFunctor { + + template + __device__ accscalar_t operator()(accscalar_t x) const { + // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm + const accscalar_t a = -0.5; + if (x < 0) { + x = -x; + } + if (x < 1) { + return ((a + 2) * x - (a + 3)) * x * x + 1; + } + if (x < 2) { + return (((x - 5) * x + 8) * x - 4) * a; + } + return 0; + } + + static const int size = 4; +}; + +template +__device__ __forceinline__ void _compute_weights_span( + const int i, + const int input_size, + const accscalar_t scale, + const accscalar_t support, + int& xmin, + int& xsize, + accscalar_t& center) { + center = scale * (i + static_cast(0.5)); + xmin = max(static_cast(center - support + static_cast(0.5)), static_cast(0)); + xsize = min(static_cast(center + support + static_cast(0.5)), input_size) - xmin; +} + +template +__device__ __forceinline__ void _compute_weights( + scalar_t* wt_ptr, + const accscalar_t scale, + int interp_size, + const interp_filter_t& interp_filter, + accscalar_t xmin_m_center, + int xsize) { + + accscalar_t invscale = (scale >= 1.0) ? 1.0 / scale : 1.0; + accscalar_t total_w = 0.0; + int j = 0; + for (j = 0; j < xsize; j++) { + accscalar_t w = interp_filter((j + xmin_m_center + static_cast(0.5)) * invscale); + wt_ptr[j] = static_cast(w); + total_w += w; + } + for (j = 0; j < xsize; j++) { + if (total_w != 0.0) { + wt_ptr[j] /= total_w; + } + } + for (; j < interp_size; j++) { + wt_ptr[j] = static_cast(0.0); + } +} + +template +__device__ __forceinline__ accscalar_t interpolate_aa_single_dim( + const scalar_t* src, + const scalar_t* weights, + int size) { + scalar_t t = static_cast(*src); + scalar_t wts = static_cast(weights[0]); + accscalar_t output = t * wts; + + int j = 1; + for (; j < size; j++) { + wts = static_cast(weights[j]); + t = static_cast(*(src + j)); + output += t * wts; + } + return output; +} + +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/block_reduce.cuh b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/block_reduce.cuh new file mode 100644 index 0000000000000000000000000000000000000000..8311c6616e3a914ae3234330719f8d053a37078f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/block_reduce.cuh @@ -0,0 +1,139 @@ +#pragma once + +#include + +#include +#include + +namespace at::native::cuda_utils { + +constexpr int kCUDABlockReduceNumThreads = 512; +// Algorithmic limitation: BlockReduce does two WarpReduce calls, each +// of which reduces C10_WARP_SIZE elements. So, at most +// C10_WARP_SIZE**2 elements can be reduced at a time. +// NOTE: This is >= the max block size on current hardware anyway (1024). +constexpr int kCUDABlockReduceMaxThreads = C10_WARP_SIZE * C10_WARP_SIZE; + +// Sums `val` across all threads in a warp. +// +// Assumptions: +// - The size of each block should be a multiple of `C10_WARP_SIZE` +template +__inline__ __device__ T WarpReduceSum(T val) { +#pragma unroll + for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) { + val += WARP_SHFL_DOWN(val, offset); + } + return val; +} + +// Picks the maximum `val` across all threads in a warp. +// +// Assumptions: +// - The size of each block should be a multiple of `C10_WARP_SIZE` +template +__inline__ __device__ T WarpReduceMax(T val) { +#pragma unroll + for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) { + val = max_propagate_nan(val, WARP_SHFL_DOWN(val, offset)); + } + return val; +} + +struct Block1D { + static __forceinline__ __device__ int Tid() { return threadIdx.x; } + + static __forceinline__ __device__ int Warps() { + return blockDim.x / C10_WARP_SIZE; + } +}; + +struct Block2D { + static __forceinline__ __device__ int Tid() { + return threadIdx.x + threadIdx.y * blockDim.x; + } + + static __forceinline__ __device__ int Warps() { + return blockDim.x * blockDim.y / C10_WARP_SIZE; + } +}; + +// Sums `val` across all threads in a block. +// +// Warning: the return value is only valid for thread 0. +// Assumptions: +// - The size of each block should be a multiple of `C10_WARP_SIZE` +// - `shared` should be a pointer to shared memory with size of, at least, +// `sizeof(T) * number_of_warps` +template +__inline__ __device__ T BlockReduceSum(T val, T* shared) { + const int tid = B::Tid(); + const int lid = tid % C10_WARP_SIZE; + const int wid = tid / C10_WARP_SIZE; + val = WarpReduceSum(val); + __syncthreads(); // prevent races when BlockReduces are called in a row. + if (lid == 0) { + shared[wid] = val; + } + __syncthreads(); + val = (tid < B::Warps()) ? shared[lid] : T(0); + if (wid == 0) { + val = WarpReduceSum(val); + } + return val; +} + +// Picks out the maximum `val` across all threads in a block. +// +// Warning: the return value is only valid for thread 0. +// Assumptions: +// - The size of each block should be a multiple of `C10_WARP_SIZE` +// - `shared` should be a pointer to shared memory with size of, at least, +// `sizeof(T) * number_of_warps` +template +__inline__ __device__ T BlockReduceMax(T val, T* shared) { + const int tid = B::Tid(); + const int lid = tid % C10_WARP_SIZE; + const int wid = tid / C10_WARP_SIZE; + val = WarpReduceMax(val); + __syncthreads(); // prevent races when BlockReduces are called in a row. + if (lid == 0) { + shared[wid] = val; + } + __syncthreads(); + val = (tid < B::Warps()) ? shared[lid] : T(std::numeric_limits::lowest()); + if (wid == 0) { + val = WarpReduceMax(val); + } + return val; +} + +template +__inline__ __device__ T WarpReduce(T val, const ReduceOp& op) { +#pragma unroll + for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) { + val = op.combine(val, op.warp_shfl_down(val, offset)); + } + return val; +} + +template +__inline__ __device__ T +BlockReduce(T val, const ReduceOp& op, const T& identity_element, T* shared) { + const int tid = B::Tid(); + const int lid = tid % C10_WARP_SIZE; + const int wid = tid / C10_WARP_SIZE; + val = WarpReduce(val, op); + __syncthreads(); // prevent races when BlockReduces are called in a row. + if (lid == 0) { + shared[wid] = val; + } + __syncthreads(); + val = (tid < B::Warps()) ? shared[lid] : identity_element; + if (wid == 0) { + val = WarpReduce(val, op); + } + return val; +} + +} // namespace at::native::cuda_utils diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/cutlass_common.cuh b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/cutlass_common.cuh new file mode 100644 index 0000000000000000000000000000000000000000..38ec667d726c513e3e0940f8ac72f43ab97b8905 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/cutlass_common.cuh @@ -0,0 +1,38 @@ +#pragma once + +#include +#include + +namespace at::cuda::detail { + +template +struct enable_2x_kernel_for_sm89 : Kernel { + template + CUTLASS_DEVICE static void invoke(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 890 + Kernel::invoke(std::forward(args)...); +#endif + } +}; + +template +struct enable_3x_kernel_for_sm9x : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 && __CUDA_ARCH__ < 1000 + Kernel::operator()(std::forward(args)...); +#endif + } +}; + +template +struct enable_3x_kernel_for_sm10_or_later : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 1000 + Kernel::operator()(std::forward(args)...); +#endif + } +}; + +} // namespace at::cuda::detail diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/fused_adam_amsgrad_impl.cuh b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/fused_adam_amsgrad_impl.cuh new file mode 100644 index 0000000000000000000000000000000000000000..30e2b2e4cb9eb94d183d24812f6d2e4fbb2faeac --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/fused_adam_amsgrad_impl.cuh @@ -0,0 +1,38 @@ +#pragma once +#include + +namespace at::native { + +void _fused_adam_amsgrad_cuda_impl_( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList max_exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +void _fused_adam_amsgrad_cuda_impl_( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList max_exp_avg_sqs, + at::TensorList state_steps, + const at::Tensor& lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/fused_adam_impl.cuh b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/fused_adam_impl.cuh new file mode 100644 index 0000000000000000000000000000000000000000..c2bad59598c17aa9c2241e588400fbf50fb4d41a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/fused_adam_impl.cuh @@ -0,0 +1,36 @@ +#pragma once +#include + +namespace at::native { + +void _fused_adam_cuda_impl_( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +void _fused_adam_cuda_impl_( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList state_steps, + const at::Tensor& lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/fused_adam_utils.cuh b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/fused_adam_utils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..ccb911aee5c56a1e3d4c7df42d9123be542e7693 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/fused_adam_utils.cuh @@ -0,0 +1,200 @@ +#pragma once +#include +#include +#include +#include +#include + +namespace at::native { + +enum class ADAM_MODE : uint8_t { ORIGINAL = 0, ADAMW = 1 }; + +namespace { + +constexpr uint8_t kParamIdx = 0; +constexpr uint8_t kGradIdx = 1; +constexpr uint8_t kExpAvgIdx = 2; +constexpr uint8_t kExpAvgSqIdx = 3; +constexpr uint8_t kMaxExpAvgSqIdx = 4; + +template < + typename scalar_type, + typename opmath_t, + int depth, + ADAM_MODE adam_mode, + bool amsgrad> +C10_DEVICE inline void adam_math( + scalar_type r_args[depth][kILP], + const double& lr, + const double& beta1, + const double& beta2, + const double& weight_decay, + const double& eps, + const bool& maximize, + const float* grad_scale_ptr, + const float* found_inf_ptr, + const opmath_t& bias_correction1, + const opmath_t& bias_correction2_sqrt) { + static_assert(depth == 4 || depth == 5); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + // Load values. + opmath_t param = static_cast(r_args[kParamIdx][ii]); + opmath_t grad = static_cast(r_args[kGradIdx][ii]); + if (grad_scale_ptr) { + grad /= (static_cast(*grad_scale_ptr)); + } + const opmath_t grad_to_store = grad; + if (maximize) { + grad = -grad; + } + opmath_t exp_avg = static_cast(r_args[kExpAvgIdx][ii]); + opmath_t exp_avg_sq = static_cast(r_args[kExpAvgSqIdx][ii]); + opmath_t max_exp_avg_sq; + if (amsgrad) { + max_exp_avg_sq = static_cast(r_args[kMaxExpAvgSqIdx][ii]); + } + // Update param, grad, 1st and 2nd order momentum. + if (weight_decay != 0) { + if constexpr (adam_mode == ADAM_MODE::ORIGINAL) { + grad += param * weight_decay; + } else if constexpr (adam_mode == ADAM_MODE::ADAMW) { + param -= lr * weight_decay * param; + } + } + // todo(crcrpar): use lerp + // ref: https://developer.nvidia.com/blog/lerp-faster-cuda/ + exp_avg = beta1 * exp_avg + (1 - beta1) * grad; + exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad * grad; + const opmath_t step_size = lr / bias_correction1; + opmath_t denom; + if (amsgrad) { + max_exp_avg_sq = std::max(max_exp_avg_sq, exp_avg_sq); + denom = (std::sqrt(max_exp_avg_sq) / bias_correction2_sqrt) + eps; + } else { + denom = (std::sqrt(exp_avg_sq) / bias_correction2_sqrt) + eps; + } + param -= step_size * exp_avg / denom; + + // Store results. + r_args[kParamIdx][ii] = param; + if (grad_scale_ptr) { + r_args[kGradIdx][ii] = grad_to_store; + } + r_args[kExpAvgIdx][ii] = exp_avg; + r_args[kExpAvgSqIdx][ii] = exp_avg_sq; + if (amsgrad) { + r_args[kMaxExpAvgSqIdx][ii] = max_exp_avg_sq; + } + } +} + +// [note: Conditional Gradient Store when `optimizer.step` is called by +// GradScaler] When a user is training their model(s) with an FP16 AMP recipe, +// parameter updates are done via `grad_scaler.step(optimizer)` instead of +// `optimizer.step()`. For most optimizers, GradScaler unscales gradients on +// behalf of those optimizers. Also, before `.step`, it makes sure that all the +// gradients involved are finite, which incurs a device sync. On the other hand, +// fused optimizers set their member variable of `_step_supports_amp_scaling` to +// `True` in order to remove the device sync above. This means that fused +// optimizers have to have their CUDA kernels (a) unscale gradients and (b) skip +// parameter updates accordingly. To be functionally on par with `torch.optim` +// optimizers and `_multi_tensor` ones, the kernel below writes out gradients +// only when `grad_scale_ptr != nullptr. +template +struct FusedAdamMathFunctor { + static_assert( + depth == 4 || depth == 5, + "depth of 4 for Adam, depth of 5 for Adam with AMSGrad."); + using opmath_t = at::opmath_type; + C10_DEVICE __forceinline__ void operator()( + int chunk_size, + FusedOptimizerTensorListMetadata& tl, + const float* lr_ptr, + const double& lr, + const double& beta1, + const double& beta2, + const double& weight_decay, + const double& eps, + const bool& maximize, + const float* grad_scale_ptr, + const float* found_inf_ptr) { + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + const double lr_double = lr_ptr ? *lr_ptr : lr; + + if (found_inf_ptr && *found_inf_ptr == 1) { + return; + } + const auto [bias_correction1, bias_correction2_sqrt] = + [&]() -> std::pair { + auto* step_count = + reinterpret_cast(tl.state_steps_addresses[tensor_loc]); + const auto bias_correction1 = 1 - at::native::pow_(beta1, *step_count); + const auto bias_correction2 = 1 - at::native::pow_(beta2, *step_count); + const auto bias_correction2_sqrt = std::sqrt(bias_correction2); + return {bias_correction1, bias_correction2_sqrt}; + }(); + + scalar_type* args[depth]; + scalar_type r_args[depth][kILP]; + const auto n = tl.numel_for_tensor[tensor_loc] - chunk_idx * chunk_size; + + const bool all_aligned{ + init_args(args, tl, chunk_idx, chunk_size, tensor_loc)}; + if ((n % kILP == 0) && (chunk_size % kILP == 0) && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { +#pragma unroll + for (int i = 0; i < depth; i++) { + load_store(r_args[i], args[i], 0, i_start); + } + adam_math( + r_args, + lr_double, + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr, + bias_correction1, + bias_correction2_sqrt); +#pragma unroll + for (int i = 0; i < depth; i++) { + if (i != kGradIdx || grad_scale_ptr) { + load_store(args[i], r_args[i], i_start, 0); + } + } + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + load_args(r_args, args, i_start, chunk_size, n); + adam_math( + r_args, + lr_double, + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr, + bias_correction1, + bias_correction2_sqrt); +#pragma unroll + for (int i = 0; i < depth; i++) { + if (i != kGradIdx || grad_scale_ptr) { + store_args(args[i], r_args[i], i_start, chunk_size, n); + } + } + } + } + } +}; +} // namespace + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/fused_adamw_amsgrad_impl.cuh b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/fused_adamw_amsgrad_impl.cuh new file mode 100644 index 0000000000000000000000000000000000000000..7a7b0b17407ce6fa38ce85054378bcc8ac099c89 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/fused_adamw_amsgrad_impl.cuh @@ -0,0 +1,38 @@ +#pragma once +#include + +namespace at::native { + +void _fused_adamw_amsgrad_cuda_impl_( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList max_exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +void _fused_adamw_amsgrad_cuda_impl_( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList max_exp_avg_sqs, + at::TensorList state_steps, + const at::Tensor& lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/fused_adamw_impl.cuh b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/fused_adamw_impl.cuh new file mode 100644 index 0000000000000000000000000000000000000000..c2891286b516637b54e7848503409fcceb80c554 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/fused_adamw_impl.cuh @@ -0,0 +1,36 @@ +#pragma once +#include + +namespace at::native { + +void _fused_adamw_cuda_impl_( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +void _fused_adamw_cuda_impl_( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList state_steps, + const at::Tensor& lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/im2col.cuh b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/im2col.cuh new file mode 100644 index 0000000000000000000000000000000000000000..cca68e5ee2ff5e918b6ad6172c214c716b3f0197 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/im2col.cuh @@ -0,0 +1,336 @@ +#pragma once + +#include +#include +#include + +#include + +namespace at::native { + +using namespace at::cuda::detail; + +// Kernel for fast unfold+copy +// (borrowed from Caffe: +// https://github.com/BVLC/caffe/blob/master/src/caffe/layers/conv_layer.cu) +// CUDA_NUM_THREADS = 1024 + +template +C10_LAUNCH_BOUNDS_1(1024) +__global__ void im2col_kernel( + const int64_t n, + const dt* data_im, + const int64_t height, + const int64_t width, + const int64_t kernel_height, + const int64_t kernel_width, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + const int64_t height_col, + const int64_t width_col, + dt* data_col) { + CUDA_KERNEL_LOOP_TYPE(index, n, int64_t) { + int64_t w_out = index % width_col; + + int64_t idx = index / width_col; + + int64_t h_out = idx % height_col; + int64_t channel_in = idx / height_col; + int64_t channel_out = channel_in * kernel_height * kernel_width; + int64_t h_in = h_out * stride_height - pad_height; + int64_t w_in = w_out * stride_width - pad_width; + + dt* col = data_col + (channel_out * height_col + h_out) * width_col + w_out; + const dt* im = data_im + (channel_in * height + h_in) * width + w_in; + + for (int64_t i = 0; i < kernel_height; ++i) { + for (int64_t j = 0; j < kernel_width; ++j) { + int64_t h = h_in + i * dilation_height; + int64_t w = w_in + j * dilation_width; + *col = (h >= 0 && w >= 0 && h < height && w < width) + ? im[i * dilation_height * width + j * dilation_width] + : static_cast
(0); + col += height_col * width_col; + } + } + } +} + +template +void im2col( + cudaStream_t stream, + const dt* data_im, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t height_col, + const int64_t width_col, + const int64_t kernel_height, + const int64_t kernel_width, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + dt* data_col) { + // We are going to launch channels * height_col * width_col kernels, each + // kernel responsible for copying a single-channel grid. + int64_t num_kernels = channels * height_col * width_col; + // Launch CUDA_NUM_THREADS = 1024 + im2col_kernel<<>>( + num_kernels, + data_im, + height, + width, + kernel_height, + kernel_width, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + height_col, + width_col, + data_col); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +__forceinline__ __device__ void col2im_device( + const int64_t index, + const dt* data_col, + const int64_t height, + const int64_t width, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + const int64_t height_col, + const int64_t width_col, + dt* data_im) { + accT val = static_cast(0); + const int64_t w_im = index % width + pad_width; + const int64_t h_im = (index / width) % height + pad_height; + const int64_t c_im = index / (width * height); + int64_t kernel_extent_w = (kernel_w - 1) * dilation_width + 1; + int64_t kernel_extent_h = (kernel_h - 1) * dilation_height + 1; + // compute the start and end of the output + const int64_t w_col_start = (w_im < kernel_extent_w) + ? 0 + : (w_im - kernel_extent_w) / stride_width + 1; + const int64_t w_col_end = ::min(w_im / stride_width + 1, width_col); + const int64_t h_col_start = (h_im < kernel_extent_h) + ? 0 + : (h_im - kernel_extent_h) / stride_height + 1; + const int64_t h_col_end = ::min(h_im / stride_height + 1, height_col); + + // TODO: use LCM of stride and dilation to avoid unnecessary loops + for (int64_t h_col = h_col_start; h_col < h_col_end; h_col += 1) { + for (int64_t w_col = w_col_start; w_col < w_col_end; w_col += 1) { + int64_t h_k = (h_im - h_col * stride_height); + int64_t w_k = (w_im - w_col * stride_width); + if (h_k % dilation_height == 0 && w_k % dilation_width == 0) { + h_k /= dilation_height; + w_k /= dilation_width; + int64_t data_col_index = + (((c_im * kernel_h + h_k) * kernel_w + w_k) * height_col + + h_col) * + width_col + + w_col; + val += data_col[data_col_index]; + } + } + } + data_im[index] = static_cast
(val); +} + +template +C10_LAUNCH_BOUNDS_1(512) +__global__ void col2im_kernel( + const int64_t n, + const dt* data_col, + const int64_t height, + const int64_t width, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + const int64_t height_col, + const int64_t width_col, + dt* data_im) { + CUDA_KERNEL_LOOP(index, n) { + col2im_device( + index, + data_col, + height, + width, + kernel_h, + kernel_w, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + height_col, + width_col, + data_im); + } +} + +template +void col2im( + cudaStream_t stream, + const dt* data_col, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t height_col, + const int64_t width_col, + const int64_t patch_height, + const int64_t patch_width, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + dt* data_im) { + int64_t num_kernels = channels * height * width; + // To avoid involving atomic operations, we will launch one kernel per + // bottom dimension, and then in the kernel add up the top dimensions. + // CUDA_NUM_THREADS = 1024 + col2im_kernel + <<>>( + num_kernels, + data_col, + height, + width, + patch_height, + patch_width, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + height_col, + width_col, + data_im); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +C10_LAUNCH_BOUNDS_1(512) +__global__ void col2im_batched_kernel( + const int64_t n, + const dt* data_col, + const int64_t col_batch_stride, + const int64_t nbatch, + const int64_t height, + const int64_t width, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + const int64_t height_col, + const int64_t width_col, + dt* data_im, + const int64_t im_batch_stride) { + using accT = at::acc_type; + const auto im_numel = n * nbatch; + + CUDA_KERNEL_LOOP_TYPE(index, im_numel, int64_t) { + const auto ibatch = index / n; + const auto slice_index = index % n; + + col2im_device( + slice_index, + data_col + ibatch * col_batch_stride, + height, + width, + kernel_h, + kernel_w, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + height_col, + width_col, + data_im + ibatch * im_batch_stride); + } +} + +template +void col2im_batched( + cudaStream_t stream, + const dt* data_col, + const int64_t col_batch_stride, + const int64_t nbatch, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t height_col, + const int64_t width_col, + const int64_t patch_height, + const int64_t patch_width, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + dt* data_im, + const int64_t im_batch_stride) { + const int64_t num_kernels = channels * height * width; + const int64_t output_numel = nbatch * num_kernels; + if (output_numel == 0) { + return; // No work to do + } + + // To avoid involving atomic operations, we will launch one kernel per + // bottom dimension, and then in the kernel add up the top dimensions. + // CUDA_NUM_THREADS = 1024 + col2im_batched_kernel<<>>( + num_kernels, + data_col, + col_batch_stride, + nbatch, + height, + width, + patch_height, + patch_width, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + height_col, + width_col, + data_im, + im_batch_stride); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/jit_utils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/jit_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..b4a3e370f5db732f87b7cb39d1af12cb7b4d4978 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/jit_utils.h @@ -0,0 +1,249 @@ +#pragma once + +#include + +#include +#include +#include + +namespace at::cuda::jit { + +enum class BinaryFuncVariant {NoScalar, RhsScalar, LhsScalar}; + +struct NvrtcFunction { + CUmodule module = CUmodule(); + CUfunction function = nullptr; +}; + +struct KernelDescriptor { + std::string name; + std::string f; + c10::ScalarType f_inputs_type; + c10::ScalarType result_type; + c10::SmallVector extra_args_types; + int nInputs, nOutputs; +}; + +// Helper function to return a vector +// corresponding to the type of the arguments in parameter pack. +template +c10::SmallVector get_extra_args_types() { + return {c10::CppTypeToScalarType::value ...}; +} + +template < + typename result_type, + typename f_inputs_type, + typename... ExtraArgs> +KernelDescriptor make_kernel_descriptor( + std::string name, + std::string f, + int nInputs, + int nOutputs) { + KernelDescriptor ret; + ret.name = std::move(name); + ret.f = std::move(f); + ret.f_inputs_type = c10::CppTypeToScalarType::value; + ret.result_type = c10::CppTypeToScalarType::value; + ret.extra_args_types = get_extra_args_types(); + ret.nInputs = nInputs; + ret.nOutputs = nOutputs; + return ret; +} + +inline int can_vectorize_up_to(size_t default_alignment, void *pointer) { + auto ip = reinterpret_cast(pointer); +#ifdef USE_ROCM + if ((default_alignment == 1) && (ip % (16 * default_alignment) == 0)) { + return 16; + } + if ((default_alignment <= 2) && (ip % (8 * default_alignment) == 0)) { + return 8; + } +#else + if (ip % (8 * default_alignment) == 0) { + return 8; + } +#endif + if (ip % (4 * default_alignment) == 0) { + return 4; + } + if (ip % (2 * default_alignment) == 0) { + return 2; + } + return 1; +} + +inline int can_vectorize_up_to(const KernelDescriptor &desc, c10::ArrayRef pointers) { + TORCH_INTERNAL_ASSERT(desc.nOutputs == 1); + TORCH_INTERNAL_ASSERT(static_cast(pointers.size()) == 1 + desc.nInputs); + + // Deals with output + auto result_size = c10::scalarTypeToTypeMeta(desc.result_type).itemsize(); + auto result = can_vectorize_up_to(result_size, pointers[0]); + + // Incorporates input(s) + auto input_size = c10::scalarTypeToTypeMeta(desc.f_inputs_type).itemsize(); + for (auto i : c10::irange(1, pointers.size())) { + result = std::min(result, can_vectorize_up_to(input_size, pointers[i])); + } + + return result; +} + +//FIXME - this are defined in Loops.cuh, but including Loops.cuh here would lead to circular includes Loops.cuh -> CUDALoops.cuh -> jit_utils.h -> Loops.cuh +#ifdef USE_ROCM +#define JIT_THREAD_WORK_SIZE 4 +#else +#define JIT_THREAD_WORK_SIZE 8 +#endif + +int calc_io_size( + const int nInputs, + const int nOutputs, + const c10::ScalarType& inputs_type, + const c10::ScalarType& result_type); + +int calc_thread_work_size( + const int nInputs, + const int nOutputs, + const c10::ScalarType& inputs_type, + const c10::ScalarType& result_type); + +std::string generate_code( + int nInputs, + int nOutputs, + const std::string& func, + const std::string& name, + const std::string& f_inputs_type, + const std::string& compute_type, + const std::string& result_type, + bool contiguous, + bool dynamic_casting, + BinaryFuncVariant scalar_pos, + c10::SmallVector& extra_args_typenames, + int thread_work_size=JIT_THREAD_WORK_SIZE, + bool vectorized=false, + int vec_size=0, + bool return_by_ref=false); + +std::string generate_code( + const KernelDescriptor &desc, + bool contiguous, + bool dynamic_casting, + BinaryFuncVariant scalar_pos, + int thread_work_size=JIT_THREAD_WORK_SIZE, + bool vectorized=false, + int vec_size=0, + bool return_by_ref=false); + +std::string generate_reduction_code( + int nOutputs, + const std::string& func, + const std::string& name, + const int vt0, + const std::string& f_inputs_type, + const std::string& reduction_accum_type, + const std::string& result_type, + bool contiguous, + bool vectorized, + int vec_size, + int max_threads_codegen); + +std::string generate_reduction_code( + const KernelDescriptor &desc, + const int vt0, + bool contiguous, + bool vectorized, + int vec_size, + int max_threads_codegen); + +NvrtcFunction jit_pwise_function( + const std::string& code, + const std::string& kernel_name); + +void launch_jitted_pwise_function( + NvrtcFunction function, + const void* args[], + const dim3 nBlocks, + const dim3 kBlockSize, + const int smem=0); + +template +struct delayed_false : std::false_type { +}; + +// Defines type names +// NOTE: General case is instantiated only for invalid types. +// All the valid types have specialization using the TYPE_NAME_FN +// macro below. +template +inline std::string typeName() { + // we can't use static_assert(false) directly as the + // program will be not compiled even if the template is not + // instantiated, so we use `delayed_false` + // to make sure compiler doesn't eagerly raise + // fail this assertion. + static_assert(delayed_false::value, "invalid type for jiterator"); + return "void"; +} + +#define TYPE_NAME_FN(ctype, name) \ +template <> inline std::string typeName(){ \ + return std::string(#ctype); \ +} + +AT_FORALL_SCALAR_TYPES(TYPE_NAME_FN) +#undef TYPE_NAME_FN +// JIT uses std::complex directly, because nvRTC compile programs +// with -default-device, so there is no such issue like: +// "std::sin(complex) is __host__ only" +template <> inline std::string typeName(){ + return "bool"; +} +template <> inline std::string typeName>(){ + return "std::complex"; +} +template <> inline std::string typeName>(){ + return "std::complex"; +} +template <> inline std::string typeName>(){ + return "std::complex"; +} +template <> inline std::string typeName(){ + return "at::Half"; +} +template <> inline std::string typeName(){ + return "at::BFloat16"; +} +template <> inline std::string typeName(){ + return "at::Float8_e5m2"; +} +template <> inline std::string typeName(){ + return "at::Float8_e4m3fn"; +} +template <> inline std::string typeName() { + return "at::Float8_e5m2fnuz"; +} +template <> inline std::string typeName() { + return "at::Float8_e4m3fnuz"; +} +template <> inline std::string typeName() { + // TODO(#146647): Can the code here be made generic for any scalartype? + return "at::Float8_e8m0fnu"; +} + +#define TYPE_NAME_CASE(ctype, scalartype) \ + case ScalarType::scalartype: return typeName(); +inline std::string typeName(ScalarType t) { + switch (t) { + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(TYPE_NAME_CASE) + default: + TORCH_CHECK(false, "invalid type for jiterator"); + } +} +#undef TYPE_NAME_CASE + +TORCH_CUDA_CPP_API void initializeCudaContext(); + +} // namespace at::cuda::jit diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/reduction_template.cuh b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/reduction_template.cuh new file mode 100644 index 0000000000000000000000000000000000000000..dee0c35e258b8d45da34b788f46d602cc1998432 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/reduction_template.cuh @@ -0,0 +1,680 @@ +namespace at::cuda { +//windows doesn't like large string literals, so split in two +const std::string reduction_template_0 = R"ESCAPE( + #define C10_HOST_DEVICE __host__ __device__ + #define C10_DEVICE __device__ + #if defined(__clang__) && defined(__HIP__) + #ifndef __forceinline__ + #define __forceinline__ inline __attribute__((always_inline)) + #endif + // until ROCm support for kernel asserts is restored + #define assert(expr) (static_cast(0)) + #endif + + template + __device__ __forceinline__ T WARP_SHFL_DOWN(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff) + { + #if defined(__clang__) && defined(__HIP__) + return __shfl_down(value, delta, width); + #else + return __shfl_down_sync(mask, value, delta, width); + #endif + } + + + #if ${complex} + template + __device__ __forceinline__ std::complex WARP_SHFL_DOWN(std::complex value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff) + { + return std::complex( + #if defined(__clang__) && defined(__HIP__) + __shfl_down(value.real(), delta, width), + __shfl_down(value.imag(), delta, width)); + #else + __shfl_down_sync(mask, value.real(), delta, width), + __shfl_down_sync(mask, value.imag(), delta, width)); + #endif + } + #endif + + // aligned vector generates vectorized load/store on CUDA + template + struct alignas(sizeof(scalar_t) * vec_size) aligned_vector { + scalar_t val[vec_size]; + }; + + + C10_HOST_DEVICE static void reduce_fraction(size_t &numerator, size_t &denominator) { + // get GCD of num and denom using Euclid's algorithm. + // Can replace this with std::gcd if we ever support c++17. + size_t a = denominator; + size_t b = numerator; + while (b != 0) { + a %= b; + // swap(a,b) + size_t tmp = a; + a = b; + b = tmp; + } + + // a is now the GCD + numerator /= a; + denominator /= a; + } + + + + + struct ReduceConfig { + //has to match host-side ReduceConfig in the eager code + static constexpr int BLOCK_X = 0; + static constexpr int BLOCK_Y = 1; + static constexpr int CTA = 2; + + static constexpr int input_vec_size = 4; + int element_size_bytes; + int num_inputs; + int num_outputs; + int step_input = 1; + int step_output = 1; + int ctas_per_output = 1; + int input_mult[3] = {0, 0, 0}; + int output_mult[2] = {0, 0}; + + int block_width; + int block_height; + int num_threads; + + bool vectorize_input = false; + int output_vec_size = 1; + + C10_HOST_DEVICE bool should_block_x_reduce() const { + return input_mult[BLOCK_X] != 0; + } + + C10_HOST_DEVICE bool should_block_y_reduce() const { + return input_mult[BLOCK_Y] != 0; + } + + C10_HOST_DEVICE bool should_global_reduce() const { + return input_mult[CTA] != 0; + } + + C10_DEVICE bool should_store(int output_idx) const { + return output_idx < num_outputs && + (!should_block_x_reduce() || threadIdx.x == 0) && + (!should_block_y_reduce() || threadIdx.y == 0); + } + + C10_DEVICE bool should_reduce_tail() const { + return (!should_block_y_reduce() || threadIdx.y == 0) && + (!should_global_reduce() || blockIdx.y == 0); + } + + C10_HOST_DEVICE int input_idx() const { + int lane = threadIdx.x; + int warp = threadIdx.y; + int cta2 = blockIdx.y; + return (lane * input_mult[BLOCK_X] + + warp * input_mult[BLOCK_Y] + + cta2 * input_mult[CTA]); + } + + template + C10_HOST_DEVICE int output_idx() const { + int lane = threadIdx.x; + int warp = threadIdx.y; + int cta1 = blockIdx.x; + return (lane * output_mult[BLOCK_X] + + warp * output_mult[BLOCK_Y] + + cta1 * step_output) * output_vec_size; + } + + C10_DEVICE int shared_memory_offset(int offset) const { + return threadIdx.x + (threadIdx.y + offset) * blockDim.x; + } + + C10_DEVICE int staging_memory_offset(int cta2) const { + int offset = cta2 + blockIdx.x * gridDim.y; + if (!should_block_x_reduce()) { + offset = threadIdx.x + offset * blockDim.x; + } + return offset; + } + + + }; + + +//TODO this will need to be different for more generic reduction functions +namespace reducer { + + using scalar_t = ${scalar_type}; + using arg_t = ${reduction_accum_type}; + using out_scalar_t = ${result_type}; + + + inline __device__ ${functor} + + inline __device__ out_scalar_t project(arg_t arg) { + return (out_scalar_t) arg; + } + + inline __device__ arg_t warp_shfl_down(arg_t arg, int offset) { + return WARP_SHFL_DOWN(arg, offset); + } + + inline __device__ arg_t translate_idx(arg_t acc, int64_t /*idx*/) { + return acc; + } + + // wrap a normal reduction that ignores the index + inline __device__ arg_t reduce(arg_t acc, arg_t val, int64_t idx) { + return combine(acc, val); + } +} + + +struct ReduceJitOp { + using scalar_t = ${scalar_type}; + using arg_t = ${reduction_accum_type}; + using out_scalar_t = ${result_type}; + + using InputCalculator = OffsetCalculator<1>; + using OutputCalculator = OffsetCalculator<2>; + +// static constexpr bool can_accumulate_in_output = +// std::is_convertible_v +// && std::is_convertible_v; + + static constexpr int input_vec_size = ReduceConfig::input_vec_size; + + arg_t ident; + ReduceConfig config; + InputCalculator input_calc; + OutputCalculator output_calc; + const void* src; + const char* dst[2]; //it accepts at most two destinations + // acc_buf used for accumulation among sub Tensor Iterator when accumulation on + // output is not permissible + void* acc_buf; + // cta_buf used for accumulation between blocks during global reduction + void* cta_buf; + int* semaphores; + int64_t base_idx; + bool accumulate; + bool final_output; + int noutputs; + + + C10_DEVICE void run() const { + extern __shared__ char shared_memory[]; + uint32_t output_idx = config.output_idx<${output_vec_size}>(); + uint32_t input_idx = config.input_idx(); + auto base_offsets1 = output_calc.get(output_idx)[1]; + + using arg_vec_t = Array; + arg_vec_t value; + + if (output_idx < config.num_outputs && input_idx < config.num_inputs) { + const scalar_t* input_slice = (const scalar_t*)((const char*)src + base_offsets1); + + value = thread_reduce<${output_vec_size}>(input_slice); + } + + if (config.should_block_y_reduce()) { + value = block_y_reduce<${output_vec_size}>(value, shared_memory); + } + if (config.should_block_x_reduce()) { + value = block_x_reduce<${output_vec_size}>(value, shared_memory); + } + + using out_ptr_vec_t = Array; + using offset_vec_t = Array; + offset_vec_t base_offsets; + out_ptr_vec_t out; + + #pragma unroll + for (int i = 0; i < ${output_vec_size}; i++) { + base_offsets[i] = output_calc.get(output_idx + i)[0]; + out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]); + } + + arg_vec_t* acc = nullptr; + if (acc_buf != nullptr) { + size_t numerator = sizeof(arg_t); + size_t denominator = sizeof(out_scalar_t); + reduce_fraction(numerator, denominator); + acc = (arg_vec_t*)((char*)acc_buf + (base_offsets[0] * numerator / denominator)); + } + + if (config.should_global_reduce()) { + value = global_reduce<${output_vec_size}>(value, acc, shared_memory); + } else if (config.should_store(output_idx)) { + if (accumulate) { + #pragma unroll + for (int i = 0; i < ${output_vec_size}; i++) { + value[i] = reducer::translate_idx(value[i], base_idx); + } + } + + if (acc == nullptr) { + if (accumulate) { + value = accumulate_in_output<${output_vec_size}>(out, value); + } + if (final_output) { + set_results_to_output<${output_vec_size}>(value, base_offsets); + } else { + #pragma unroll + for (int i = 0; i < ${output_vec_size}; i++) { + *(out[i]) = get_accumulated_output(out[i], value[i]); + } + } + } else { + if (accumulate) { + #pragma unroll + for (int i = 0; i < ${output_vec_size}; i++) { + value[i] = reducer::combine((*acc)[i], value[i]); + } + } + if (final_output) { + set_results_to_output<${output_vec_size}>(value, base_offsets); + } else { + *acc = value; + } + } + } + } + + template + C10_DEVICE Array thread_reduce(const scalar_t* data) const { + if (config.vectorize_input) { + assert(output_vec_size == 1); + // reduce at the header of input_slice where memory is not aligned, + // so that thread_reduce will have an aligned memory to work on. + return {input_vectorized_thread_reduce_impl(data)}; + } else { + uint32_t element_stride = input_calc.strides_[0][0] / sizeof(scalar_t); + bool is_contiguous = (input_calc.dims == 1 && element_stride == 1); + if (is_contiguous) { + return thread_reduce_impl(data, [](uint32_t idx) { return idx; }); + } else if (input_calc.dims == 1) { + return thread_reduce_impl(data, [&](uint32_t idx) { return idx * element_stride; }); + } else { + return thread_reduce_impl(data, [&](uint32_t idx) { return input_calc.get(idx)[0] / sizeof(scalar_t); }); + } + } + } + + C10_DEVICE arg_t input_vectorized_thread_reduce_impl(const scalar_t* data) const { + uint32_t end = config.num_inputs; + + // Handle the head of input slice where data is not aligned + arg_t value = ident; + constexpr int align_bytes = alignof(aligned_vector); + constexpr int align_elements = align_bytes / sizeof(scalar_t); + int shift = ((int64_t)data) % align_bytes / sizeof(scalar_t); + if (shift > 0) { + data -= shift; + end += shift; + if(threadIdx.x >= shift && threadIdx.x < align_elements && config.should_reduce_tail()){ + value = reducer::reduce(value, data[threadIdx.x], threadIdx.x - shift); + } + end -= align_elements; + data += align_elements; + shift = align_elements - shift; + } + + // Do the vectorized reduction + using load_t = aligned_vector; + + uint32_t idx = config.input_idx(); + const uint32_t stride = config.step_input; + + // Multiple accumulators to remove dependency between unrolled loops. + arg_t value_list[input_vec_size]; + value_list[0] = value; + + #pragma unroll + for (int i = 1; i < input_vec_size; i++) { + value_list[i] = ident; + } + + scalar_t values[input_vec_size]; + + load_t *values_vector = reinterpret_cast(&values[0]); + + while (idx * input_vec_size + input_vec_size - 1 < end) { + *values_vector = reinterpret_cast(data)[idx]; + #pragma unroll + for (uint32_t i = 0; i < input_vec_size; i++) { + value_list[i] = reducer::reduce(value_list[i], values[i], shift + idx * input_vec_size + i); + } + idx += stride; + } + + // tail + uint32_t tail_start = end - end % input_vec_size; + if (config.should_reduce_tail()) { + int idx = tail_start + threadIdx.x; + if (idx < end) { + value_list[0] = reducer::reduce(value_list[0], data[idx], idx + shift); + } + } + + // combine accumulators + #pragma unroll + for (int i = 1; i < input_vec_size; i++) { + value_list[0] = reducer::combine(value_list[0], value_list[i]); + } + return value_list[0]; + } + + template + C10_DEVICE Array thread_reduce_impl(const scalar_t* data_, offset_calc_t calc) const { + uint32_t idx = config.input_idx(); + const uint32_t end = config.num_inputs; + const uint32_t stride = config.step_input; + const int vt0=${vt0}; + + using arg_vec_t = Array; + using load_t = aligned_vector; + const load_t* data = reinterpret_cast(data_); + + // Multiple accumulators to remove dependency between unrolled loops. + arg_vec_t value_list[vt0]; + + #pragma unroll + for (int i = 0; i < vt0; i++) { + #pragma unroll + for (int j = 0; j < output_vec_size; j++) { + value_list[i][j] = ident; + } + } + + load_t values[vt0]; + + while (idx + (vt0 - 1) * stride < end) { + #pragma unroll + for (uint32_t i = 0; i < vt0; i++) { + values[i] = data[calc(idx + i * stride) / output_vec_size]; + } + #pragma unroll + for (uint32_t i = 0; i < vt0; i++) { + #pragma unroll + for (uint32_t j = 0; j < output_vec_size; j++) { + value_list[i][j] = reducer::reduce(value_list[i][j], values[i].val[j], idx + i * stride); + } + } + idx += stride * vt0; + } + + // tail + int idx_ = idx; + #pragma unroll + for (uint32_t i = 0; i < vt0; i++) { + if (idx >= end) { + break; + } + values[i] = data[calc(idx) / output_vec_size]; + idx += stride; + } + idx = idx_; + #pragma unroll + for (uint32_t i = 0; i < vt0; i++) { + if (idx >= end) { + break; + } + #pragma unroll + for (uint32_t j = 0; j < output_vec_size; j++) { + value_list[i][j] = reducer::reduce(value_list[i][j], values[i].val[j], idx); + } + idx += stride; + } + + // combine accumulators + #pragma unroll + for (int i = 1; i < vt0; i++) { + #pragma unroll + for (uint32_t j = 0; j < output_vec_size; j++) { + value_list[0][j] = reducer::combine(value_list[0][j], value_list[i][j]); + } + } + return value_list[0]; + } + template + C10_DEVICE Array block_x_reduce(Array value, char* shared_memory) const { + using args_vec_t = Array; + int dim_x = blockDim.x; + args_vec_t* shared = (args_vec_t*)shared_memory; + if (dim_x > warpSize) { + int address_base = threadIdx.x + threadIdx.y*blockDim.x; + shared[address_base] = value; + for (int offset = dim_x/2; offset >= warpSize; offset >>= 1) { + __syncthreads(); + if (threadIdx.x < offset && threadIdx.x + offset < blockDim.x) { + args_vec_t other = shared[address_base + offset]; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = reducer::combine(value[i], other[i]); + } + shared[address_base] = value; + } + } + dim_x = warpSize; + } + + __syncthreads(); + + for (int offset = 1; offset < dim_x; offset <<= 1) { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + arg_t other = reducer::warp_shfl_down(value[i], offset); + value[i] = reducer::combine(value[i], other); + } + } + return value; + } + + template + C10_DEVICE Array block_y_reduce(Array value, char* shared_memory) const { + using args_vec_t = Array; + args_vec_t* shared = (args_vec_t*)shared_memory; + shared[config.shared_memory_offset(0)] = value; + for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) { + __syncthreads(); + if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) { + args_vec_t other = shared[config.shared_memory_offset(offset)]; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = reducer::combine(value[i], other[i]); + } + shared[config.shared_memory_offset(0)] = value; + } + } + return value; + } + )ESCAPE"; + + const std::string reduction_template_1 = R"ESCAPE( + + C10_DEVICE bool mark_block_finished() const { + __shared__ bool is_last_block_done_shared; + + __syncthreads(); + if (threadIdx.x == 0 && threadIdx.y == 0) { + int prev_blocks_finished = atomicAdd(&semaphores[blockIdx.x], 1); + is_last_block_done_shared = (prev_blocks_finished == gridDim.y - 1); + } + + __syncthreads(); + + return is_last_block_done_shared; + } + + template + C10_DEVICE Array accumulate_in_output( + Array out, + Array value + ) const { + Array ret; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + ret[i] = reducer::combine(*(out[i]), value[i]); + } + return ret; + } + + + C10_DEVICE out_scalar_t get_accumulated_output( + out_scalar_t* out, arg_t value + ) const { + assert(!final_output); + return (out_scalar_t)value; + } + + template + C10_DEVICE void set_results(const T x, const uint32_t base_offset) const { + assert(noutputs == 1); + auto res = (out_scalar_t*)((char*)dst[0] + base_offset); + *res = x; + } + +//TODO - multi-output reduction - we won't be able to use thrust::pair +//just explicitly specify typed output reads/writes +//Currently implemented for max of two outputs +// template +// C10_DEVICE void set_results(const thrust::pair x, const index_t base_offset) const { +// if (noutputs >= 1) { +// auto res0 = (T1*)((char*)dst[0] + base_offset); +// *res0 = x.first; +// } +// if (noutputs >= 2) { +// // base offset is computed assuming element size being sizeof(T1), so we need to make a +// // correction to obtain the correct base offset +// auto res1 = (T2*) ((char *) dst[1] + base_offset / sizeof(T1) * sizeof(T2)); +// *res1 = x.second; +// } +// } + + template + C10_DEVICE void set_results_to_output(Array value, Array base_offset) const { + assert(final_output); + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + set_results(reducer::project(value[i]), base_offset[i]); + } + } + + template + C10_DEVICE Array global_reduce(Array value, Array *acc, char* shared_memory) const { + using arg_vec_t = Array; + using out_ptr_vec_t = Array; + using offset_vec_t = Array; + + arg_vec_t* reduce_buffer = (arg_vec_t*)cta_buf; + uint32_t output_idx = config.output_idx(); + offset_vec_t base_offsets; + out_ptr_vec_t out; + + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + base_offsets[i] = output_calc.get(output_idx + i)[0]; + out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]); + } + + bool should_store = config.should_store(output_idx); + if (should_store) { + uint32_t offset = config.staging_memory_offset(blockIdx.y); + reduce_buffer[offset] = value; + } + + __threadfence(); // make sure writes are globally visible + __syncthreads(); // if multiple warps in this block wrote to staging, make sure they're all done + bool is_last_block_done = mark_block_finished(); + + if (is_last_block_done) { + __threadfence(); //complete acquire pattern + value = ident; + if (config.should_block_x_reduce()) { + uint32_t input_offset = threadIdx.x + threadIdx.y * blockDim.x; + uint32_t step = blockDim.x * blockDim.y; + for (; input_offset < config.ctas_per_output; input_offset += step) { + uint32_t idx = config.staging_memory_offset(input_offset); + arg_vec_t next = reduce_buffer[idx]; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = reducer::combine(value[i], next[i]); + } + } + } else { + uint32_t input_offset = threadIdx.y; + uint32_t step = blockDim.y; + for (; input_offset < config.ctas_per_output; input_offset += step) { + uint32_t idx = config.staging_memory_offset(input_offset); + arg_vec_t next = reduce_buffer[idx]; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = reducer::combine(value[i], next[i]); + } + } + } + value = block_y_reduce(value, shared_memory); + if (config.should_block_x_reduce()) { + value = block_x_reduce(value, shared_memory); + } + if (should_store) { + if (accumulate) { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = reducer::translate_idx(value[i], base_idx); + } + } + + if (acc == nullptr) { + if (accumulate) { + value = accumulate_in_output(out, value); + } + if (final_output) { + set_results_to_output(value, base_offsets); + } else { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + *(out[i]) = get_accumulated_output(out[i], value[i]); + } + } + } else { + if (accumulate) { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = reducer::combine((*acc)[i], value[i]); + } + } + if (final_output) { + set_results_to_output(value, base_offsets); + } else { + *acc = value; + } + } + } + } + + return value; + } +}; + +extern "C" +__launch_bounds__(${max_threads_lb}, 4) +__global__ void reduction_${name}_kernel(ReduceJitOp r){ + r.run(); +} +)ESCAPE"; + +const std::string reduction_template = reduction_template_0 + reduction_template_1; + + +const std::string &get_reduction_template() { + return reduction_template; +} + +} // namespace at::cuda diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/thread_constants.h b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/thread_constants.h new file mode 100644 index 0000000000000000000000000000000000000000..78c37782eb6be7feb0b0894bbf3366e0b8bc759e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/thread_constants.h @@ -0,0 +1,25 @@ +#pragma once +#include + +// Marks a lambda as executable on both the host and device. The __host__ +// attribute is important so that we can access static type information from +// the host, even if the function is typically only executed on the device. +#ifndef GPU_LAMBDA +#define GPU_LAMBDA __host__ __device__ +#endif + +#if defined(USE_ROCM) +constexpr int num_threads() { + return 256; +} + +constexpr int thread_work_size() { return 4; } +#else +constexpr uint32_t num_threads() { + return C10_WARP_SIZE * 4; +} + +constexpr int thread_work_size() { return 8; } +#endif + +constexpr int block_work_size() { return thread_work_size() * num_threads(); } diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/vol2col.cuh b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/vol2col.cuh new file mode 100644 index 0000000000000000000000000000000000000000..1e743de363c762c48243e63d45b5bc4696bd1ffd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/cuda/vol2col.cuh @@ -0,0 +1,262 @@ +#pragma once + +#include +#include +#include +#include + +#include + +namespace at::native { + +using namespace at::cuda::detail; + +// Kernel for fast unfold+copy on volumes +template +C10_LAUNCH_BOUNDS_1(1024) +__global__ void vol2col_kernel( + const int64_t n, + const T* data_vol, + const int depth, + const int height, + const int width, + const int ksize_t, + const int ksize_h, + const int ksize_w, + const int pad_t, + const int pad_h, + const int pad_w, + const int stride_t, + const int stride_h, + const int stride_w, + const int dilation_t, + const int dilation_h, + const int dilation_w, + const int depth_col, + const int height_col, + const int width_col, + T* data_col) { + CUDA_KERNEL_LOOP_TYPE(index, n, int64_t) { + auto w_out = index % width_col; + index /= width_col; + auto h_out = index % height_col; + index /= height_col; + auto t_out = index % depth_col; + auto channel_in = index / depth_col; + auto channel_out = channel_in * ksize_t * ksize_h * ksize_w; + auto t_in = t_out * stride_t - pad_t; + auto h_in = h_out * stride_h - pad_h; + auto w_in = w_out * stride_w - pad_w; + data_col += + ((channel_out * depth_col + t_out) * height_col + h_out) * width_col + + w_out; + data_vol += ((channel_in * depth + t_in) * height + h_in) * width + w_in; + for (int i = 0; i < ksize_t; ++i) { + for (int j = 0; j < ksize_h; ++j) { + for (int k = 0; k < ksize_w; ++k) { + auto t = t_in + i * dilation_t; + auto h = h_in + j * dilation_h; + auto w = w_in + k * dilation_w; + *data_col = (t >= 0 && h >= 0 && w >= 0 && t < depth && h < height && + w < width) + ? data_vol + [i * dilation_t * height * width + j * dilation_h * width + + k * dilation_w] + : static_cast(0); + data_col += depth_col * height_col * width_col; + } + } + } + } +} + +template +void vol2col( + cudaStream_t stream, + const T* data_vol, + const int channels, + const int depth, + const int height, + const int width, + const int depth_col, + const int height_col, + const int width_col, + const int ksize_t, + const int ksize_h, + const int ksize_w, + const int pad_t, + const int pad_h, + const int pad_w, + const int stride_t, + const int stride_h, + const int stride_w, + const int dilation_t, + const int dilation_h, + const int dilation_w, + T* data_col) { + // We are going to launch channels * depth_col * height_col * width_col + // kernels, each kernel responsible for copying a single-channel grid. + // We cast an operand to int64 so that the product will not overflow + const auto num_kernels = static_cast(channels) * depth_col * height_col * width_col; + // Launch + vol2col_kernel<<>>( + num_kernels, + data_vol, + depth, + height, + width, + ksize_t, + ksize_h, + ksize_w, + pad_t, + pad_h, + pad_w, + stride_t, + stride_h, + stride_w, + dilation_t, + dilation_h, + dilation_w, + depth_col, + height_col, + width_col, + data_col); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +__global__ void vol2im_kernel( + const int64_t n, + const T* data_col, + const unsigned depth, + const unsigned height, + const unsigned width, + const unsigned channels, + const unsigned kernel_t, + const unsigned kernel_h, + const unsigned kernel_w, + const unsigned pad_t, + const unsigned pad_h, + const unsigned pad_w, + const unsigned stride_t, + const unsigned stride_h, + const unsigned stride_w, + const unsigned dilation_t, + const unsigned dilation_h, + const unsigned dilation_w, + const unsigned depth_col, + const unsigned height_col, + const unsigned width_col, + T* data_vol) { + CUDA_KERNEL_LOOP(index, n) { + accT val = static_cast(0); + const auto w_im = index % width + pad_w; + const auto h_im = (index / width) % height + pad_h; + const auto t_im = (index / width / height) % depth + pad_t; + const auto c_im = index / (width * height * depth); + auto kernel_extent_w = (kernel_w - 1) * dilation_w + 1; + auto kernel_extent_h = (kernel_h - 1) * dilation_h + 1; + auto kernel_extent_t = (kernel_t - 1) * dilation_t + 1; + // compute the start and end of the output + const auto w_col_start = + (w_im < kernel_extent_w) ? 0 : (w_im - kernel_extent_w) / stride_w + 1; + const auto w_col_end = std::min(w_im / stride_w + 1, width_col); + const auto h_col_start = + (h_im < kernel_extent_h) ? 0 : (h_im - kernel_extent_h) / stride_h + 1; + const auto h_col_end = std::min(h_im / stride_h + 1, height_col); + const auto t_col_start = + (t_im < kernel_extent_t) ? 0 : (t_im - kernel_extent_t) / stride_t + 1; + const auto t_col_end = std::min(t_im / stride_t + 1, depth_col); + // TODO: use LCM of stride and dilation to avoid unnecessary loops + for (unsigned t_col = t_col_start; t_col < t_col_end; t_col += 1) { + for (unsigned h_col = h_col_start; h_col < h_col_end; h_col += 1) { + for (unsigned w_col = w_col_start; w_col < w_col_end; w_col += 1) { + uint64_t t_k = (t_im - t_col * stride_t); + uint64_t h_k = (h_im - h_col * stride_h); + uint64_t w_k = (w_im - w_col * stride_w); + if (t_k % dilation_t == 0 && h_k % dilation_h == 0 && + w_k % dilation_w == 0) { + t_k /= dilation_t; + h_k /= dilation_h; + w_k /= dilation_w; + const int64_t idx_k = + ((c_im * kernel_t + t_k) * kernel_h + h_k) * kernel_w + w_k; + const int64_t data_col_index = + ((idx_k * depth_col + t_col) * + height_col + h_col) * + width_col + w_col; + val += data_col[data_col_index]; + } + } + } + } + data_vol[index] = static_cast(val); + } +} + +template +void col2vol( + cudaStream_t stream, + const T* data_col, + const int64_t channels, + const int64_t depth, + const int64_t height, + const int64_t width, + const int64_t output_depth, + const int64_t output_height, + const int64_t output_width, + const int64_t patch_t, + const int64_t patch_h, + const int64_t patch_w, + const int64_t pad_t, + const int64_t pad_h, + const int64_t pad_w, + const int64_t stride_t, + const int64_t stride_h, + const int64_t stride_w, + const int64_t dilation_t, + const int64_t dilation_h, + const int64_t dilation_w, + T* data_vol) { + const auto num_kernels = channels * depth * height * width; + + auto check_fits_in_unsigned = + [](int64_t val, const char * name) { + constexpr auto umax = std::numeric_limits::max(); + TORCH_CHECK(val >= 0 && val <= umax, + name, " must fit in a 32-bit unsigned value"); + }; + check_fits_in_unsigned(num_kernels, "input size"); + check_fits_in_unsigned( + channels * patch_t * patch_h * patch_w, "channels x kernel size"); + + // To avoid involving atomic operations, we will launch one kernel per + // bottom dimension, and then in the kernel add up the top dimensions. + vol2im_kernel + <<>>( + num_kernels, + data_col, + depth, + height, + width, + channels, + patch_t, + patch_h, + patch_w, + pad_t, + pad_h, + pad_w, + stride_t, + stride_h, + stride_w, + dilation_t, + dilation_h, + dilation_w, + output_depth, + output_height, + output_width, + data_vol); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/group_norm.h b/phivenv/Lib/site-packages/torch/include/ATen/native/group_norm.h new file mode 100644 index 0000000000000000000000000000000000000000..6a4e079db4e24a3980e8fc9cf84454c5073ad437 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/group_norm.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include + +namespace at { +class Tensor; + +namespace native { + +using forward_fn = void (*)( + const Tensor& /* X */, + const Tensor& /* gamma */, + const Tensor& /* beta */, + int64_t /* N */, + int64_t /* C */, + int64_t /* HxW */, + int64_t /* group */, + double /* eps */, + Tensor& /* Y */, + Tensor& /* mean */, + Tensor& /* rstd */); + +using backward_fn = void (*)( + const Tensor& /* dY */, + const Tensor& /* X */, + const Tensor& /* mean */, + const Tensor& /* rstd */, + const Tensor& /* gamma */, + int64_t /* N */, + int64_t /* C */, + int64_t /* HxW */, + int64_t /* group */, + Tensor& /* dX */, + Tensor& /* dgamma */, + Tensor& /* dbeta */); + +DECLARE_DISPATCH(forward_fn, GroupNormKernel) +DECLARE_DISPATCH(backward_fn, GroupNormBackwardKernel) + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/hip/bgemm_kernels/bgemm_kernel_collection.h b/phivenv/Lib/site-packages/torch/include/ATen/native/hip/bgemm_kernels/bgemm_kernel_collection.h new file mode 100644 index 0000000000000000000000000000000000000000..5b653d519ccb1167b9d5d9e39a44f319682f886d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/hip/bgemm_kernels/bgemm_kernel_collection.h @@ -0,0 +1,33 @@ +#pragma once + +#include +#include + +namespace at::native { +void bgemm_kernel_bf16bf16bf16_256_256x256x32_32x32_4x4_8x32x1_8x32x1_1x16x1x16_4_Intrawave_v4(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_256_256x256x32_32x32_4x4_16x16x1_16x16x1_1x16x1x16_4_Intrawave_v4(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_256_256x256x32_32x32_4x4_4x64x1_4x64x1_1x16x1x16_4_Intrawave_v3(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_256_256x256x32_32x32_4x4_4x64x1_4x64x1_1x16x1x16_4_Intrawave_v5(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_256_224x256x64_16x16_7x8_8x32x1_8x32x1_1x16x1x16_4_Intrawave_v3(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_256_256x224x64_16x16_8x7_8x32x1_8x32x1_1x32x1x8_4_Intrawave_v3(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_256_128x128x64_32x32_2x2_8x32x1_8x32x1_1x16x1x16_4_Intrawave_v3(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_256_128x128x64_32x32_2x2_8x32x1_8x32x1_1x16x1x16_4_Intrawave_v5(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_256_128x128x64_32x32_2x2_8x32x1_8x32x1_1x16x1x16_4_Intrawave_v1(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_128_32x16x64_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2_Intrawave_v1(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_64_16x16x64_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4_Intrawave_v1(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_128_16x32x64_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4_Intrawave_v1(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_128_16x32x64_16x16_1x1_16x8x1_16x8x1_1x16x1x8_4_Intrawave_v1(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_128_16x32x64_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4_Intrawave_v1(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_256_256x16x64_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2_Intrawave_v2(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_256_256x16x64_16x16_4x1_16x16x1_16x8x1_1x32x1x8_2_Intrawave_v2(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_256_256x16x64_16x16_4x1_32x8x1_32x4x1_1x32x1x8_2_Intrawave_v2(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_128_128x16x64_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2_Intrawave_v2(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_128_64x16x64_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2_Intrawave_v2(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_128_32x16x64_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2_Intrawave_v2(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_64_16x16x64_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4_Intrawave_v2(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_128_16x32x64_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4_Intrawave_v2(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_128_16x64x64_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4_Intrawave_v2(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_128_16x128x64_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4_Intrawave_v2(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_256_16x256x64_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4_Intrawave_v2(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); + +}; // namespace at::native \ No newline at end of file diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/hip/bgemm_kernels/bgemm_kernel_template.h b/phivenv/Lib/site-packages/torch/include/ATen/native/hip/bgemm_kernels/bgemm_kernel_template.h new file mode 100644 index 0000000000000000000000000000000000000000..2dc07c984297f5fa926786f17cfcf91064e9bbaa --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/hip/bgemm_kernels/bgemm_kernel_template.h @@ -0,0 +1,164 @@ +#undef __HIP_NO_HALF_CONVERSIONS__ + + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace at::native { + +// Define commonly used types. +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using F32 = float; + +using AccDataType = F32; +using DsDataType = ck::Tuple<>; +using CDataType = BF16; +using CShuffleDataType = BF16; +using DsLayout = ck::Tuple<>; +using CLayout = Row; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template < + typename A_DATA_TYPE, + typename B_DATA_TYPE, + int BLOCK_SIZE, + int MBLOCK, + int NBLOCK, + int KBLOCK, + int AK1, + int BK1, + int WAVE_TILE_M, + int WAVE_TILE_N, + int WAVE_MAP_M, + int WAVE_MAP_N, + typename ABLOCK_TRANSFER, + int ABLOCK_TRANSFER_SSPV, + int ABLOCK_TRANSFER_DSPV_K1, + typename BBLOCK_TRANSFER, + int BBLOCK_TRANSFER_SSPV, + int BBLOCK_TRANSFER_SSPV_K1, + int CSHUFFLE_MXDL_PWPS, + int CSHUFFLE_NXDL_PWPS, + typename CSHUFFLEBLOCK_TRANSFER, + typename CDESHUFFLEBLOCK_TRANSFER, + ck::BlockGemmPipelineScheduler LOOP_SCHED, + ck::BlockGemmPipelineVersion PIPELINE_VERSION, + ck::tensor_operation::device::GemmSpecialization GEMM_SPEC = + ck::tensor_operation::device::GemmSpecialization::MNPadding, + bool TRANSA = false, + bool TRANSB = false> +void bgemm_kernel_impl(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) { + + using ADataType = typename CkMathType::dtype; + using BDataType = typename CkMathType::dtype; + + using ALayout = typename CkTensorLayout::a_layout; + using BLayout = typename CkTensorLayout::b_layout; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + auto gemm = ck::tensor_operation::device::DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< + ALayout, // ALayout + BLayout, // BLayout + DsLayout, // DsLayout + CLayout, // CLayout + ADataType, // ADataType + BDataType, // BDataType + DsDataType, // DsDataType + CDataType, // CDataType + AccDataType, // AccDataType + CShuffleDataType, // CshuffleType + AElementOp, // AElementwiseOperation + BElementOp, // BElementwiseOperation + CDEElementOp, // CElementwiseOperation + GEMM_SPEC, // GEMMSpecialization + BLOCK_SIZE, // BlockSize + MBLOCK, // MPerBlock + NBLOCK, // NPerBlock + KBLOCK, // KPerBlock + AK1, // AK1 + BK1, // BK1 + WAVE_TILE_M, // MPerXDL + WAVE_TILE_N, // NPerXDL + WAVE_MAP_M, // MXdlPerWave + WAVE_MAP_N, // NXdlPerWave + ABLOCK_TRANSFER, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + ABLOCK_TRANSFER_SSPV, // ABlockTransferSrcScalarPerVector + ABLOCK_TRANSFER_DSPV_K1, // ABlockTransferDstScalarPerVector_AK1 + 0, // ABlockLdsExtraM + BBLOCK_TRANSFER, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + BBLOCK_TRANSFER_SSPV, // BBlockTransferSrcScalarPerVector + BBLOCK_TRANSFER_SSPV_K1, // BBlockTransferDstScalarPerVector_BK1 + 0, // BBlockLdsAddExtraN + CSHUFFLE_MXDL_PWPS, // CShuffleMXdlPerWavePerShuffle + CSHUFFLE_NXDL_PWPS, // CShuffleNXdlPerWavePerShuffle + CSHUFFLEBLOCK_TRANSFER, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + CDESHUFFLEBLOCK_TRANSFER, // CDEShuffleBlockTransferScalarPerVectors + LOOP_SCHED, // BlockGemmPipelineScheduler + PIPELINE_VERSION // BlockGemmPipelineVersion + >{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument( + b, // A and B are swapped for CK + a, + {}, + c, + n, + m, + k, + num_batches, + ldb, + lda, + {}, + ldc, + n * k, // batch_stride_a + m * k, // batch_stride_b + {}, + m * n, // batch_stride_c + a_element_op, + b_element_op, + cde_element_op + ); + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + auto stream = at::cuda::getCurrentHIPStream().stream(); + invoker.Run(argument, StreamConfig{stream, false}); +} + +}; // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/hip/ck_bgemm.h b/phivenv/Lib/site-packages/torch/include/ATen/native/hip/ck_bgemm.h new file mode 100644 index 0000000000000000000000000000000000000000..c9d09accbe263fc8a93a60a57c70b630bc34d453 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/hip/ck_bgemm.h @@ -0,0 +1,16 @@ +#pragma once + +#include +#include + +namespace at::native { + +template +inline void bgemm_internal_ck(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dtype)) { + static_assert(false&&sizeof(Dtype),"at::cuda::blas_bgemm_internal_ck: not implemented"); +} + +template <> +void bgemm_internal_ck(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/hip/ck_gemm.h b/phivenv/Lib/site-packages/torch/include/ATen/native/hip/ck_gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..d45ac13d738275d6b6232e2e019264127abcfcfd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/hip/ck_gemm.h @@ -0,0 +1,24 @@ +#pragma once + +#include +#include +namespace at::native { + + +template +inline void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(Dtype)) { + static_assert(false&&sizeof(Dtype),"at::cuda::blas_gemm_internal_ck: not implemented"); +} + +template <> +void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(double)); +template <> +void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(float)); +template <> +void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(at::Half)); +template <> +void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)); + + + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/hip/ck_gemm_template.h b/phivenv/Lib/site-packages/torch/include/ATen/native/hip/ck_gemm_template.h new file mode 100644 index 0000000000000000000000000000000000000000..31b5b094b55f57838d117a317107156439881747 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/hip/ck_gemm_template.h @@ -0,0 +1,416 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#undef __HIP_NO_HALF_CONVERSIONS__ +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +// Define commonly used types. +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +namespace at::native { + +// Elementwise Operators +struct AlphaBetaAdd +{ + AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(C& c, const AB& ab) const; + + template<> + __host__ __device__ constexpr void operator() + (float& c, const float& ab) const + { + c = alpha_ * ab; + }; + + template<> + __host__ __device__ constexpr void operator() + (ck::bhalf_t& c, const ck::bhalf_t& ab) const + { + c = alpha_ * ab; + }; + + template<> + __host__ __device__ constexpr void operator() + (ck::half_t& c, const ck::half_t& ab) const + { + c = alpha_ * ab; + }; + + float alpha_; + // TODO: Leaving for now, will use later + float beta_; +}; + +template < + typename Dtype, + int BLOCK_SIZE, + int MBLOCK, + int NBLOCK, + int KBLOCK, + int AK1, + int BK1, + int MPER_XDL, + int NPER_XDL, + int MPER_WAVE, + int NPER_WAVE, + typename ABLOCK_CLUSTER_LENS, + typename ABLOCK_CLUSTER_ORDER, + typename ABLOCK_SRC_ORDER, + int ABLOCK_VECTOR_DIM, + int ABLOCK_SCALAR_VEC, + int ABLOCK_SCALAR_VEC_AK1, + bool ABLOCK_LDS_EXTRAM, + typename BBLOCK_CLUSTER_LENS, + typename BBLOCK_CLUSTER_ORDER, + typename BBLOCK_SRC_ORDER, + int BBLOCK_VECTOR_DIM, + int BBLOCK_SCALAR_VEC, + int BBLOCK_SCALAR_VEC_AK1, + bool BBLOCK_LDS_EXTRAN, + int CMPER_WAVE, + int CNPER_WAVE, + typename BLOCK_CLUSTER_LENS, + typename CDE_SCALAR_VEC, + bool PADDING = false, + bool TRANSA = false, + bool TRANSB = false> +void gemm_impl(CUDABLAS_GEMM_ARGTYPES(Dtype)) { + // Get input information. + int M = m; + int N = n; + int K = k; + + int StrideA = lda; + int StrideB = ldb; + int StrideC = ldc; + + int KBatch = 1; + + float falpha = alpha; + float fbeta = beta; + + using ADataType = typename CkMathType::dtype; + using BDataType = typename CkMathType::dtype; + using CDataType = typename CkMathType::dtype; + using DDataType = typename CkMathType::dtype; + + using AccDataType = float; + using CShuffleDataType = typename CkMathType::dtype; + + using ALayout = typename CkTensorLayout::a_layout; + using BLayout = typename CkTensorLayout::b_layout; + + using DLayout = Row; + using CLayout = Row; + + using AElementOp = PassThrough; + using BElementOp = PassThrough; + using CElementOp = AlphaBetaAdd; + + + static constexpr auto GemmDefault = + ck::tensor_operation::device::GemmSpecialization::Default; + static constexpr auto GemmMNKPadding = + ck::tensor_operation::device::GemmSpecialization::MNKPadding; + static constexpr auto GemmSpec = PADDING ? GemmMNKPadding : GemmDefault; + + + using DeviceGemmInstance = + ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3, + CLayout, + ADataType, + BDataType, + ck::Tuple<>, + CDataType, + AccDataType, + CShuffleDataType, + AElementOp, + BElementOp, + CElementOp, + GemmSpec, + BLOCK_SIZE, + MBLOCK, + NBLOCK, + KBLOCK, + AK1, + BK1, + MPER_XDL, + NPER_XDL, + MPER_WAVE, + NPER_WAVE, + ABLOCK_CLUSTER_LENS, + ABLOCK_CLUSTER_ORDER, + ABLOCK_SRC_ORDER, + ABLOCK_VECTOR_DIM, + ABLOCK_SCALAR_VEC, + ABLOCK_SCALAR_VEC_AK1, + ABLOCK_LDS_EXTRAM, + BBLOCK_CLUSTER_LENS, + BBLOCK_CLUSTER_ORDER, + BBLOCK_SRC_ORDER, + BBLOCK_VECTOR_DIM, + BBLOCK_SCALAR_VEC, + BBLOCK_SCALAR_VEC_AK1, + BBLOCK_LDS_EXTRAN, + CMPER_WAVE, + CNPER_WAVE, + BLOCK_CLUSTER_LENS, + CDE_SCALAR_VEC>; + + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{alpha, beta}; + + + using DDataArrayType = std::array; + DDataArrayType DDataArray; + + // We swap A and B inputs here as a temporary workaround + auto argument = gemm.MakeArgument( + reinterpret_cast(b), + reinterpret_cast(a), + DDataArray, + reinterpret_cast(c), + N, + M, + K, + StrideB, + StrideA, + std::array{}, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + + auto stream = at::cuda::getCurrentHIPStream().stream(); + invoker.Run(argument, StreamConfig{stream, false}); +} + + +template < + typename Dtype, + int BLOCK_SIZE, + int MBLOCK, + int NBLOCK, + int KBLOCK, + int K1, + int MPER_WMMA, + int NPER_WMMA, + int MPER_WAVE, + int NPER_WAVE, + typename ABLOCK_CLUSTER_LENS, + typename ABLOCK_CLUSTER_ORDER, + typename ABLOCK_SRC_ORDER, + int ABLOCK_VECTOR_DIM, + int ABLOCK_SCALAR_VEC, + int ABLOCK_SCALAR_VEC_K1, + bool ABLOCK_LDS_EXTRAM, + typename BBLOCK_CLUSTER_LENS, + typename BBLOCK_CLUSTER_ORDER, + typename BBLOCK_SRC_ORDER, + int BBLOCK_VECTOR_DIM, + int BBLOCK_SCALAR_VEC, + int BBLOCK_SCALAR_VEC_AK1, + bool BBLOCK_LDS_EXTRAN, + int CMPER_WAVE, + int CNPER_WAVE, + typename CBLOCK_CLUSTER_LENS, + int CNPER_BLOCK, + bool PADDING = false, + bool TRANSA = false, + bool TRANSB = false> +void gemm_impl_wmma(CUDABLAS_GEMM_ARGTYPES(Dtype)) { + // Get input information. + int M = m; + int N = n; + int K = k; + + int StrideA = lda; + int StrideB = ldb; + int StrideC = ldc; + + int KBatch = 1; + + float falpha = alpha; + float fbeta = beta; + + using ADataType = typename CkMathType::dtype; + using BDataType = typename CkMathType::dtype; + using CDataType = typename CkMathType::dtype; + using DDataType = typename CkMathType::dtype; + + using AccDataType = float; + using CShuffleDataType = typename CkMathType::dtype; + + using ALayout = typename CkTensorLayout::a_layout; + using BLayout = typename CkTensorLayout::b_layout; + + using DLayout = Row; + using CLayout = Row; + + using AElementOp = PassThrough; + using BElementOp = PassThrough; + using CElementOp = PassThrough; + + + static constexpr auto GemmDefault = + ck::tensor_operation::device::GemmSpecialization::Default; + static constexpr auto GemmMNKPadding = + ck::tensor_operation::device::GemmSpecialization::MNKPadding; + static constexpr auto GemmSpec = PADDING ? GemmMNKPadding : GemmDefault; + + + using DeviceGemmInstance = + ck::tensor_operation::device::DeviceGemmWmma_CShuffle; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + + using DDataArrayType = std::array; + DDataArrayType DDataArray; + + // We swap A and B inputs here as a temporary workaround + auto argument = gemm.MakeArgument( + reinterpret_cast(b), + reinterpret_cast(a), + reinterpret_cast(c), + N, + M, + K, + StrideB, + StrideA, + StrideC, + b_element_op, + a_element_op, + c_element_op); + + + if(!gemm.IsSupportedArgument(argument)) + { + printf("error shape = %ld %ld %ld TRANSA=%d TRANSB=%d \n", + n, m, k,TRANSA, TRANSB); + + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + + auto stream = at::cuda::getCurrentHIPStream().stream(); +#if 1 + invoker.Run(argument, StreamConfig{stream, false}); +#else + float ave_time = invoker.Run(argument, StreamConfig{stream, true}); + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << N <<" " < +#include +#include +#include + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace at::native { + +template +struct CkMathType { + using dtype = T; +}; + +template <> +struct CkMathType { + using dtype = ck::bhalf_t; +}; + +template <> +struct CkMathType { + using dtype = ck::half_t; +}; + +template +struct CkTensorLayout { + // default goes to row-wise for now + using a_layout = Row; + using b_layout = Row; +}; + +// True denotes transpose is necessary. Default is Col, so return Row +template <> +struct CkTensorLayout { + using a_layout = Col; + using b_layout = Col; +}; + +template <> +struct CkTensorLayout { + using a_layout = Row; + using b_layout = Col; +}; + +template <> +struct CkTensorLayout { + using a_layout = Col; + using b_layout = Row; +}; + +template <> +struct CkTensorLayout { + using a_layout = Row; + using b_layout = Row; +}; + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/im2col.h b/phivenv/Lib/site-packages/torch/include/ATen/native/im2col.h new file mode 100644 index 0000000000000000000000000000000000000000..d9ac61a078e53ba891a89fd6f335441a0486ada1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/im2col.h @@ -0,0 +1,149 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +namespace at::native { + +template +static void im2col( + const T* data_im, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t output_height, + const int64_t output_width, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_h, + const int64_t pad_w, + const int64_t stride_h, + const int64_t stride_w, + const int64_t dilation_h, + const int64_t dilation_w, + T* data_col, + bool is_channels_last = false) { + const int64_t height_col = output_height; + const int64_t width_col = output_width; + const int64_t channels_col = channels * kernel_h * kernel_w; + + if (is_channels_last) { + at::parallel_for(0, height_col * width_col, 0, [&](int64_t begin, int64_t end) { + int64_t h_col{0}, w_col{0}; + data_index_init(begin, h_col, height_col, w_col, width_col); + + for (const auto i_col : c10::irange(begin, end)) { + for (const auto h_offset : c10::irange(kernel_h)) { + int64_t h_im = h_col * stride_h - pad_h + h_offset * dilation_h; + for (const auto w_offset : c10::irange(kernel_w)) { + int64_t w_im = w_col * stride_w - pad_w + w_offset * dilation_w; + + const T* slice_im = data_im + (h_im * width + w_im) * channels; + T* slice_col = data_col + (i_col * kernel_h * kernel_w + h_offset * kernel_w + w_offset) * channels; + + if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { + std::copy_n(slice_im, channels, slice_col); + } else { + std::fill_n(slice_col, channels, T(0)); + } + } + } + + // move the next index + data_index_step(h_col, height_col, w_col, width_col); + } + }); + } else { + at::parallel_for(0, channels_col, 0, [&](int64_t begin, int64_t end) { + int64_t c_im{0}, h_offset{0}, w_offset{0}; + data_index_init(begin, c_im, channels, h_offset, kernel_h, w_offset, kernel_w); + + for (const auto c_col : c10::irange(begin, end)) { + for (const auto h_col : c10::irange(height_col)) { + int64_t h_im = h_col * stride_h - pad_h + h_offset * dilation_h; + for (const auto w_col : c10::irange(width_col)) { + int64_t w_im = w_col * stride_w - pad_w + w_offset * dilation_w; + data_col[(c_col * height_col + h_col) * width_col + w_col] = + (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) + ? c10::load(&(data_im[(c_im * height + h_im) * width + w_im])) + : static_cast(0); + } + } + + // move to the next index + data_index_step(c_im, channels, h_offset, kernel_h, w_offset, kernel_w); + } + }); + } +} + +template +static void col2im( + const T* data_col, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t output_height, + const int64_t output_width, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_h, + const int64_t pad_w, + const int64_t stride_h, + const int64_t stride_w, + const int64_t dilation_h, + const int64_t dilation_w, + T* data_im, + bool is_channels_last = false) { + std::fill_n(data_im, height * width * channels, T(0)); + + const int64_t height_col = output_height; + const int64_t width_col = output_width; + const int64_t channels_col = channels * kernel_h * kernel_w; + + if (is_channels_last) { + for (const auto h_col : c10::irange(height_col)) { + for (const auto w_col : c10::irange(width_col)) { + for (const auto h_offset : c10::irange(kernel_h)) { + int64_t h_im = h_col * stride_h - pad_h + h_offset * dilation_h; + for (const auto w_offset : c10::irange(kernel_w)) { + int64_t w_im = w_col * stride_w - pad_w + w_offset * dilation_w; + + T* slice_im = data_im + (h_im * width + w_im) * channels; + const T* slice_col = data_col + ((h_col * width_col + w_col) * kernel_h * kernel_w + + h_offset * kernel_w + w_offset) * channels; + + if (h_im >= 0 && h_im < height && w_im >= 0 && w_im < width) { + std::transform(slice_col, slice_col + channels, slice_im, slice_im, std::plus()); + } + } + } + } + } + } else { + for (const auto c_col : c10::irange(channels_col)) { + int64_t w_offset = c_col % kernel_w; + int64_t h_offset = (c_col / kernel_w) % kernel_h; + int64_t c_im = c_col / kernel_h / kernel_w; + + for (const auto h_col : c10::irange(height_col)) { + int64_t h_im = h_col * stride_h - pad_h + h_offset * dilation_h; + for (const auto w_col : c10::irange(width_col)) { + int64_t w_im = w_col * stride_w - pad_w + w_offset * dilation_w; + + if (h_im >= 0 && h_im < height && w_im >= 0 && w_im < width) + data_im[(c_im * height + h_im) * width + w_im] += + data_col[(c_col * height_col + h_col) * width_col + w_col]; + } + } + } + } +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/im2col_shape_check.h b/phivenv/Lib/site-packages/torch/include/ATen/native/im2col_shape_check.h new file mode 100644 index 0000000000000000000000000000000000000000..fb3b82d60697d9ea5eb596097c45060ec974d72a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/im2col_shape_check.h @@ -0,0 +1,232 @@ +#pragma once +#include +#include +#include + +namespace at::native { + +inline void col2im_shape_check( + const Tensor& input, + const Tensor& grad_output, + int64_t output_height, + int64_t output_width, + int64_t kernel_height, + int64_t kernel_width, + int64_t dilation_height, + int64_t dilation_width, + int64_t pad_height, + int64_t pad_width, + int64_t stride_height, + int64_t stride_width) { + TORCH_CHECK( + kernel_width > 0 && kernel_height > 0, + "kernel size should be greater than zero, but got kernel_height: ", + kernel_height, + " kernel_width: ", + kernel_width); + TORCH_CHECK( + stride_width > 0 && stride_height > 0, + "stride should be greater than zero, but got stride_height: ", + stride_height, + " stride_width: ", + stride_width); + TORCH_CHECK( + dilation_width > 0 && dilation_height > 0, + "dilation should be greater than zero, but got dilation_height: ", + dilation_height, + " dilation_width: ", + dilation_width); + TORCH_CHECK( + pad_width >= 0 && pad_height >= 0, + "padding should be non-negative, but got pad_height: ", + pad_height, + " pad_width: ", + pad_width); + + + int64_t ndim = input.ndimension(); + // allow dim=0 only the batch dimension. + TORCH_CHECK( + (ndim == 2 && input.size(0) != 0 && input.size(1) != 0) || + (ndim == 3 && input.size(1) != 0 && input.size(2) != 0), + "Expected 2D or 3D (batch mode) tensor for input with possibly 0 batch size and non-zero dimensions for input, but got: ", + input.sizes()); + + int64_t batch_dim = (ndim == 3) ? 0 : -1; + int64_t n_input_plane = input.size(batch_dim + 1); + + if (n_input_plane % (kernel_width * kernel_height) != 0) { + TORCH_CHECK(false, + "Expected size of input's dimension 1 to be divisible by the " + "product of kernel_size, but got input.size(1)=", + n_input_plane, + " and kernel_size=(", + kernel_height, + ", ", + kernel_width, + ")."); + } + + int64_t input_length = input.size(batch_dim + 2); + int64_t n_blocks_height = + div_rtn( + output_height + 2 * pad_height - + dilation_height * (kernel_height - 1) - 1, + stride_height) + + 1; + int64_t n_blocks_width = div_rtn( + output_width + 2 * pad_width - + dilation_width * (kernel_width - 1) - 1, + stride_width) + + 1; + + if (input_length != (n_blocks_height * n_blocks_width)) { + TORCH_CHECK(false, + "Given output_size=(", + output_height, + ", ", + output_width, + "), kernel_size=(", + kernel_height, + ", ", + kernel_width, + "), dilation=(", + dilation_height, + ", ", + dilation_width, + "), padding=(", + pad_height, + ", ", + pad_width, + "), stride=(", + stride_height, + ", ", + stride_width, + "), expected size of input's dimension 2 to match the calculated number of ", + "sliding blocks ", + n_blocks_height, + " * ", + n_blocks_width, + " = ", + (n_blocks_height * n_blocks_width), + ", but got input.size(2)=", + input_length, + "."); + } + + TORCH_CHECK( + n_blocks_height >= 1 && n_blocks_width >= 1, + "Given output_size=(", output_height, ", ", output_width, "), ", + "kernel_size=(", kernel_height, ", ", kernel_width, "), ", + "dilation=(", dilation_height, ", ", dilation_width, "), ", + "padding=(", pad_height, ", ", pad_width, "), ", + "stride=(", stride_height, ", ", stride_width, "), ", + "calculated shape of the array of sliding blocks as ", + "(", n_blocks_height, ", ", n_blocks_width, "), ", + "which is too small (non-positive)"); + + if (output_width < 1 || output_height < 1) { + TORCH_CHECK(false, + "Expected output spatial size to be positive, but got: output_size=(", + output_height, + ", ", + output_width, + ")."); + } +} + +inline void im2col_shape_check( + const Tensor& input, + const Tensor& grad_output, + int64_t kernel_height, + int64_t kernel_width, + int64_t dilation_height, + int64_t dilation_width, + int64_t pad_height, + int64_t pad_width, + int64_t stride_height, + int64_t stride_width) { + TORCH_CHECK( + kernel_width > 0 && kernel_height > 0, + "kernel size should be greater than zero, but got kernel_height: ", + kernel_height, + " kernel_width: ", + kernel_width); + + TORCH_CHECK( + dilation_width > 0 && dilation_height > 0, + "dilation should be greater than zero, but got dilation_height: ", + dilation_height, + " dilation_width: ", + dilation_width); + + TORCH_CHECK( + pad_width >= 0 && pad_height >= 0, + "padding should be non-negative, but got pad_height: ", + pad_height, + " pad_width: ", + pad_width); + + TORCH_CHECK( + stride_width > 0 && stride_height > 0, + "stride should be greater than zero, but got stride_height: ", + stride_height, + " stride_width: ", + stride_width); + + int64_t ndim = input.ndimension(); + + // allow dim=0 only the batch dimension. + bool valid_dims = input.size(1) != 0 && input.size(2) != 0; + TORCH_CHECK( + (ndim == 3 && input.size(0) && valid_dims) || + (ndim == 4 && valid_dims && input.size(3) != 0), + "Expected 3D or 4D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got: ", + input.sizes()); + + int64_t dim_batch = 0; + + if (ndim == 3) { + dim_batch = -1; + } + + int64_t input_height = input.size(dim_batch + 2); + int64_t input_width = input.size(dim_batch + 3); + int64_t output_height = div_rtn( + input_height + 2 * pad_height - + (dilation_height * (kernel_height - 1) + 1), + stride_height) + + 1; + int64_t output_width = div_rtn( + input_width + 2 * pad_width - + (dilation_width * (kernel_width - 1) + 1), + stride_width) + + 1; + + if (output_height < 1 || output_width < 1) { + TORCH_CHECK(false, + "Given input with spatial size (", + input_height, + ", ", + input_height, + "), kernel_size=(", + kernel_height, + ", ", + kernel_width, + "), dilation=(", + dilation_height, + ", ", + dilation_width, + "), padding=(", + pad_height, + ", ", + pad_width, + "), calculated shape of the array of sliding blocks as (", + output_height, + ", ", + output_width, + "), but its components must be at least one."); + } +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/kleidiai/kai_kernels.h b/phivenv/Lib/site-packages/torch/include/ATen/native/kleidiai/kai_kernels.h new file mode 100644 index 0000000000000000000000000000000000000000..aae1495c0c92e55aae6a93970ba15043bd7b9a63 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/kleidiai/kai_kernels.h @@ -0,0 +1,42 @@ +#pragma once +#include +#include +#if AT_KLEIDIAI_ENABLED() + +namespace at::native::kleidiai { + +/** + * @brief Rearranges the quantized weight to support kleidiai inference + * @param bl Groupsize for quantization should be multiple of 32 + */ +void kai_pack_int4_rhs( + const Tensor& weight_packed, + const Tensor& weight, + const Tensor& scales, + const std::optional& bias, + const int64_t n, + const int64_t k, + const int64_t bl); + +/** + * @brief Outputs the buffer size for the packed weights + * @param bl Groupsize for quantization. 32 for groupwise , 0 for channelwise + */ +size_t kai_pack_rhs_int4_size( + const int64_t n, + const int64_t k, + const int64_t bl); + +/** + * @brief Run 2 operations ( Input quantize and pack -> 4 bit Matmul ) + */ +void kai_quant_pack_lhs_int4_mm( + const Tensor& output, + const Tensor& input, + const Tensor& weight, + const int64_t m, + const int64_t n, + const int64_t k, + const int64_t bl); +} // namespace at::native::kleidiai +#endif diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/kleidiai/kai_pack.h b/phivenv/Lib/site-packages/torch/include/ATen/native/kleidiai/kai_pack.h new file mode 100644 index 0000000000000000000000000000000000000000..9fafb06b50e7d88624074da17b5784c7b161361e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/kleidiai/kai_pack.h @@ -0,0 +1,106 @@ +#pragma once +#include +#include +#include +#include +#if AT_KLEIDIAI_ENABLED() + +namespace at::native::kleidiai { + +template +void kai_pack_rhs_groupwise_int4( + T& kernel, + const Tensor& weight_packed, + const Tensor& weight, + const Tensor& scales, + const std::optional& bias, + const int64_t n, + const int64_t k, + const int64_t bl, + const int64_t rhs_stride, + const int64_t scale_stride) { + const auto& ukernel = kernel.ukernel; + const size_t nr = ukernel.get_nr(); + const size_t kr = ukernel.get_kr(); + const size_t sr = ukernel.get_sr(); + auto weight_packed_data = + reinterpret_cast(weight_packed.data_ptr()); + const auto weight_data = weight.data_ptr(); + auto scales_data = scales.const_data_ptr(); + + if (weight_data == nullptr) { + AT_ERROR("kai_pack_rhs_channelwise_int4: Weight data pointer is null"); + } + + if (scales_data == nullptr) { + AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null"); + } + + float* bias_ptr = bias.has_value() ? bias.value().data_ptr() : NULL; + auto& params = kernel.rhs_pack_params; + + kernel.kai_run_rhs_pack( + /*num_groups=*/1, + n, + k, + nr, + kr, + sr, + bl, + (const uint8_t*)(weight_data), + rhs_stride, + bias_ptr, + scales_data, + scale_stride, + weight_packed_data, + 0, + ¶ms); +} + +template +void kai_pack_rhs_channelwise_int4( + T& kernel, + const Tensor& weight_packed, + const Tensor& weight, + const Tensor& scales, + const std::optional& bias, + const int64_t n, + const int64_t k) { + const auto& ukernel = kernel.ukernel; + const size_t nr = ukernel.get_nr(); + const size_t kr = ukernel.get_kr(); + const size_t sr = ukernel.get_sr(); + auto weight_packed_data = + reinterpret_cast(weight_packed.data_ptr()); + const auto weight_data = weight.data_ptr(); + const auto scales_data = scales.data_ptr(); + + if (weight_data == nullptr) { + AT_ERROR("kai_pack_rhs_channelwise_int4: Weight data pointer is null"); + } + + if (scales_data == nullptr) { + AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null"); + } + + float* bias_ptr = bias.has_value() ? bias.value().data_ptr() : NULL; + auto& params = kernel.rhs_pack_params; + + kernel.kai_run_rhs_pack( + /*num_groups=*/1, + n, + k, + nr, + kr, + sr, + (const uint8_t*)(weight_data), + (const float*)(bias_ptr), + (const float*)(scales_data), + weight_packed_data, + 0, + ¶ms); +} + +} // namespace at::native::kleidiai + +#endif diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/kleidiai/kai_ukernel_interface.h b/phivenv/Lib/site-packages/torch/include/ATen/native/kleidiai/kai_ukernel_interface.h new file mode 100644 index 0000000000000000000000000000000000000000..f8a8bd421c0c447ddf5f182ca7f0763f1cb5220b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/kleidiai/kai_ukernel_interface.h @@ -0,0 +1,144 @@ +#pragma once +#include +#if AT_KLEIDIAI_ENABLED() +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native::kleidiai { + +enum class kai_kernel_id { + matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod = + 0, // Groupwise 4 bit GEMV + matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm = + 1, // Groupwise 4 bit GEMM + matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod = + 2, // Channelwise 4 bit GEMV + matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm = + 3 // Channelwise 4 bit GEMM +}; + +// Channelwise Kernel mapping +struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp { + struct kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel ukernel; + struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params rhs_pack_params; + size_t (*kai_get_lhs_packed_size)( + size_t m, + size_t k, + size_t mr, + size_t kr, + size_t sr); + size_t (*kai_get_rhs_packed_size)( + size_t n, + size_t k, + size_t nr, + size_t kr, + size_t sr); + void (*kai_run_lhs_quant_pack)( + size_t m, + size_t k, + size_t mr, + size_t kr, + size_t sr, + size_t m_idx_start, + const float* lhs, + size_t lhs_stride, + void* lhs_packed); + void (*kai_run_rhs_pack)( + size_t num_groups, + size_t n, + size_t k, + size_t nr, + size_t kr, + size_t sr, + const uint8_t* rhs, + const float* bias, + const float* scale, + void* rhs_packed, + size_t extra_bytes, + const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params); + + kai_matmul_ukernel_f32_qa8dxp_qs4cxp( + const kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel& kernel) + : ukernel(kernel), + kai_get_lhs_packed_size( + &kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32), + kai_get_rhs_packed_size( + &kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0), + kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32), + kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0) {} +}; + +struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp +kai_select_channelwise_matmul_ukernel(const kai_kernel_id id); + +// Groupwise Kernel mapping +struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p { + struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel ukernel; + struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params rhs_pack_params; + size_t (*kai_get_lhs_packed_size)( + size_t m, + size_t k, + size_t mr, + size_t kr, + size_t sr); + size_t (*kai_get_rhs_packed_size)( + size_t n, + size_t k, + size_t nr, + size_t kr, + size_t sr, + size_t bl, + enum kai_datatype scale_dt); + void (*kai_run_lhs_quant_pack)( + size_t m, + size_t k, + size_t mr, + size_t kr, + size_t sr, + size_t m_idx_start, + const float* lhs, + size_t lhs_stride, + void* lhs_packed); + void (*kai_run_rhs_pack)( + size_t num_groups, + size_t n, + size_t k, + size_t nr, + size_t kr, + size_t sr, + size_t bl, + const uint8_t* rhs, + size_t rhs_stride, + const float* bias, + const void* scale, + size_t scale_stride, + void* rhs_packed, + size_t extra_bytes, + const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params); + + kai_matmul_ukernel_f32_qa8dxp_qs4c32p( + const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& kernel) + : ukernel(kernel), + kai_get_lhs_packed_size( + &kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32), + kai_get_rhs_packed_size( + &kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0), + kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32), + kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0) {} +}; + +struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p kai_select_groupwise_matmul_ukernel( + const kai_kernel_id id); + +} // namespace at::native::kleidiai +#endif diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/layer_norm.h b/phivenv/Lib/site-packages/torch/include/ATen/native/layer_norm.h new file mode 100644 index 0000000000000000000000000000000000000000..abd2cc0e06d2bda75cab626e0cce81f628c005e0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/layer_norm.h @@ -0,0 +1,141 @@ +#pragma once + +#include +#include +#include + +namespace at::native { + +namespace { + +C10_ALWAYS_INLINE void _check_rms_norm_inputs_symint( + const Tensor& input, + c10::SymIntArrayRef normalized_shape, + const Tensor& weight /* optional */) { + + const int normalized_ndim = normalized_shape.size(); + TORCH_CHECK( + normalized_ndim >= 1, + "Expected normalized_shape to be at least 1-dimensional, i.e., ", + "containing at least one element, but got normalized_shape = ", + normalized_shape); + TORCH_CHECK( + !weight.defined() || weight.sym_sizes().equals(normalized_shape), + "Expected weight to be of same shape as normalized_shape, but got ", + "weight of shape ", + weight.sym_sizes(), + " and normalized_shape = ", + normalized_shape); + + const auto input_ndim = input.dim(); + const auto input_shape = input.sym_sizes(); + if (input_ndim < normalized_ndim || + !input_shape.slice(input_ndim - normalized_ndim) + .equals(normalized_shape)) { + std::stringstream ss; + ss << "Given normalized_shape=" << normalized_shape + << ", expected input with shape [*"; + for (auto size : normalized_shape) { + ss << ", " << size; + } + ss << "], but got input of size" << input_shape; + TORCH_CHECK(false, ss.str()); + } +} + +C10_ALWAYS_INLINE std::pair _check_layer_norm_inputs( + const Tensor& input, + IntArrayRef normalized_shape, + const Tensor& weight /* optional */, + const Tensor& bias /* optional */) { + + const int normalized_ndim = normalized_shape.size(); + TORCH_CHECK( + normalized_ndim >= 1, + "Expected normalized_shape to be at least 1-dimensional, i.e., ", + "containing at least one element, but got normalized_shape = ", + normalized_shape); + TORCH_CHECK( + !weight.defined() || weight.sizes().equals(normalized_shape), + "Expected weight to be of same shape as normalized_shape, but got ", + "weight of shape ", + weight.sizes(), + " and normalized_shape = ", + normalized_shape); + TORCH_CHECK( + !bias.defined() || bias.sizes().equals(normalized_shape), + "Expected bias to be of same shape as normalized_shape, but got ", + "bias of shape ", + bias.sizes(), + " and normalized_shape = ", + normalized_shape); + + const auto input_shape = input.sizes(); + const auto input_ndim = input.dim(); + + if (input_ndim < normalized_ndim || + !input_shape.slice(input_ndim - normalized_ndim) + .equals(normalized_shape)) { + std::stringstream ss; + ss << "Given normalized_shape=" << normalized_shape + << ", expected input with shape [*"; + for (auto size : normalized_shape) { + ss << ", " << size; + } + ss << "], but got input of size" << input_shape; + TORCH_CHECK(false, ss.str()); + } + + const int axis = input_ndim - normalized_ndim; + const int64_t M = + c10::multiply_integers(input_shape.cbegin(), input_shape.cbegin() + axis); + const int64_t N = + c10::multiply_integers(input_shape.cbegin() + axis, input_shape.cend()); + + return std::make_pair(M, N); +} + +} // namespace + +void layer_norm_cpu_out( + at::Tensor& out, + const at::Tensor& input, + const Tensor& gamma, + const Tensor& beta, + double eps, + int64_t M, + int64_t N); + +Tensor rms_norm_symint( + const Tensor& input, + c10::SymIntArrayRef normalized_shape, + const std::optional& weight_opt /* optional */, + std::optional eps); + +using forward_fn = void (*)( + const Tensor& /* X */, + const Tensor& /* gamma */, + const Tensor& /* beta */, + int64_t /* M */, + int64_t /* N */, + double /* eps */, + Tensor* /* Y */, + Tensor* /* mean */, + Tensor* /* rstd */); + +using backward_fn = void (*)( + const Tensor& /* dY */, + const Tensor& /* X */, + const Tensor& /* mean */, + const Tensor& /* rstd */, + const Tensor& /* gamma */, + int64_t /* M */, + int64_t /* N */, + Tensor* /* dX */, + Tensor* /* dgamma */, + Tensor* /* dbeta */); + +DECLARE_DISPATCH(forward_fn, LayerNormKernel) +DECLARE_DISPATCH(backward_fn, LayerNormBackwardKernel) + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/mkldnn/xpu/Conv.h b/phivenv/Lib/site-packages/torch/include/ATen/native/mkldnn/xpu/Conv.h new file mode 100644 index 0000000000000000000000000000000000000000..81322ab2cf26b11faf149756aa03c502f9a66407 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/mkldnn/xpu/Conv.h @@ -0,0 +1,54 @@ +#pragma once + +#include +#include +#include + +#if AT_MKLDNN_ENABLED() + +namespace at::native::xpu { +C10_API Tensor convolution_pointwise( + const Tensor& input_t, + const Tensor& weight_t, + const std::optional& bias_opt, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + std::string_view attr, + torch::List> scalars, + std::optional algorithm); + +C10_API Tensor convolution_pointwise_binary( + const Tensor& input_t, + const Tensor& other_t, + const Tensor& weight_t, + const std::optional& bias_opt, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + std::string_view binary_attr, + std::optional alpha, + std::optional unary_attr, + torch::List> unary_scalars, + std::optional unary_algorithm); + +C10_API Tensor& convolution_pointwise_binary_( + Tensor& other_t, + const Tensor& input_t, + const Tensor& weight_t, + const std::optional& bias_opt, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + std::string_view binary_attr, + std::optional alpha, + std::optional unary_attr, + torch::List> unary_scalars, + std::optional unary_algorithm); + +} // namespace at::native::xpu + +#endif // AT_MKLDNN_ENABLED() diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/mkldnn/xpu/FusionUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/mkldnn/xpu/FusionUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..5b92cb4b6b69e0452392eb6a768c520510820007 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/mkldnn/xpu/FusionUtils.h @@ -0,0 +1,53 @@ +#pragma once +#include + +// +// This header file provides utility functions for constructing and managing +// oneDNN attributes used in fusion operations on XPU devices. These utilities +// include functions for creating unary and binary post-operations attributes, +// as well as mapping string representations of operations to oneDNN attributes. +// + +namespace at::native::xpu { +at::native::onednn::Attr& unary_attr_with_arg( + onednn::Attr& attr, + std::string_view unary, + torch::List> scalars, + std::optional algorithm); + +at::native::onednn::Attr& string_to_unary_attr( + onednn::Attr& attr, + std::string_view unary); + +at::native::onednn::Attr& construct_unary_attr( + onednn::Attr& attr, + std::string_view unary, + torch::List> scalars, + std::optional algorithm); + +template +onednn::Attr& construct_binary_attr( + onednn::Attr& attr, + std::string_view binary, + const Tensor& other) { + if (binary == "mul") { + attr.append_post_binary(attr.kind_with_binary_mul, other); + } else if (binary == "sub") { + attr.append_post_binary(attr.kind_with_binary_sub, other); + } else if (binary == "div") { + attr.append_post_binary(attr.kind_with_binary_div, other); + } else if (binary == "add") { + attr.append_post_binary(attr.kind_with_binary_add, other); + } else if (binary == "sum") { + attr.append_post_sum(1.f, 1.f, 0); + } else { + TORCH_CHECK( + binary == "none", + "Binary attr ", + binary, + "is not supported for conv/linear post binary fusion"); + } + return attr; +} + +} // namespace at::native::xpu diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/mkldnn/xpu/detail/Attr.h b/phivenv/Lib/site-packages/torch/include/ATen/native/mkldnn/xpu/detail/Attr.h new file mode 100644 index 0000000000000000000000000000000000000000..692fb9555c887bb4885fe7f4830cd055f66e034a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/mkldnn/xpu/detail/Attr.h @@ -0,0 +1,463 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace at::native::onednn { +/* oneDNN quantization usage: + https://oneapi-src.github.io/oneDNN/dev_guide_attributes_quantization.html# + + src_fp32 = scale_src * (src_int8 - zero_point) + wei_fp32 = scale_wei * (wei_int8 - zero_point) + dst_fp32 = scale_dst * (dst_int8 - zero_point) + fp32 Convolution: dst_fp32 = src_fp32 * wei_fp32 + Int8 Convolution: dst_fp32 = (src_int8 * wei_int8) * (scale_src * scale_wei) + Int8 Convolution: dst_int8 = 1 / scale_dst * dst_fp32; + + Considering zero-point (asymmetric): + dst_fp32 = (src_int8 - src_zp) * src_sc * wei_int8 * wei_sc + dst_sc * (dst_int8 - dst_zp) = (src_int8 - src_zp) * wei_int8 * src_sc * + wei_sc + dst_int8 = (src_int8 - src_zp) * wei_int8 * src_sc * wei_sc / dst_sc + + dst_zp + + considering bias: + fp32 Convolution: dst_fp32 = src_fp32 * wei_fp32 + bias + Int8 Convolution: dst_fp32 = (src_int8 * wei_int8) * (scale_src * scale_wei) + + bias Int8 Convolution: dst_fp32 = (src_int8 * wei_int8 + bias/(scale_src * + scale_wei)) * (scale_src * scale_wei) Int8 Convolution: dst_int8 = 1 / + scale_dst * dst_fp32; +*/ + +/* + oneDNN postops usage: + Currently, oneDNN supports 5 kinds of post ops. More details can be refered +to oneDNN doc. + https://oneapi-src.github.io/oneDNN/dev_guide_attributes_post_ops.html#doxid-dev-guide-attributes-post-ops-1dev-guide-attributes-post-ops-eltwise + +0. without post ops + dst = Conv(src, wei) + bias; + dst_int8 = 1/q_scale * dst; q_scale is the op output quantization scale + fp32 API: Attr attr; + int8 API: Attr attr(q_scale); + +1. append eltwise post op + dst = elt_scale * Eltwise{conv_scale * [Conv(src, wei) + bias], alpha, beta} + dst_int8 = 1/q_scale * dst; + fp32 API: + Attr attr; + attr.append_post_eltwise(1.f, conv_scale, 0.f, kind_with_linear) + attr.append_post_eltwise(elt_scale, alpha, beta, eltwise_algorithm) + int8 API: + Attr attr(q_scale); + attr.append_post_eltwise(1.f, conv_scale, 0.f, kind_with_linear) + attr.append_post_eltwise(elt_scale, alpha, beta, eltwise_algorithm) + +2. append sum post op + dst = conv_scale * Conv(src, wei) + sum_scale * (dst - zp) + dst_int8 = 1/q_scale * dst; + fp32 API: + Attr attr; + attr.append_post_eltwise(1.f, conv_scale, 0.f, kind_with_linear) + attr.append_post_sum(sum_scale) + int8 API: + Attr attr(q_scale); + attr.append_post_eltwise(1.f, conv_scale, 0.f, kind_with_linear) + attr.append_post_sum(sum_scale) + +3. append binary post op + dst = Binary[Conv(src, wei)] + +*/ +using kind_t = dnnl::primitive::kind; +struct PostOpParam { + // eltwise post op constructor + PostOpParam( + float scale, + float alpha, + float beta, + dnnl::algorithm algo, + kind_t kind) + : scale_(scale), alpha_(alpha), beta_(beta), algo_(algo), kind_(kind) {} + // sum post op constructor + PostOpParam(float scale, kind_t kind) : scale_(scale), kind_(kind) {} + // sum post op with zp + PostOpParam(float scale, int64_t zero_point, kind_t kind) + : scale_(scale), zero_point_(zero_point), kind_(kind) {} + // binary post op constructor + PostOpParam( + at::Tensor& binary, + dnnl::memory::desc& binary_md, + dnnl::memory::desc& expected_md, + dnnl::algorithm algo, + kind_t kind) + : binary_(binary), + meta_(binary_md), + expected_meta_(expected_md), + algo_(algo), + kind_(kind) {} + // prelu post op constructor + PostOpParam(int mask, kind_t kind) : mask_(mask), kind_(kind) {} + + // post sum or binary with scale post op constructor + PostOpParam( + at::Tensor& binary, + float scale, + dnnl::algorithm algo, + kind_t kind) + : scale_(scale), binary_(binary), algo_(algo), kind_(kind) {} + + // for int8 sum/eltwise + float scale_ = 1.0; + int64_t zero_point_ = 0; + // for eltwise + float alpha_ = 0.0; + float beta_ = 0.0; + // for binary + at::Tensor binary_ = at::Tensor(); + at::Tensor expected_binary_ = at::Tensor(); + void* binary_ptr_ = nullptr; + dnnl::memory::desc meta_ = dnnl::memory::desc(); + dnnl::memory::desc expected_meta_ = dnnl::memory::desc(); + // for prelu + int mask_ = 0; + // common + dnnl::algorithm algo_ = dnnl::algorithm::eltwise_relu; + kind_t kind_ = kind_t::eltwise; +}; + +class Attr { + public: + Attr() : q_scale_(1.f) {} + Attr(float q_scale, int64_t zp = 0) : q_scale_(q_scale), q_zero_point_(zp) {} + + /***** eltwise *****/ + dnnl::algorithm kind_with_relu = dnnl::algorithm::eltwise_relu; + dnnl::algorithm kind_with_sigmoid = dnnl::algorithm::eltwise_logistic; + dnnl::algorithm kind_with_gelu_tanh = dnnl::algorithm::eltwise_gelu_tanh; + dnnl::algorithm kind_with_gelu_erf = dnnl::algorithm::eltwise_gelu_erf; + dnnl::algorithm kind_with_mish = dnnl::algorithm::eltwise_mish; + dnnl::algorithm kind_with_linear = dnnl::algorithm::eltwise_linear; + dnnl::algorithm kind_with_swish = dnnl::algorithm::eltwise_swish; + dnnl::algorithm kind_with_sqrt = dnnl::algorithm::eltwise_sqrt; + dnnl::algorithm kind_with_tanh = dnnl::algorithm::eltwise_tanh; + dnnl::algorithm kind_with_square = dnnl::algorithm::eltwise_square; + dnnl::algorithm kind_with_abs = dnnl::algorithm::eltwise_abs; + dnnl::algorithm kind_with_exp = dnnl::algorithm::eltwise_exp; + dnnl::algorithm kind_with_log = dnnl::algorithm::eltwise_log; + dnnl::algorithm kind_with_round = dnnl::algorithm::eltwise_round; + dnnl::algorithm kind_with_hardswish = dnnl::algorithm::eltwise_hardswish; + dnnl::algorithm kind_with_soft_relu = dnnl::algorithm::eltwise_soft_relu; + dnnl::algorithm kind_with_elu = dnnl::algorithm::eltwise_elu; + dnnl::algorithm kind_with_pow = dnnl::algorithm::eltwise_pow; + dnnl::algorithm kind_with_clip = dnnl::algorithm::eltwise_clip; + // note: hardsigmoid seems oneDNN still not support + dnnl::algorithm kind_with_hardsigmoid = dnnl::algorithm::eltwise_hardsigmoid; + + /***** binary *****/ + dnnl::algorithm kind_with_binary_mul = dnnl::algorithm::binary_mul; + dnnl::algorithm kind_with_binary_add = dnnl::algorithm::binary_add; + dnnl::algorithm kind_with_binary_sub = dnnl::algorithm::binary_sub; + dnnl::algorithm kind_with_binary_div = dnnl::algorithm::binary_div; + dnnl::algorithm kind_with_binary_eq = dnnl::algorithm::binary_eq; + dnnl::algorithm kind_with_binary_ne = dnnl::algorithm::binary_ne; + dnnl::algorithm kind_with_binary_ge = dnnl::algorithm::binary_ge; + dnnl::algorithm kind_with_binary_gt = dnnl::algorithm::binary_gt; + dnnl::algorithm kind_with_binary_le = dnnl::algorithm::binary_le; + dnnl::algorithm kind_with_binary_lt = dnnl::algorithm::binary_lt; + dnnl::algorithm kind_with_binary_max = dnnl::algorithm::binary_max; + dnnl::algorithm kind_with_binary_min = dnnl::algorithm::binary_min; + + // append sum post op + Attr& append_post_sum( + float sum_scale, + float sum_q_scale = 1.f, + int64_t zp = 0) { + ops_params_.push_back( + PostOpParam(/*scale_sum*/ sum_scale * sum_q_scale, zp, kind_t::sum)); + return *this; + } + + // append eltwise post op + Attr& append_post_eltwise( + float scale, + float alpha, + float beta, + dnnl::algorithm algo) { + ops_params_.push_back( + PostOpParam(scale, alpha, beta, algo, kind_t::eltwise)); + return *this; + } + + // append binary post op + template + Attr& append_post_binary(dnnl::algorithm algo, const at::Tensor& binary) { + auto binary_ = binary.is_quantized() ? at::dequantize(binary) : binary; + bool binary_is_channels_last = + (binary_.suggest_memory_format() == at::MemoryFormat::ChannelsLast || + binary_.suggest_memory_format() == at::MemoryFormat::ChannelsLast3d); + + if constexpr (!is_matmul) { + binary_ = binary_is_channels_last ? binary_ : binary_.contiguous(); + } + dnnl::memory::desc md = get_onednn_md(binary_); + auto expected_md = dnnl::memory::desc( + md.get_dims(), md.get_data_type(), dnnl::memory::format_tag::any); + if constexpr (is_matmul) { + ops_params_.push_back(PostOpParam(binary_, md, md, algo, kind_t::binary)); + } else { + ops_params_.push_back( + PostOpParam(binary_, md, expected_md, algo, kind_t::binary)); + } + + return *this; + } + + Attr& append_scale_binary( + dnnl::algorithm algo, + at::Tensor binary, + float scale, + float sum_q_scale = 1.f, + int64_t zp = 0) { + ops_params_.push_back(PostOpParam( + binary, /*scale_sum*/ scale * sum_q_scale, algo, kind_t::binary)); + return *this; + } + + // append bias with binary_add method (only used for QConv now) + Attr& append_bias(const at::Tensor& binary, const int ndimension) { + // In PyTorch, bias are in shape of [OC], + // we expand its shape according to Conv dimension + // Conv1d [OC, 1, 1], Conv2d [1, OC, 1, ,1], Conv3d [1, OC, 1, 1, 1] + at::Tensor binary_ = binary.contiguous(); + dnnl::memory::desc binary_md; + switch (ndimension) { + case 1: + binary_md = dnnl::memory::desc( + {binary.size(0), 1, 1}, + dnnl::memory::data_type::f32, + dnnl::memory::format_tag::abc); + break; + case 2: + binary_md = dnnl::memory::desc( + {1, binary.size(0), 1, 1}, + dnnl::memory::data_type::f32, + dnnl::memory::format_tag::abcd); + break; + case 3: + binary_md = dnnl::memory::desc( + {1, binary.size(0), 1, 1, 1}, + dnnl::memory::data_type::f32, + dnnl::memory::format_tag::abcde); + break; + default: + TORCH_INTERNAL_ASSERT( + 0, "XPU only supports append_bias for Conv1d, Conv2d and Conv3d."); + } + // In this case, expected_md = binary_md + ops_params_.push_back(PostOpParam( + binary_, binary_md, binary_md, kind_with_binary_add, kind_t::binary)); + return *this; + } + + // append prelu post op + Attr& append_post_prelu(int mask) { + ops_params_.push_back(PostOpParam(mask, kind_t::prelu)); + return *this; + } + + dnnl::post_ops extract_post_ops(const at::Tensor& dst) { + // this function is used to extract post ops params from the ops_params_ + // and put them into onednn post ops + for (size_t i = 0; i < ops_params_.size(); ++i) { + kind_t kind = ops_params_[i].kind_; + switch (kind) { + case kind_t::eltwise: { + dnnl::algorithm algo = ops_params_[i].algo_; + float alpha = ops_params_[i].alpha_; + float beta = ops_params_[i].beta_; + dnnl_post_ops_.append_eltwise(algo, alpha, beta); + break; + } + case kind_t::sum: { + float scale = ops_params_[i].scale_; + int64_t zero_point = ops_params_[i].zero_point_; + // TODO [Asymmetric]: + // Post-sum zp for gpu is not supported currently + dnnl_post_ops_.append_sum(scale, zero_point); + break; + } + case kind_t::binary: { + dnnl::algorithm algo = ops_params_[i].algo_; + auto expected_md = ops_params_[i].expected_meta_; + // In this case user may create src1 memory descriptor with + // format_tag::any or set a specific tag. However, in later case if + // tags mismatch with dst, it would result in suboptimal performance. + // So here we use format_tag::any to make sure the fast can be + // selected. + // Thus we use expected_md (with format_any) here to create pd instead + // of original md + dnnl_post_ops_.append_binary(algo, expected_md); + break; + } + default: + break; + } + } + + return dnnl_post_ops_; + } + + bool with_sum() { + for (size_t i = 0; i < ops_params_.size(); ++i) { + if (ops_params_[i].kind_ == kind_t::sum) { + return true; + } + } + return false; + } + + bool with_binary() { + for (size_t i = 0; i < ops_params_.size(); ++i) { + if (ops_params_[i].kind_ == kind_t::binary) { + return true; + } + } + return false; + } + + void construct_post_binary( + dnnl::primitive_desc& pd, + std::unordered_map& args) { + // This function is used to construct binary memory desc in binary post ops. + // According to oneDNN doc, the binary tensor can be in shape of + // [1, 1, 1, 1], tensor broadcast + // [1, C, 1, 1], channel broadcast + // [dst.shape], no broadcast and eltwise-wise binary operations on dst + + auto& engine = GpuEngineManager::Instance().get_engine(); + for (size_t i = 0; i < ops_params_.size(); ++i) { + kind_t kind = ops_params_[i].kind_; + if (kind == kind_t::binary) { + dnnl::memory binary_m; + auto binary = ops_params_[i].binary_; + auto md = ops_params_[i].meta_; + // qeury expected_md to achieve peak performance + auto expected_md = pd.query_md( + dnnl::query::exec_arg_md, + DNNL_ARG_ATTR_MULTIPLE_POST_OP(i) | DNNL_ARG_SRC_1); + + binary_m = at::native::onednn::make_onednn_memory( + md, engine, binary.data_ptr()); + + args.insert( + {DNNL_ARG_ATTR_MULTIPLE_POST_OP(i) | DNNL_ARG_SRC_1, binary_m}); + } + } + } + + float q_scale_ = 1.0; // the scale used to quantize the fused result from fp32 + // to int8, only works for int8 case + int64_t q_zero_point_ = 0; + std::vector ops_params_; // series of post ops + dnnl::post_ops dnnl_post_ops_; +}; + +static inline void construct_attr_for_unary( + const std::string_view& unary_post_op, + const torch::List>& unary_post_op_args, + const std::string_view& unary_post_op_algorithm, + at::native::onednn::Attr& attr) { + if (unary_post_op == "relu") { + attr = attr.append_post_eltwise( + /* eltwise_scale */ 1.f, + /* alpha */ 0.f, + /* beta */ 0.f, + attr.kind_with_relu); + } else if (unary_post_op == "leaky_relu") { + auto alpha = unary_post_op_args[0].value().to(); + attr = attr.append_post_eltwise(1.0, alpha, 0.f, attr.kind_with_relu); + } else if (unary_post_op == "tanh") { + attr = attr.append_post_eltwise(1.0f, 0.0f, 0.0f, attr.kind_with_tanh); + } else if (unary_post_op == "gelu") { + auto post_algorithm = unary_post_op_algorithm == "none" + ? attr.kind_with_gelu_erf + : attr.kind_with_gelu_tanh; + attr = attr.append_post_eltwise(1.0f, 0.0f, 0.0f, post_algorithm); + } else if (unary_post_op == "hardtanh") { + auto alpha = unary_post_op_args[0].value().to(); + auto beta = unary_post_op_args[1].value().to(); + attr = attr.append_post_eltwise(1.0, alpha, beta, attr.kind_with_clip); + } else if (unary_post_op == "hardswish") { + attr = attr.append_post_eltwise( + 1.0f, 1.f / 6.f, 1.f / 2.f, attr.kind_with_hardswish); + } else if (unary_post_op == "swish") { + attr = attr.append_post_eltwise(1.0f, 1.0f, 0.0f, attr.kind_with_swish); + } else { + TORCH_CHECK( + unary_post_op == "none", + "onednn qlinear: unspported unary post op", + unary_post_op); + } +} + +static inline void construct_attr_by_post_op( + const std::string_view& binary_post_op, + double binary_alpha, + double input1_scale, + int64_t input1_zero_point, + std::optional accum, + const std::string_view& unary_post_op, + const torch::List>& unary_post_op_args, + const std::string_view& unary_post_op_algorithm, + at::native::onednn::Attr& attr) { + bool is_none_post_op = + (binary_post_op == "none" && unary_post_op == "none"); // not post-ops + bool is_unary_post_op_only = + (binary_post_op == "none" && unary_post_op != "none"); // ex., conv + relu + bool is_valid_binary_combination = + (binary_post_op == "add" || binary_post_op == "sum") && + (unary_post_op == "none" || unary_post_op == "relu"); + TORCH_INTERNAL_ASSERT( + is_unary_post_op_only || is_none_post_op || is_valid_binary_combination, + "Please provide valid combination of unary post operators and binary post operators"); + + if (binary_post_op == "none") { + construct_attr_for_unary( + unary_post_op, unary_post_op_args, unary_post_op_algorithm, attr); + } else if (binary_post_op == "sum") { + if (unary_post_op == "none") { + if (input1_zero_point != 0) + attr = attr.append_post_eltwise( + /*scale*/ 1.f, + /*alpha*/ 1.f, + -input1_zero_point * input1_scale, + attr.kind_with_linear); + attr = attr.append_post_sum(1, input1_scale, /*input1_zero_point*/ 0); + } else if (unary_post_op == "relu") { + if (input1_zero_point != 0) + attr = attr.append_post_eltwise( + /*scale*/ 1.f, + /*alpha*/ 1.f, + -input1_zero_point * input1_scale, + attr.kind_with_linear); + attr = attr.append_post_sum(1, input1_scale, /*input1_zero_point*/ 0); + attr = attr.append_post_eltwise( + /* scale */ 1.f, + /* alpha */ 0.f, + /* beta */ 0.f, + attr.kind_with_relu); + } + } else if (binary_post_op == "add") { + TORCH_CHECK(accum.has_value()); + attr = attr.append_post_binary(attr.kind_with_binary_add, accum.value()); + if (unary_post_op == "relu") { + attr = attr.append_post_eltwise(1.f, 0.f, 0.f, attr.kind_with_relu); + } + } +} + +} // namespace at::native::onednn diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/mkldnn/xpu/detail/DnnlExt.h b/phivenv/Lib/site-packages/torch/include/ATen/native/mkldnn/xpu/detail/DnnlExt.h new file mode 100644 index 0000000000000000000000000000000000000000..7731144f1e212738c91203055813cf694095ca23 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/mkldnn/xpu/detail/DnnlExt.h @@ -0,0 +1,594 @@ +#pragma once + +#include + +#include +#include +#include + +#include +#include + +namespace std { + +template <> +struct hash { + size_t operator()(dnnl::memory::dims const& vec) const { + size_t seed = vec.size(); + for (auto& i : vec) { + seed ^= i + 0x9e3779b9 + (seed << 6) + (seed >> 2); + } + return seed; + } +}; + +} // namespace std + +using namespace dnnl; + +namespace at::native::onednn { + +class primitive_ext : public primitive { + static constexpr int max_args = 12; + + public: + primitive_ext(const primitive& base) : primitive(base) {} + primitive_ext(primitive&& base) : primitive(std::move(base)) {} + + /// Returns a memory descriptor. + /// + /// @note + /// There are also convenience methods + /// #dnnl::primitive_desc_base::src_desc(), + /// #dnnl::primitive_desc_base::dst_desc(), and others. + /// + /// @param what The kind of parameter to query; can be + /// #dnnl::query::src_md, #dnnl::query::dst_md, etc. + /// @param idx Index of the parameter. For example, convolution bias can + /// be queried with what = #dnnl::query::weights_md and idx = 1. + /// @returns The requested memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// parameter of the specified kind or index. + const_dnnl_memory_desc_t query_md(query what, int idx = 0) const { + std::vector valid_q{ + query::src_md, + query::diff_src_md, + query::weights_md, + query::diff_weights_md, + query::dst_md, + query::diff_dst_md, + query::workspace_md, + query::scratchpad_md, + query::exec_arg_md}; + if (!std::any_of(valid_q.cbegin(), valid_q.cend(), [=](query q) { + return what == q; + })) + DNNL_THROW_ERROR( + dnnl_invalid_arguments, "memory descriptor query is invalid"); + + const_dnnl_memory_desc_t cdesc = dnnl_primitive_desc_query_md( + this->get_primitive_desc(), dnnl::convert_to_c(what), idx); + + return cdesc ? cdesc : nullptr; + } + + /// Returns a source memory descriptor. + /// @param idx Source index. + /// @returns Source memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// source parameter with index @p idx. + const_dnnl_memory_desc_t src_desc(int idx) const { + return query_md(query::src_md, idx); + } + + /// Returns a destination memory descriptor. + /// @param idx Destination index. + /// @returns Destination memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// destination parameter with index @p idx. + const_dnnl_memory_desc_t dst_desc(int idx) const { + return query_md(query::dst_md, idx); + } + + /// Returns a weights memory descriptor. + /// @param idx Weights index. + /// @returns Weights memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// weights parameter with index @p idx. + const_dnnl_memory_desc_t weights_desc(int idx) const { + return query_md(query::weights_md, idx); + } + + /// Returns a diff source memory descriptor. + /// @param idx Diff source index. + /// @returns Diff source memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// diff source parameter with index @p idx. + const_dnnl_memory_desc_t diff_src_desc(int idx) const { + return query_md(query::diff_src_md, idx); + } + + /// Returns a diff destination memory descriptor. + /// @param idx Diff destination index. + /// @returns Diff destination memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// diff destination parameter with index @p idx. + const_dnnl_memory_desc_t diff_dst_desc(int idx) const { + return query_md(query::diff_dst_md, idx); + } + + /// Returns a diff weights memory descriptor. + /// @param idx Diff weights index. + /// @returns Diff weights memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// diff weights parameter with index @p idx. + const_dnnl_memory_desc_t diff_weights_desc(int idx) const { + return query_md(query::diff_weights_md, idx); + } + + const_dnnl_memory_desc_t exec_arg_desc(int idx) const { + return query_md(query::exec_arg_md, idx); + } + + // Separate versions without the index argument for documentation + // purposes. + + /// Returns a source memory descriptor. + /// @returns Source memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// source parameter. + const_dnnl_memory_desc_t src_desc() const { + return src_desc(0); + } + + /// Returns a destination memory descriptor. + /// @returns Destination memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// destination parameter. + const_dnnl_memory_desc_t dst_desc() const { + return dst_desc(0); + } + + /// Returns a weights memory descriptor. + /// @returns Weights memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// weights parameter. + const_dnnl_memory_desc_t weights_desc() const { + return weights_desc(0); + } + + /// Returns a diff source memory descriptor. + /// @returns Diff source memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// diff source memory with. + const_dnnl_memory_desc_t diff_src_desc() const { + return diff_src_desc(0); + } + + /// Returns a diff destination memory descriptor. + /// @returns Diff destination memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// diff destination parameter. + const_dnnl_memory_desc_t diff_dst_desc() const { + return diff_dst_desc(0); + } + + /// Returns a diff weights memory descriptor. + /// @returns Diff weights memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// diff weights parameter. + const_dnnl_memory_desc_t diff_weights_desc() const { + return diff_weights_desc(0); + } + + /// Returns the workspace memory descriptor. + /// @returns Workspace memory descriptor. + /// @returns A zero memory descriptor if the primitive does not require + /// workspace parameter. + const_dnnl_memory_desc_t workspace_desc() const { + return query_md(query::workspace_md, 0); + } + + /// Returns the scratchpad memory descriptor. + /// @returns scratchpad memory descriptor. + /// @returns A zero memory descriptor if the primitive does not require + /// scratchpad parameter. + /// @sa @ref dev_guide_attributes_scratchpad + const_dnnl_memory_desc_t scratchpad_desc() const { + return query_md(query::scratchpad_md, 0); + } + + inline memory make_memory( + const_dnnl_memory_desc_t md_t, + const engine& aengine, + void* handle = DNNL_MEMORY_ALLOCATE) const { + sycl_interop::memory_kind kind = dnnl::sycl_interop::memory_kind::usm; + dnnl_memory_t c_memory; + error::wrap_c_api( + dnnl_sycl_interop_memory_create( + &c_memory, md_t, aengine.get(), convert_to_c(kind), handle), + "could not create a memory"); + return memory(c_memory); + } + + memory make_src(const engine& aengine, void* handle = DNNL_MEMORY_ALLOCATE) + const { + return make_memory(src_desc(), aengine, handle); + } + + memory make_weight(const engine& aengine, void* handle = DNNL_MEMORY_ALLOCATE) + const { + return make_memory(weights_desc(), aengine, handle); + } + + memory make_bias(const engine& aengine, void* handle = DNNL_MEMORY_ALLOCATE) + const { + return make_memory(weights_desc(1), aengine, handle); + } + + memory make_dst(const engine& aengine, void* handle = DNNL_MEMORY_ALLOCATE) + const { + return make_memory(dst_desc(), aengine, handle); + } + + memory make_scratchpad( + const engine& aengine, + void* handle = DNNL_MEMORY_ALLOCATE) const { + return make_memory(scratchpad_desc(), aengine, handle); + } + + size_t get_scratchpad_size() const { + return dnnl_memory_desc_get_size(scratchpad_desc()); + } + + memory make_args(int arg_class, const engine& aengine, void* handle) const { + switch (arg_class) { + case DNNL_ARG_SRC: + return make_src(aengine, handle); + case DNNL_ARG_WEIGHTS: + return make_weight(aengine, handle); + case DNNL_ARG_SCRATCHPAD: + return make_scratchpad(aengine, handle); + case DNNL_ARG_DST: + return make_dst(aengine, handle); + case DNNL_ARG_BIAS: + return make_bias(aengine, handle); + default: + TORCH_INTERNAL_ASSERT( + false, "unsupported argument class for primitive_ext"); + } + } + + template + void set_attribute(int slot, int arg_class, void* handle, M constructor) { + if (mem_arg_cache[slot]) + mem_arg_cache[slot].set_data_handle(handle); + else { + mem_arg_cache[slot] = constructor(); + c_args[slot].arg = arg_class; + c_args[slot].memory = mem_arg_cache[slot].get(); + } + } + + sycl::event execute( + const stream& astream, + const engine& aengine, + std::vector>&& handles, + int slot_off = 2) { + auto off = slot_off; + for (const auto& p : handles) { + auto& m_arg = mem_arg_cache[off]; + if (m_arg) + m_arg.set_data_handle(p.second); + else { + m_arg = make_args(p.first, aengine, p.second); + c_args[off].arg = p.first; + c_args[off].memory = m_arg.get(); + } + ++off; + } + + sycl::event return_event; + std::vector deps{}; + error::wrap_c_api( + dnnl_sycl_interop_primitive_execute( + this->get(), astream.get(), off, c_args, &deps, &return_event), + "could not execute a primitive"); + return return_event; + } + + private: + memory mem_arg_cache[max_args]; + dnnl_exec_arg_t c_args[max_args]; +}; + +// Specifies the combined data types of input and weight tensors. +// For example, f32 means both input and weight are FP32, +// bf16_int4 means input is BF16 and weight is INT4. +enum class joint_dtypes_t { f32 = 0, f16, bf16, int8, f16_int4, bf16_int4 }; + +// Specifies the transposition state of input and weight tensors. +// Convention: first letter = input, second letter = weight. +// 'n' = not transposed, 't' = transposed. +// For example, 'nt' means input is not transposed, weight is transposed. +enum class trans_type_t { nn = 0, nt, tn, tt }; + +// Specifies the type and placement of bias in the computation. +// 'none' = no bias, +// 'scalar' = a single scalar bias applied to all elements, +// 'm' = per-row bias (typically matched to input rows), +// 'n' = per-column bias (typically matched to output channels), +// 'mn' = full bias matrix matching the output dimensions. +enum class bias_type_t { none = 0, scalar, m, n, mn }; + +template +T concat(const T& t1, at::ScalarType d) { + T t; + t.insert(t.end(), t1.begin(), t1.end()); + t.push_back((int64_t)d); + + return t; +} + +template +T concat(const T& t1, bool b) { + T t; + t.insert(t.end(), t1.begin(), t1.end()); + t.push_back(b); + + return t; +} + +template +T concat(const T& t1, int b) { + T t; + t.insert(t.end(), t1.begin(), t1.end()); + t.push_back(b); + + return t; +} + +template +T concat(const T& t1, const T& t2) { + T t; + t.insert(t.end(), t1.begin(), t1.end()); + t.insert(t.end(), t2.begin(), t2.end()); + + return t; +} + +template +T1 concat(const T1& t1, const T2& t2, const Ts&... ts) { + return concat(concat(t1, t2), ts...); +} + +template +struct onednn_types_mapper; + +template <> +struct onednn_types_mapper { + static inline std::tuple + get() { + return std::make_tuple( + dnnl::memory::data_type::f16, dnnl::memory::data_type::u4); + } +}; + +template <> +struct onednn_types_mapper { + static inline std::tuple + get() { + return std::make_tuple( + dnnl::memory::data_type::bf16, dnnl::memory::data_type::u4); + } +}; + +// TODO: bias types maybe not right +static inline dnnl::memory::dims get_bias_type( + bias_type_t b_dims, + const int m, + const int n) { + switch (b_dims) { + case bias_type_t::none: + return {0}; + case bias_type_t::scalar: + return {1, 1}; + case bias_type_t::m: + return {m, 1}; + case bias_type_t::n: + return {1, n}; + case bias_type_t::mn: + return {m, n}; + default: + TORCH_INTERNAL_ASSERT(false, "unsupported bias type ..."); + } +} + +// TODO: use template specialization on struct +template +inline void get_strides( + memory::dims& src_strides, + memory::dims& wei_strides, + memory::dims& dst_strides, + const int64_t lda, + const int64_t ldb, + const int64_t ldc) {} + +template <> +inline void get_strides( + memory::dims& src_strides, + memory::dims& wei_strides, + memory::dims& dst_strides, + const int64_t lda, + const int64_t ldb, + const int64_t ldc) { + src_strides = {lda, 1}; + wei_strides = {1, ldb}; + dst_strides = {ldc, 1}; +} + +using primitive_cache = + at::native::onednn::lru_cache; + +template +struct matmul_primitive_cache_t { + static inline primitive_ext& get( + const int m, + const int n, + const int k, + const int64_t lda, + const int64_t ldb, + const int64_t ldc, + const bias_type_t + b_dims, // for shapeless bias, not put it into template parameter + const int device_id, + F f_attr, + const int64_t scale_group_size, + const int64_t zp_group_size) { + auto& cached = get_cache(device_id); + memory::dims src_strides, wei_strides, dst_strides; + get_strides(src_strides, wei_strides, dst_strides, lda, ldb, ldc); + auto pri_key = at::native::onednn::concat( + src_strides, + wei_strides, + m, + n, + k, + int(b_dims), + int(scale_group_size), + int(zp_group_size)); + auto iter = cached.find(pri_key); + if (iter == cached.end()) { + auto [src_dt, wei_dt] = onednn_types_mapper::get(); + auto bias_dims = get_bias_type(b_dims, m, n); + + auto src_md = memory::desc({m, k}, src_dt, src_strides); + auto wei_md = memory::desc({k, n}, wei_dt, wei_strides); + auto dst_md = memory::desc({m, n}, src_dt, dst_strides); + auto bias_format = b_dims == bias_type_t::none + ? dnnl::memory::format_tag::undef + : dnnl::memory::format_tag::ab; + auto bias_md = + memory::desc(bias_dims, src_dt, bias_format); // {m, n} or {1, n} + + primitive_attr pattr; + f_attr(pattr); + + dnnl::matmul::primitive_desc matmul_pd; + auto aengine = + at::native::onednn::GpuEngineManager::Instance().get_engine( + device_id); + if (b_dims == bias_type_t::none) { + matmul_pd = dnnl::matmul::primitive_desc( + aengine, src_md, wei_md, dst_md, pattr); + } else { + matmul_pd = dnnl::matmul::primitive_desc( + aengine, src_md, wei_md, bias_md, dst_md, pattr); + } + + return cached.insert({pri_key, primitive_ext(dnnl::matmul(matmul_pd))}) + .first->second; + } else { + return iter->second; + } + } + + private: + static constexpr int max_cache_capacity = 512; + // if default constructor of primitive cache could read the environment + // variable then it'll save a lot of trouble + static inline thread_local std::array mappings; + + // this won't be needed if primitive_cache have good default constructor + static inline primitive_cache& get_cache(const int device_id) { + auto& mapping = mappings[device_id]; + if (mapping.max_size() == 0) { + mapping.resize(max_cache_capacity); + } + return mapping; + } +}; + +template +static inline primitive_ext& matmul_primitive_create_and_cache( + const trans_type_t Tt, + const bias_type_t b_dims, + const int m, + const int n, + const int k, + const int64_t lda, + const int64_t ldb, + const int64_t ldc, + const int device_id, + F attr, + const int64_t scale_group_size, + const int64_t zp_group_size) { + switch (Tt) { + case trans_type_t::nt: + return matmul_primitive_cache_t::get( + m, + n, + k, + lda, + ldb, + ldc, + b_dims, + device_id, + attr, + scale_group_size, + zp_group_size); + default: + TORCH_INTERNAL_ASSERT(false, "unsupported trans type ..."); + } +} + +template +static inline primitive_ext& matmul_primitive_create_and_cache( + const joint_dtypes_t Ts, + const trans_type_t Tt, + const bias_type_t b_dims, + const int m, + const int n, + const int k, + const int64_t lda, + const int64_t ldb, // is weight ldb necessary? + const int64_t ldc, + const int device_id, + F attr, + const int64_t scale_group_size = 0, + const int64_t zp_group_size = 0) { + switch (Ts) { + case joint_dtypes_t::f16_int4: + return matmul_primitive_create_and_cache( + Tt, + b_dims, + m, + n, + k, + lda, + ldb, + ldc, + device_id, + attr, + scale_group_size, + zp_group_size); + case joint_dtypes_t::bf16_int4: + return matmul_primitive_create_and_cache( + Tt, + b_dims, + m, + n, + k, + lda, + ldb, + ldc, + device_id, + attr, + scale_group_size, + zp_group_size); + default: + TORCH_INTERNAL_ASSERT(false, "Only support int4 ..."); + } +} + +} // namespace at::native::onednn diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/mkldnn/xpu/detail/LRUCache.h b/phivenv/Lib/site-packages/torch/include/ATen/native/mkldnn/xpu/detail/LRUCache.h new file mode 100644 index 0000000000000000000000000000000000000000..1ea76d0b069cc17d67c734d21f1938421c1ea90e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/mkldnn/xpu/detail/LRUCache.h @@ -0,0 +1,110 @@ +#pragma once + +#include +#include +#include +#include + +namespace at::native::onednn { + +template < + class key_t, + class value_t, + template class map_t = std::unordered_map> +class lru_cache { + public: + using value_type = std::pair; + using list_type = std::list; + using list_iter = typename list_type::iterator; + using map_type = map_t; + using const_list_iter = typename list_type::const_iterator; + using size_type = typename list_type::size_type; + + explicit lru_cache(size_type capacity) : capacity_(capacity) {} + lru_cache() : capacity_(0) {} + + [[nodiscard]] size_type size() const noexcept { + return map_.size(); + } + [[nodiscard]] size_type max_size() const noexcept { + return capacity_; + } + [[nodiscard]] bool empty() const noexcept { + return vlist_.empty(); + } + + void resize(size_type new_capacity) { + capacity_ = new_capacity; + trim(); + } + + list_iter begin() noexcept { + return vlist_.begin(); + } + const_list_iter begin() const noexcept { + return vlist_.begin(); + } + list_iter end() noexcept { + return vlist_.end(); + } + const_list_iter end() const noexcept { + return vlist_.end(); + } + + void clear() noexcept { + map_.clear(); + vlist_.clear(); + } + + void swap(lru_cache& other) noexcept { + using std::swap; + swap(vlist_, other.vlist_); + swap(map_, other.map_); + swap(capacity_, other.capacity_); + } + + list_iter find(const key_t& key) { + auto it = map_.find(key); + if (it == map_.end()) + return end(); + vlist_.splice(vlist_.begin(), vlist_, it->second); + return it->second; + } + + std::pair insert(const value_type& value) { + auto it = map_.find(value.first); + if (it != map_.end()) { + // Move existing to front + vlist_.splice(vlist_.begin(), vlist_, it->second); + return {it->second, false}; + } + + // Insert new at front + vlist_.emplace_front(value); + map_[value.first] = vlist_.begin(); + + trim(); + + return {vlist_.begin(), true}; + } + + list_iter erase(list_iter pos) { + map_.erase(pos->first); + return vlist_.erase(pos); + } + + private: + void trim() { + while (map_.size() > capacity_) { + auto last = std::prev(vlist_.end()); + map_.erase(last->first); + vlist_.pop_back(); + } + } + + list_type vlist_; + map_type map_; + size_type capacity_; +}; + +} // namespace at::native::onednn diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/mkldnn/xpu/detail/Utils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/mkldnn/xpu/detail/Utils.h new file mode 100644 index 0000000000000000000000000000000000000000..f11915125963b1353bf337d9bb23eee788b3230e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/mkldnn/xpu/detail/Utils.h @@ -0,0 +1,134 @@ +#pragma once +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include + +#define ONEDNN_SUPPORT_DETERMINISTIC \ + (DNNL_VERSION_MAJOR >= 3 && DNNL_VERSION_MINOR >= 4) + +namespace at::native::onednn { + +dnnl::memory::format_tag get_dnnl_default_format( + int ndims, + bool is_channels_last = false, + bool allow_undef = false); + +dnnl::memory::data_type get_onednn_dtype( + const at::Tensor& tensor, + bool allow_undef = false); + +dnnl::memory::data_type get_onednn_dtype_include_double( + const at::Tensor& tensor, + bool allow_undef = false); + +bool is_supported_onednn_dtype(const at::Tensor& tensor); + +dnnl::memory::dims get_onednn_dims(const at::Tensor& tensor); + +dnnl::memory::dims get_onednn_strides(const at::Tensor& tensor); +dnnl::memory::desc get_onednn_md(const at::Tensor& tensor); + +bool onednn_strides_check(const at::Tensor& src); +bool is_broadcast(const at::Tensor& t); +void undo_broadcast_on_batch(at::Tensor& m1, at::Tensor& m2); +void undo_broadcast(at::Tensor& tensor); + +bool is_onednn_matmul_strides(const at::Tensor& tensor); + +bool is_broadcast_from_other_to_self( + const at::Tensor& self, + const at::Tensor& other); + +at::MemoryFormat get_cl_tag_by_ndim(const int64_t ndim); + +void apply_tf32_if_allowed(dnnl::primitive_attr& primitive_attr); + +bool binary_valid( + const at::Tensor& self, + const at::Tensor& other, + bool is_fusion = false); + +bool use_channels_last_for_conv( + const at::Tensor& src, + const at::Tensor& weight); + +dnnl::memory::format_tag conv_src_fmt( + const int64_t ndim, + const bool is_channels_last = false); + +dnnl::memory::dims compatible_weight_dims( + const int64_t ndim, + const int64_t groups, + const int64_t oc, + const int64_t ic, + const IntArrayRef wsizes); + +dnnl::memory::format_tag conv_weight_fmt( + const int64_t ndim, + const bool grouped = false, + const bool is_channels_last = false); + +template +dnnl::memory::dims compatible_dilation(Vec&& dilation) { + dnnl::memory::dims ret = dilation.vec(); + for (auto it = ret.begin(); it != ret.end(); it++) { + *it -= 1; + } + return ret; +} + +template +dnnl::memory dnnl_memory_from_host_scalar( + T host_value, + Tensor& holder, + dnnl::engine& engine) { + auto options = at::TensorOptions() + .dtype(c10::CppTypeToScalarType::value) + .device(kXPU); + holder = at::empty({1}, options).fill_(host_value); + dnnl::memory::desc md = get_onednn_md(holder); + dnnl::memory mem = make_onednn_memory(md, engine, holder.data_ptr()); + return mem; +} + +struct PartitionCache { + std::unordered_map, dnnl::graph::partition> partition_map_{}; + + // The first 8 bits are reserved + // bit 0: is int8 + // bit 1: is uint8 + // bit 2: fp16(0) / bf16(1) + // bit 3: is fp32 + // bit 4: is sdp pattern + // bit 5-7: N/A + // The rest of the bits depend upon the arguments provided + // However, down the line, we might have different bitsets for different + // patterns + dnnl::graph::partition& insert_partition_cache( + std::bitset<32>& patternID, + dnnl::graph::partition& p) { + partition_map_[patternID] = std::move(p); + return partition_map_[patternID]; + } + std::optional> find_partition( + std::bitset<32>& patternID) { + auto iter = partition_map_.find(patternID); + if (iter != partition_map_.end()) { + return iter->second; + } + return std::nullopt; + } +}; + +} // namespace at::native::onednn diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/mkldnn/xpu/detail/oneDNN.h b/phivenv/Lib/site-packages/torch/include/ATen/native/mkldnn/xpu/detail/oneDNN.h new file mode 100644 index 0000000000000000000000000000000000000000..ca31bb57760255a36524a91007ea5450930c0f29 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/mkldnn/xpu/detail/oneDNN.h @@ -0,0 +1,182 @@ +#pragma once + +#include +#include +#include +#include + +namespace at::native::onednn { + +TORCH_API sycl::event matmul( + at::Tensor& result, + const at::Tensor& mat1, + const at::Tensor& mat2, + const at::Tensor& b_raw, + bool m2_trans, + Attr attr, + const std::vector& deps = {}); + +TORCH_API sycl::event convolution( + at::Tensor& dst, + const at::Tensor& src, + const at::Tensor& weight, + const at::Tensor& bia, + IntArrayRef padding_front_top_left, + IntArrayRef padding_back_bottom_right, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + Attr& attr, + const std::vector& deps = {}); + +TORCH_API sycl::event convolution_backward_weights( + at::Tensor& diff_weight, + at::Tensor& diff_bia, + const at::Tensor& diff_dst, + const at::Tensor& src, + IntArrayRef diff_weight_aten_size, + IntArrayRef padding_front_top_left, + IntArrayRef padding_back_bottom_right, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + const std::vector& deps = {}); + +TORCH_API sycl::event convolution_backward_data( + at::Tensor& diff_src, + const at::Tensor& diff_dst, + const at::Tensor& weight, + IntArrayRef padding_front_top_left, + IntArrayRef padding_back_bottom_right, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + bool bias_defined, + const std::vector& deps = {}); + +TORCH_API sycl::event deconvolution( + at::Tensor& dst, + const at::Tensor& src, + const at::Tensor& weight, + const at::Tensor& bia, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dst_padding, + IntArrayRef dilation, + int64_t groups, + Attr& attr, + const std::vector& deps = {}); + +TORCH_API sycl::event deconvolution_backward_data( + at::Tensor& diff_src, + const at::Tensor& diff_dst, + const at::Tensor& weight, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + bool bias_defined, + const std::vector& deps = {}); + +TORCH_API sycl::event deconvolution_backward_weights( + at::Tensor& diff_weight, + at::Tensor& diff_bia, + const at::Tensor& diff_dst, + const at::Tensor& src, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + const std::vector& deps = {}); + +TORCH_API void woq_matmul_int4( + at::Tensor& result, // dst, [M, N] + const at::Tensor& mat1_, // src, [M, K] + const at::Tensor& mat2_, // quantized weight, [K/8, N] + const at::Tensor& scale, // [K/group_size, N] + const at::Tensor& zp, // [k/group_size, N] + int64_t group_size, + bool pri_cache = true); + +dnnl::memory::dims conv_dst_size( + int64_t ndim, + IntArrayRef src_tz, + IntArrayRef wgh_tz, + IntArrayRef padding_front_top_left, + IntArrayRef padding_back_bottom_right, + IntArrayRef stride, + IntArrayRef dilation); + +dnnl::memory::dims deconv_dst_size( + IntArrayRef src_size, + IntArrayRef wgh_size, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + IntArrayRef dst_padding, + int64_t groups); + +at::Tensor quantized_convolution( + at::Tensor act, + double act_scale, + int64_t act_zero_point, + at::Tensor weight, + at::Tensor weight_scales, + at::Tensor weight_zero_points, + std::optional bias, + torch::List stride, + torch::List padding, + torch::List dilation, + bool transposed, + int64_t groups, + at::Tensor output, + double inv_output_scale, + int64_t output_zero_point, + std::optional accum, + double accum_scale, + int64_t accum_zero_point, + std::optional output_dtype, + std::optional binary_attr, + std::optional binary_alpha, + std::optional unary_attr, + torch::List> unary_scalars, + std::optional unary_algorithm); + +void quantized_matmul( + at::Tensor mat1, // act + double input_scale, + int64_t input_zero_point, + at::Tensor mat2, // weight + at::Tensor& weight_scales, + at::Tensor& weight_zero_points, + at::Tensor& b_raw, + at::Tensor result, // output + double output_scale, + int64_t output_zero_point, + std::optional output_dtype, + std::optional other, // extra input for binary-post-op + double other_scale, + int64_t other_zero_point, + const std::string_view& binary_post_op, + double binary_alpha, + const std::string_view& unary_post_op, + torch::List>& unary_post_op_args, + std::string_view unary_post_op_algorithm, + bool m2_trnas); + +void gpu_float_sdpa( + int batch_size, + int seq_len_q, + int seq_len_kv, + int num_head_q, + int num_head_kv, + int head_dim_qk, + int head_dim_v, + const Tensor& query, + const Tensor& key, + const Tensor& value, + std::optional attn_mask, + bool is_causal, + float softmax_scale, + const Tensor& output); +} // namespace at::native::onednn diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/mkldnn/xpu/detail/oneDNNContext.h b/phivenv/Lib/site-packages/torch/include/ATen/native/mkldnn/xpu/detail/oneDNNContext.h new file mode 100644 index 0000000000000000000000000000000000000000..b4f091d4dbb99fd4cc2541091a3e15905d0a1dd5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/mkldnn/xpu/detail/oneDNNContext.h @@ -0,0 +1,90 @@ +#pragma once + +#include + +#include +#include +#include +#include + +#include +#include +#include + +namespace at::native::onednn { + +TORCH_XPU_API dnnl::memory make_onednn_memory( + dnnl::memory::desc md, + dnnl::engine& engine, + void* ptr); + +// Keep non-static and non-inline +bool set_onednn_verbose(int level); + +// GpuEngineManager singleton +struct TORCH_XPU_API GpuEngineManager { + static GpuEngineManager& Instance(); // Singleton + + dnnl::engine& get_engine( + DeviceIndex device_index = c10::xpu::current_device()) { + c10::xpu::check_device_index(device_index); + return *engine_pool[device_index]; + } + + dnnl::engine& get_engine(const Device& device) { + TORCH_INTERNAL_ASSERT(device.type() == kXPU); + return get_engine(device.index()); + } + + GpuEngineManager(GpuEngineManager const&) = delete; + GpuEngineManager& operator=(GpuEngineManager const&) = delete; + GpuEngineManager(GpuEngineManager&&) = default; + GpuEngineManager& operator=(GpuEngineManager&&) = default; + + protected: + GpuEngineManager(); + ~GpuEngineManager() = default; + + private: + std::vector> engine_pool; +}; + +// GpuStreamManager singleton +struct TORCH_XPU_API GpuStreamManager { + static GpuStreamManager& Instance(); // Singleton + + dnnl::stream& get_stream( + DeviceIndex device_index = c10::xpu::current_device()) { + auto stream = c10::xpu::getCurrentXPUStream(device_index); + auto priority = stream.priority(); + if (stream_pool[device_index][priority].find(stream) == + stream_pool[device_index][priority].end()) { + stream_pool[device_index][priority][stream] = + std::make_shared(dnnl::sycl_interop::make_stream( + GpuEngineManager::Instance().get_engine(device_index), + stream.queue())); + } + return *stream_pool[device_index][priority][stream]; + } + + GpuStreamManager(GpuStreamManager const&) = delete; + GpuStreamManager& operator=(GpuStreamManager const&) = delete; + GpuStreamManager(GpuStreamManager&&) = default; + GpuStreamManager& operator=(GpuStreamManager&&) = default; + + protected: + GpuStreamManager() { + c10::DeviceIndex device_count = c10::xpu::device_count_ensure_non_zero(); + stream_pool.resize(device_count); + } + ~GpuStreamManager() = default; + + private: + using stream_hash_map = + ska::flat_hash_map>; + std::vector< + std::array> + stream_pool; +}; + +} // namespace at::native::onednn diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/mps/Copy.h b/phivenv/Lib/site-packages/torch/include/ATen/native/mps/Copy.h new file mode 100644 index 0000000000000000000000000000000000000000..d430db53fa38a8573bfe1f93ae582714ff01034e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/mps/Copy.h @@ -0,0 +1,14 @@ +// Copyright © 2022 Apple Inc. + +#pragma once +#include + +namespace at::native::mps { + +at::Tensor& mps_copy_( + at::Tensor& dst, + const at::Tensor& src, + bool non_blocking); +void copy_blit_mps(void* dst, const void* src, size_t size); + +} // namespace at::native::mps diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/mps/MPSGraphSequoiaOps.h b/phivenv/Lib/site-packages/torch/include/ATen/native/mps/MPSGraphSequoiaOps.h new file mode 100644 index 0000000000000000000000000000000000000000..2543f878843d6209bc9bfcc8c5c3dbef30d83d27 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/mps/MPSGraphSequoiaOps.h @@ -0,0 +1,41 @@ +#pragma once + +#include + +#if !defined(__MAC_15_0) && (!defined(MAC_OS_X_VERSION_15_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_15_0)) + +@interface MPSNDArrayIdentity : MPSNDArrayUnaryKernel +- (MPSNDArray* __nullable)reshapeWithCommandBuffer:(__nullable id)cmdBuf + sourceArray:(MPSNDArray* __nonnull)sourceArray + shape:(MPSShape* __nonnull)shape + destinationArray:(MPSNDArray* __nullable)destinationArray; +@end + +@interface MPSNDArrayDescriptor () +@property(readwrite, nonatomic) BOOL preferPackedRows; +@end + +@interface MPSNDArray () +- (nonnull instancetype)initWithBuffer:(id _Nonnull)buffer + offset:(NSUInteger)offset + descriptor:(MPSNDArrayDescriptor* _Nonnull)descriptor; +- (MPSNDArray* __nullable)arrayViewWithShape:(MPSShape* _Nullable)shape strides:(MPSShape* _Nonnull)strides; +@end + +typedef NS_ENUM(NSInteger, MTLMathMode) { + MTLMathModeSafe = 0, + MTLMathModeRelaxed = 1, + MTLMathModeFast = 2, +}; + +typedef NS_ENUM(NSInteger, MTLMathFloatingPointFunctions) { + MTLMathFloatingPointFunctionsFast = 0, + MTLMathFloatingPointFunctionsPrecise = 1, +}; + +@interface MTLCompileOptions () +@property(readwrite, nonatomic) MTLMathMode mathMode; +@property(readwrite, nonatomic) MTLMathFloatingPointFunctions mathFloatingPointFunctions; +@end + +#endif diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/mps/MPSGraphSonomaOps.h b/phivenv/Lib/site-packages/torch/include/ATen/native/mps/MPSGraphSonomaOps.h new file mode 100644 index 0000000000000000000000000000000000000000..3067507378e7761bc8778cdb1adaf277da5fc5b6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/mps/MPSGraphSonomaOps.h @@ -0,0 +1,48 @@ +#pragma once + +#include + +#if !defined(__MAC_14_0) && (!defined(MAC_OS_X_VERSION_14_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_14_0)) + +typedef NS_ENUM(NSUInteger, MPSGraphFFTScalingMode) { + MPSGraphFFTScalingModeNone = 0L, + MPSGraphFFTScalingModeSize = 1L, + MPSGraphFFTScalingModeUnitary = 2L, +}; + +@interface FakeMPSGraphFFTDescriptor : NSObject +@property(readwrite, nonatomic) BOOL inverse; +@property(readwrite, nonatomic) MPSGraphFFTScalingMode scalingMode; +@property(readwrite, nonatomic) BOOL roundToOddHermitean; ++ (nullable instancetype)descriptor; +@end + +@compatibility_alias MPSGraphFFTDescriptor FakeMPSGraphFFTDescriptor; + +@interface MPSGraph (SonomaOps) +- (MPSGraphTensor* _Nonnull)conjugateWithTensor:(MPSGraphTensor* _Nonnull)tensor name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)realPartOfTensor:(MPSGraphTensor* _Nonnull)tensor name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)fastFourierTransformWithTensor:(MPSGraphTensor* _Nonnull)tensor + axes:(NSArray* _Nonnull)axes + descriptor:(MPSGraphFFTDescriptor* _Nonnull)descriptor + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)realToHermiteanFFTWithTensor:(MPSGraphTensor* _Nonnull)tensor + axes:(NSArray* _Nonnull)axes + descriptor:(MPSGraphFFTDescriptor* _Nonnull)descriptor + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)HermiteanToRealFFTWithTensor:(MPSGraphTensor* _Nonnull)tensor + axes:(NSArray* _Nonnull)axes + descriptor:(MPSGraphFFTDescriptor* _Nonnull)descriptor + name:(NSString* _Nullable)name; +@end + +// define BFloat16 enums for MacOS13 +#define MPSDataTypeBFloat16 ((MPSDataType)(MPSDataTypeAlternateEncodingBit | MPSDataTypeFloat16)) + +// define Metal version +#define MTLLanguageVersion3_1 ((MTLLanguageVersion)((3 << 16) + 1)) +#endif diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/mps/MPSGraphVenturaOps.h b/phivenv/Lib/site-packages/torch/include/ATen/native/mps/MPSGraphVenturaOps.h new file mode 100644 index 0000000000000000000000000000000000000000..18da72f73b7ca0f6568baa07055785ec200cb89c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/mps/MPSGraphVenturaOps.h @@ -0,0 +1,196 @@ +#pragma once +#include + +// TODO: Remove me when moved to MacOS 13 +#if !defined(__MAC_13_2) && (!defined(MAC_OS_X_VERSION_13_2) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_13_2)) + +@interface FakeMPSGraphConvolution3DOpDescriptor : NSObject + +@property(readwrite, nonatomic) NSUInteger strideInX; +@property(readwrite, nonatomic) NSUInteger strideInY; +@property(readwrite, nonatomic) NSUInteger strideInZ; +@property(readwrite, nonatomic) NSUInteger dilationRateInX; +@property(readwrite, nonatomic) NSUInteger dilationRateInY; +@property(readwrite, nonatomic) NSUInteger dilationRateInZ; + +@property(readwrite, nonatomic) NSUInteger paddingLeft; +@property(readwrite, nonatomic) NSUInteger paddingRight; +@property(readwrite, nonatomic) NSUInteger paddingTop; +@property(readwrite, nonatomic) NSUInteger paddingBottom; +@property(readwrite, nonatomic) NSUInteger paddingFront; +@property(readwrite, nonatomic) NSUInteger paddingBack; + +@property(readwrite, nonatomic) MPSGraphPaddingStyle paddingStyle; +@property(readwrite, nonatomic) MPSGraphTensorNamedDataLayout dataLayout; +@property(readwrite, nonatomic) MPSGraphTensorNamedDataLayout weightsLayout; + +@property(readwrite, nonatomic) NSUInteger groups; + +@end + +@compatibility_alias MPSGraphConvolution3DOpDescriptor FakeMPSGraphConvolution3DOpDescriptor; + +#endif + +@interface MPSGraph (VenturaOps) + +#if !defined(__MAC_13_0) && (!defined(MAC_OS_X_VERSION_13_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_13_0)) + +typedef NS_ENUM(NSUInteger, MPSGraphResizeNearestRoundingMode) { + MPSGraphResizeNearestRoundingModeRoundPreferCeil = 0L, + MPSGraphResizeNearestRoundingModeRoundPreferFloor = 1L, + MPSGraphResizeNearestRoundingModeCeil = 2L, + MPSGraphResizeNearestRoundingModeFloor = 3L, + MPSGraphResizeNearestRoundingModeRoundToEven = 4L, + MPSGraphResizeNearestRoundingModeRoundToOdd = 5L, +}; + +// Define complex enums for MacOS 12 +#define MPSDataTypeComplexBit 0x01000000 +#define MPSDataTypeComplexFloat32 ((MPSDataType)(MPSDataTypeFloatBit | MPSDataTypeComplexBit | 64)) +#define MPSDataTypeComplexFloat16 ((MPSDataType)(MPSDataTypeFloatBit | MPSDataTypeComplexBit | 32)) +#endif + +- (MPSGraphTensor* _Nonnull)convolution3DWithSourceTensor:(MPSGraphTensor* _Nonnull)source + weightsTensor:(MPSGraphTensor* _Nonnull)weights + descriptor:(MPSGraphConvolution3DOpDescriptor* _Nonnull)descriptor + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull) + convolution3DDataGradientWithIncomingGradientTensor:(MPSGraphTensor* _Nonnull)incomingGradient + weightsTensor:(MPSGraphTensor* _Nonnull)weights + outputShape:(MPSShape* _Nonnull)outputShape + forwardConvolutionDescriptor: + (MPSGraphConvolution3DOpDescriptor* _Nonnull)forwardConvolutionDescriptor + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull) + convolution3DWeightsGradientWithIncomingGradientTensor:(MPSGraphTensor* _Nonnull)incomingGradient + sourceTensor:(MPSGraphTensor* _Nonnull)source + outputShape:(MPSShape* _Nonnull)outputShape + forwardConvolutionDescriptor: + (MPSGraphConvolution3DOpDescriptor* _Nonnull)forwardConvolutionDescriptor + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)cumulativeSumWithTensor:(MPSGraphTensor* _Nonnull)tensor + axis:(NSInteger)axis + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)sortWithTensor:(MPSGraphTensor* _Nonnull)tensor + axis:(NSInteger)axis + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)sortWithTensor:(MPSGraphTensor* _Nonnull)tensor + axis:(NSInteger)axis + descending:(BOOL)descending + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)sortWithTensor:(MPSGraphTensor* _Nonnull)tensor + axisTensor:(MPSGraphTensor* _Nonnull)axisTensor + descending:(BOOL)descending + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)sortWithTensor:(MPSGraphTensor* _Nonnull)tensor + axisTensor:(MPSGraphTensor* _Nonnull)axisTensor + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)argSortWithTensor:(MPSGraphTensor* _Nonnull)tensor + axis:(NSInteger)axis + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)argSortWithTensor:(MPSGraphTensor* _Nonnull)tensor + axis:(NSInteger)axis + descending:(BOOL)descending + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)argSortWithTensor:(MPSGraphTensor* _Nonnull)tensor + axisTensor:(MPSGraphTensor* _Nonnull)axisTensor + descending:(BOOL)descending + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)argSortWithTensor:(MPSGraphTensor* _Nonnull)tensor + axisTensor:(MPSGraphTensor* _Nonnull)axisTensor + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)inverseOfTensor:(MPSGraphTensor* _Nonnull)inputTensor name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)resizeNearestWithTensor:(MPSGraphTensor* _Nonnull)imagesTensor + sizeTensor:(MPSGraphTensor* _Nonnull)size + nearestRoundingMode:(MPSGraphResizeNearestRoundingMode)nearestRoundingMode + centerResult:(BOOL)centerResult + alignCorners:(BOOL)alignCorners + layout:(MPSGraphTensorNamedDataLayout)layout + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)resizeNearestWithTensor:(MPSGraphTensor* _Nonnull)imagesTensor + sizeTensor:(MPSGraphTensor* _Nonnull)size + scaleOffsetTensor:(MPSGraphTensor* _Nonnull)scaleOffset + nearestRoundingMode:(MPSGraphResizeNearestRoundingMode)nearestRoundingMode + layout:(MPSGraphTensorNamedDataLayout)layout + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)resizeBilinearWithTensor:(MPSGraphTensor* _Nonnull)imagesTensor + sizeTensor:(MPSGraphTensor* _Nonnull)size + centerResult:(BOOL)centerResult + alignCorners:(BOOL)alignCorners + layout:(MPSGraphTensorNamedDataLayout)layout + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)resizeBilinearWithTensor:(MPSGraphTensor* _Nonnull)imagesTensor + sizeTensor:(MPSGraphTensor* _Nonnull)size + scaleOffsetTensor:(MPSGraphTensor* _Nonnull)scaleOffset + layout:(MPSGraphTensorNamedDataLayout)layout + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)resizeNearestWithGradientTensor:(MPSGraphTensor* _Nonnull)gradient + input:(MPSGraphTensor* _Nonnull)input + nearestRoundingMode:(MPSGraphResizeNearestRoundingMode)nearestRoundingMode + centerResult:(BOOL)centerResult + alignCorners:(BOOL)alignCorners + layout:(MPSGraphTensorNamedDataLayout)layout + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)resizeNearestWithGradientTensor:(MPSGraphTensor* _Nonnull)gradient + input:(MPSGraphTensor* _Nonnull)input + scaleOffsetTensor:(MPSGraphTensor* _Nonnull)scaleOffset + nearestRoundingMode:(MPSGraphResizeNearestRoundingMode)nearestRoundingMode + layout:(MPSGraphTensorNamedDataLayout)layout + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)resizeBilinearWithGradientTensor:(MPSGraphTensor* _Nonnull)gradient + input:(MPSGraphTensor* _Nonnull)input + centerResult:(BOOL)centerResult + alignCorners:(BOOL)alignCorners + layout:(MPSGraphTensorNamedDataLayout)layout + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)resizeBilinearWithGradientTensor:(MPSGraphTensor* _Nonnull)gradient + input:(MPSGraphTensor* _Nonnull)input + scaleOffsetTensor:(MPSGraphTensor* _Nonnull)scaleOffset + layout:(MPSGraphTensorNamedDataLayout)layout + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)sampleGridWithSourceTensor:(MPSGraphTensor* _Nonnull)source + coordinateTensor:(MPSGraphTensor* _Nonnull)coordinates + layout:(MPSGraphTensorNamedDataLayout)layout + normalizeCoordinates:(BOOL)normalizeCoordinates + relativeCoordinates:(BOOL)relativeCoordinates + alignCorners:(BOOL)alignCorners + paddingMode:(MPSGraphPaddingMode)paddingMode + samplingMode:(MPSGraphResizeMode)samplingMode + constantValue:(double)constantValue + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)sampleGridWithSourceTensor:(MPSGraphTensor* _Nonnull)source + coordinateTensor:(MPSGraphTensor* _Nonnull)coordinates + layout:(MPSGraphTensorNamedDataLayout)layout + normalizeCoordinates:(BOOL)normalizeCoordinates + relativeCoordinates:(BOOL)relativeCoordinates + alignCorners:(BOOL)alignCorners + paddingMode:(MPSGraphPaddingMode)paddingMode + nearestRoundingMode:(MPSGraphResizeNearestRoundingMode)nearestRoundingMode + constantValue:(double)constantValue + name:(NSString* _Nullable)name; +- (MPSGraphTensor* _Nonnull)truncateWithTensor:(MPSGraphTensor* _Nonnull)tensor name:(NSString* _Nullable)name; + +@end diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/mps/MetalShaderLibrary.h b/phivenv/Lib/site-packages/torch/include/ATen/native/mps/MetalShaderLibrary.h new file mode 100644 index 0000000000000000000000000000000000000000..81dbc15694dc22075f86ccca71e63a208cd19dd5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/mps/MetalShaderLibrary.h @@ -0,0 +1,178 @@ +#pragma once +#ifdef __OBJC__ +#include +typedef id MTLLibrary_t; +typedef id MTLFunction_t; +typedef id MTLComputePipelineState_t; +typedef id MTLComputeCommandEncoder_t; +#else +typedef void MTLCompileOptions; +typedef void* MTLLibrary_t; +typedef void* MTLFunction_t; +typedef void* MTLComputePipelineState_t; +typedef void* MTLComputeCommandEncoder_t; +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +// Forward declaration of TensorBase and TensorIteratorBase +namespace at { +class TensorBase; +struct TensorIteratorBase; +} // namespace at + +namespace at::native::mps { + +namespace detail { +template +class has_size_type { + template + static constexpr std::true_type check(typename U::size_type*); + template + static constexpr std::false_type check(...); + + public: + static constexpr bool value = decltype(check(nullptr))::value; +}; + +template +constexpr bool has_size_type_v = has_size_type::value; + +} // namespace detail + +// Returns `gpuAddress` of respective `id` plus storage offset +void* get_tensor_gpu_address(const at::TensorBase&); + +class MetalKernelFunction { + public: + MetalKernelFunction(MTLComputePipelineState_t cps_, MTLFunction_t f_); + ~MetalKernelFunction(); + MetalKernelFunction(MetalKernelFunction&) = delete; + // Shader properties + uint64_t getMaxThreadsPerThreadgroup() const; + uint64_t getThreadExecutionWidth() const; + uint64_t getStaticThreadGroupMemoryLength() const; + void runCommandBlock(std::function f); + // Methods below should be called from runCommandBlock function + void startEncoding(); + void setArg(unsigned idx, const at::TensorBase& t); + void setArg(unsigned idx, const void* ptr, uint64_t size); + template < + typename T, + typename = std::enable_if_t< + std::is_integral_v || std::is_same_v || + (std::is_class_v && std::is_trivially_copyable_v && + !detail::has_size_type_v)>> + inline void setArg(unsigned idx, const T val) { + setArg(idx, &val, sizeof(T)); + } + + template < + typename Container, + typename = std::enable_if_t>> + inline void setArg(unsigned idx, const Container& values) { + setArg( + idx, + values.data(), + values.size() * sizeof(typename Container::value_type)); + } + void dispatch( + uint64_t length, + std::optional groupSize = std::nullopt); + void dispatch( + c10::ArrayRef length, + c10::OptionalArrayRef groupSize = std::nullopt); + + private: + MTLComputePipelineState_t cps; + MTLFunction_t func; + MTLComputeCommandEncoder_t encoder = nullptr; +}; + +class MetalShaderLibrary { + public: + MetalShaderLibrary(std::string src) + : shaderSource(std::move(src)), nparams(0), compile_options(nullptr) {} + MetalShaderLibrary(std::string src, unsigned nparams_) + : shaderSource(std::move(src)), + nparams(nparams_), + compile_options(nullptr) {} + MetalShaderLibrary( + std::string src, + unsigned nparams_, + MTLCompileOptions* compile_options_) + : shaderSource(std::move(src)), + nparams(nparams_), + compile_options(compile_options_) {} + MetalShaderLibrary(const MetalShaderLibrary&) = delete; + virtual ~MetalShaderLibrary(); + std::vector getFunctionNames(); + std::shared_ptr getKernelFunction( + const std::string& name); + inline MTLComputePipelineState_t getPipelineStateForFunc( + const std::string& fname) { + return getLibraryPipelineState(getLibrary(), fname).first; + } + MTLComputePipelineState_t getPipelineStateForFunc( + const std::string& fname, + const std::initializer_list& params) { + return getLibraryPipelineState(getLibrary(params), fname).first; + } + inline MTLFunction_t getMTLFunction(const std::string& fname) { + return getLibraryPipelineState(getLibrary(), fname).second; + } + MTLFunction_t getMTLFunction( + const std::string& fname, + const std::initializer_list& params) { + return getLibraryPipelineState(getLibrary(params), fname).second; + } + static MetalShaderLibrary& getBundledLibrary(); + void exec_unary_kernel( + TensorIteratorBase& iter, + const std::string& name, + const std::optional alpha = std::nullopt, + const std::optional scalar_arg_type = std::nullopt); + void exec_binary_kernel( + TensorIteratorBase& iter, + const std::string& name, + const std::optional alpha = std::nullopt, + const std::optional scalar_arg_type = std::nullopt); + + protected: + virtual MTLLibrary_t getLibrary(); + virtual MTLLibrary_t getLibrary( + const std::initializer_list& params); + MTLLibrary_t library = nullptr; + + private: + std::pair getLibraryPipelineState( + MTLLibrary_t lib, + const std::string& fname); + MTLLibrary_t compileLibrary(const std::string& src); + std::string shaderSource; + unsigned nparams; + MTLCompileOptions* compile_options; + std::unordered_map libMap; + std::unordered_map< + std::string, + std::pair> + cplMap; +}; + +class DynamicMetalShaderLibrary : public MetalShaderLibrary { + public: + DynamicMetalShaderLibrary(const std::string& src) : MetalShaderLibrary(src) { + // Compile right away + getLibrary(); + } + ~DynamicMetalShaderLibrary() override; +}; + +} // namespace at::native::mps diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/mps/OperationUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/mps/OperationUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..c9fc72bcb4c309de13296f02e9013886868ea849 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/mps/OperationUtils.h @@ -0,0 +1,656 @@ +// Copyright © 2022 Apple Inc. + +#pragma once + +#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#endif + +#include + +@interface MPSGraph (PyTorchFixups) +- (MPSGraphTensor*)minimumWithNaNPropagationAndIntFallbackWithPrimaryTensor:(MPSGraphTensor*)primaryTensor + secondaryTensor:(MPSGraphTensor*)secondaryTensor + name:(NSString*)name; + +- (MPSGraphTensor*)maximumWithNaNPropagationAndIntFallbackWithPrimaryTensor:(MPSGraphTensor*)primaryTensor + secondaryTensor:(MPSGraphTensor*)secondaryTensor + name:(NSString*)name; +@end + +using namespace at::mps; + +namespace at::native::mps { + +void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)()); + +struct MPSScalar { + id getMTLBuffer() const { + return __builtin_bit_cast(id, buffer.get()); + } + + size_t size = 0; + ScalarType type = ScalarType::Undefined; + c10::DataPtr buffer; // stores MTLBuffer (frees buffer if MPSScalar instance goes out of scope) + union { + float f; // MPS doesn't support 'double' + at::Half h; + int64_t i; + bool b; + c10::complex cf; + c10::complex ch; + at::BFloat16 bf16; + } value{}; +}; + +void runMPSGraph(MPSStream* mpsStream, MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results); + +MPSDataType getMPSDataType(ScalarType scalar_type); +static inline MPSDataType getMPSDataType(const TensorBase& t) { + return getMPSDataType(t.scalar_type()); +} +MPSDataType getMPSScalarType(ScalarType scalar_type); +static inline MPSDataType getMPSScalarType(const TensorBase& t) { + return getMPSScalarType(t.scalar_type()); +} +MPSScalar getMPSScalar(const Scalar& scalar, ScalarType type); +std::string getMPSTypeString(ScalarType scalar_type, bool short_name = false); +static inline std::string getMPSTypeString(const TensorBase& t, bool short_name = false) { + return getMPSTypeString(t.scalar_type(), short_name); +} +std::string scalarToMetalTypeString(const c10::ScalarType& scalar_type); +static inline std::string scalarToMetalTypeString(const TensorBase& t) { + return scalarToMetalTypeString(t.scalar_type()); +} +NSArray* getTensorAxes(const TensorBase& t); +NSArray* getTensorAxes(const IntArrayRef& sizes, at::OptionalIntArrayRef dim); +std::string getMPSShapeString(MPSShape* shape); +std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype = true, bool exclude_shape = false); +std::string getArrayRefString(const IntArrayRef s); +// use has_storage() on the returned tensor to determine if src actually is a view +Tensor gatherViewTensor(const Tensor& src, Tensor& dst); +Tensor& scatterViewTensor(const Tensor& src, Tensor& output); +MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, + MPSGraphTensor* inputTensor, + const TensorBase& input, + bool includesInt64 = false); +MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, + MPSGraphTensor* inputTensor, + const TensorBase& input, + bool includesInt64 = false); + +MPSNDArray* getStridedMPSNDArray(const TensorBase& src, MPSNDArray* srcNDArray); +MPSNDArray* getMPSNDArray(const TensorBase& t, const IntArrayRef& sizes = {}, const IntArrayRef& strides = {}); +MPSNDArray* getMPSNDArray(const TensorBase& t, MPSShape* sizes = nil, MPSShape* strides = nil); +// The MPSShape could vary based on memory format +Tensor getTensorView(const Tensor& t, MPSShape* shape); +MPSShape* getMPSShape(const TensorBase& t, c10::MemoryFormat memory_format = MemoryFormat::Contiguous); +MPSShape* getMPSShape(IntArrayRef sizes, c10::MemoryFormat memory_format = MemoryFormat::Contiguous); + +static inline id getMTLBufferStorage(const TensorBase& tensor) { + return __builtin_bit_cast(id, tensor.storage().data()); +} + +class Placeholder { + public: + Placeholder() : _placeholder(nullptr), _value(nullptr), _tensor(Tensor()) {} + Placeholder(MPSGraphTensor* mpsGraphTensor) : _placeholder(mpsGraphTensor), _value(nullptr), _tensor(Tensor()) {} + Placeholder(MPSGraphTensor* mpsGraphTensor, MPSNDArray* mpsNDArray); + Placeholder(MPSGraphTensor* mpsGraphTensor, + const Tensor& self, + MPSShape* mpsShape = nullptr, + bool gatherTensorData = true, + MPSDataType dataType = MPSDataTypeInvalid, + bool useMPSStridedAPI = true); + MPSGraphTensor* getMPSGraphTensor() { + return _placeholder; + } + MPSGraphTensorData* getMPSGraphTensorData() { + return _value; + } + bool isIntermediate() { + return _value == nullptr; + } + + private: + MPSGraphTensor* _placeholder; + MPSGraphTensorData* _value; + Tensor _tensor; +}; + +void resize_tensor(Tensor* output); +Tensor wrapped_scalar_tensor_mps(const Scalar& scalar, const Device device); +MPSGraphTensor* convertNHWCtoNCHW(MPSGraph* mpsGraph, MPSGraphTensor* tensor); +MPSGraphTensor* castMPSTensor(MPSGraph* mpsGraph, MPSGraphTensor* tensor, ScalarType toType); +MPSGraphTensor* castMPSTensor(MPSGraph* mpsGraph, MPSGraphTensor* tensor, MPSDataType toType); +MPSGraphTensorData* getMPSGraphTensorData(MPSGraph* mpsGraph, MPSStream* mpsStream, const TensorBase& tensor); +MPSGraphTensorData* getMPSGraphTensorFromScalar(MPSStream* mpsStream, MPSScalar& scalar); + +MPSGraph* make_mps_graph(); +void printTensorNDArray(const TensorBase& t); +MPSNDArray* ndArrayFromTensor(const TensorBase& tensor, MPSShape* shape, MPSDataType mpsType); + +MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType); +MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType, MPSShape* mpsShape); +MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph* mpsGraph, const TensorBase& tensor); +MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType); +MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph* mpsGraph, const Scalar& scalar); + +std::string get_mem_format_string(c10::MemoryFormat memory_format); + +using MPSCacheKey = uint64_t; + +struct MPSCachedKernel { + MPSCachedKernel(NSObject* object) : _object([object retain]) {} + virtual ~MPSCachedKernel() { + [_object release]; + _object = nullptr; + } + + // Delete copy constructor and assignment + MPSCachedKernel(const MPSCachedKernel&) = delete; + void operator=(const MPSCachedKernel&) = delete; + + template + inline T* kernel() const { + return (T*)_object; + } + + private: + NSObject* _object = nullptr; +}; + +// derive this class to cache a graph and its inputs/outputs +// can be used to store any NSObject +struct MPSCachedGraph { + MPSCachedGraph(NSObject* object) : _object([object retain]) {} + virtual ~MPSCachedGraph() { + [_object release]; + _object = nullptr; + } + + template + inline T* as() { + return static_cast(this); + } + + MPSGraph* graph() const { + return (MPSGraph*)_object; + } + NSObject* object() const { + return _object; + } + + private: + NSObject* _object = nullptr; +}; + +struct MPSUnaryCachedGraph : public MPSCachedGraph { + MPSUnaryCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; +}; + +struct MPSUnaryGradCachedGraph : public MPSCachedGraph { + MPSUnaryGradCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* gradOutputTensor_ = nil; + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; // some backward input is actually the forward's output + MPSGraphTensor* gradInputTensor_ = nil; +}; + +struct MPSBinaryCachedGraph : public MPSCachedGraph { + MPSBinaryCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* otherTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; +}; + +struct MPSBinaryGradCachedGraph : public MPSCachedGraph { + MPSBinaryGradCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* gradOutputTensor_ = nil; + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* otherTensor_ = nil; + MPSGraphTensor* gradInputTensor_ = nil; +}; + +struct MPSKernelCache { + typedef MPSCachedKernel* (^CreateCachedKernelBlock)(); + + struct CacheEntry { + CacheEntry(const std::string& key, MPSCachedKernel* cachedKernel) : cachedKernel_(cachedKernel), key_(key) {} + MPSCachedKernel* cachedKernel_ = nullptr; + std::string key_; + }; + + public: + static MPSKernelCache* getInstance() { + if (_instance_cache == nullptr) { + _instance_cache = new MPSKernelCache(); + } + return _instance_cache; + } + + ~MPSKernelCache() { + dispatch_release(serialQueue_); + for (const auto& i : cache_) { + delete i.second.cachedKernel_; + } + } + + // Disallow the copy constructor and operator= functions + MPSKernelCache(const MPSKernelCache&) = delete; + void operator=(const MPSKernelCache&) = delete; + + MPSCachedKernel* CreateCachedKernel(const std::string& key, CreateCachedKernelBlock createCacheBlock) { + __block MPSCachedKernel* cachedKernel = nil; + MPSCacheKey hash = std::hash{}(key); + dispatch_sync_with_rethrow(serialQueue_, ^() { + if (cache_.count(hash) != 0) { + auto& entry = cache_.at(hash); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached kernel!\n"); + cachedKernel = entry.cachedKernel_; + } else { + cachedKernel = createCacheBlock(); + CacheEntry entry(key, cachedKernel); + cache_.emplace(hash, entry); + } + }); + return cachedKernel; + } + template + inline T* CreateCachedKernelAs(const std::string& key, CreateCachedKernelBlock createCacheBlock) { + return static_cast(CreateCachedKernel(key, createCacheBlock)); + } + + MPSCachedKernel* LookUp(const std::string& key) const { + __block MPSCachedKernel* cachedKernel = nil; + + MPSCacheKey hash = std::hash{}(key); + dispatch_sync_with_rethrow(serialQueue_, ^() { + if (cache_.count(hash) != 0) { + auto& entry = cache_.at(hash); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached kernel!\n"); + cachedKernel = entry.cachedKernel_; + } + }); + return cachedKernel; + } + + template + inline T* LookUpAs(const std::string& key) const { + return static_cast(LookUp(key)); + } + + private: + MPSKernelCache() { + serialQueue_ = dispatch_queue_create("kernel cache queue", DISPATCH_QUEUE_SERIAL); + } + + static MPSKernelCache* _instance_cache; + std::unordered_map cache_; + dispatch_queue_t serialQueue_ = nullptr; +}; + +// Common template for creating cached kernel if missing +template +inline T* LookUpOrCreateCachedKernel(const std::string& key, std::function instantiate) { + auto cache_ = MPSKernelCache::getInstance(); + if (auto rc = cache_->LookUpAs(key)) { + return rc; + } + return cache_->CreateCachedKernelAs(key, ^mps::MPSCachedKernel*() { + auto k_ = new mps::MPSCachedKernel(instantiate()); + return k_; + }); +} + +// TODO: Improve the overall design of MPSGraphCache. +// https://github.com/pytorch/pytorch/issues/77176 +// Cache holding various keys mapped to graphs +struct MPSGraphCache { + typedef MPSCachedGraph* (^CreateCachedGraphBlock)(); + + struct CacheEntry { + CacheEntry(const std::string& key, MPSCachedGraph* cachedGraph) : cachedGraph_(cachedGraph), key_(key) {} + MPSCachedGraph* cachedGraph_ = nullptr; + std::string key_; + }; + + public: + static MPSGraphCache* getInstance() { + if (_instance_cache == nullptr) { + _instance_cache = new MPSGraphCache(); + } + return _instance_cache; + } + + ~MPSGraphCache() { + dispatch_release(serialQueue_); + + for (const auto& i : cache_) { + delete i.second.cachedGraph_; + } + } + + // Disallow the copy constructor and operator= functions + MPSGraphCache(const MPSGraphCache&) = delete; + void operator=(const MPSGraphCache&) = delete; + + MPSCachedGraph* CreateCachedGraph(const std::string& key, CreateCachedGraphBlock createCacheBlock) { + __block MPSCachedGraph* cachedGraph = nil; + + MPSCacheKey hash = std::hash{}(key); + + dispatch_sync_with_rethrow(serialQueue_, ^() { + // verify the cached entry doesn't already exist + if (cache_.count(hash) != 0) { + auto& entry = cache_.at(hash); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached graph!\n"); + cachedGraph = entry.cachedGraph_; + } else { + cachedGraph = createCacheBlock(); + CacheEntry entry(key, cachedGraph); + cache_.emplace(hash, entry); + profileCachedGraph(entry); + } + }); + return cachedGraph; + } + + template + inline T* CreateCachedGraphAs(const std::string& key, CreateCachedGraphBlock createCacheBlock) { + return static_cast(CreateCachedGraph(key, createCacheBlock)); + } + + MPSCachedGraph* LookUp(const std::string& key) const { + __block MPSCachedGraph* cachedGraph = nullptr; + + MPSCacheKey hash = std::hash{}(key); + + dispatch_sync(serialQueue_, ^() { + if (cache_.count(hash) != 0) { + auto& entry = cache_.at(hash); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached graph!\n"); + cachedGraph = entry.cachedGraph_; + profileCachedGraph(entry); + } + }); + return cachedGraph; + } + + template + inline T* LookUpAs(const std::string& key) const { + return static_cast(LookUp(key)); + } + + private: + MPSGraphCache() { + serialQueue_ = dispatch_queue_create("cache queue", DISPATCH_QUEUE_SERIAL); + } + // this is defined in OperationUtils.mm to not include + // MPSProfiler.h in header OperationUtils.h + void profileCachedGraph(const CacheEntry& cacheEntry) const; + + static MPSGraphCache* _instance_cache; + std::unordered_map cache_; + dispatch_queue_t serialQueue_ = nullptr; +}; + +// Common template for creating graph with a specified cache if missing +template +inline T* LookUpOrCreateCachedGraph(const std::string& key, std::function instantiate) { + auto cache_ = MPSGraphCache::getInstance(); + if (auto rc = cache_->LookUpAs(key)) { + return rc; + } + return cache_->CreateCachedGraphAs(key, ^mps::MPSCachedGraph*() { + T* newCachedGraph = nil; + @autoreleasepool { + // Initialize graph + auto mpsGraph = mps::make_mps_graph(); + newCachedGraph = new T(mpsGraph); + instantiate(mpsGraph, newCachedGraph); + } + return newCachedGraph; + }); +} + +// Common math operations +MPSGraphTensor* log1p(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor); + +#define MPS_CHECK_INT64_OP_SUPPORTED(input_tensor, mac_os_13_3_plus, op_name) \ + if (!mac_os_13_3_plus && input_tensor.scalar_type() == kLong) { \ + TORCH_WARN_ONCE( \ + "MPS: no support for int64 for ", \ + op_name, \ + ", downcasting to a smaller data type (int32/float32). Native support for int64 has been added in macOS 13.3."); \ + } + +/** + * Returns distance from lowest to highest element offset in given tensor. + */ +size_t compute_storage_numel_distance(const TensorBase& t); + +/** + * Checks whether tensor is mapped to a contiguous area in the storage. + */ +inline bool is_dense_in_storage(const TensorBase& t) { + return compute_storage_numel_distance(t) == static_cast(t.numel()); +} + +template , encoder_t> || + std::is_same_v, encoder_t>>> +static inline void mtl_setBuffer(encoder_t encoder, const TensorBase& t, unsigned idx) { + if (C10_UNLIKELY(t.device().type() == kCPU)) { + if constexpr (std::is_same_v, encoder_t>) { + TORCH_CHECK(t.dim() == 0, "Passed CPU tensor to MPS op"); + // MPS does not support doubles, silently downcast CPU scalar to float + if (C10_UNLIKELY(t.scalar_type() == kDouble)) { + auto val = static_cast(*reinterpret_cast(t.const_data_ptr())); + [encoder setBytes:&val length:sizeof(val) atIndex:idx]; + return; + } + if (C10_UNLIKELY(t.scalar_type() == kComplexDouble)) { + auto val = static_cast>(*reinterpret_cast*>(t.const_data_ptr())); + [encoder setBytes:&val length:sizeof(val) atIndex:idx]; + return; + } + [encoder setBytes:t.storage().data() length:t.element_size() atIndex:idx]; + } else { + TORCH_CHECK(false, "Passed CPU tensor to MPS op"); + } + return; + } + [encoder setBuffer:getMTLBufferStorage(t) offset:t.storage_offset() * t.element_size() atIndex:idx]; +} + +// Implementation of setBytes for containers vs trivially copiable types must be separate +// Containers like `std::array` could have been uploaded directly, but `c10::ArrayRef`, +// while trivially copiable, includes padding which if copied as Metal shader parameters +// might overwrite other values +template < + typename T, + typename = std::enable_if_t || std::is_same_v || + (std::is_class_v && std::is_trivially_copyable_v && !detail::has_size_type_v)>> +static inline void mtl_setBytes(id encoder, const T val, unsigned idx) { + [encoder setBytes:&val length:sizeof(T) atIndex:idx]; +} + +template >> +static inline void mtl_setBytes(id encoder, const Container& values, unsigned idx) { + [encoder setBytes:values.data() length:sizeof(typename Container::value_type) * values.size() atIndex:idx]; +} + +static inline void mtl_setBytes(id encoder, const MPSScalar& s, unsigned idx) { + [encoder setBytes:&s.value length:s.size atIndex:idx]; +} + +static size_t iter_tensor_offset(TensorIteratorBase& iter, unsigned idx) { + // At the moment, MPS storage data is not the real GPU pointer, but rather a pointer to id object + // But TensorIterator constructs data_ptr as if base was just a raw pointer + // Workaround this problem by computing an offset from the start of the tensor, which works for both + // tensor views and sliced 64-bit iterators + return reinterpret_cast(iter.data_ptr(idx)) - + reinterpret_cast(iter.tensor_base(idx).storage().data()); +} + +static inline void bind_iter_tensors(id encoder, + TensorIteratorBase& iter, + std::optional ntensors = std::nullopt) { + for (auto idx : c10::irange(ntensors.value_or(iter.ntensors()))) { + auto& t = iter.tensor_base(idx); + // Handle CPU scalars + if (C10_UNLIKELY(t.device().type() == kCPU)) { + mtl_setBuffer(encoder, t, idx); + continue; + } + auto offs = iter_tensor_offset(iter, idx); + [encoder setBuffer:getMTLBufferStorage(t) offset:offs atIndex:idx]; + } +} + +namespace detail { +template +inline void mtl_setArg(id encoder, const T& val, unsigned idx) { + mtl_setBytes(encoder, val, idx); +} + +inline void mtl_setArg(id encoder, id val, unsigned idx) { + [encoder setBuffer:val offset:0 atIndex:idx]; +} + +template <> +inline void mtl_setArg(id encoder, const Tensor& val, unsigned idx) { + mtl_setBuffer(encoder, val, idx); +} + +template <> +inline void mtl_setArg(id encoder, const std::optional& val, unsigned idx) { + if (val.has_value()) { + mtl_setBuffer(encoder, val.value(), idx); + } +} + +template <> +inline void mtl_setArg(id encoder, const TensorBase& val, unsigned idx) { + mtl_setBuffer(encoder, val, idx); +} +// MPS does not support doubles, so cast it down to float before passing as an argument +template <> +inline void mtl_setArg(id encoder, const double& val, unsigned idx) { + float val_f = static_cast(val); + mtl_setBytes(encoder, val_f, idx); +} +} // namespace detail + +template +static inline void mtl_setArgs(id encoder, const T& val) { + detail::mtl_setArg(encoder, val, idx); +} + +template +static inline void mtl_setArgs(id encoder, const T& val, Args&&... args) { + detail::mtl_setArg(encoder, val, idx); + mtl_setArgs(encoder, std::forward(args)...); +} + +static inline void mtl_dispatch1DJob(id encoder, + id cplState, + NSUInteger length) { + static_assert(sizeof(NSUInteger) == sizeof(uint64_t)); + const auto maxThreadsPerGroup = [cplState maxTotalThreadsPerThreadgroup]; + auto size = MTLSizeMake(length, 1, 1); + auto threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, length), 1, 1); + [encoder dispatchThreads:size threadsPerThreadgroup:threadGroupSize]; +} + +id generateKernelDataOffsets(id commandEncoder, + const TensorIteratorBase& iter, + bool use_64bit_index = false); + +inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1) { + return @{p1.getMPSGraphTensor() : p1.getMPSGraphTensorData()}; +} + +inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2) { + return @{ + p1.getMPSGraphTensor() : p1.getMPSGraphTensorData(), + p2.getMPSGraphTensor() : p2.getMPSGraphTensorData(), + }; +} + +inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2, Placeholder& p3) { + return @{ + p1.getMPSGraphTensor() : p1.getMPSGraphTensorData(), + p2.getMPSGraphTensor() : p2.getMPSGraphTensorData(), + p3.getMPSGraphTensor() : p3.getMPSGraphTensorData(), + }; +} + +inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2, Placeholder& p3, Placeholder& p4) { + return @{ + p1.getMPSGraphTensor() : p1.getMPSGraphTensorData(), + p2.getMPSGraphTensor() : p2.getMPSGraphTensorData(), + p3.getMPSGraphTensor() : p3.getMPSGraphTensorData(), + p4.getMPSGraphTensor() : p4.getMPSGraphTensorData(), + }; +} + +inline void runMPSGraph(MPSStream* stream, MPSGraph* graph, NSDictionary* feeds, Placeholder& result) { + runMPSGraph(stream, graph, feeds, dictionaryFromPlaceholders(result)); +} + +inline bool supportsComplex() { + return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS); +} + +// MPS yet to support double types, but starting from MacOS 14, supports bfloat16 +inline bool supportedFloatingType(ScalarType dtype) { + return dtype == kFloat || dtype == kHalf || dtype == kBFloat16; +} + +inline bool supportedFloatingType(const TensorBase& t) { + return supportedFloatingType(t.scalar_type()); +} + +inline bool supportedFloatingOrComplexType(ScalarType dtype) { + if (dtype == kComplexFloat || dtype == kComplexHalf) { + return supportsComplex(); + } + return supportedFloatingType(dtype); +} +inline bool supportedFloatingOrComplexType(const TensorBase& t) { + return supportedFloatingOrComplexType(t.scalar_type()); +} + +inline void checkSupportsBFloat16() { + TORCH_CHECK_TYPE(is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS), + "MPS bfloat16 type is supported on MacOS 14.0 or newer."); +} + +inline bool needsGather(const TensorBase& t) { + static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS); + return !is_macOS_15_0_or_newer && (!t.is_contiguous() || t.storage_offset()); +} + +} // namespace at::native::mps diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/mps/TensorFactory.h b/phivenv/Lib/site-packages/torch/include/ATen/native/mps/TensorFactory.h new file mode 100644 index 0000000000000000000000000000000000000000..08fbea015db7f4a4c9dc91327b15e17fc3e30eed --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/mps/TensorFactory.h @@ -0,0 +1,14 @@ +// Copyright © 2022 Apple Inc. + +#define AT_DISPATCH_MPS_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) AT_DISPATCH_CASE( \ + at::ScalarType::Half, \ + __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)) diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/mps/kernels/UpSample.h b/phivenv/Lib/site-packages/torch/include/ATen/native/mps/kernels/UpSample.h new file mode 100644 index 0000000000000000000000000000000000000000..c137787ce0ed7fcb595686e87c093756144801ba --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/mps/kernels/UpSample.h @@ -0,0 +1,20 @@ +#pragma once + +#ifndef __METAL__ +#include +using ulong = unsigned long; +#define _ARRAY_NS std +#else +#include +#define _ARRAY_NS metal +#endif + +template +struct UpsampleParams { + _ARRAY_NS::array input_strides; + _ARRAY_NS::array input_sizes; + _ARRAY_NS::array output_strides; + _ARRAY_NS::array output_sizes; + _ARRAY_NS::array scales; + bool align_corners; +}; diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/mps/operations/BinaryKernel.h b/phivenv/Lib/site-packages/torch/include/ATen/native/mps/operations/BinaryKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..d161dc10cf4dd036a4c1656e61e7d62f3f76b7d6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/mps/operations/BinaryKernel.h @@ -0,0 +1,10 @@ +#pragma once + +namespace at::native::mps { +void binary_op_kernel( + const std::string func_name, + const Tensor& input, + const Tensor& other, + const Tensor& output, + const std::optional alpha = std::nullopt); +} // namespace at::native::mps diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.h b/phivenv/Lib/site-packages/torch/include/ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..89d635205c2b1f9333ae1e40b6f7423ea3b508d7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.h @@ -0,0 +1,38 @@ +#pragma once +#include + +namespace at::native::mps { + +void _fused_adam_amsgrad_mps_impl_( + TensorList params, + TensorList grads, + TensorList exp_avgs, + TensorList exp_avg_sqs, + TensorList max_exp_avg_sqs, + TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +void _fused_adam_amsgrad_mps_impl_( + TensorList params, + TensorList grads, + TensorList exp_avgs, + TensorList exp_avg_sqs, + TensorList max_exp_avg_sqs, + TensorList state_steps, + const at::Tensor& lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +} // namespace at::native::mps diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/mps/operations/FusedAdamKernelImpl.h b/phivenv/Lib/site-packages/torch/include/ATen/native/mps/operations/FusedAdamKernelImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..9f8eddbcd9125bc72920132627781a62f9ddfa42 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/mps/operations/FusedAdamKernelImpl.h @@ -0,0 +1,35 @@ +#pragma once +#include + +namespace at::native::mps { + +void _fused_adam_mps_impl_( + TensorList params, + TensorList grads, + TensorList exp_avgs, + TensorList exp_avg_sqs, + TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +void _fused_adam_mps_impl_( + TensorList params, + TensorList grads, + TensorList exp_avgs, + TensorList exp_avg_sqs, + TensorList state_steps, + const Tensor& lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); +} // namespace at::native::mps diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.h b/phivenv/Lib/site-packages/torch/include/ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..90f9d97b91c7909b9c9dfcb9178b05ada52603ec --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.h @@ -0,0 +1,37 @@ +#pragma once +#include + +namespace at::native::mps { + +void _fused_adamw_amsgrad_mps_impl_( + TensorList params, + TensorList grads, + TensorList exp_avgs, + TensorList exp_avg_sqs, + TensorList max_exp_avg_sqs, + TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +void _fused_adamw_amsgrad_mps_impl_( + TensorList params, + TensorList grads, + TensorList exp_avgs, + TensorList exp_avg_sqs, + TensorList max_exp_avg_sqs, + TensorList state_steps, + const Tensor& lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); +} // namespace at::native::mps diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/mps/operations/FusedAdamWKernelImpl.h b/phivenv/Lib/site-packages/torch/include/ATen/native/mps/operations/FusedAdamWKernelImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..6071f05faba68cd6ab2f47d100d7f5aa3bd7b2f8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/mps/operations/FusedAdamWKernelImpl.h @@ -0,0 +1,36 @@ +#pragma once +#include + +namespace at::native::mps { + +void _fused_adamw_mps_impl_( + TensorList params, + TensorList grads, + TensorList exp_avgs, + TensorList exp_avg_sqs, + TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +void _fused_adamw_mps_impl_( + TensorList params, + TensorList grads, + TensorList exp_avgs, + TensorList exp_avg_sqs, + TensorList state_steps, + const Tensor& lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +} // namespace at::native::mps diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/mps/operations/Indexing.h b/phivenv/Lib/site-packages/torch/include/ATen/native/mps/operations/Indexing.h new file mode 100644 index 0000000000000000000000000000000000000000..bbbbb52ca52d80db94dfb708b7afe3415e5c1785 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/mps/operations/Indexing.h @@ -0,0 +1,8 @@ +// Copyright © 2022 Apple Inc. +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include + +using namespace at::mps; diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/mps/operations/MultiTensorApply.h b/phivenv/Lib/site-packages/torch/include/ATen/native/mps/operations/MultiTensorApply.h new file mode 100644 index 0000000000000000000000000000000000000000..5b0cfb6c8d86add56f0369b62664defcc8f2b45c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/mps/operations/MultiTensorApply.h @@ -0,0 +1,362 @@ +#pragma once +#include +#include +#include + +static_assert(sizeof(bool) == 1); + +namespace at::native::mps { + +static constexpr int64_t kChunkSize = 65536; +static constexpr int64_t kmaxThreadGroups = 32; +static constexpr int64_t kmaxTensors = 32; + +struct MetadataArguments { // the size of this struct must be less than 4 kilobytes + uint64_t numels[kmaxTensors]; + uint64_t threadgroup_to_tensor[kmaxThreadGroups]; + uint64_t threadgroup_to_chunk[kmaxThreadGroups]; +}; + +struct FusedAdamEncodingFunctor { + void operator()(id& computeEncoder, + id& tensorArgumentBuffer, + const MetadataArguments& metadata_arguments, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize) const { + mtl_setArgs( + computeEncoder, tensorArgumentBuffer, metadata_arguments, lr, beta1, beta2, weight_decay, eps, maximize); + } + + void operator()(id& computeEncoder, + id& tensorArgumentBuffer, + const MetadataArguments& metadata_arguments, + const at::Tensor& lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize) const { + mtl_setArgs( + computeEncoder, tensorArgumentBuffer, metadata_arguments, lr, beta1, beta2, weight_decay, eps, maximize); + } +}; + +template +struct FusedSgdEncodingFunctor {}; + +template <> +struct FusedSgdEncodingFunctor { + void operator()(id& computeEncoder, + id& tensorArgumentBuffer, + const MetadataArguments& metadata_arguments, + const double weight_decay, + const double momentum, + const double lr, + const double dampening, + const bool nesterov, + const bool maximize, + const bool is_first_step) const { + mtl_setArgs(computeEncoder, + tensorArgumentBuffer, + metadata_arguments, + weight_decay, + momentum, + lr, + dampening, + nesterov, + maximize, + is_first_step); + } + + void operator()(id& computeEncoder, + id& tensorArgumentBuffer, + const MetadataArguments& metadata_arguments, + const double weight_decay, + const double momentum, + const at::Tensor& lr, + const double dampening, + const bool nesterov, + const bool maximize, + const bool is_first_step) const { + mtl_setArgs(computeEncoder, + tensorArgumentBuffer, + metadata_arguments, + weight_decay, + momentum, + lr, + dampening, + nesterov, + maximize, + is_first_step); + } +}; + +template <> +struct FusedSgdEncodingFunctor { + void operator()(id& computeEncoder, + id& tensorArgumentBuffer, + const MetadataArguments& metadata_arguments, + const double weight_decay, + const double lr, + const bool maximize) const { + mtl_setArgs(computeEncoder, tensorArgumentBuffer, metadata_arguments, weight_decay, lr, maximize); + } + + void operator()(id& computeEncoder, + id& tensorArgumentBuffer, + const MetadataArguments& metadata_arguments, + const double weight_decay, + const at::Tensor& lr, + const bool maximize) const { + mtl_setArgs(computeEncoder, tensorArgumentBuffer, metadata_arguments, weight_decay, lr, maximize); + } +}; + +std::pair, id> getFusedAdamCPLState(const std::string& fname); +template +static void multi_tensor_apply_for_fused_optimizer(const std::string& kernel_name, + std::vector>& tensor_lists, + at::TensorList state_steps, + encoder_func_t encode, + ArgTypes... args) { + const auto num_tensors = tensor_lists[0].size(); + + if (num_tensors == 0) { + return; + } + + TORCH_CHECK(tensor_lists.size() == depth, "Number of tensor lists has to match the depth"); + for (const auto& d : c10::irange(depth)) { + const auto scalar_type = tensor_lists[d][0].scalar_type(); + TORCH_CHECK(scalar_type == kFloat || scalar_type == kHalf || scalar_type == kBFloat16, + "Only float, bfloat and half are supported"); + } + + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + + // Remove comment for debugging + /* + mpsStream->addCompletedHandler(^(id cb) { + [cb.logs enumerateObjectsUsingBlock:^(NSString* log, NSUInteger idx, BOOL* stop) { + NSLog(@"MPSStream: %@", log); + } + ]; + }); + */ + + dispatch_sync_with_rethrow(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + auto [fusedOptimizerPSO, fusedOptimizerFunc] = getFusedAdamCPLState(kernel_name); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(fusedOptimizerPSO, kernel_name, {tensor_lists[0]}); + + [computeEncoder setComputePipelineState:fusedOptimizerPSO]; + + // BufferIndex is the index in the kernel function + auto tensorArgumentEncoder = [[fusedOptimizerFunc newArgumentEncoderWithBufferIndex:0] autorelease]; + id tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength + options:0] autorelease]; + [tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0]; + + int64_t tensor_loc = 0; + int64_t threadgroup_loc = 0; + MetadataArguments metadata_arguments; + + for (const auto tensor_index : c10::irange(num_tensors)) { + // short-circuit to avoid adding empty tensors to tensorListMeta + if (tensor_lists[0][tensor_index].numel() == 0) { + continue; + } + + for (const auto& d : c10::irange(depth)) { + mtl_setBuffer(tensorArgumentEncoder, tensor_lists[d][tensor_index], d * kmaxTensors + tensor_loc); + [computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index]) + usage:MTLResourceUsageRead | MTLResourceUsageWrite]; + } + if (!state_steps.empty()) { + mtl_setBuffer(tensorArgumentEncoder, state_steps[tensor_index], depth * kmaxTensors + tensor_loc); + [computeEncoder useResource:getMTLBufferStorage(state_steps[tensor_index]) usage:MTLResourceUsageRead]; + } + metadata_arguments.numels[tensor_loc] = tensor_lists[0][tensor_index].numel(); + + tensor_loc++; + + const auto numel = tensor_lists[0][tensor_index].numel(); + const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0); + TORCH_CHECK(chunks > -1); + + for (const auto& chunk : c10::irange(chunks)) { + metadata_arguments.threadgroup_to_tensor[threadgroup_loc] = tensor_loc - 1; + metadata_arguments.threadgroup_to_chunk[threadgroup_loc] = chunk; + + threadgroup_loc++; + + const auto tensor_full = tensor_loc == kmaxTensors && chunk == chunks - 1; + // Reach the maximum threadgroups per dispatch + const auto blocks_full = threadgroup_loc == kmaxThreadGroups; + + if (tensor_full || blocks_full) { + encode(computeEncoder, tensorArgumentBuffer, metadata_arguments, args...); + MTLSize gridSize = MTLSizeMake(threadgroup_loc, 1, 1); + uint32_t maxThreadsPerGroup = [fusedOptimizerPSO maxTotalThreadsPerThreadgroup]; + MTLSize threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, kThreadGroupSize), 1, 1); + [computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize]; + + // Reset + threadgroup_loc = 0; + if (chunk == chunks - 1) { + // last chunk + tensor_loc = 0; + tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength + options:0] autorelease]; + [tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0]; + } else { + // reuse the current tensor since the current one isn't done. + metadata_arguments.numels[0] = metadata_arguments.numels[tensor_loc - 1]; + + tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength + options:0] autorelease]; + [tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0]; + + for (const auto& d : c10::irange(depth)) { + mtl_setBuffer(tensorArgumentEncoder, tensor_lists[d][tensor_index], d * kmaxTensors); + [computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index]) + usage:MTLResourceUsageWrite | MTLResourceUsageRead]; + } + if (!state_steps.empty()) { + mtl_setBuffer(tensorArgumentEncoder, state_steps[tensor_index], depth * kmaxTensors); + [computeEncoder useResource:getMTLBufferStorage(state_steps[tensor_index]) usage:MTLResourceUsageRead]; + } + tensor_loc = 1; + } + } + } + } + + if (threadgroup_loc != 0) { + encode(computeEncoder, tensorArgumentBuffer, metadata_arguments, args...); + MTLSize gridSize = MTLSizeMake(threadgroup_loc, 1, 1); + uint32_t maxThreadsPerGroup = [fusedOptimizerPSO maxTotalThreadsPerThreadgroup]; + MTLSize threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, kThreadGroupSize), 1, 1); + [computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize]; + } + + getMPSProfiler().endProfileKernel(fusedOptimizerPSO); + } + }); +} + +std::pair, id> getAmpCPLState(const std::string& fname); +template +void multi_tensor_apply(const std::string& kernel_name, + std::vector>& tensor_lists, + ArgTypes... args) { + const auto num_tensors = tensor_lists[0].size(); + if (num_tensors == 0) { + return; + } + + TORCH_CHECK(tensor_lists.size() == depth, "Number of tensor lists must match depth."); + + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + + dispatch_sync_with_rethrow(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + auto [pipeline, function] = getAmpCPLState(kernel_name); + [computeEncoder setComputePipelineState:pipeline]; + + id argumentEncoder = [function newArgumentEncoderWithBufferIndex:0]; + auto tensorArgumentBuffer = [[device newBufferWithLength:argumentEncoder.encodedLength options:0] autorelease]; + [argumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0]; + + int tensor_loc = 0; + int threadgroup_loc = 0; + MetadataArguments metadata_arguments; + std::memset(&metadata_arguments, 0, sizeof(metadata_arguments)); + + for (size_t t = 0; t < num_tensors; t++) { + if (tensor_lists[0][t].numel() == 0) + continue; + + // bind each tensor in this list to the correct slots across depths + for (int d = 0; d < depth; d++) { + mtl_setBuffer(argumentEncoder, tensor_lists[d][t], d * kmaxTensors + tensor_loc); + [computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][t]) + usage:(MTLResourceUsageRead | MTLResourceUsageWrite)]; + } + + // save number of elements for this tensor + metadata_arguments.numels[tensor_loc] = tensor_lists[0][t].numel(); + int currentTensorIndex = tensor_loc; + tensor_loc++; + + const auto numel = tensor_lists[0][t].numel(); + const auto chunks = numel / kChunkSize + ((numel % kChunkSize) ? 1 : 0); + + // process tensor in chunks based on max chunk size + for (uint chunk = 0; chunk < chunks; chunk++) { + metadata_arguments.threadgroup_to_tensor[threadgroup_loc] = currentTensorIndex; + metadata_arguments.threadgroup_to_chunk[threadgroup_loc] = chunk; + threadgroup_loc++; + + // dispatch when we've filled the threadgroup array or finished the chunks + const bool dispatch_now = (threadgroup_loc == kmaxThreadGroups) || (chunk == chunks - 1); + if (dispatch_now) { + // check for a partial dispatch (i.e. more chunks remain for the current tensor) + bool partial = (chunk != chunks - 1); + uint carried_numels = 0; + if (partial) { + carried_numels = metadata_arguments.numels[currentTensorIndex]; + } + + mtl_setArgs(computeEncoder, tensorArgumentBuffer, metadata_arguments, args...); + MTLSize gridSize = MTLSizeMake(threadgroup_loc, 1, 1); + uint32_t maxThreads = [pipeline maxTotalThreadsPerThreadgroup]; + MTLSize threadGroupSize = MTLSizeMake(std::min(maxThreads, (uint32_t)64), 1, 1); + [computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize]; + + // prepare for the next batch: reset threadgroup count and create a new buffer + threadgroup_loc = 0; + tensorArgumentBuffer = [[device newBufferWithLength:argumentEncoder.encodedLength options:0] autorelease]; + [argumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0]; + + if (partial) { + // for a partial dispatch, rebind the partially processed tensor to slot 0 + // so that its metadata is in the correct location + for (int d = 0; d < depth; d++) { + mtl_setBuffer(argumentEncoder, tensor_lists[d][t], d * kmaxTensors + 0); + [computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][t]) + usage:(MTLResourceUsageRead | MTLResourceUsageWrite)]; + } + metadata_arguments.numels[0] = carried_numels; + // the currently processed tensor now lives at index 0 + currentTensorIndex = 0; + tensor_loc = 1; + } else { + tensor_loc = 0; + } + } + } + } + + if (threadgroup_loc != 0) { + mtl_setArgs(computeEncoder, tensorArgumentBuffer, metadata_arguments, args...); + MTLSize gridSize = MTLSizeMake(threadgroup_loc, 1, 1); + uint32_t maxThreads = [pipeline maxTotalThreadsPerThreadgroup]; + MTLSize threadGroupSize = MTLSizeMake(std::min(maxThreads, static_cast(64)), 1, 1); + [computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize]; + } + } + }); +} + +} // namespace at::native::mps diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/mtia/EmptyTensor.h b/phivenv/Lib/site-packages/torch/include/ATen/native/mtia/EmptyTensor.h new file mode 100644 index 0000000000000000000000000000000000000000..82ad76164512c33cec093022853682f43ae2177b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/mtia/EmptyTensor.h @@ -0,0 +1,42 @@ + +#pragma once +#include + +namespace at::detail { + +TensorBase empty_mtia( + IntArrayRef size, + ScalarType dtype, + std::optional device_opt, + std::optional memory_format_opt); + +TensorBase empty_mtia( + IntArrayRef size, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt, + std::optional memory_format_opt); + +TensorBase empty_mtia(IntArrayRef size, const TensorOptions& options); + +TensorBase empty_strided_mtia( + IntArrayRef size, + IntArrayRef stride, + ScalarType dtype, + std::optional device_opt); + +TensorBase empty_strided_mtia( + IntArrayRef size, + IntArrayRef stride, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt); + +TensorBase empty_strided_mtia( + IntArrayRef size, + IntArrayRef stride, + const TensorOptions& options); + +} // namespace at::detail diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/nested/NestedTensorBinaryOps.h b/phivenv/Lib/site-packages/torch/include/ATen/native/nested/NestedTensorBinaryOps.h new file mode 100644 index 0000000000000000000000000000000000000000..5f80526a87f3003db1c674e2d31d8ab670954c4d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/nested/NestedTensorBinaryOps.h @@ -0,0 +1,18 @@ +#pragma once + +#include +#include + +namespace at::native { + +enum class NESTED_DENSE_OP : uint8_t { ADD, MUL }; + +using nested_dense_elementwise_fn = void (*)( + Tensor& result, + const Tensor& self, + const Tensor& other, + const NESTED_DENSE_OP& op); + +DECLARE_DISPATCH(nested_dense_elementwise_fn, nested_dense_elementwise_stub) + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/nested/NestedTensorMath.h b/phivenv/Lib/site-packages/torch/include/ATen/native/nested/NestedTensorMath.h new file mode 100644 index 0000000000000000000000000000000000000000..ac9a35b6992dbdd200a5a23576556c6ecded8e33 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/nested/NestedTensorMath.h @@ -0,0 +1,79 @@ +#pragma once + +#include +#include +#include + +namespace at::native { + +TORCH_API Tensor NestedTensor_to_padded_tensor_generic( + const Tensor& t, + double padding, + OptionalIntArrayRef output_size); + +template +Tensor map_nt(const Tensor& nt, Func f) { + auto* nt_impl = get_nested_tensor_impl(nt); + const auto& sizes = nt_impl->get_nested_sizes(); + return at::detail::make_tensor(f(nt_impl->get_buffer()), sizes); +} +template +Tensor map_nt_binary(const Tensor& nt_1, const Tensor& nt_2, Func f){ + auto* nt_impl_1 = get_nested_tensor_impl(nt_1); + auto* nt_impl_2 = get_nested_tensor_impl(nt_2); + const auto& sizes = nt_impl_1->get_nested_sizes(); + return at::detail::make_tensor(f(nt_impl_1->get_buffer(), nt_impl_2->get_buffer()), sizes); +} + +C10_ALWAYS_INLINE std::pair _check_nested_layer_norm_inputs( + const NestedTensorImpl& input, + IntArrayRef normalized_shape, + const Tensor& weight /* optional */, + const Tensor& bias /* optional */) { + + const size_t normalized_ndim = normalized_shape.size(); + TORCH_CHECK( + normalized_ndim >= 1, + "Expected normalized_shape to be at least 1-dimensional, i.e., ", + "containing at least one element, but got normalized_shape = ", + normalized_shape); + TORCH_CHECK( + !weight.defined() || weight.sizes().equals(normalized_shape), + "Expected weight to be of same shape as normalized_shape, but got ", + "weight of shape ", + weight.sizes(), + " and normalized_shape = ", + normalized_shape); + TORCH_CHECK( + !bias.defined() || bias.sizes().equals(normalized_shape), + "Expected bias to be of same shape as normalized_shape, but got ", + "bias of shape ", + bias.sizes(), + " and normalized_shape = ", + normalized_shape); + + // Check that the normalized_shape has the exact same sizes as the last dimensions from the NestedTensor input + // Also, compute M and N considering the idiosyncracies of NestedTensors + int64_t N = 1; + for (const auto i: c10::irange(normalized_ndim)) { + TORCH_CHECK( + input.opt_size(-normalized_ndim + i).has_value(), + "normalized_shape extends into irregular dimensions for the nested tensor" + ); + TORCH_CHECK( + normalized_shape[i] == input.opt_size(-normalized_ndim + i), + "The shape at dimension ", + i, + "of normalized_shape doesn't match the input" + ); + N *= normalized_shape[i]; + } + + const int64_t M = input.numel() / N; + + return std::make_pair(M, N); +} + +Tensor reshape_nested(const Tensor& self, IntArrayRef proposed_shape); + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/nested/NestedTensorTransformerFunctions.h b/phivenv/Lib/site-packages/torch/include/ATen/native/nested/NestedTensorTransformerFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..fb9b1f4e4c497e32a5607ca31bd78cf61606bfd2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/nested/NestedTensorTransformerFunctions.h @@ -0,0 +1,103 @@ +/** + * Transformer-specific NestedTensor utility functions. + * + * Not co-located with NestedTensor core code yet because they only + * support specific cases needed in transformers. + */ +#pragma once + +#include + +#include +#include + +namespace c10 { +class Scalar; +} // namespace c10 + +namespace at { +class Tensor; +namespace native { +struct NestedTensorImpl; + +// Requires that self is a contiguous NestedTensor, other is not a +// NestedTensor, self.dim() == 3, and other.dim() == 2. Also, self +// must have a consistent last dimension across its included Tensors +// and that dimension must match other.size(0). +Tensor NestedTensor_matmul(const Tensor& self, const Tensor& other); + +// Requires that mat1 is a contiguous NestedTensor, self & mat2 are +// not NestedTensors, mat1.dim() == 3, mat2.dim() == 2, and that mat1 +// has a consistent last dimension across its included Tensors that +// matches mat2.size(0). +Tensor NestedTensor_times_Tensor_plus_Tensor_addmm( + const Tensor& self, + const Tensor& mat1, + const Tensor& mat2, + const c10::Scalar& beta, + const c10::Scalar& alpha, + std::optional use_gelu = std::nullopt); + +Tensor NestedTensor_add_NestedTensor_in_place( + const Tensor& self, + const Tensor& other); + +TORCH_API Tensor NestedTensor_batch_offsets_from_size_tensor( + const Tensor& sizes, + int64_t extra_elements); + +Tensor NestedTensor_from_padded_tensor_cpu( + const Tensor& padded, + const NestedTensorImpl& nt); + +TORCH_API Tensor NestedTensor_to_mask(const Tensor& nt, std::optional mask_dim, std::optional mask_dim_length); + +template +void remove_padding_kernelLauncher( + const T* input, + T* output, + const int* offsets, + const int* input_sizes, + const int* output_sizes, + int64_t output_dim, + const int64_t batch_size); + +template +void remove_padding_transform0213_kernelLauncher( + const T* input, + T* output, + const int* offsets, + const int* input_sizes, + const int* output_sizes, + int64_t output_dim, + const int64_t batch_size); + +template +void add_padding_kernelLauncher( + T* input, + T* output, + T padding_value, + const int* offsets, + const int* input_sizes, + int input_dim, + const std::vector& output_sizes, + const int batch_size, + const int output_batch_size); + +TORCH_API Tensor flash_attention_helper( + const Tensor& query, + const Tensor& key, + const Tensor& value, + double dropout_p, + bool need_attn_weights, + bool is_causal); + +TORCH_API std::tuple mem_efficient_helper_nested_unpacked( + const Tensor& query, + const Tensor& key, + const Tensor& value, + double dropout_p, + bool need_attn_weights, + bool is_causal); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/nested/NestedTensorTransformerUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/nested/NestedTensorTransformerUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..af3f363f93ab1391c6b0d65a1a6dd3f64e6beca3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/nested/NestedTensorTransformerUtils.h @@ -0,0 +1,40 @@ +#pragma once +#include + +namespace at::native::preprocessing { + +/** + * This function will take nested query, key, and value + * and will preprocess it in order to run with either + * the flash-attention or efficient-attention kernels. + * @return A tuple containing all the necessary data for running the fused + * kernels + */ +std::tuple +sdpa_nested_preprocessing( + const Tensor& query, + const Tensor& key, + const Tensor& value); + +/** + * This function will take nested query, key, and value, grad_out, and out + * and will preprocess it in order to run with either + * the flash-attention or efficient-attention kernels backwards. + * We use both functions to avoid having to do the same preprocessing + * for cumulative_sequence_length_q and cumulative_sequence_length_kv + * @return A tuple containing all the necessary data for running the fused + * kernels + */ +std::tuple +sdpa_nested_preprocessing_backward( + const at::Tensor& grad_out_, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& out, + const Tensor& cumulative_sequence_length_q, + const Tensor& cumulative_sequence_length_kv, + const int64_t max_seqlen_batch_q, + const int64_t max_seqlen_batch_kv); + +} // namespace at::native::preprocessing diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/nested/NestedTensorUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/nested/NestedTensorUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..644897d017a1c3b34ac7ec8ecf9129770f6fe7ca --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/nested/NestedTensorUtils.h @@ -0,0 +1,449 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS + +#include +#include +#else +#include +#include +#include +#include +#include +#include +#endif + +#include +#include +#include + +namespace at::native { +struct NestedTensorImpl; + +// The following functions are used to construct nested tensors from buffers and +// metadata. + +inline at::Tensor wrap_buffer(const at::Tensor& buffer, const at::Tensor& nested_sizes) { + TORCH_CHECK( + buffer.dim() == 1, + "Expected given buffer to be 1dim, but got ", + buffer.dim(), + " instead."); + TORCH_CHECK( + buffer.is_contiguous(), "Expected given buffer to be contiguous."); + return at::detail::make_tensor( + buffer, nested_sizes); +} + +// TODO: Figure out if we need a non-moving wrap_buffer() +inline at::Tensor wrap_buffer( + const at::Tensor& buffer, + at::Tensor nested_sizes, + at::Tensor nested_strides, + at::Tensor storage_offsets) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + buffer.is_contiguous(), "Given buffer must be contiguous."); + return at::detail::make_tensor( + buffer, + std::move(nested_sizes), + std::move(nested_strides), + std::move(storage_offsets)); +} + +inline at::Tensor get_buffer(const at::Tensor& tensor) { + return get_nested_tensor_impl(tensor)->get_buffer(); +} + +/** + * Create a new nested tensor that is a view of a base nested tensor + * + * create_view_tensor calls a specialized constructor that copies the + * keys from base onto the new view tensor being created. + * The storage is shared between the base and the returned view tensor + * + * All callers of this helper must: + * - Only return a view of the input + * - Must be explicit and define a derivative + * + * @param base Base tensor to construct view from. + * @param nested_sizes View tensors' sizes. + * @param nested_strides View tensors' strides. + * @param storage_offsets View tensors' offsets. + * @return A newly constructed view tensor + */ +inline at::Tensor create_nested_view_tensor( + const at::Tensor& base, + at::Tensor nested_sizes, + at::Tensor nested_strides, + at::Tensor storage_offsets) { + TORCH_INTERNAL_ASSERT( + base.is_nested(), + "This function can only be used to create nested tensor views"); + TORCH_INTERNAL_ASSERT( + c10::impl::tls_local_dispatch_key_set().excluded_.has( + c10::DispatchKey::AutogradFunctionality), + "Creating a non differentiable nested tensor view in a CompositeImplicit function is not allowed."); + return at::detail::make_tensor( + c10::TensorImpl::VIEW, + base, + std::move(nested_sizes), + std::move(nested_strides), + std::move(storage_offsets)); +} +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Helper functions for getting information about a nested tensor's shape. + +int64_t get_consistent_last_dim_of_nested_tensor(const NestedTensorImpl& nt); + +// The sizes of the underlying tensors +inline std::vector NestedTensor_get_sizes( + const NestedTensorImpl* self_ptr) { + int64_t ntensors = self_ptr->size(0); + std::vector sizes(ntensors); + if (ntensors == 0) { + return sizes; + } + const Tensor& sizemat = self_ptr->get_nested_sizes(); + int64_t orig_dim = sizemat.size(1); + // nesting scalars has empty sizes + if (orig_dim == 0) { + return sizes; + } + const int64_t* sizemat_ptr = sizemat.const_data_ptr(); + + for (const auto i : c10::irange(ntensors)) { + sizes[i] = IntArrayRef(sizemat_ptr, sizemat_ptr + orig_dim); + sizemat_ptr += orig_dim; + } + return sizes; +} + +TORCH_API std::vector NestedTensor_get_max_size( + const NestedTensorImpl& nt); + +std::vector NestedTensor_get_max_size_from_size_tensor( + const Tensor& sizes); + +inline std::vector NestedTensor_get_sizes(const at::Tensor& self) { + const NestedTensorImpl* self_ptr = get_nested_tensor_impl(self); + return NestedTensor_get_sizes(self_ptr); +} +// The strides of the underlying tensors +inline std::vector NestedTensor_get_strides( + const NestedTensorImpl* self_ptr) { + int64_t ntensors = self_ptr->size(0); + std::vector strides(ntensors); + if (ntensors == 0) { + return strides; + } + const Tensor& stridemat = self_ptr->get_nested_strides(); + int64_t orig_dim = stridemat.size(1); + // nesting scalars has empty strides + if (orig_dim == 0) { + return strides; + } + const int64_t* stridemat_ptr = stridemat.const_data_ptr(); + for (const auto i : c10::irange(ntensors)) { + strides[i] = IntArrayRef(stridemat_ptr, stridemat_ptr + orig_dim); + stridemat_ptr += orig_dim; + } + return strides; +} + +inline std::vector NestedTensor_get_strides( + const at::Tensor& self) { + const NestedTensorImpl* self_ptr = get_nested_tensor_impl(self); + return NestedTensor_get_strides(self_ptr); +} + +inline void check_numel_equals_buffer_size(const at::Tensor& self) { + auto self_impl = get_nested_tensor_impl(self); + TORCH_CHECK( + self.numel() == static_cast(self_impl->get_buffer_size()), + "Number of elements in nested tensor must match number of elements in buffer."); +} + +inline void check_numel_equals_buffer_size(const NestedTensorImpl* self_ptr) { + TORCH_CHECK( + self_ptr->numel() == static_cast(self_ptr->get_buffer_size()), + "Number of elements in nested tensor must match number of elements in buffer."); +} + +// Helper function to get size / stride / offset for a nested/normal tensor. +inline IntArrayRef get_size_for_index(const Tensor& tensor, int64_t i) { + if (tensor.is_nested()) { + std::vector tensor_sizes = + NestedTensor_get_sizes(get_nested_tensor_impl(tensor)); + return tensor_sizes[i]; + } else { + return tensor.sizes().slice(1); + } +} + +inline IntArrayRef get_stride_for_index(const Tensor& tensor, int64_t i) { + if (tensor.is_nested()) { + std::vector tensor_strides = + NestedTensor_get_strides(get_nested_tensor_impl(tensor)); + return tensor_strides[i]; + } else { + return tensor.strides().slice(1); + } +} + +inline int64_t get_offset_for_index(const Tensor& tensor, int64_t i) { + if (tensor.is_nested()) { + int64_t* offsets_ptr = get_nested_tensor_impl(tensor) + ->get_storage_offsets() + .data_ptr(); + return offsets_ptr[i]; + + } else { + int64_t offset = tensor.storage_offset(); + return offset + tensor.strides()[0] * i; + } +} +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// Data structures and functions for generically applying a function on a nested +// tensor. +namespace impl { + +template +struct NestedNode { + NestedNode() = delete; + explicit NestedNode(std::vector children) + : _is_leaf(false), _children(std::move(children)) {} + explicit NestedNode(TensorList children) + : _is_leaf(false), _children(children.vec()) {} + explicit NestedNode(T payload) + : _is_leaf(true), _payload(std::move(payload)) {} + NestedNode(const NestedNode&) = delete; + NestedNode& operator=(const NestedNode&) = delete; + NestedNode(NestedNode&&) noexcept = default; + NestedNode& operator=(NestedNode&&) noexcept = default; + ~NestedNode() = default; + inline bool is_leaf() const { + return _is_leaf; + } + inline size_t degree() const { + return _children.size(); + } + inline const std::vector unbind() const { + return _children; + } + inline T children(size_t i) const { + return _children[i]; + } + inline const T& payload() const { + return _payload; + } + inline T& payload() { + return _payload; + } + + private: + bool _is_leaf; + std::vector _children; + T _payload{}; +}; + +using TensorNode = NestedNode; + +template +class _map; + +template +class _map> { + public: + static A function_one(const F& fn, const Args&... nested_node) { + return fn(nested_node...); + } + static NestedNode function( + const F& fn, + const NestedNode&... nested_node) { + size_t degree = 0; + bool all_leaf = true; + c10::guts::tuple_map( + std::forward_as_tuple(nested_node...), [&all_leaf, °ree](auto n) { + all_leaf = all_leaf && (n.is_leaf()); + if (degree > 1 && n.degree() > 1) { + TORCH_CHECK( + degree == n.degree(), "NestedNodes must match in degree."); + } + if (n.degree() > degree) { + degree = n.degree(); + } + return nullptr; + }); + // All NestedNodes just wrap regular objects. + if (all_leaf) { + return NestedNode(std::forward(fn)(nested_node.payload()...)); + } + // Some NestedNodes wrap regular Tensors, some NestedTensors and some other + // types. + std::vector result; + for (size_t i = 0; i < degree; i++) { + auto children = c10::guts::tuple_map( + std::forward_as_tuple(nested_node...), [&i](auto a) { + static_assert( + c10::guts::is_instantiation_of::value, + "Internal error."); + // Broadcast regular arguments across NestedTensor constituents. + // This could be a Tensor, integer or anything else really. + if (a.is_leaf()) { + return a.payload(); + } + // Broadcast NestedTensors with one constituent. + if (a.degree() == 1 && !a.is_leaf()) { + return a.children(0); + } + TORCH_CHECK(a.degree() > 0, "Internal assert."); + return a.children(i); + }); + std::apply( + [&result, &fn](Args... filtered) { + result.emplace_back(function_one(fn, filtered...)); + }, + std::move(children)); + } + return NestedNode(std::move(result)); + } +}; + +// TODO: Add static assert to verify lambda arguments match nested_node types +template +static inline NestedNode< + typename c10::guts::infer_function_traits::type::return_type> +map(F&& fn, const NestedNode&... nested_node) { + return _map< + F, + typename c10::guts::infer_function_traits::type::return_type, + typename c10::guts::infer_function_traits::type::parameter_types>:: + function(std::forward(fn), nested_node...); +} + +inline TensorNode get_nested_tensor_structure(at::Tensor tensor) { + if (get_nested_tensor_impl_or_null(tensor) == nullptr) { + return TensorNode(std::move(tensor)); + } + return TensorNode(tensor.unbind()); +} + +inline Tensor wrap_tensor_node( + TensorNode tensor_node, + std::optional dtype, + std::optional layout, + std::optional device, + std::optional pin_memory) { + TORCH_CHECK( + !tensor_node.is_leaf(), "Expected TensorNode to wrap a list of Tensors."); + TensorOptions options_ = + TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory( + pin_memory); + if (tensor_node.degree() == 0) { + return wrap_buffer(ones({0}, dtype, layout, device), ones({})); + } + + // Fast path: if all tensors are on CPU, have contiguous memory, and the same + // dtype, copying can be done much faster. + bool all_tensors_cpu = true; + bool all_tensors_contiguous = true; + bool all_tensors_same_dtype = true; + auto first_dtype = tensor_node.children(0).dtype(); + std::vector start_offsets(tensor_node.degree()); + start_offsets[0] = 0; + long total_size = 0; + for (const auto i : c10::irange(tensor_node.degree())) { + all_tensors_cpu = all_tensors_cpu && tensor_node.children(i).is_cpu(); + all_tensors_contiguous = + all_tensors_contiguous && tensor_node.children(i).is_contiguous(); + all_tensors_same_dtype = all_tensors_same_dtype && + (first_dtype == tensor_node.children(i).dtype()); + if (!(all_tensors_cpu && all_tensors_contiguous && + all_tensors_same_dtype)) { + break; + } + if (i > 0) { + start_offsets[i] = + start_offsets[i - 1] + tensor_node.children(i - 1).numel(); + } + total_size += tensor_node.children(i).numel(); + } + + TensorOptions options; + Tensor nt_buffer, nt_sizes; + if (all_tensors_cpu && all_tensors_contiguous && all_tensors_same_dtype) { + nt_buffer = at::empty({total_size}, tensor_node.children(0).options()); + nt_sizes = at::empty( + {static_cast(tensor_node.degree()), + static_cast(tensor_node.children(0).sizes().size())}, + TensorOptions().dtype(kLong)); + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( + at::ScalarType::Half, + at::ScalarType::Bool, + at::ScalarType::BFloat16, + c10::typeMetaToScalarType(first_dtype), + "create_nt_buffer", + [&]() { + at::parallel_for( + 0, tensor_node.degree(), 1, [&](int64_t begin, int64_t end) { + for (int64_t i = begin; i < end; ++i) { + // Only try copying memory if there is more than 0 elements + // for a certain tensor + if (tensor_node.children(i).numel() > 0) { + memcpy( + nt_buffer.mutable_data_ptr() + start_offsets[i], + tensor_node.children(i).const_data_ptr(), + tensor_node.children(i).numel() * sizeof(scalar_t)); + } + } + }); + }); + long sizes_offset = 0; + for (size_t i = 0; i < tensor_node.degree(); ++i) { + auto tensor_sizes = tensor_node.children(i).sizes(); + for (int64_t tensor_size : tensor_sizes) { + nt_sizes.mutable_data_ptr()[sizes_offset++] = tensor_size; + } + } + options = nt_buffer.options().merge_in(options_); + } else { // Slow path + std::vector flat_tensors; + std::vector sizes; + for (const auto i : c10::irange(tensor_node.degree())) { + flat_tensors.push_back(tensor_node.children(i).reshape(-1).contiguous()); + sizes.push_back( + tensor(c10::IntArrayRef(tensor_node.children(i).sizes()))); + } + options = flat_tensors[0].options().merge_in(options_); + nt_buffer = at::cat(flat_tensors); + nt_sizes = at::native::stack(sizes); + } + + return wrap_buffer(nt_buffer.to(options), nt_sizes); +} + +} // namespace impl + +// This function is meant to ease rapid operator coverage for +// NestedTensor kernels. It is not meant to be efficient. Use it judiciously. +template +inline at::Tensor map_nested_tensor(F&& fn, A... a) { + return wrap_tensor_node( + impl::map(std::forward(fn), impl::get_nested_tensor_structure(a)...), + std::nullopt, + std::nullopt, + std::nullopt, + std::nullopt); +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/AffineQuantizer.h b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/AffineQuantizer.h new file mode 100644 index 0000000000000000000000000000000000000000..f478ba3f29526cc629316935aa8ed32f0b25b4bf --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/AffineQuantizer.h @@ -0,0 +1,128 @@ +#pragma once + +#include +#include +#include +#include + +namespace at::native { + +TORCH_API Tensor& quantize_tensor_per_tensor_affine( + const Tensor& rtensor, + Tensor& qtensor, + double scale, + int64_t zero_point); +TORCH_API Tensor& quantize_tensor_per_channel_affine( + const Tensor& rtensor, + Tensor& qtensor, + const Tensor& scales, + Tensor zero_points, + int64_t axis); + +TORCH_API Tensor& quantize_tensor_per_channel_float_qparams( + const Tensor& rtensor, + Tensor& qtensor, + const Tensor& scales, + const Tensor& zero_points, + int64_t axis); + +TORCH_API Tensor& dequantize_tensor_per_tensor_affine( + const Tensor& qtensor, + Tensor& rtensor, + double scale, + int64_t zero_point); +TORCH_API Tensor& dequantize_tensor_per_channel_affine( + const Tensor& qtensor, + Tensor& rtensor, + const Tensor& scales, + Tensor zero_points, + int64_t axis); +TORCH_API Tensor& dequantize_tensor_per_channel_float_qparams( + const Tensor& qtensor, + Tensor& rtensor, + const Tensor& scales, + const Tensor& zero_points, + int64_t axis); + +using quantize_tensor_per_tensor_affine_fn = + void (*)(const Tensor& rtensor, Tensor& qtensor, double scale, int64_t zero_point); + +using quantize_tensor_per_channel_affine_fn = void (*)( + const Tensor& rtensor, + Tensor& qtensor, + const Tensor& scales, + const Tensor& zero_points, + int64_t axis); + +using quantize_tensor_per_channel_float_qparams_fn = void (*)( + const Tensor& rtensor, + Tensor& qtensor, + const Tensor& scales, + const Tensor& zero_points, + int64_t axis); + +using dequantize_tensor_per_tensor_affine_fn = + void (*)(const Tensor& qtensor, Tensor& rtensor, double scale, int64_t zero_point); + +using dequantize_tensor_per_channel_affine_fn = void (*)( + const Tensor& qtensor, + Tensor& rtensor, + const Tensor& scales, + const Tensor& zero_points, + int64_t axis); + +using dequantize_tensor_per_channel_float_qparams_fn = void (*)( + const Tensor& qtensor, + Tensor& rtensor, + const Tensor& scales, + const Tensor& zero_points, + int64_t axis); + +using quantize_tensor_per_tensor_affine_sub_byte_fn = + void (*)(const Tensor& rtensor, Tensor& qtensor, float scale, float zero_point); + +using dequantize_tensor_per_tensor_affine_sub_byte_fn = + void (*)(const Tensor& qtensor, Tensor& rtensor, float scale, float zero_point); + +DECLARE_DISPATCH( + quantize_tensor_per_tensor_affine_fn, + quantize_tensor_per_tensor_affine_stub) +DECLARE_DISPATCH( + quantize_tensor_per_channel_affine_fn, + quantize_tensor_per_channel_affine_stub) +DECLARE_DISPATCH( + quantize_tensor_per_channel_float_qparams_fn, + quantize_tensor_per_channel_float_qparams_stub) + +DECLARE_DISPATCH( + dequantize_tensor_per_tensor_affine_fn, + dequantize_tensor_per_tensor_affine_stub) +DECLARE_DISPATCH( + dequantize_tensor_per_channel_affine_fn, + dequantize_tensor_per_channel_affine_stub) +DECLARE_DISPATCH( + dequantize_tensor_per_channel_float_qparams_fn, + dequantize_tensor_per_channel_float_qparams_stub) + +DECLARE_DISPATCH( + quantize_tensor_per_tensor_affine_sub_byte_fn, + quantize_tensor_per_tensor_affine_sub_byte_stub) + +DECLARE_DISPATCH( + dequantize_tensor_per_tensor_affine_sub_byte_fn, + dequantize_tensor_per_tensor_affine_sub_byte_stub) + +template +TORCH_API Tensor quantize_tensor( + Tensor rtensor, + Tensor qtensor, + double scale, + int64_t zero_point); +template +TORCH_API Tensor dequantize_tensor( + Tensor qtensor, + Tensor rtensor, + double scale, + int64_t zero_point); + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/AffineQuantizerBase.h b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/AffineQuantizerBase.h new file mode 100644 index 0000000000000000000000000000000000000000..e26345368029592c66b5517a152bc071d37fb571 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/AffineQuantizerBase.h @@ -0,0 +1,45 @@ +#pragma once +#include +#include + +namespace at::native { + +// Quantize a float value into a uint value given scale and zero_point +template +TORCH_API T quantize_val(double scale, int64_t zero_point, float value); +// TODO combine this with quantize_val once the numerics for ARM are aligned +// with it +template +T quantize_val_arm( + const float scale, + const int32_t zero_point, + const float value); +template +void quantize_vec( + double scale, + int64_t zero_point, + const float* src, + T* dst, + size_t count = 8); +template +TORCH_API float dequantize_val(double scale, int64_t zero_point, T value); +template +TORCH_API float dequantize_vec( + double scale, + int64_t zero_point, + const T* src, + float* dst, + size_t count = 8); +template +TORCH_API DST_T requantize_val(double, int64_t, double, int64_t, SRC_T src); + +// Given a multiplier and a zero_point, requantize int32_t computed values back +// to quantized values. See comment above +// make_per_tensor_affine_quantizer function for the usage of int64_t +template +TORCH_API DST_T +requantize_from_int(double multiplier, int64_t zero_point, int64_t src); + +int quantize_val_float_qparams(float scale, float zero_point, float value, int qmin, int qmax); + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/ConvUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/ConvUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..679777f8d65862e79fb8f3deb6aaa3b2d5c9428e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/ConvUtils.h @@ -0,0 +1,62 @@ +#pragma once +#include +#include + +namespace at::native::quantized { +namespace { +// MakeConvOutputShape used from both CPU and CUDA libraries +// and exporting symbol from torch_cpu would probably take more storage +// than duplicating implementation which likely be inlined away +template +at::SmallVector MakeConvOutputShape( + int N, // mini-batch + int M, // output channels + const std::array& input_image_shape, + const std::vector& kernel, + const torch::List& stride, + const torch::List& padding, + const torch::List& dilation); + +#if defined(USE_CUDA) || defined(USE_PYTORCH_QNNPACK) +template <> +at::SmallVector MakeConvOutputShape<2>( + int N, // mini-batch + int M, // output channels + const std::array& input_image_shape, + const std::vector& kernel, + const at::List& stride, + const at::List& padding, + const at::List& dilation) { + const int H = input_image_shape[0]; + const int W = input_image_shape[1]; + const int64_t Y_H = + (H + 2 * padding[0] - dilation[0] * (kernel[0] - 1) - 1) / stride[0] + 1; + const int64_t Y_W = + (W + 2 * padding[1] - dilation[1] * (kernel[1] - 1) - 1) / stride[1] + 1; + return {N, M, Y_H, Y_W}; +} + +template <> +at::SmallVector MakeConvOutputShape<3>( + int N, // mini-batch + int M, // output channels + const std::array& input_image_shape, + const std::vector& kernel, + const at::List& stride, + const at::List& padding, + const torch::List& dilation) { + const int D = input_image_shape[0]; + const int H = input_image_shape[1]; + const int W = input_image_shape[2]; + const int64_t Y_D = + (D + 2 * padding[0] - dilation[0] * (kernel[0] - 1) - 1) / stride[0] + 1; + const int64_t Y_H = + (H + 2 * padding[1] - dilation[1] * (kernel[1] - 1) - 1) / stride[1] + 1; + const int64_t Y_W = + (W + 2 * padding[2] - dilation[2] * (kernel[2] - 1) - 1) / stride[2] + 1; + return {N, M, Y_D, Y_H, Y_W}; +} + +#endif +} // anonymous namespace +} // namespace at::native::quantized diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/Copy.h b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/Copy.h new file mode 100644 index 0000000000000000000000000000000000000000..6b21a8c8a597e3964bdda3eaa255bced5ace4064 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/Copy.h @@ -0,0 +1,8 @@ +#pragma once + +#include + +namespace at::native { + +Tensor& quantized_copy_from_float_(Tensor& self, const Tensor& src); +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/FakeQuantAffine.h b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/FakeQuantAffine.h new file mode 100644 index 0000000000000000000000000000000000000000..143804f990f5cf8c64eb33d5ca03a25e1d48f6e1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/FakeQuantAffine.h @@ -0,0 +1,67 @@ +#pragma once + +#include +#include +#include + +namespace at { + +struct TensorIterator; + +namespace native { + +using fake_quant_tensor_cachemask_fn = void (*)( + Tensor& output, + Tensor& mask, + const Tensor& input, + float sc, + int64_t z_point, + int64_t quant_min, + int64_t quant_max); + +using fake_quant_tensor_cachemask_tensor_qparams_fn = void (*)( + Tensor& output, + Tensor& mask, + const Tensor& input, + const Tensor& sc, + const Tensor& z_point, + const Tensor& fake_quant_enabled, + int64_t quant_min, + int64_t quant_max); + +using fake_quant_learnable_grad_tensor_fn = void (*)( + TensorIterator& iter, + float scale, + float inv_scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + float grad_factor); + +DECLARE_DISPATCH(fake_quant_tensor_cachemask_fn, fake_quant_tensor_cachemask_stub) +DECLARE_DISPATCH(fake_quant_tensor_cachemask_tensor_qparams_fn, fake_quant_tensor_cachemask_tensor_qparams_stub) +DECLARE_DISPATCH(fake_quant_learnable_grad_tensor_fn, fake_quant_grad_learnable_tensor_stub) + +using fake_quant_per_channel_fn = void (*)( + TensorIterator &iter, + int64_t quant_min, + int64_t quant_max); + +using fake_quant_per_channel_cachemask_fn = void (*)( + TensorIterator &iter, + TensorIterator &iter_mask, + int64_t quant_min, + int64_t quant_max); + +DECLARE_DISPATCH(fake_quant_per_channel_cachemask_fn, fake_quant_per_channel_cachemask_stub) + +using fake_quant_learnable_per_channel_fn = void (*)( + TensorIterator &iter, + int64_t quant_min, + int64_t quant_max, + float grad_factor); + +DECLARE_DISPATCH(fake_quant_learnable_per_channel_fn, fake_quant_grad_learnable_channel_stub) + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/IndexKernel.h b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/IndexKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..9364f949ca02c1294cc08704521d9bec48c664c0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/IndexKernel.h @@ -0,0 +1,13 @@ +#pragma once +#include +#include + +namespace at::native { +using masked_fill_kernel_quantized_fn = void(*)(TensorIterator& iter, const Scalar& value, double scale, int zero_point); +using index_put_kernel_quantized_fn = void(*)(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate, double scale, int zero_point); + +DECLARE_DISPATCH(masked_fill_kernel_quantized_fn, masked_fill_kernel_quantized_stub) +DECLARE_DISPATCH(index_put_kernel_quantized_fn, index_put_kernel_quantized_stub) + + +} // at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/PackedParams.h b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/PackedParams.h new file mode 100644 index 0000000000000000000000000000000000000000..03c812603a95b1c97ef84101882671017619f245 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/PackedParams.h @@ -0,0 +1,147 @@ +#pragma once + +#include +#include + +struct LinearPackedParamsBase : public torch::jit::CustomClassHolder { + virtual at::Tensor apply( + at::Tensor input, + double output_scale, + int64_t output_zero_point) = 0; + virtual at::Tensor apply_relu( + at::Tensor input, + double output_scale, + int64_t output_zero_point) = 0; + + // out variant of LinearPackedParamsBase::apply + virtual at::Tensor& apply_out( + const at::Tensor& /*input*/, + double /*output_scale*/, + int64_t /*output_zero_point*/, + at::Tensor& output) { + throw std::runtime_error( + "apply_out is not implemented for this packed " + "parameter type"); + return output; + } + + virtual at::Tensor& apply_relu_out( + const at::Tensor& /*input*/, + double /*output_scale*/, + int64_t /*output_zero_point*/, + at::Tensor& output) { + throw std::runtime_error( + "apply_relu_out is not implemented for this packed " + "parameter type"); + return output; + } + + // Corresponding pattern (the ops with `*` are part of the pattern that + // represents the computation of quantized::linear_with_input_q_dq_qweight_dq_output_fp32): + // input -> q* -> dq* -> linear* -> + // qweight -> dq* / + // + // After fusion: + // input -> quantized::linear_with_input_q_dq_qweight_dq_output_fp32* -> + // qweight / + // + // Additional Note: the weight is packed as well + // Params: + // X: float32 Tensor, will be quantized to quint8 in the op + // W_prepack: packed qint8 quantized weight and bias + // Returns: + // Y: float32 Tensor + virtual at::Tensor apply_with_input_q_dq_qweight_dq_output_fp32( + at::Tensor input, + double input_scale, + int64_t input_zero_point) { + throw std::runtime_error( + "apply_with_input_q_dq_qweight_dq_output_fp32 is not implemented for this packed " + "parameter type"); + return {}; + } + + // Corresponding pattern (the ops with `*` are part of the pattern that + // represents the computation of quantized::linear_with_input_q_dq_qweight_dq_relu_output_fp32): + // input -> q* -> dq* -> linear* -> relu* -> + // qweight -> dq* / + // + // After fusion: + // input -> quantized::linear_with_input_q_dq_qweight_dq_relu_output_fp32* -> + // qweight / + // + // Additional Note: the weight is packed as well + // Params: + // input: float32 Tensor, will be quantized to quint8 in the op + // Returns: + // float32 Tensor + virtual at::Tensor apply_with_input_q_dq_qweight_dq_relu_output_fp32( + at::Tensor input, + double input_scale, + int64_t input_zero_point) { + throw std::runtime_error( + "apply_with_input_q_dq_qweight_dq_relu_output_fp32 is not implemented for this packed " + "parameter type"); + return {}; + } + + virtual at::Tensor apply_dynamic( + at::Tensor input, + bool reduce_range = false) = 0; + virtual at::Tensor apply_dynamic_relu( + at::Tensor input, + bool reduce_range = false) = 0; + + virtual at::Tensor& apply_dynamic_out( + const at::Tensor& /* input */, + at::Tensor& output, + bool /* reduce_range */) { + throw std::runtime_error( + "apply_dynamic_out is not implemented for this packed " + "parameter type"); + return output; + } + virtual at::Tensor& apply_dynamic_relu_out( + const at::Tensor& /* input */, + at::Tensor& output, + bool /* reduce_range */) { + throw std::runtime_error( + "apply_dynamic_relu_out is not implemented for this packed " + "parameter type"); + return output; + } + + virtual std::tuple> unpack() = 0; + + virtual std::optional bias() = 0; + + virtual void set_bias(std::optional /*bias*/) { + throw std::runtime_error( + "set_bias is not implemented for this packed " + "parameter type"); + } +}; + +template +struct ConvPackedParamsBase : public torch::jit::CustomClassHolder { + virtual at::Tensor apply( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) = 0; + virtual at::Tensor apply_relu( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) = 0; + virtual at::Tensor apply_dynamic( + const at::Tensor& input, + bool reduce_range) = 0; + + virtual std::tuple> unpack() = 0; + + virtual torch::List stride() const = 0; + virtual torch::List padding() const = 0; + virtual torch::List output_padding() const = 0; + virtual torch::List dilation() const = 0; + virtual int64_t groups() const = 0; + virtual bool transpose() const = 0; +}; diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/ACLUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/ACLUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..42ac2307807bf203c04ae0d95b2a1e49c3073a82 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/ACLUtils.h @@ -0,0 +1,257 @@ +#pragma once + +#include +#if AT_MKLDNN_ACL_ENABLED() + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Utilities for Arm Compute Library (ACL) quantized operations +// Provides interfaces to leverage ACL's accelerated kernels for statically and +// dynamically quantized matmuls (i.e. qlinear and qlinear_dynamic) These are +// utalized through PackedLinearWeightsACL which extends +// PackedLinearWeightsOnednn Note that PackedLinearWeightsACL extends rather +// than replaces PackedLinearWeightsOnednn for AArch64 because ACL currently +// only supports per_tensor weight quantization. +namespace at::native::acl_utils { + +using QuantMatmulCacheKey = std::tuple< + int64_t, // M + bool, // FUSE_RELU + int64_t, // NUM_THREADS + double, // INPUT_SCALE + int64_t, // INPUT_OFFSET + double, // OUTPUT_SCALE + int64_t, // OUTPUT_OFFSET + bool // SIGNED_INPUT + >; + +enum class QuantMatmulCacheKeyIndex { + M, + FUSE_RELU, + NUM_THREADS, + INPUT_SCALE, + INPUT_OFFSET, + OUTPUT_SCALE, + OUTPUT_OFFSET, + SIGNED_INPUT +}; + +// Abstract interface to share common stuff between static/dynamic ACL matmuls. +struct QuantMatmul { + arm_compute::NEGEMMLowpMatrixMultiplyCore gemm; + // key for use in the cache + QuantMatmulCacheKey key; + + QuantMatmul( + int64_t weight_dim_0, + int64_t weight_dim_1, + double weight_scale, + int64_t weight_offset, + int8_t* weight_ptr, + std::optional bias_ptr, + const QuantMatmulCacheKey& cache_key); + + virtual ~QuantMatmul(); + virtual arm_compute::Status validate() = 0; + virtual void configure() = 0; + + protected: + arm_compute::Tensor wei_q_tensor_; + std::optional bia_tensor_; + arm_compute::GEMMInfo gemm_info_; + std::optional relu_info_; +}; + +struct DynamicQuantMatmul : public QuantMatmul { + arm_compute::Tensor src_q_tensor; + arm_compute::Tensor src_tensor; + arm_compute::Tensor dst_tensor; + arm_compute::NEQuantizationLayer quant; + // We need a ReLU layer here (unlike static quantization) because the ReLU + // cannot be "truly" fused with the GEMM through gemm_info in ACL dynamically + // quantized matmuls. + std::optional relu; + + DynamicQuantMatmul( + int64_t weight_dim_0, + int64_t weight_dim_1, + double weight_scale, + int64_t weight_offset, + int8_t* weight_ptr, + std::optional bias_ptr, + const QuantMatmulCacheKey& cache_key); + + ~DynamicQuantMatmul() override; + + arm_compute::Status validate() override; + void configure() override; + + private: + at::Tensor src_q_tensor_orig_; +}; + +struct StaticQuantMatmul : public QuantMatmul { + arm_compute::Tensor src_q_tensor; + arm_compute::Tensor dst_q_tensor; + + StaticQuantMatmul( + int64_t weight_dim_0, + int64_t weight_dim_1, + double weight_scale, + int64_t weight_offset, + int8_t* weight_ptr, + std::optional bias_ptr, + const QuantMatmulCacheKey& cache_key); + + ~StaticQuantMatmul() override; + + arm_compute::Status validate() override; + void configure() override; + + private: + std::optional bia_q_tensor_; + std::optional bia_q_tensor_orig_; +}; + +struct QuantAdd { + arm_compute::Tensor qa_tensor; + arm_compute::Tensor qb_tensor; + arm_compute::Tensor qdst_tensor; + arm_compute::NEArithmeticAddition q_add; + + QuantAdd( + arm_compute::DataType dtype, + const std::vector& input_dims, + double qa_scale, + int64_t qa_offset, + double qb_scale, + int64_t qb_offset, + double dst_scale, + int64_t dst_offset); + + arm_compute::Status validate(); + void configure(); + + private: + arm_compute::ConvertPolicy policy{arm_compute::ConvertPolicy::SATURATE}; +}; + +} // namespace at::native::acl_utils +struct PackedLinearWeightsACL : public PackedLinearWeightsOnednn { + using ACLQuantMatmul = at::native::acl_utils::QuantMatmul; + using ACLDynamicQuantMatmul = at::native::acl_utils::DynamicQuantMatmul; + using ACLStaticQuantMatmul = at::native::acl_utils::StaticQuantMatmul; + using ACLQuantMatmulCacheKey = at::native::acl_utils::QuantMatmulCacheKey; + using ACLQuantMatmulCacheKeyIndex = + at::native::acl_utils::QuantMatmulCacheKeyIndex; + + PackedLinearWeightsACL( + std::unique_ptr weight, + std::optional bias, + at::Tensor orig_weight, + std::optional orig_bias); + + at::Tensor apply_dynamic(at::Tensor input, bool reduce_range = false) + override; + at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range = false) + override; + + at::Tensor apply( + at::Tensor input, + double output_scale, + int64_t output_zero_point) override; + at::Tensor apply_relu( + at::Tensor input, + double output_scale, + int64_t output_zero_point) override; + + template + std::shared_ptr get_acl_quant_matmul( + const ACLQuantMatmulCacheKey& key) { + return std::dynamic_pointer_cast( + fetch_or_create_acl_quant_matmul(key)); + } + + private: + int64_t k_; + int64_t n_; + int64_t weight_zero_point_; + double weight_scale_; + + // A 2 element (per layer) cache. Given it's not intended to store more than 2 + // elements, we do not need a fancy implementation. The idea behind it is to + // allow for a (configuration free) fast path for autoregressive + // transformer-like models which usually involve 2 input tensor shapes; one + // for the prefill phase and another for the autoregressive phase + std::array, 2> cache_; + + template + std::shared_ptr fetch_or_create_acl_quant_matmul( + const ACLQuantMatmulCacheKey& key) { + // We're only maintaining a 2 element LRU cache + // hit first + if (cache_[0] != nullptr && cache_[0]->key == key) { + return cache_[0]; + } + // hit second + if (cache_[1] != nullptr && cache_[1]->key == key) { + // Update LRU + std::swap(cache_[0], cache_[1]); + return cache_[0]; + } + // miss -> replace Least Recently Used - i.e. element at index 1 + cache_[1] = create_acl_quant_matmul(key); + std::swap(cache_[0], cache_[1]); + return cache_[0]; + } + + template + std::shared_ptr create_acl_quant_matmul( + const ACLQuantMatmulCacheKey& key) { + std::optional bias_ptr; + if (bias_.has_value()) { + bias_ptr = (float*)bias_.value().get_data_handle(); + } + auto acl_gemm = std::make_shared( + k_, + n_, + weight_scale_, + weight_zero_point_, + (int8_t*)weight_.get()->get_data_handle(), + bias_ptr, + key); + + // validate + auto status = acl_gemm->validate(); + if (status.error_code() != arm_compute::ErrorCode::OK) { + TORCH_WARN( + "Arm Compute Library's Quantized Matmul Validation Failed: " + + status.error_description()); + return nullptr; + } + + // configure + acl_gemm->configure(); + return acl_gemm; + } + + template + at::Tensor apply_dynamic_impl(at::Tensor input, bool reduce_range = false); + + template + at::Tensor apply_impl( + at::Tensor input, + double output_scale, + int64_t output_zero_point); +}; + +#endif // AT_MKLDNN_ACL_ENABLED() diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/BinaryOps.h b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/BinaryOps.h new file mode 100644 index 0000000000000000000000000000000000000000..bfb907b468c74eadfe26cbee0cbd789b5657a315 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/BinaryOps.h @@ -0,0 +1,6 @@ +#include + +namespace at::native { +TORCH_API Tensor +quantized_add(Tensor qa, Tensor qb, double scale, int64_t zero_point); +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/EmbeddingPackedParams.h b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/EmbeddingPackedParams.h new file mode 100644 index 0000000000000000000000000000000000000000..f8a2b1707ec7855cc619e0e9f48f53adc4e4a942 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/EmbeddingPackedParams.h @@ -0,0 +1,29 @@ +#pragma once + +#include +#include + +struct EmbeddingPackedParamsBase : public torch::jit::CustomClassHolder { + virtual at::Tensor embeddingbag_byte( + const at::Tensor& indices, + const std::optional& offsets, + bool pruned_weights, + const std::optional& per_sample_weights_, + const std::optional& compressed_indices_mapping, + bool include_last_offset, + bool is_embedding_op) = 0; + + virtual at::Tensor embeddingbag_4bit( + const at::Tensor& indices, + const std::optional& offsets, + bool pruned_weights, + const std::optional& per_sample_weights_, + const std::optional& compressed_indices_mapping, + bool include_last_offset, + bool is_embedding_op) = 0; + + virtual at::Tensor unpack() = 0; + + virtual int64_t bit_rate() const = 0; + virtual int64_t version() const = 0; +}; diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/OnednnUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/OnednnUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..abe597d89465b34a4100d69e8f9cb90ab0c9aeff --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/OnednnUtils.h @@ -0,0 +1,463 @@ +#pragma once + +#include +#if AT_MKLDNN_ENABLED() +#include +#include +#include +#if !defined(__powerpc__) +#include +#endif + +#include + +using PrimitiveCacheKey = std::tuple< + double, // input_scale + int64_t, // input_zero_point + std::vector, // input_shape + double, // output_scale + int64_t, // output_zero_point + int64_t, // OMP_number_of_threads + double, // accum_scale + int64_t>; // accum_zero_point + +enum CacheKeyIndex { + InputScale, + InputZeroPoint, + InputShape, + OutputScale, + OutputZeroPoint, + NumOfThreads, +}; + +// Base class of primitive cache +struct PrimitiveCache { + PrimitiveCacheKey key; + + bool hit(const PrimitiveCacheKey& key) { + return this->key == key; + } +}; + +using LinearParams = ideep::matmul_forward_params; +using Conv = dnnl::convolution_forward; +using ConvDesc = dnnl::convolution_forward::primitive_desc; +using ConvParams = ideep::convolution_forward_params; +using Deconv = dnnl::deconvolution_forward; +using DeconvDesc = dnnl::deconvolution_forward::primitive_desc; +using DeconvParams = ideep::deconv_forward_params; + +struct LinearPrimitiveCache : PrimitiveCache { + LinearPrimitiveCache() = default; + + LinearPrimitiveCache( + const PrimitiveCacheKey& key, + const LinearParams& param) { + this->key = key; + this->param = param; + } + + LinearParams param; + + // For dynamic qlinear, scale and zero point + // are set at execution time. So we only need to compare + // the rest part of key. + bool hit_dynamic(const PrimitiveCacheKey& new_key) { + auto const& cached_input_shape = std::get(this->key); + auto const& new_input_shape = std::get(new_key); + return ( + cached_input_shape == new_input_shape && + std::get(this->key) == std::get(new_key)); + } + + LinearParams& get_param() { + return param; + } +}; + +struct ConvPrimitiveCache : PrimitiveCache { + ConvPrimitiveCache() = default; + + ConvPrimitiveCache( + const PrimitiveCacheKey& key, + const ConvParams& params) { + this->key = key; + this->params = params; + } + + ConvParams params; + + ConvParams& get_params() { + return params; + } +}; + +struct DeconvPrimitiveCache : PrimitiveCache { + DeconvPrimitiveCache() = default; + + DeconvPrimitiveCache( + const PrimitiveCacheKey& key, + const DeconvParams& params) { + this->key = key; + this->params = params; + } + + DeconvParams params; + + DeconvParams& get_params() { + return params; + } +}; + +enum PostOps { + NoPostOp, + Relu, + LeakyRelu, + Tanh, + Gelu +}; + + +struct PackedLinearWeightsOnednn : public LinearPackedParamsBase { + PackedLinearWeightsOnednn( + std::unique_ptr weight, + std::optional bias, + at::Tensor orig_weight, + std::optional orig_bias) + : weight_(std::move(weight)), + bias_(std::move(bias)), + orig_weight_(std::move(orig_weight)), + orig_bias_(std::move(orig_bias)) { + cache_initialized_flag = std::make_unique(); + } + std::unique_ptr weight_; + std::optional bias_; + at::Tensor orig_weight_; + std::optional orig_bias_; + + at::Tensor apply( + at::Tensor input, + double output_scale, + int64_t output_zero_point) override; + at::Tensor apply_relu( + at::Tensor input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_dynamic(at::Tensor input, bool reduce_range=false) override; + at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range=false) override; + + at::Tensor apply_leaky_relu( + at::Tensor input, + double output_scale, + int64_t output_zero_point, + double negative_slope); + + at::Tensor apply_tanh( + at::Tensor input, + double output_scale, + int64_t output_zero_point); + + std::tuple> unpack() override; + + std::optional bias() override { + return orig_bias_; + } + + static c10::intrusive_ptr prepack( + at::Tensor weight, + std::optional bias); + + private: + LinearPrimitiveCache prim_cache; + std::unique_ptr cache_initialized_flag; + + template + at::Tensor apply_impl( + at::Tensor input, + double output_scale, + int64_t output_zero_point, + torch::List post_op_args = torch::List()); + + template + at::Tensor apply_dynamic_impl(at::Tensor input, bool reduce_range=false); + + LinearPrimitiveCache& get_cache() { + return prim_cache; + } +}; + +template +struct PackedConvWeightsOnednn : public ConvPackedParamsBase { + PackedConvWeightsOnednn( + std::unique_ptr weight, + std::optional bias, + at::Tensor orig_weight, + std::optional orig_bias, + torch::List stride, + torch::List padding, + torch::List output_padding, + torch::List dilation, + int64_t groups, + uint8_t transpose) + : weight_(std::move(weight)), + bias_(std::move(bias)), + orig_weight_(std::move(orig_weight)), + orig_bias_(std::move(orig_bias)), + stride_(std::move(stride)), + padding_(std::move(padding)), + output_padding_(std::move(output_padding)), + dilation_(std::move(dilation)), + groups_(groups), + transpose_(transpose) { + cache_initialized_flag = std::make_unique(); + } + + std::unique_ptr weight_; + std::optional bias_; + at::Tensor orig_weight_; + std::optional orig_bias_; + torch::List stride_; + torch::List padding_; + torch::List output_padding_; + torch::List dilation_; + int64_t groups_; + uint8_t transpose_; + + at::Tensor apply( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_relu( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_dynamic( + const at::Tensor& input, + bool reduce_range) override; + + at::Tensor apply_add( + const at::Tensor& input, + const at::Tensor& accum, + double output_scale, + int64_t output_zero_point); + + at::Tensor apply_add_relu( + const at::Tensor& input, + const at::Tensor& accum, + double output_scale, + int64_t output_zero_point); + + std::tuple> unpack() override; + + static c10::intrusive_ptr> prepack( + at::Tensor weight, + std::optional bias, + torch::List stride, + torch::List padding, + torch::List output_padding, + torch::List dilation, + int64_t groups, + bool transpose); + + torch::List stride() const override { + return stride_; + } + + torch::List padding() const override { + return padding_; + } + + torch::List output_padding() const override { + return output_padding_; + } + + torch::List dilation() const override { + return dilation_; + } + + int64_t groups() const override { + return groups_; + } + + bool transpose() const override { + return (bool)transpose_; + } + + private: + ConvPrimitiveCache conv_prim_cache; + DeconvPrimitiveCache deconv_prim_cache; + std::unique_ptr cache_initialized_flag; + + template + at::Tensor apply_impl( + const at::Tensor& input, + const std::optional& accum, + double output_scale, + int64_t output_zero_point); + + ConvPrimitiveCache& get_conv_cache() { + assert(!transpose()); + return conv_prim_cache; + } + + DeconvPrimitiveCache& get_deconv_cache() { + assert(transpose()); + return deconv_prim_cache; + } +}; + +namespace onednn_utils { + +inline ideep::attr_t create_attr_by_post_op( + const std::string_view& binary_post_op, + double binary_alpha, + double input1_scale, + int64_t input1_zero_point, + const ideep::tensor::desc& input1_desc, + const std::string_view& unary_post_op, + const torch::List>& unary_post_op_args, + const std::string_view& unary_post_op_algorithm) { + using ideep::tensor; + if (binary_post_op == "none") { + if (unary_post_op == "relu") { + return ideep::attr_t::fuse_relu(); + } else if (unary_post_op == "leaky_relu") { + TORCH_CHECK( + unary_post_op_args.size() == 1, + "onednn qlinear: expect one argument for post op leaky_relu but got ", unary_post_op_args.size(), " args"); + auto alpha = unary_post_op_args[0].value().to(); + return ideep::attr_t::fuse_relu_v2(alpha); + } else if (unary_post_op == "tanh") { + return ideep::attr_t::fuse_tanh(); + } else if (unary_post_op == "gelu") { + TORCH_CHECK( + unary_post_op_algorithm == "none" || unary_post_op_algorithm == "tanh", + "onednn qlinear: algorithm for post op gelu must be none or tanh but got ", unary_post_op_algorithm); + auto post_algorithm = unary_post_op_algorithm == "none" ? + dnnl::algorithm::eltwise_gelu_erf : + dnnl::algorithm::eltwise_gelu_tanh; + return ideep::attr_t::fuse_gelu_v2(0.f, 0.f, post_algorithm); + } else if (unary_post_op == "hardtanh") { + TORCH_CHECK( + unary_post_op_args.size() == 2 && + unary_post_op_args[0].has_value() && + unary_post_op_args[1].has_value(), + "hardtanh is expected to have two scalar input: min_val and max_val"); + auto lower_bound_value = + unary_post_op_args[0].value().to(); + auto upper_bound_value = + unary_post_op_args[1].value().to(); + return ideep::attr_t::fuse_clamp(lower_bound_value, upper_bound_value); + } else if (unary_post_op == "hardswish") { + return ideep::attr_t::fuse_hardswish(); + } else if (unary_post_op == "swish") { + return ideep::attr_t::fuse_swish(); + } else { + TORCH_CHECK( + unary_post_op == "none", + "onednn qlinear: unsupported unary post op ", unary_post_op); + } + } else if (binary_post_op == "sum") { + if (unary_post_op == "none") { + return ideep::attr_t::fuse_sum(input1_scale, input1_zero_point); + } else if (unary_post_op == "relu") { + return ideep::attr_t::residual_with_sum_zero_point(input1_scale, input1_zero_point); + } else { + TORCH_CHECK( + false, + "onednn qlinear: unsupported unary post op ", unary_post_op, " with binary post op sum"); + } + } else if (binary_post_op == "add") { + if (unary_post_op == "none") { + return ideep::attr_t::fuse_binary(ideep::algorithm::binary_add, input1_desc); + } else if (unary_post_op == "relu") { + ideep::post_ops po; + po.append_binary(ideep::algorithm::binary_add, input1_desc); + po.append_eltwise(ideep::algorithm::eltwise_relu, 0, 0); + return ideep::attr_t::attr_post_ops(po); + } else { + TORCH_CHECK( + false, + "onednn qlinear: unsupported unary post op ", unary_post_op, " with binary post op add"); + } + } else { + TORCH_CHECK( + false, + "onednn qlinear: unsupported binary post op ", binary_post_op); + } + return ideep::attr_t(); +} + +// ONEDNN requires symmetric quantization of weight +// Use this util function to check. +inline bool is_weight_symmetric_quant( + const at::Tensor& weight, + bool is_transposed_conv) { + bool is_symmetric = true; + const auto qtype = weight.qscheme(); + if (qtype == c10::kPerTensorAffine) { + is_symmetric &= (weight.q_zero_point() == 0); + } else if (qtype == c10::kPerChannelAffine) { + if (is_transposed_conv) { + // This case is currently not supported in PyTorch + // but we do not want to raise an error in this util function. + is_symmetric = false; + } else { + auto output_channels = weight.size(0); + for (int i = 0; i < output_channels; ++i) { + auto zp = weight.q_per_channel_zero_points()[i].item(); + is_symmetric &= (zp == 0); + } + } + } else { + // This case is currently not supported in PyTorch + // but we do not want to raise an error in this util function. + is_symmetric = false; + } + return is_symmetric; +} + +// When qengine is x86, use this util func to check if onednn kernel +// is preferred than fbgemm's to get better performance. +inline bool should_use_onednn_quant( + const at::Tensor& weight, + bool is_transposed_conv, + int groups, + torch::List output_padding) { + // Performance of onednn is only validated on Linux right now. + // Also, the heuristics for dispatching are based on perf data on Linux. + // So, for x86 qengine, we always use fbgemm kernels if OS is not Linux. + // TODO Support more OSs. +#if !defined(__linux__) + return false; +#else +#if defined(__powerpc__) + constexpr auto vnni_available = true; +#else + const auto vnni_available = cpuinfo_has_x86_avx512vnni(); +#endif + bool w_sym_quant = + is_weight_symmetric_quant(weight, is_transposed_conv); + bool opad_all_zero = + std::all_of(output_padding.begin(), output_padding.end(), [](int i) { return i==0; }); + return vnni_available && (groups <= 100) && w_sym_quant && opad_all_zero; +#endif +} + +} // onednn_utils + +at::Tensor _qconv_prepack_onednn( + at::Tensor weight, // from CPU backend instead of QuantizedCPU + at::Tensor weight_scales, // Weight zero points must be 0 for onednn + double input_scale, + int64_t input_zero_point, + torch::List stride, + torch::List padding, + torch::List dilation, + int64_t groups, + std::optional> input_shape=std::nullopt); + +#endif // #if AT_MKLDNN_ENABLED() diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/QnnpackUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/QnnpackUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..dff7ead94d4af235f30af30fff1db3b291394b7e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/QnnpackUtils.h @@ -0,0 +1,508 @@ +#pragma once + +#ifdef USE_PYTORCH_QNNPACK +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +#include +inline int kPaddingChannels = 8; +struct QnnpackOperatorDeleter { + void operator()(pytorch_qnnp_operator_t op) { + pytorch_qnnp_delete_operator(op); + } +}; + +// PackedWeight struct for QNNPACK stores the original Weight and Bias as +// QNNPACK currently does not support an unpack function. +// For PyTorch Mobile, once the model is scripted and serialized we don't need +// to call unpack, so we can save some memory by checking for this case and free +// the original weights after packing. +// Input scale is set to null in pre-pack step. QNNPACK needs bias quantized +// with input scale which is available at runtime in pytorch. During runtime if +// input scale value changes then we requantize bias with the updated scale. For +// inference we expect the graph to be static so the input scale should not +// change across consecutive inference calls. +struct PackedLinearWeightsQnnp : public LinearPackedParamsBase { + PackedLinearWeightsQnnp( + std::unique_ptr w, + at::Tensor orig_weight, + at::Tensor bias, + std::optional input_scale, + at::Tensor w_scales, + std::vector&& w_zps) + : w(std::move(w)), + orig_weight(std::move(orig_weight)), + bias_(at::native::mobile::allocate_padded_contiguous_if_needed( + bias, bias.suggest_memory_format())), + per_channel_(this->orig_weight.qscheme() == at::kPerChannelAffine), + input_scale(std::move(input_scale)), + w_scales(std::move(w_scales)), + w_zero_points(std::move(w_zps)), + q_scheme(this->orig_weight.qscheme()) { + weight_sizes = this->orig_weight.sizes().vec(); + } + + std::unique_ptr w; + at::Tensor orig_weight; + at::Tensor bias_; + bool per_channel_; + std::optional input_scale; + at::Tensor w_scales; + std::vector w_zero_points; + std::vector requantization_scales; + std::vector weight_sizes; + c10::QScheme q_scheme; + + at::Tensor apply( + at::Tensor input, + double output_scale, + int64_t output_zero_point) override; + at::Tensor apply_relu( + at::Tensor input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_dynamic(at::Tensor input, bool reduce_range=false) override; + at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range=false) override; + + std::tuple> unpack() override; + + std::optional bias() override { + return bias_; + } + + static c10::intrusive_ptr prepack( + at::Tensor weight, + std::optional bias); + + bool per_channel() const { + return per_channel_; + } + + private: + std::mutex qnnp_mutex_; + +#ifdef USE_XNNPACK + xnnpack_operator xnnp_linear_op; + + template + at::Tensor apply_impl_xnnp( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point); +#endif // USE_XNNPACK + + template + at::Tensor apply_impl( + at::Tensor input, + double output_scale, + int64_t output_zero_point); + + template + at::Tensor apply_dynamic_impl(at::Tensor input, bool reduce_range); +}; + +template +struct PackedConvWeightsQnnp : public ConvPackedParamsBase { + PackedConvWeightsQnnp( + std::unique_ptr w, + at::Tensor orig_weight, + at::Tensor bias, + torch::List stride, + torch::List padding, + torch::List output_padding, + torch::List dilation, + int64_t groups, + bool transpose, + std::optional input_scale, + std::vector kernel, + at::Tensor w_scale, + std::vector&& w_zps, + bool is_per_channel) + : w(std::move(w)), + orig_weight(std::move(orig_weight)), + bias(std::move(bias)), + stride_(std::move(stride)), + padding_(std::move(padding)), + output_padding_(std::move(output_padding)), + dilation_(std::move(dilation)), + groups_(groups), + transpose_(transpose), + is_per_channel_(is_per_channel), + input_scale(input_scale), + kernel_(std::move(kernel)), + w_scales(std::move(w_scale)), + w_zero_points(std::move(w_zps)) { + const bool any_padding = std::any_of( + padding_.begin(), padding_.end(), [](const auto& e) { return e != 0; }); + const size_t kernel_size = + std::accumulate(kernel_.begin(), kernel_.end(), 1, std::multiplies<>()); + + const size_t group_input_channels = transpose + ? this->orig_weight.size(0) / groups + : this->orig_weight.size(1); + const size_t group_output_channels = transpose + ? this->orig_weight.size(1) + : this->orig_weight.size(0) / groups; + + const size_t kernel_depth = kSpatialDim == 3 ? kernel_[0] : 1; + const size_t kernel_height = kernel_[kSpatialDim - 2]; + const size_t kernel_width = kernel_[kSpatialDim - 1]; + + pytorch_qnnp_ukernel_type ukernel_type; + if (transpose_) { + ukernel_type = pytorch_qnnp_ukernel_type_conv; + } else { + ukernel_type = pytorch_qnnp_ukernel_type_none; + + const bool has_depthwise_dimensions = + (kSpatialDim == 2 && + ((kernel_height == 3 && kernel_width == 3) || + (kernel_height == 5 && kernel_width == 5))) || + (kSpatialDim == 3 && kernel_height == 3 && kernel_width == 3 && + kernel_depth == 3); + const bool has_depthwise_grouping = + group_input_channels == 1 && group_output_channels == 1 && groups > 1; + + if (has_depthwise_dimensions && has_depthwise_grouping) { + ukernel_type = pytorch_qnnp_ukernel_type_dwconv; + } else if ( + kernel_size == 1 && + std::all_of( + stride_.begin(), + stride_.end(), + [](const auto& e) { return e == 1; }) && + !any_padding) { + ukernel_type = group_input_channels >= SIZE_MAX + ? pytorch_qnnp_ukernel_type_xzp_gemm + : pytorch_qnnp_ukernel_type_gemm; + } else { + ukernel_type = pytorch_qnnp_ukernel_type_conv; + } + } + + if (is_per_channel && ukernel_type == pytorch_qnnp_ukernel_type_xzp_gemm) { + TORCH_INTERNAL_ASSERT( + false, "Per channel quantized weights are not supported for XZP kernels"); + } + + pytorch_qnnp_operator_t convolution{nullptr}; + // Initially all the params are set to zero. + convolution = static_cast( + calloc(1, sizeof(struct pytorch_qnnp_operator))); + if (convolution == nullptr) { + TORCH_INTERNAL_ASSERT( + false, "failed to allocate %zu bytes for pytorch_qnnp_operator structure", + sizeof(struct pytorch_qnnp_operator)); + } + + convolution_op = + std::unique_ptr( + convolution); + + // NOLINTNEXTLINE(clang-analyzer-core.NullDereference) + convolution->ukernel_type = ukernel_type; + convolution->groups = groups; + convolution->group_input_channels = group_input_channels; + convolution->group_output_channels = group_output_channels; + convolution->kernel_depth = kernel_depth; + convolution->kernel_height = kernel_height; + convolution->kernel_width = kernel_width; + convolution->stride_depth = kSpatialDim == 3 ? stride_[0] : 1; + convolution->stride_height = stride_[kSpatialDim - 2]; + convolution->stride_width = stride_[kSpatialDim - 1]; + convolution->dilation_depth = kSpatialDim == 3 ? dilation_[0] : 1; + convolution->dilation_height = dilation_[kSpatialDim - 2]; + convolution->dilation_width = dilation_[kSpatialDim - 1]; + convolution->input_padding_height = padding_[kSpatialDim - 2]; + convolution->input_padding_width = padding_[kSpatialDim - 1]; + convolution->input_padding_depth = kSpatialDim == 3 ? padding_[0] : 0; + convolution->per_channel = is_per_channel_; + convolution->transpose = transpose_; + + const uint32_t kr = pytorch_qnnp_params.q8conv.kr; + const size_t k_stride = (group_input_channels + (kr - 1)) & -kr; + + size_t zero_size = sizeof(uint8_t) * k_stride; + size_t zero_offset = 0; + + if (transpose_) { + convolution->adjustment_width = output_padding_[1]; + convolution->adjustment_height = output_padding_[0]; + if (group_input_channels < 8) { + zero_size += 8; + zero_offset = 8; + } + } else { + zero_buffer_size = 0; + if (any_padding) { + zero_size = 0; + zero_offset = 0; + if (ukernel_type == pytorch_qnnp_ukernel_type_dwconv) { + const uint32_t cr = pytorch_qnnp_params.q8dw9.cr; + const size_t group_stride = (groups + (cr - 1)) & -cr; + if (groups >= 8) { + zero_size = sizeof(uint8_t) * group_stride; + zero_offset = 0; + } else { + zero_size = sizeof(uint8_t) * group_stride + 8; + zero_offset = sizeof(uint8_t) * 8; + } + } else if ( + ukernel_type == pytorch_qnnp_ukernel_type_conv || + ukernel_type == pytorch_qnnp_ukernel_type_gemm) { + if (group_input_channels >= 8) { + zero_size = sizeof(uint8_t) * k_stride; + zero_offset = 0; + } else { + zero_size = sizeof(uint8_t) * k_stride + 8; + zero_offset = 8; + } + } + } + } + + // NOLINTNEXTLINE(clang-analyzer-optin.portability.UnixAPI) + void* zero_buffer = malloc(zero_size); + if (zero_buffer == nullptr) { + pytorch_qnnp_delete_operator(convolution); + TORCH_INTERNAL_ASSERT( + false, "failed to allocate %zu bytes for zero padding", + zero_size); + } + // Need to set to input zero point + // memset(zero_buffer, input_zero_point, zero_size); + zero_buffer_size = zero_size; + convolution->zero_buffer = zero_buffer; + convolution->zero_pointer = (void*)((uintptr_t)zero_buffer + zero_offset); + } + + std::unique_ptr convolution_op; + #ifdef USE_XNNPACK + xnnpack_operator xnnp_convolution_op; + #endif // USE_XNNPACK + std::unique_ptr w; + at::Tensor orig_weight; + at::Tensor bias; + torch::List stride_; + torch::List padding_; + torch::List output_padding_; + torch::List dilation_; + int64_t groups_; + bool transpose_; + bool is_per_channel_; + std::optional input_scale; + std::vector kernel_; + at::Tensor w_scales; + std::vector w_zero_points; + std::vector requantization_scales; + size_t zero_buffer_size; + + at::Tensor apply( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_relu( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_dynamic( + const at::Tensor& input, + bool reduce_range=false) override; + + std::tuple> unpack() override; + + static c10::intrusive_ptr> prepack( + at::Tensor weight, + std::optional bias, + torch::List stride, + torch::List padding, + torch::List output_padding, + torch::List dilation, + int64_t groups, + bool transpose); + + torch::List stride() const override { + return stride_; + } + + torch::List padding() const override { + return padding_; + } + + torch::List output_padding() const override { + return output_padding_; + } + + torch::List dilation() const override { + return dilation_; + } + + int64_t groups() const override { + return groups_; + } + + bool transpose() const override { + return transpose_; + } + + bool per_channel() const { + return is_per_channel_; + } + + private: + std::mutex qnnp_mutex_; + template + at::Tensor apply_impl( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point); + +#ifdef USE_XNNPACK + template + at::Tensor apply_impl_xnnp( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point); +#endif // USE_XNNPACK +}; + +enum class Activation : uint8_t { NONE = 0, RELU = 1 }; + +template +inline T QuantizeValue(float scale, int32_t zero_point, float value) { + const int32_t qmin = std::numeric_limits::min(); + const int32_t qmax = std::numeric_limits::max(); + auto r = zero_point + static_cast(std::nearbyint(value / scale)); + r = std::max(r, qmin); + r = std::min(r, qmax); + return static_cast(r); +} + +template +inline std::pair activationLimits( + float scale, + int32_t zero_point, + Activation Ac) { + switch (Ac) { + case Activation::NONE: + return {std::numeric_limits::min(), + std::numeric_limits::max()}; + case Activation::RELU: + return {QuantizeValue(scale, zero_point, 0.0), + std::numeric_limits::max()}; + default: +#ifdef _MSC_VER + __assume(0); +#else + __builtin_unreachable(); +#endif + } +} + +namespace at::native::qnnp_avgpool_helper { +Tensor qnnpack_avg_pool2d( + Tensor input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + bool ceil_mode, + bool count_include_pad, + std::optional divisor_override); +} // namespace at::native::qnnp_avgpool_helper + +namespace { +[[maybe_unused]] std::vector generate_requantization_scales( + const at::Tensor& weight_scales, + const float input_scale, + const float output_scale, + std::vector& requant_scales) { + // Since weight scale is allocated with padding + // weight_scales.numel() gives us padded num elements. + const auto num_output_channels_padded = weight_scales.numel(); + float *const weight_scales_data = weight_scales.data_ptr(); + if (static_cast(requant_scales.size()) < num_output_channels_padded) { + requant_scales.resize(num_output_channels_padded); + } + for (const auto i : c10::irange(num_output_channels_padded)) { + const auto inverse_output_scale = 1.f /output_scale; + requant_scales[i] = (weight_scales_data[i] * input_scale) * inverse_output_scale; + TORCH_CHECK( + (requant_scales[i] > 0.0f && std::isnormal(requant_scales[i])), + "failed to create op with requantization scale: ", + requant_scales[i], + ": requantization scale must be finite and positive"); + } + return requant_scales; +} + +[[maybe_unused]] std::pair, at::Tensor> +make_zero_points_and_scales_tensor( + const at::Tensor& weight_contig, + bool transpose = false, + uint32_t groups = 1) { + const int out_ch_idx = transpose ? 1 : 0; + const auto num_output_channels = weight_contig.size(out_ch_idx) * (transpose ? groups : 1); + // Add 8 to account for bufferring needed by QNNPACK. + const auto num_output_channels_padded = num_output_channels + kPaddingChannels; + const auto qtype = weight_contig.qscheme(); + std::vector weight_zp(num_output_channels_padded, 0); + // Adjust weight zero point, similar to weight data. + if (qtype == at::kPerTensorAffine) { + for (const auto i : c10::irange(num_output_channels)) { + weight_zp[i] = (uint8_t)(weight_contig.q_zero_point() + 128); + } + } else if (qtype == at::kPerChannelAffine) { + TORCH_CHECK( + weight_contig.q_per_channel_zero_points().scalar_type() == at::kLong, + "Per channel zero points dtype must be long int."); + const int64_t* per_channel_zero_points = + weight_contig.q_per_channel_zero_points().data_ptr(); + for (const auto i : c10::irange(num_output_channels)) { + weight_zp[i] = (uint8_t)(per_channel_zero_points[i] + 128); + } + } else { + TORCH_INTERNAL_ASSERT(false, "Unsupported quantization scheme."); + } + at:: Tensor weight_scales = + at::empty( + {num_output_channels_padded}, + at::device(at::kCPU).dtype(at::kFloat)); + float *const weight_scales_data = weight_scales.data_ptr(); + if (qtype == at::kPerTensorAffine) { + for (const auto i : c10::irange(num_output_channels)) { + weight_scales_data[i] = weight_contig.q_scale(); + } + } else if (qtype == at::kPerChannelAffine) { + TORCH_CHECK( + weight_contig.q_per_channel_scales().scalar_type() == at::kDouble, + "Per channel scales dtype must be double."); + const double *const per_channel_scales = + weight_contig.q_per_channel_scales().data_ptr(); + for (const auto i : c10::irange(num_output_channels)) { + weight_scales_data[i] = static_cast(per_channel_scales[i]); + } + } else { + TORCH_INTERNAL_ASSERT(false, "Unsupported quantization scheme."); + } + for (const auto i : c10::irange(num_output_channels, num_output_channels_padded)) { + weight_scales_data[i] = 1.f; + } + return {weight_zp, weight_scales}; +} +} // namespace + +#endif diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/QuantUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/QuantUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..4b62044e226faba01bfe42dcf7e1385c1cf221ee --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/QuantUtils.h @@ -0,0 +1,240 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#endif + +namespace quant_utils { +namespace { + float RawUint16ToFp16(unsigned short value) { + // Convert raw 16 bits half precision floating point number + // to single precision floating point number. + const unsigned short sign_bits = value >> 15; + const unsigned short exponent_bits = value >> 10 & 0x1f; + const unsigned short significand_bits = value & 0x3ff; + + const float sign = sign_bits ? -1 : 1; + const float significand = + 1 + significand_bits * 0.0009765625f; // 0.0009765625f = 0x1p-10 = 2^-10; + const float exponent = exponent_bits - 0xf; + + return sign * std::ldexp(significand, exponent); +} + +template +bool CheckAndSaturate(T max_val, T* element) { + if (*element > max_val) { + *element = max_val; + return true; + } + if (*element < -max_val) { + *element = -max_val; + return true; + } + return false; +} +} +using namespace std; +// A structure to hold quantization parameters 'scale' and 'zero_point'. +// The meaning of these values is as the constants in the quantization equation +// +// real_value = scale * (quantized_value - zero_point) +// +// In other words, 'zero_point' is the quantized value that corresponds +// to the real value 0, and 'scale' is the difference of real values +// corresponding to consecutive quantized values. +struct TensorQuantizationParams { + double scale; + std::int32_t zero_point; + int precision; +}; + +// Use fp16_min as the small scale cutoff because we don't want to use scales in +// fp16 subnormal range. This is to be consistent with Glow and FakeLowP +// implementation for NNPI. +constexpr float SMALL_SCALE_THRESHOLD = 6.1e-5f; + +// Following implementation should be identical to fbgemm::ChooseQuantizationParams +inline TensorQuantizationParams ChooseQuantizationParams( + float min, + float max, + int32_t qmin, + int32_t qmax, + bool preserve_sparsity = false, + bool force_scale_power_of_two = false, + bool reduce_range = false) { + TORCH_CHECK( + min <= max, + "In ChooseQuantizationParams, min should be less than or equal to max"); + + if (reduce_range) { + qmin = qmin/2; + qmax = qmax/2; + } + if (min < 0 && max > 0 && preserve_sparsity) { + int symmetric_qmin = -((qmax - qmin) / 2 + 1); + int symmetric_qmax = (qmax - qmin) / 2; + double max_scale = + std::max(fabs(min / symmetric_qmin), fabs(max / symmetric_qmax)); + min = max_scale * symmetric_qmin; + max = max_scale * symmetric_qmax; + } + + // We extend the [min, max] interval to ensure that it contains 0. + // Otherwise, we would not meet the requirement that 0 be an exactly + // representable value. + min = std::min(min, 0.f); + max = std::max(max, 0.f); + + TORCH_CHECK( + qmin < qmax, + "In ChooseQuantizationParams, qmin should be less than qmax"); + + // Use double precision for intermediate computation but use single precision + // in final number to reflect the actual number used during quantization. + double scale = (static_cast(max) - min) / (qmax - qmin); + // If scale is 0 or too small so its reciprocal is infinity, we arbitrary + // adjust the scale to 0.1 . We want to avoid scale's reciprocal being + // infinity because some of fbgemm code pre-computes scale's reciprocal to do + // multiplication instead of division in the time critical part of code. + if (float(scale) == 0.0f || std::isinf(1.0f / float(scale))) { + scale = 0.1; + } + TORCH_CHECK(scale > 0, "quantization scale should be > 0"); + + if (force_scale_power_of_two) { + if (scale < 1) { + scale = 1.0 / (1 << static_cast(floor(log(1.0 / scale) / log(2)))); + } else { + scale = 1 << static_cast(ceil(log(scale) / log(2))); + } + } + + // Cut off small scale + if (scale < SMALL_SCALE_THRESHOLD) { + float org_scale = scale; + scale = SMALL_SCALE_THRESHOLD; + // Adjust the min and max based on the new scale + if (min == 0.0f) { + max = SMALL_SCALE_THRESHOLD * (qmax - qmin); + } else if (max == 0.0f) { + min = -SMALL_SCALE_THRESHOLD * (qmax - qmin); + } else { + float amplifier = SMALL_SCALE_THRESHOLD / org_scale; + min *= amplifier; + max *= amplifier; + } + } + + // Zero-point computation. + // First the initial floating-point computation. The zero-point can be + // determined from solving an affine equation for any known pair + // (real value, corresponding quantized value). + // We know two such pairs: (rmin, qmin) and (rmax, qmax). + // The arithmetic error on the zero point computed from either pair + // will be roughly machine_epsilon * (sum of absolute values of terms) + // so we want to use the variant that adds the smaller terms. + double zero_point_from_min = qmin - min / static_cast(scale); + double zero_point_from_max = qmax - max / static_cast(scale); + double zero_point_from_min_error = + std::abs(qmin) - std::abs(min / static_cast(scale)); + double zero_point_from_max_error = + std::abs(qmax) - std::abs(max / static_cast(scale)); + double initial_zero_point = + zero_point_from_min_error < zero_point_from_max_error + ? zero_point_from_min + : zero_point_from_max; + + // for symmetric quantization (preserve_sparsity == true), we force zero_point + // to be a middle value between qmin and qmax. + // If either min or max is 0, then we just use 0 as zero_point. + if (min < 0 && max > 0 && preserve_sparsity) { + initial_zero_point = static_cast(qmin + qmax) / 2; + } + + // Now we need to nudge the zero point to be an integer + // (our zero points are integer, and this is motivated by the requirement + // to be able to represent the real value "0" exactly as a quantized value, + // which is required in multiple places, for example in Im2col with zero + // padding). + int32_t nudged_zero_point = 0; + if (initial_zero_point < qmin) { + nudged_zero_point = qmin; + } else if (initial_zero_point > qmax) { + nudged_zero_point = qmax; + } else { + nudged_zero_point = nearbyint(initial_zero_point); + } + + TensorQuantizationParams result; + result.scale = scale; + result.zero_point = nudged_zero_point; + return result; +} + +// This function helps to convert the Conv1D dimensions usable by the Conv2d op. +constexpr int64_t kConv1dSqueezeDim = 0; +[[maybe_unused]] static torch::List MakeArgForConv1d( + const torch::List& arg, + int64_t base_value) { + TORCH_CHECK(!arg.empty(), "Argument must have elements."); + torch::List result({arg.get(0), base_value}); + if (arg.size() == 1) { + result[1] = arg.get(0); + } else { + result[1] = arg.get(1); + } + result[kConv1dSqueezeDim] = base_value; + return result; +} + +// The range for using FP16 quantization of weights requires that the elements +// should be in the range of [5.96e-8, 65504]. If it is out of range, then the +// number will be saturated to max or min representable values by FP16. +inline void HandleWeightsSaturation(int64_t N, float* weight) { + const float kFp16Max = RawUint16ToFp16(0x7BFF); + bool found_out_of_range = false; + for (const auto i : c10::irange(N)) { + bool saturate = CheckAndSaturate(kFp16Max, weight + i); + if (saturate) { + found_out_of_range = true; + } + } + if (found_out_of_range) { + TORCH_WARN("FOUND weight out of range "); + } +} + +// Util function for quantizing bias. +inline at::Tensor QuantizeBias( + bool is_per_channel, + const at::Tensor& bias, + const at::Tensor& weight_contig, + double input_scale) { + at::Tensor qbias; + if (is_per_channel) { + auto bias_quant_scales = + weight_contig.q_per_channel_scales() * input_scale; + auto bias_zp = at::zeros(bias_quant_scales.sizes(), c10::kInt); + qbias = at::native::quantize_per_channel( + bias, bias_quant_scales, bias_zp, 0, c10::kQInt32); + } else { + qbias = at::native::quantize_per_tensor( + bias, weight_contig.q_scale() * input_scale, 0, c10::kQInt32); + } + return qbias; +} + +} // namespace quant_utils diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/QuantizedOps.h b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/QuantizedOps.h new file mode 100644 index 0000000000000000000000000000000000000000..265b4cb89fb8588cafa1e05bad5cfbfd786675cd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/QuantizedOps.h @@ -0,0 +1,282 @@ +#pragma once +#include +#include +#include +#include +#include +#include + +namespace at::native { + +using qrelu_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/); +using qrelu_leaky_fn = void (*)(Tensor& /*out*/, const Tensor& /*qx*/, + const Scalar& /*negval_*/); +using qgelu_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/, GeluType /* approximate */); +using qsigmoid_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/, double output_scale, int64_t output_zero_point); +using qhardsigmoid_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/); +using qclamp_fn = void (*)( + const at::Tensor& /*qx*/, + const Scalar& min, + const Scalar& max, + at::Tensor& /*qy*/); +using qclamp_minmax_fn = void (*)( + const at::Tensor& /*qx*/, + const Scalar& /*min or max*/, + at::Tensor& /*qy*/); +using qthreshold_fn = void (*)( + const at::Tensor& /*qx*/, + const Scalar& threshold, + const Scalar& value, + at::Tensor& /*qy*/); +using qtanh_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/); +using qelu_fn = void(*)( + const at::Tensor& /*qx*/, + const Scalar& /*alpha*/, + const Scalar& /*scale*/, + const Scalar& /*input_scale*/, + at::Tensor& /*qy*/); +using qbinary_fn = + void (*)(Tensor& /*out*/, const Tensor& /*self*/, const Tensor& /*other*/); +using qadd_scalar_fn = + void (*)(Tensor& /*out*/, const Tensor& /*self*/, const Scalar& other /*other*/); +using qhardswish_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/); +using qdropout_fn = void(*)( + const at::Tensor& /*qx*/, + const Scalar& /*p*/, + bool training /*training*/, + at::Tensor& /*qy*/); +using qmaxpool_2d_fn = void (*)( + const Tensor& qx, + int64_t iC, // input/output channels + int64_t iH, + int64_t iW, // input sizes + int64_t oH, + int64_t oW, // output sizes + int64_t kH, + int64_t kW, // kernel size + int64_t sH, + int64_t sW, // strides + int64_t pH, + int64_t pW, // padding + int64_t dH, + int64_t dW, // dilation + Tensor& qy); +using qmaxpool_3d_fn = void (*)( + const Tensor& qx, + int64_t iC, // input/output channels + int64_t iT, + int64_t iH, + int64_t iW, // input sizes + int64_t oT, + int64_t oH, + int64_t oW, // output sizes + int64_t kT, + int64_t kH, + int64_t kW, // kernel size + int64_t sT, + int64_t sH, + int64_t sW, // strides + int64_t pT, + int64_t pH, + int64_t pW, // padding + int64_t dT, + int64_t dH, + int64_t dW, // dilation + Tensor& qy); +using qadaptive_avg_pool2d_fn = void (*)( + const Tensor& qx, + Tensor& qy, + int64_t sizeB, + int64_t sizeC, + int64_t isizeH, + int64_t isizeW, + int64_t osizeH, + int64_t osizeW, + int64_t istrideB, + int64_t istrideC, + int64_t istrideH, + int64_t istrideW); +using qadaptive_avg_pool3d_fn = void (*)( + const Tensor& qx, + Tensor& qy, + int64_t sizeB, + int64_t sizeC, + int64_t isizeD, + int64_t isizeH, + int64_t isizeW, + int64_t osizeD, + int64_t osizeH, + int64_t osizeW, + int64_t istrideB, + int64_t istrideC, + int64_t istrideD, + int64_t istrideH, + int64_t istrideW); +using qavg_pool2d_fn = void (*)( + const Tensor& qx, + Tensor& qy, + int64_t nBatch, + int64_t nInputPlane, + int64_t inputWidth, + int64_t inputHeight, + int64_t outputWidth, + int64_t outputHeight, + int kW, + int kH, + int dW, + int dH, + int padW, + int padH, + bool count_include_pad, + std::optional divisor_override); + +using qavg_pool3d_fn = void (*)( + const Tensor& qx, + Tensor& qy, + int64_t nBatch, + int64_t nInputPlane, + int64_t inputWidth, + int64_t inputHeight, + int64_t inputDepth, + int64_t outputWidth, + int64_t outputHeight, + int64_t outputDepth, + int kW, + int kH, + int kD, + int dW, + int dH, + int dD, + int padW, + int padH, + int padD, + bool count_include_pad, + std::optional divisor_override); + +using qupsample_bilinear2d_fn = void (*)( + Tensor& output, + const Tensor& input, + int64_t input_height, + int64_t input_width, + int64_t output_height, + int64_t output_width, + int64_t nbatch, + int64_t channels, + bool align_corners, + std::optional scales_h, + std::optional scales_w); + +using qcat_nhwc_fn = Tensor (*)( + const MaterializedITensorListRef& qxs, + int64_t dim, + double scale, + int64_t zero_point); +using qtopk_fn = void(*)(Tensor&, Tensor&, const Tensor&, int64_t, int64_t, bool, bool); + +using qbatch_norm_fn = void(*)(int64_t, int64_t, int64_t, int64_t, int64_t, const Tensor&, const Tensor&, const Tensor&, Tensor&); + +using qnormalize_fn = void (*)( + const Tensor& /* X */, + const Tensor& /* gamma */, + const Tensor& /* beta */, + bool /* affine_per_channel */, + int /* num_channels */, + int /* num_groups */, + int64_t /* M */, + int64_t /* N */, + double /* eps */, + Tensor* /* Y */); + +using qmean_inner_dim_fn = void (*)( + const Tensor& /* X */, + OptionalIntArrayRef /* opt_dim */, + bool /* keepdim */, + std::optional /* opt_dtype */, + Tensor& /* Y */); + +using qstd_inner_dim_fn = void (*)( + const Tensor& /* X */, + OptionalIntArrayRef /* dim */, + const std::optional& /* correction */, + bool /* keepdim */, + Tensor& /* Y */); + +using qnormalize_nhwc_fn = void (*)( + const Tensor& /* X */, + const Tensor& /* gamma */, + const Tensor& /* beta */, + bool /* affine_per_channel */, + int /* num_channels */, + int /* num_groups */, + int64_t /* M */, + int64_t /* N */, + double /* eps */, + Tensor* /* Y */); + +using qprelu_fn = void (*)(Tensor& /*out*/, const Tensor& /*qx*/, + const Tensor& /*qw*/); + +using qbinary_eltwise_cpu_fn = void (*)( + Tensor& /*out*/, + const Tensor& /*qx*/, + double /*qx_scale*/, + int64_t /*qx_zero_point*/, + const Tensor& /*qy*/, + double /*qy_scale*/, + int64_t /*qy_zero_point*/, + double /*output_scale*/, + int64_t /*output_zero_point*/); + +using qbatch_norm_cpu_fn = void(*)( + int64_t /*N*/, + int64_t /*C*/, + int64_t /*H * W*/, + int64_t /*in_zero_point*/, + int64_t /*out_zero_point*/, + const Tensor& /*input*/, + const Tensor& /*a*/, + const Tensor& /*b*/, + Tensor& /*output*/); + +DECLARE_DISPATCH(qadaptive_avg_pool2d_fn, qadaptive_avg_pool2d_nhwc_stub) +DECLARE_DISPATCH(qadaptive_avg_pool3d_fn, qadaptive_avg_pool3d_ndhwc_stub) +DECLARE_DISPATCH(qadd_scalar_fn, qadd_scalar_relu_stub) +DECLARE_DISPATCH(qadd_scalar_fn, qadd_scalar_stub) +DECLARE_DISPATCH(qavg_pool2d_fn, qavg_pool2d_nhwc_stub) +DECLARE_DISPATCH(qavg_pool3d_fn, qavg_pool3d_nhwc_stub) +DECLARE_DISPATCH(qbatch_norm_fn, qbatch_norm_relu_stub) +DECLARE_DISPATCH(qbatch_norm_fn, qbatch_norm_stub) +DECLARE_DISPATCH(qbinary_fn, qadd_relu_stub) +DECLARE_DISPATCH(qbinary_fn, qadd_stub) +DECLARE_DISPATCH(qbinary_fn, qmul_relu_stub) +DECLARE_DISPATCH(qbinary_fn, qmul_stub) +DECLARE_DISPATCH(qcat_nhwc_fn, qcat_nhwc_stub) +DECLARE_DISPATCH(qcat_nhwc_fn, qcat_relu_nhwc_stub) +DECLARE_DISPATCH(qclamp_fn, qclamp_stub) +DECLARE_DISPATCH(qclamp_minmax_fn, qclamp_min_stub) +DECLARE_DISPATCH(qclamp_minmax_fn, qclamp_max_stub) +DECLARE_DISPATCH(qelu_fn, qelu_stub) +DECLARE_DISPATCH(qhardsigmoid_fn, qhardsigmoid_stub) +DECLARE_DISPATCH(qhardswish_fn, qhardswish_stub) +DECLARE_DISPATCH(qdropout_fn, qdropout_stub) +DECLARE_DISPATCH(qmaxpool_2d_fn, qmaxpool_2d_nhwc_stub) +DECLARE_DISPATCH(qmaxpool_3d_fn, qmaxpool_3d_nthwc_stub) +DECLARE_DISPATCH(qnormalize_fn, quantized_normalize_stub) +DECLARE_DISPATCH(qnormalize_nhwc_fn, quantized_groupnorm_nhwc_stub) +DECLARE_DISPATCH(qrelu_fn, qrelu_stub) +DECLARE_DISPATCH(qrelu_leaky_fn, qrelu_leaky_stub) +DECLARE_DISPATCH(qgelu_fn, qgelu_stub) +DECLARE_DISPATCH(qsigmoid_fn, qsigmoid_stub) +DECLARE_DISPATCH(qtanh_fn, qtanh_stub) +DECLARE_DISPATCH(qthreshold_fn, qthreshold_stub) +DECLARE_DISPATCH(qtopk_fn, qtopk_stub) +DECLARE_DISPATCH(qupsample_bilinear2d_fn, qupsample_bilinear2d_nhwc_stub) +DECLARE_DISPATCH(qmean_inner_dim_fn, qmean_inner_dim_stub) +DECLARE_DISPATCH(qstd_inner_dim_fn, qstd_inner_dim_stub) +DECLARE_DISPATCH(qprelu_fn, qprelu_stub) +DECLARE_DISPATCH(qbinary_eltwise_cpu_fn, qmul_tensor_cpu_stub) +DECLARE_DISPATCH(qbinary_eltwise_cpu_fn, qadd_tensor_cpu_stub) +DECLARE_DISPATCH(qbinary_eltwise_cpu_fn, qadd_relu_tensor_cpu_stub) +DECLARE_DISPATCH(qbatch_norm_cpu_fn, qbatch_norm_cpu_stub) + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/RuyUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/RuyUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..5072862553b8b1fca8f6e265a7ae75e2d8df504b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/RuyUtils.h @@ -0,0 +1,17 @@ +#pragma once + +#ifdef USE_RUY_QMATMUL + +#include + +namespace at::native::ruy_utils { + +ruy::Context* get_ruy_context(); + +void quantize_multiplier(double scale, + int* multiplier_fixedpoint, + int* multiplier_exponent); + +} // namespace at::native::ruy_utils + +#endif // USE_RUY_QMATMUL diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/XnnpackUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/XnnpackUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..e45f75cafe237b914582918f89a7f83b728abbed --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/XnnpackUtils.h @@ -0,0 +1,331 @@ +#pragma once + +#ifdef USE_XNNPACK +#include + +#include +#include + +using xnnpack_operator = at::native::xnnpack::Operator; + +namespace at::native::xnnp_utils { + +/* + * Return shape in the same order as the memory format + * e.g. channels_last will return NHWC instead of NCHW + */ +std::vector get_mem_format_aware_shape(const at::Tensor& in); + +/* + * Input is always int8_t, output can be [int8_t, uint8_t]. + * input + offset = output + * int8_t + 128 = uint8_t + * int8_t + 0 = int8_t + */ +template +void q8_copy_int8_weight_and_add_offset(const at::Tensor& in, at::Tensor& out); + +template +Tensor convert_conv_weights_to_channel_last_tensor( + const at::Tensor& src, + int groups, + bool transpose); + +/* + * Series of create wrapper functions to call xnn_create_[de]conv* functions. + */ +C10_ALWAYS_INLINE +enum xnn_status xnnp_create_convolution2d_nhwc( + uint32_t pad_top, + uint32_t pad_right, + uint32_t pad_bottom, + uint32_t pad_left, + uint32_t kernel_h, + uint32_t kernel_w, + uint32_t stride_h, + uint32_t stride_w, + uint32_t dilation_h, + uint32_t dilation_w, + uint32_t groups, + size_t group_input_channels, + size_t group_output_channels, + size_t ip_chan_stride, + size_t op_chan_stride, + int8_t izp, + float ip_scale, + int8_t kzp, + const float* k_scales, + const int8_t* kernel, + const int32_t* bias, + int8_t ozp, + float op_scale, + int8_t op_min, + int8_t op_max, + uint32_t flags, + xnn_operator_t* op, + bool per_channel, + bool transpose) { + /* Symmetric quantization forces kzp = 0 */ + TORCH_CHECK(!kzp, "XNNPACK Q[SC]8 conv kernels expects kernel zero point to be zero." + "But got: ", kzp); + + if (transpose) { + TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!"); + return xnn_create_deconvolution2d_nhwc_qs8( + pad_top, /* uint32_t output_padding_top */ + pad_right, /* uint32_t output_padding_right */ + pad_bottom, /* uint32_t output_padding_bottom */ + pad_left, /* uint32_t output_padding_left */ + kernel_h, /* uint32_t kernel_height */ + kernel_w, /* uint32_t kernel_width */ + stride_h, /* uint32_t stride_height */ + stride_w, /* uint32_t stride_width */ + dilation_h, /* uint32_t dilation_height */ + dilation_w, /* uint32_t dilation_width */ + groups, /* uint32_t groups */ + group_input_channels, /* size_t group_input_channels */ + group_output_channels, /* size_t group_output_channels */ + ip_chan_stride, /* size_t input_pixel_stride */ + op_chan_stride, /* size_t output_pixel_stride */ + izp, /* int8_t input_zero_point */ + ip_scale, /* float input_scale */ + k_scales[0], /* float kernel_scale */ + kernel, /* const int8_t* kernel */ + bias, /* const int32_t* bias */ + ozp, /* int8_t output_zero_point */ + op_scale, /* float output_scale */ + op_min, /* int8_t output_min */ + op_max, /* int8_t output_max */ + flags, /* uint32_t flags */ + nullptr, /* xnn_caches_t caches */ + nullptr, /* xnn_weights_cache_t weights_cache */ + op); /* xnn_operator_t* deconvolution_op_out */ + + } + + if (!per_channel) { + return xnn_create_convolution2d_nhwc_qs8( + pad_top, /* uint32_t input_padding_top */ + pad_right, /* uint32_t input_padding_right */ + pad_bottom, /* uint32_t input_padding_bottom */ + pad_left, /* uint32_t input_padding_left */ + kernel_h, /* uint32_t kernel_height */ + kernel_w, /* uint32_t kernel_width */ + stride_h, /* uint32_t subsampling_height */ + stride_w, /* uint32_t subsampling_width */ + dilation_h, /* uint32_t dilation_height */ + dilation_w, /* uint32_t dilation_width */ + groups, /* uint32_t groups */ + group_input_channels, /* size_t group_input_channels */ + group_output_channels, /* size_t group_output_channels*/ + ip_chan_stride, /* size_t input_channel_stride */ + op_chan_stride, /* size_t output_channel_stride */ + izp, /* int8_t input_zero_point */ + ip_scale, /* float input_scale */ + k_scales[0], /* float kernel_scale */ + kernel, /* const int8_t* kernel */ + bias, /* const int32_t* bias */ + ozp, /* int8_t output_zero_point */ + op_scale, /* float output_scale */ + op_min, /* int8_t output_min */ + op_max, /* int8_t output_max */ + flags, /* uint32_t flags */ + nullptr, /* xnn_caches_t caches */ + nullptr, /* xnn_weights_cache_t weights_cache */ + op); /* xnn_operator_t* convolution_op_out */ + } else { /* per_channel */ + return xnn_create_convolution2d_nhwc_qs8_qc8w( + pad_top, /* uint32_t input_padding_top */ + pad_right, /* uint32_t input_padding_right */ + pad_bottom, /* uint32_t input_padding_bottom */ + pad_left, /* uint32_t input_padding_left */ + kernel_h, /* uint32_t kernel_height */ + kernel_w, /* uint32_t kernel_width */ + stride_h, /* uint32_t subsampling_height */ + stride_w, /* uint32_t subsampling_width */ + dilation_h, /* uint32_t dilation_height */ + dilation_w, /* uint32_t dilation_width */ + groups, /* uint32_t groups */ + group_input_channels, /* size_t group_input_channels */ + group_output_channels, /* size_t group_output_channels*/ + ip_chan_stride, /* size_t input_channel_stride */ + op_chan_stride, /* size_t output_channel_stride */ + izp, /* int8_t input_zero_point */ + ip_scale, /* float input_scale */ + k_scales, /* const float* kernel_scale */ + kernel, /* const int8_t* kernel */ + bias, /* const int32_t* bias */ + ozp, /* int8_t output_zero_point */ + op_scale, /* float output_scale */ + op_min, /* int8_t output_min */ + op_max, /* int8_t output_max */ + flags, /* uint32_t flags */ + nullptr, /* xnn_caches_t caches */ + nullptr, /* xnn_weights_cache_t weights_cache */ + op); /* xnn_operator_t* convolution_op_out */ + } +} + +/* + * Series of reshape wrapper functions to call xnn_reshape_[de]conv* functions. + */ +C10_ALWAYS_INLINE +enum xnn_status xnnp_reshape_convolution2d_nhwc( + xnn_operator_t op, + size_t batch, + size_t in_h, + size_t in_w, + pthreadpool_t pt_pool, + bool per_channel = false, + bool transpose = false, + uint32_t adj_h = 0, + uint32_t adj_w = 0) { + if(transpose) { + TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!"); + return xnn_reshape_deconvolution2d_nhwc_qs8( + op, /* xnn_operator_t deconvolution_op */ + batch, /* size_t batch_size */ + in_h, /* size_t input_height */ + in_w, /* size_t input_width */ + adj_h, /* uint32_t adjustment_height */ + adj_w, /* uint32_t adjustment_width */ + nullptr, /* size_t* output_height_out */ + nullptr, /* size_t* output_width_out */ + pt_pool); /* pthreadpool_t threadpool */ + } + + size_t workspace_size = SIZE_MAX; + size_t workspace_alignment = SIZE_MAX; + + if (!per_channel) { + return xnn_reshape_convolution2d_nhwc_qs8( + op, /* xnn_operator_t convolution_op */ + batch, /* size_t batch_size */ + in_h, /* size_t input_height */ + in_w, /* size_t input_width */ + &workspace_size, /* size_t* workspace_size */ + &workspace_alignment, /* size_t* workspace_alignment */ + nullptr, /* size_t* output_height_out */ + nullptr, /* size_t* output_width_out */ + pt_pool); /* pthreadpool_t threadpool */ + } else { /* per_channel */ + return xnn_reshape_convolution2d_nhwc_qs8_qc8w( + op, /* xnn_operator_t convolution_op */ + batch, /* size_t batch_size */ + in_h, /* size_t input_height */ + in_w, /* size_t input_width */ + &workspace_size, /* size_t* workspace_size */ + &workspace_alignment, /* size_t* workspace_alignment */ + nullptr, /* size_t* output_height_out */ + nullptr, /* size_t* output_width_out */ + pt_pool); /* pthreadpool_t threadpool */ + } +} + + +/* + * Series of setup wrapper functions to call xnn_setup_[de]conv* functions. + */ +C10_ALWAYS_INLINE +enum xnn_status xnnp_setup_convolution2d_nhwc( + xnn_operator_t op, + const int8_t* inp, + int8_t* outp, + bool per_channel = false, + bool transpose = false) { + if(transpose) { + TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!"); + + return xnn_setup_deconvolution2d_nhwc_qs8( + op, /* xnn_operator_t deconvolution_op */ + inp, /* const int8_t* input */ + outp); /* int8_t* output */ + } + + if (!per_channel) { + return xnn_setup_convolution2d_nhwc_qs8( + op, /* xnn_operator_t deconvolution_op */ + nullptr, /* void workspace */ + inp, /* const int8_t* input */ + outp); /* int8_t* output */ + } else { /* per_channel */ + return xnn_setup_convolution2d_nhwc_qs8_qc8w( + op, /* xnn_operator_t deconvolution_op */ + nullptr, /* void workspace */ + inp, /* const int8_t* input */ + outp); /* int8_t* output */ + } +} + + +/* + * Series of wrapper functions to call xnn_create* and xnn_setup* + * functions for linear + */ +C10_ALWAYS_INLINE +enum xnn_status xnnp_create_fully_connected_nc( + size_t input_channels, + size_t output_channels, + size_t input_stride, + size_t output_stride, + int8_t input_zero_point, + float input_scale, + int8_t kernel_zero_point, + float kernel_scale, + const int8_t* kernel, + const int32_t* bias, + int8_t output_zero_point, + float output_scale, + int8_t output_min, + int8_t output_max, + uint32_t flags, + xnn_operator_t* fully_connected_op_out) { + /* Symmetric quantization forces kzp = 0 */ + TORCH_CHECK(!kernel_zero_point, "XNNPACK QS8 linear kernel expects kernel zero point to be zero." + "But got: ", kernel_zero_point); + return xnn_create_fully_connected_nc_qs8( + input_channels, /* size_t input_channels */ + output_channels, /* size_t output_channels */ + input_stride, /* size_t input_stride */ + output_stride, /* size_t output_stride */ + input_zero_point, /* int8_t input_zero_point */ + input_scale, /* float input_scale */ + kernel_scale, /* float kernel_scale */ + kernel, /* const int8_t* kernel */ + bias, /* const int32_t* bias */ + output_zero_point, /* int8_t output_zero_point */ + output_scale, /* float output_scale */ + output_min, /* int8_t output_min */ + output_max, /* int8_t output_max */ + flags, /* uint32_t flags */ + nullptr, /* xnn_caches_t caches */ + nullptr, /* xnn_weights_cache_t */ + fully_connected_op_out); /* xnn_operator_t* fully_connected_op_out */ +} + +C10_ALWAYS_INLINE +enum xnn_status xnnp_reshape_fully_connected_nc( + xnn_operator_t fully_connected_op, + size_t batch_size, + pthreadpool_t threadpool) { + return xnn_reshape_fully_connected_nc_qs8( + fully_connected_op, /* xnn_operator_t fully_connected_op */ + batch_size, /* size_t batch_size */ + threadpool); /* pthreadpool_t threadpool */ +} + +C10_ALWAYS_INLINE +enum xnn_status xnnp_setup_fully_connected_nc( + xnn_operator_t fully_connected_op, + const int8_t* input, + int8_t* output) { + return xnn_setup_fully_connected_nc_qs8( + fully_connected_op, /* xnn_operator_t fully_connected_op */ + input, /* const int8_t* input */ + output /* int8_t* output */ + ); +} + +} // namespace at::native::xnnp_utils + +#endif // USE_XNNPACK diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/conv_serialization.h b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/conv_serialization.h new file mode 100644 index 0000000000000000000000000000000000000000..40c6d1b79ead23a5164e11320475fbbc30a9f945 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/conv_serialization.h @@ -0,0 +1,417 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#if !defined(__s390x__) && !defined(__powerpc__) +#include +#endif + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + + +#include + +/* Convolution prepacked parameters serialization. + * + * Version 1 + * + * - Fields: + * 1. weight + * 2. bias + * 3. stride x kSpatialDim + * 4. padding x kSpatialDim + * 5. dilation x kSpatialDim + * 6. groups + * + * Version 2 + * + * - Fields: + * 0. version (string) + * 1. list of non-optional tensors + * 0: packed parameters (int16_t) + * - kSpatialDim + * - stride x kSpatialDim + * - padding x kSpatialDim + * - dilation x kSpatialDim + * - output_padding x kSpatialDim + * - groups + * - transpose (0 or 1) + * 1: weight + * 2. list of optional tensors + * 0: bias + * + * Version 3 + * + * - Fields: + * 0. version (int64_t) + * 1. list of int64_t configuration values + * - kSpatialDim + * - stride x kSpatialDim + * - padding x kSpatialDim + * - dilation x kSpatialDim + * - output_padding x kSpatialDim + * - groups + * - flags (bitmask) + * - (1 << 0) transpose (1 = yes) + * 2. list of optional tensors + * 0: None (helps with type inference) + * 1: weight (this must be present) + * 2: bias + */ + +using ConvParamsSerializationTypeV2 = std::tuple< + // version, for versions 2 and up + std::string, + // non-optional tensors + std::vector, + // optional tensors + std::vector>>; + +using ConvParamsSerializationTypeV3 = std::tuple< + // version, int for versions 3 and up + int64_t, + // configuration values + std::vector, + // optional tensors + std::vector>>; + +// Parses any historical conv packed params format into +// the current format. +template +ConvParamsSerializationTypeV3 parse_conv_serialized_state(const c10::IValue& v) { + + // determine the version based on IValue contents + int version = -1; + if (v.isTuple()) { + const auto& elements = v.toTupleRef().elements(); + if (!elements.empty()) { + auto firstElement = elements[0]; + if (firstElement.isTensor()) { + version = 1; + } else if (firstElement.isString()) { + const std::string& version_str = firstElement.toStringRef(); + // note: not parsing the string to automatically handle bad + // inputs + if (version_str == "2") { + version = 2; + } + } else if (firstElement.isInt()) { + auto raw_version = firstElement.toInt(); + if (raw_version == 3) { + version = 3; + } + } + } + } + TORCH_INTERNAL_ASSERT(version != -1, "Unable to parse serialization version"); + + if (version == 1) { + // version 1 - convert to version 3 manually + + const auto& elements = v.toTupleRef().elements(); + + at::Tensor weight = elements[0].toTensor(); + std::optional bias = elements[1].toOptional(); + torch::List stride_x_kSpatialDim = elements[2].toTensorList(); + torch::List padding_x_kSpatialDim = elements[3].toTensorList(); + torch::List dilation_x_kSpatialDim = elements[4].toTensorList(); + at::Tensor groups = elements[5].toTensor(); + + std::vector config_vals; + config_vals.reserve( + stride_x_kSpatialDim.size() + padding_x_kSpatialDim.size() + + dilation_x_kSpatialDim.size() + kSpatialDim + 3); + config_vals.push_back(kSpatialDim); + for (const auto i : c10::irange(stride_x_kSpatialDim.size())) { + auto const & stride = stride_x_kSpatialDim.get(i); + config_vals.push_back(stride[0].item()); + } + for (const auto i : c10::irange(padding_x_kSpatialDim.size())) { + auto const &padding = padding_x_kSpatialDim.get(i); + config_vals.push_back(padding[0].item()); + } + for (const auto i : c10::irange(dilation_x_kSpatialDim.size())) { + auto const &dilation = dilation_x_kSpatialDim.get(i); + config_vals.push_back(dilation[0].item()); + } + // output_padding does not exist in v1, so we fill in a default value + for ([[maybe_unused]] const auto i : c10::irange(kSpatialDim)) { + config_vals.push_back(0); + } + config_vals.push_back(groups[0].item()); + // transpose does not exist in v1, so we fill in a default value + config_vals.push_back(0); + + std::vector> tensors; + tensors.emplace_back(); + tensors.emplace_back(weight); + tensors.emplace_back(bias); + + int64_t version = 3; + return std::tie(version, config_vals, tensors); + } else if (version == 2) { + // version 2 + const auto& elements = v.toTupleRef().elements(); + std::vector non_optional = elements[1].toTensorList().vec(); + std::vector> optional; + + if (elements[2].isTensorList()) { + for (const auto& elem : elements[2].toTensorList()) { + optional.emplace_back(static_cast(elem)); + } + } else { + for (const auto& elem : elements[2].toList()) { + optional.emplace_back(static_cast(elem).toOptional()); + } + } + // create default optional value for bias + if (optional.empty()) { + optional.emplace_back(); + } + + auto config_a = non_optional[0].accessor(); + std::vector config_vals; + config_vals.reserve(config_a.size(0)); + for (const auto i : c10::irange(config_a.size(0))) { + config_vals.emplace_back(config_a[i]); + } + + auto weight = non_optional[1]; + auto bias = optional[0]; + + std::vector> tensors; + tensors.emplace_back(); + tensors.emplace_back(weight); + tensors.emplace_back(bias); + + int64_t version = 3; + return std::tie(version, config_vals, tensors); + } else if (version == 3) { + return v.to(); + } else { + TORCH_INTERNAL_ASSERT(false, "Unexpected serialized qconv version: ", + version); + } +} + +#define QCONV_SERIALIZATION_VERSION 2 + +#if QCONV_SERIALIZATION_VERSION == 2 +using ConvParamsSerializationType = ConvParamsSerializationTypeV2; + +template +ConvParamsSerializationTypeV2 serialize_conv( + const c10::intrusive_ptr>& params) { + + std::string version = "2"; + std::vector non_optional; + std::vector> optional; + + // create a packed int8_t tensor for conv params + std::vector params_vec; + params_vec.push_back(kSpatialDim); + auto stride = params->stride().vec(); + params_vec.insert(params_vec.end(), stride.begin(), stride.end()); + auto padding = params->padding().vec(); + params_vec.insert(params_vec.end(), padding.begin(), padding.end()); + auto dilation = params->dilation().vec(); + params_vec.insert(params_vec.end(), dilation.begin(), dilation.end()); + auto output_padding = params->output_padding().vec(); + params_vec.insert(params_vec.end(), output_padding.begin(), + output_padding.end()); + params_vec.push_back(params->groups()); + params_vec.push_back(params->transpose()); + int64_t vec_size = params_vec.size(); + at::Tensor params_tensor = at::from_blob( + params_vec.data(), {vec_size}, + at::TensorOptions().dtype(at::kShort)) + // clone to retain ownership of the data + .clone(); + + auto [weight, bias] = params->unpack(); + + non_optional.emplace_back(std::move(params_tensor)); + non_optional.emplace_back(std::move(weight)); + optional.emplace_back(std::move(bias)); + + return std::tie(version, non_optional, optional); +} + +#elif QCONV_SERIALIZATION_VERSION == 3 +using ConvParamsSerializationType = ConvParamsSerializationTypeV3; + +template +ConvParamsSerializationTypeV3 serialize_conv( + const c10::intrusive_ptr>& params) { + std::vector config_vals; + config_vals.push_back(kSpatialDim); + auto stride = params->stride().vec(); + config_vals.insert(config_vals.end(), stride.begin(), stride.end()); + auto padding = params->padding().vec(); + config_vals.insert(config_vals.end(), padding.begin(), padding.end()); + auto dilation = params->dilation().vec(); + config_vals.insert(config_vals.end(), dilation.begin(), dilation.end()); + auto output_padding = params->output_padding().vec(); + config_vals.insert(config_vals.end(), output_padding.begin(), + output_padding.end()); + config_vals.push_back(params->groups()); + config_vals.push_back(params->transpose()); + + auto [weight, bias] = params->unpack(); + + std::vector> tensors; + tensors.emplace_back(); + tensors.emplace_back(weight); + tensors.emplace_back(bias); + + int64_t version = 3; + return std::tie(version, config_vals, tensors); +} + +#else +#error "Invalid qconv serialization version." +#endif + +template +c10::intrusive_ptr> deserialize_conv( + ConvParamsSerializationTypeV3 state) { + auto & [version, config_vals, tensors] = state; + TORCH_INTERNAL_ASSERT(version == 3, "Unexpected serialized qconv version: ", version); + + TORCH_CHECK(tensors.size() == 3, "Wrong number of tensors", tensors.size()); + auto & weight = tensors[1]; + auto & bias [[maybe_unused]] = tensors[2]; + TORCH_INTERNAL_ASSERT(weight.has_value(), "Weight should always be present in serialized qconv."); + + torch::List stride, padding, output_padding, dilation; + // skip kSpatialDim + int idx = 1; + for ([[maybe_unused]] const auto i : c10::irange(kSpatialDim)) { + stride.emplace_back(config_vals.at(idx)); + idx++; + } + for ([[maybe_unused]] const auto i : c10::irange(kSpatialDim)) { + padding.emplace_back(config_vals.at(idx)); + idx++; + } + for ([[maybe_unused]] const auto i : c10::irange(kSpatialDim)) { + dilation.emplace_back(config_vals.at(idx)); + idx++; + } + for ([[maybe_unused]] const auto i : c10::irange(kSpatialDim)) { + TORCH_INTERNAL_ASSERT( + idx < static_cast(config_vals.size()), + "Unexpected index = ", + idx, + " for config_vals of size ", + config_vals.size()); + output_padding.emplace_back(config_vals.at(idx)); + idx++; + } + int64_t groups [[maybe_unused]] = config_vals.at(idx); + idx++; + int64_t flags [[maybe_unused]] = config_vals.at(idx); + idx++; + TORCH_INTERNAL_ASSERT(idx == static_cast(config_vals.size()), + "Unexpected length of config_vals, expected ", + idx, + " got ", + config_vals.size()); + + bool transpose [[maybe_unused]] = flags & (1 << 0); + + int64_t other_flags = flags & ~(1 << 0); + TORCH_INTERNAL_ASSERT(other_flags == 0, "Unexpected flags set in ", flags, "."); + + auto& ctx = at::globalContext(); + +#ifdef USE_FBGEMM + if (ctx.qEngine() == at::QEngine::X86) { +#if AT_MKLDNN_ENABLED() + bool use_onednn = onednn_utils::should_use_onednn_quant( + weight.value(), transpose, groups, output_padding); + if (use_onednn) { + return PackedConvWeightsOnednn::prepack( + std::move(weight.value()), + std::move(bias), + stride, + padding, + output_padding, + dilation, + groups, + transpose + ); + } +#endif + return PackedConvWeight::prepack( + std::move(weight.value()), + std::move(bias), + stride, + padding, + output_padding, + dilation, + groups, + transpose + ); + } // x86 +#endif + +#ifdef USE_FBGEMM + if (ctx.qEngine() == at::QEngine::FBGEMM) { + return PackedConvWeight::prepack( + std::move(weight.value()), + std::move(bias), + stride, + padding, + output_padding, + dilation, + groups, + transpose + ); + } +#endif // USE_FBGEMM +#ifdef USE_PYTORCH_QNNPACK + if (ctx.qEngine() == at::QEngine::QNNPACK) { + TORCH_CHECK( + kSpatialDim == 2, + "prepack/__setstate__: QNNPACK only supports Conv2d " + "now."); + return PackedConvWeightsQnnp::prepack( + std::move(weight.value()), + std::move(bias), + stride, + padding, + output_padding, + dilation, + groups, + transpose + ); + } +#endif // USE_PYTORCH_QNNPACK +#if AT_MKLDNN_ENABLED() + if (ctx.qEngine() == at::QEngine::ONEDNN) { + return PackedConvWeightsOnednn::prepack( + std::move(weight.value()), + std::move(bias), + stride, + padding, + output_padding, + dilation, + groups, + transpose + ); + } +#endif // AT_MKLDNN_ENABLED() +TORCH_CHECK( + false, + "Didn't find engine for when deserializing ConvPackedParams: ", + toString(ctx.qEngine())); +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/fbgemm_utils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/fbgemm_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..c6113958e4b986cc7ec9f51b0960fedb32f59047 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/fbgemm_utils.h @@ -0,0 +1,408 @@ +#pragma once + +#include +#include +#include +#include +#include + +#ifdef USE_FBGEMM +#include +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Winconsistent-missing-destructor-override") +#include +C10_DIAGNOSTIC_POP() +#include + +// The struct for the packed weight matrix (PackBMatrix) and the corresponding +// column offsets used for the fully connect layer, which are both prepared in +// the prepacking step to save the computations in the inference. Note the +// column offsets include the sum of the B columns as well as the scalar term +// B_zero_point * K, whereas the row offsets created by +// PackAWithQuantRowOffset/PackAWithIm2Col/PackAWithRowOffset are only the sum +// of the A rows. The column offsets are needed for the asymmetric quantization +// (affine quantization) of input matrix. +// Note that in JIT mode we can think of a way to fuse col_offsets with bias. +struct TORCH_API PackedLinearWeight : public LinearPackedParamsBase { + PackedLinearWeight( + std::unique_ptr> w, + std::optional bias, + std::vector col_offsets, + std::vector w_scale, + std::vector w_zp, + c10::QScheme q_scheme) + : w(std::move(w)), + bias_(std::move(bias)), + col_offsets(std::move(col_offsets)), + w_scale(std::move(w_scale)), + w_zp(std::move(w_zp)), + q_scheme(std::move(q_scheme)) {} + std::unique_ptr> w; + std::optional bias_; + std::vector col_offsets; + std::vector w_scale; + std::vector w_zp; + c10::QScheme q_scheme; + + at::Tensor apply( + at::Tensor input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_relu( + at::Tensor input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor& apply_out( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point, + at::Tensor& output) override; + + at::Tensor& apply_relu_out( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point, + at::Tensor& output) override; + + at::Tensor apply_with_input_q_dq_qweight_dq_output_fp32( + at::Tensor input, + double input_scale, + int64_t input_zero_point) override; + + at::Tensor apply_with_input_q_dq_qweight_dq_relu_output_fp32( + at::Tensor input, + double input_scale, + int64_t input_zero_point) override; + + at::Tensor apply_dynamic(at::Tensor input, bool reduce_range = false) + override; + + at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range = false) + override; + + std::tuple> unpack() override; + + std::optional bias() override { + return bias_; + } + + static c10::intrusive_ptr prepack( + at::Tensor weight, + std::optional bias); + + private: + template + at::Tensor& apply_impl( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point, + at::Tensor& output); + + template + at::Tensor apply_with_input_q_dq_qweight_dq_output_fp32_impl( + const at::Tensor& input, + double input_scale, + int64_t input_zero_point); + + template + at::Tensor apply_dynamic_impl(at::Tensor input, bool reduce_range = false); +}; + +struct TORCH_API PackedLinearWeightFp16 : public LinearPackedParamsBase { + PackedLinearWeightFp16( + std::unique_ptr w, + std::optional bias) + : w(std::move(w)), bias_(std::move(bias)) {} + + std::unique_ptr w; + std::optional bias_; + + at::Tensor apply( + at::Tensor /*input*/, + double /*output_scale*/, + int64_t /*output_zero_point*/) override { + TORCH_INTERNAL_ASSERT(false); + } + at::Tensor apply_relu( + at::Tensor /*input*/, + double /*output_scale*/, + int64_t /*output_zero_point*/) override { + TORCH_INTERNAL_ASSERT(false); + } + + at::Tensor apply_dynamic(at::Tensor input, bool reduce_range = false) + override; + at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range = false) + override; + + at::Tensor& apply_dynamic_out( + const at::Tensor& input, + at::Tensor& output, + bool reduce_range = false) override; + at::Tensor& apply_dynamic_relu_out( + const at::Tensor& input, + at::Tensor& output, + bool reduce_range = false) override; + + std::tuple> unpack() override; + + std::optional bias() override { + return bias_; + } + + static c10::intrusive_ptr prepack( + at::Tensor weight, + std::optional bias); + + void set_bias(std::optional bias) override; + + private: + template + at::Tensor& apply_dynamic_impl(const at::Tensor& input, at::Tensor& output); +}; + +template +struct TORCH_API PackedConvWeight : public ConvPackedParamsBase { + PackedConvWeight( + std::unique_ptr> w, + std::optional bias, + torch::List stride, + torch::List padding, + torch::List output_padding, + torch::List dilation, + int64_t groups, + uint8_t transpose, + std::vector col_offsets, + std::vector kernel, + std::vector w_scale, + std::vector w_zp, + c10::QScheme q_scheme) + : w(std::move(w)), + bias(std::move(bias)), + stride_(std::move(stride)), + padding_(std::move(padding)), + output_padding_(std::move(output_padding)), + dilation_(std::move(dilation)), + groups_(groups), + transpose_(transpose), + col_offsets(std::move(col_offsets)), + kernel(std::move(kernel)), + w_scale(std::move(w_scale)), + w_zp(std::move(w_zp)), + q_scheme(q_scheme) {} + + std::unique_ptr> w; + std::optional bias; + torch::List stride_; + torch::List padding_; + torch::List output_padding_; + torch::List dilation_; + int64_t groups_; + uint8_t transpose_; + std::vector col_offsets; + std::vector kernel; + std::vector w_scale; + std::vector w_zp; + c10::QScheme q_scheme; + + at::Tensor apply( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_relu( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_dynamic( + const at::Tensor& input, + bool reduce_range) override; + + std::tuple> unpack() override; + + static c10::intrusive_ptr> prepack( + at::Tensor weight, + std::optional bias, + torch::List stride, + torch::List padding, + torch::List output_padding, + torch::List dilation, + int64_t groups, + bool transpose); + + const float* GetBiasData(at::Tensor* bias); + + void GetQuantizationParams( + float act_scale, + float out_scale, + std::vector* output_multiplier_float, + std::vector* act_times_w_scale); + + torch::List stride() const override { + return stride_; + } + + torch::List padding() const override { + return padding_; + } + + torch::List output_padding() const override { + return output_padding_; + } + + torch::List dilation() const override { + return dilation_; + } + + int64_t groups() const override { + return groups_; + } + + bool transpose() const override { + return (bool)transpose_; + } + + private: + template + at::Tensor apply_impl( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point); +}; + +// PackWeight: Convert the weight from uint8 to int8. +inline void convert_uint8_int8( + int len, + const uint8_t* src_uint8, + int8_t* dst_int8) { + for (const auto i : c10::irange(len)) { + dst_int8[i] = static_cast(static_cast(src_uint8[i]) - 128); + } +} + +// UnpackWeight: Convert the weight from int8 to uint8. +inline void convert_int8_uint8( + int len, + const int8_t* src_int8, + uint8_t* dst_uint8) { + for (const auto i : c10::irange(len)) { + dst_uint8[i] = + static_cast(static_cast(src_int8[i]) + 128); + } +} + +namespace at::native::fbgemm_utils { + +template +fbgemm::conv_param_t MakeFbgemmConvParam( + int N, + int C, + int M, + const std::vector& image_shape, + int groups, + const std::vector& kernels, + const std::vector& strides, + const std::vector& pads, + const std::vector& dilations, + const std::vector& output_padding = std::vector(kSpatialDim, 0), + bool transposed = false); + +// TODO: Remove functions below when ChannelsLast3d is ready. +Tensor MakeStridedQTensorCPU( + const IntArrayRef& sizes, + const IntArrayRef& strides, + const TensorOptions& options, + QuantizerPtr quantizer); + +Tensor MakeEmptyAffineQuantizedChannelsLast3dTensor( + int64_t N, + int64_t C, + int64_t D, + int64_t H, + int64_t W, + const TensorOptions& options, + double scale, + int64_t zero_point); + +Tensor MakeEmptyPerChannelAffineQuantizedChannelsLast3dTensor( + int64_t N, + int64_t C, + int64_t D, + int64_t H, + int64_t W, + const TensorOptions& options, + const Tensor& scales, + const Tensor& zero_points); + +Tensor ConvertToChannelsLast3dTensor(const Tensor& src); + +template +Tensor TransposeConvTensorUnpackConversion(const Tensor& src, int groups); + +template +Tensor ConvertConvWeightsToChannelLastTensor( + const at::Tensor& src, + int groups, + bool transpose); +} // at::native::namespace fbgemm_utils + +#endif // USE_FBGEMM + +struct TORCH_API PackedEmbeddingBagWeight : public EmbeddingPackedParamsBase { + PackedEmbeddingBagWeight( + at::Tensor packed_w, + std::vector w_scale, + std::vector w_zp, + int64_t bit_rate, + c10::QScheme q_scheme, + int64_t version) + : packed_w(std::move(packed_w)), + w_scale(std::move(w_scale)), + w_zp(std::move(w_zp)), + bit_rate_(bit_rate), + q_scheme(q_scheme), + version_(version) { + if (!this->packed_w.is_contiguous()) { + this->packed_w = this->packed_w.contiguous(); + } + } + + at::Tensor packed_w; + std::vector w_scale; + std::vector w_zp; + int64_t bit_rate_; + c10::QScheme q_scheme; + int64_t version_; + + at::Tensor unpack() override; + static c10::intrusive_ptr prepack( + at::Tensor weight); + + int64_t bit_rate() const override { + return bit_rate_; + } + + int64_t version() const override { + return version_; + } + + at::Tensor embeddingbag_byte( + const at::Tensor& indices, + const std::optional& offsets, + bool pruned_weights, + const std::optional& per_sample_weights_, + const std::optional& compressed_indices_mapping, + bool include_last_offset, + bool is_embedding_op) override; + + at::Tensor embeddingbag_4bit( + const at::Tensor& indices, + const std::optional& offsets, + bool pruned_weights, + const std::optional& per_sample_weights_, + const std::optional& compressed_indices_mapping, + bool include_last_offset, + bool is_embedding_op) override; +}; diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/init_qnnpack.h b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/init_qnnpack.h new file mode 100644 index 0000000000000000000000000000000000000000..b2f24e678d1b0305c7698b3f2cd1dde56268fe08 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/init_qnnpack.h @@ -0,0 +1,11 @@ +#pragma once + +#ifdef USE_PYTORCH_QNNPACK + +namespace at::native { + +void initQNNPACK(); + +} // namespace at::native + +#endif diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/qconv.h b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/qconv.h new file mode 100644 index 0000000000000000000000000000000000000000..83e1f94fe272508c6081b09da42d8a1c55349f50 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/qconv.h @@ -0,0 +1,100 @@ +#pragma once +#include +#include + +namespace at { +namespace native { + +class QConvoneDNN final { + public: + + C10_API static at::Tensor run_pointwise( + at::Tensor act, // contains quantized values but not QTensor + double act_scale, + int64_t act_zero_point, + at::Tensor weight, // contains quantized values but not QTensor + at::Tensor weight_scales, + at::Tensor weight_zero_points, + std::optional bias, + torch::List stride, + torch::List padding, + torch::List dilation, + int64_t groups, + double output_scale, + int64_t output_zero_point, + std::optional output_dtype, + std::string_view attr, + torch::List> scalars, + std::optional algorithm); + + C10_API static at::Tensor run_pointwise_tensor( + at::Tensor act, // contains quantized values but not QTensor + at::Tensor act_scale, + at::Tensor act_zero_point, + at::Tensor weight, // contains quantized values but not QTensor + at::Tensor weight_scales, + at::Tensor weight_zero_points, + std::optional bias, + torch::List stride, + torch::List padding, + torch::List dilation, + int64_t groups, + double output_scale, + int64_t output_zero_point, + std::optional output_dtype, + std::string_view attr, + torch::List> scalars, + std::optional algorithm); + + C10_API static at::Tensor run_pointwise_binary( + at::Tensor act, // contains quantized values but not QTensor + double act_scale, + int64_t act_zero_point, + at::Tensor weight, // contains quantized values but not QTensor + at::Tensor weight_scales, + at::Tensor weight_zero_points, + at::Tensor accum, // contains quantized values but not QTensor + std::optional bias, + torch::List stride, + torch::List padding, + torch::List dilation, + int64_t groups, + double output_scale, + int64_t output_zero_point, + std::optional output_dtype, + double accum_scale, + int64_t accum_zero_point, + std::string_view binary_attr, + std::optional alpha, + std::optional unary_attr, + torch::List> unary_scalars, + std::optional unary_algorithm); + + C10_API static at::Tensor run_pointwise_binary_tensor( + at::Tensor act, // contains quantized values but not QTensor + at::Tensor act_scale, + at::Tensor act_zero_point, + at::Tensor weight, // contains quantized values but not QTensor + at::Tensor weight_scales, + at::Tensor weight_zero_points, + at::Tensor accum, // contains quantized values but not QTensor + std::optional bias, + torch::List stride, + torch::List padding, + torch::List dilation, + int64_t groups, + double output_scale, + int64_t output_zero_point, + std::optional output_dtype, + double accum_scale, + int64_t accum_zero_point, + std::string_view binary_attr, + std::optional alpha, + std::optional unary_attr, + torch::List> unary_scalars, + std::optional unary_algorithm); + +}; + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/qembeddingbag.h b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/qembeddingbag.h new file mode 100644 index 0000000000000000000000000000000000000000..4f6a480e5595704fb0b66af4a69f1ce013fd5d46 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/qembeddingbag.h @@ -0,0 +1,32 @@ +#pragma once +#include +#include + +namespace at::native { +Tensor& embedding_bag_byte_rowwise_offsets_out( + Tensor& output, + const Tensor& weight, + const Tensor& indices, + const std::optional& offsets_in, + const bool /* scale_grad_by_freq */, + const int64_t /* mode */, + bool pruned_weights, + const std::optional& per_sample_weights_, + const std::optional& compressed_indices_mapping, + bool include_last_offset); + +Tensor& embedding_bag_4bit_rowwise_offsets_out( + Tensor& output, + const Tensor& weight, + const Tensor& indices, + const std::optional& offsets_in, + const bool /* scale_grad_by_freq */, + const int64_t /* mode */, + bool pruned_weights, + const std::optional& per_sample_weights_, + const std::optional& compressed_indices_mapping, + bool include_last_offset); + +Tensor& qembeddingbag_byte_unpack_out(Tensor& output, const Tensor& packed_weight); + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/qembeddingbag_prepack.h b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/qembeddingbag_prepack.h new file mode 100644 index 0000000000000000000000000000000000000000..ce7ac6e6ad8ff23d4fdd0632297fb81b090cf2a8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/qembeddingbag_prepack.h @@ -0,0 +1,12 @@ +#pragma once +#include + +namespace at::native { + +Tensor& qembeddingbag_byte_prepack_out(Tensor& output, const Tensor& weight); + +Tensor qembeddingbag_byte_prepack(const Tensor& weight); + +Tensor qembeddingbag_byte_prepack_meta(const Tensor& weight); + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/qlinear.h b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/qlinear.h new file mode 100644 index 0000000000000000000000000000000000000000..67c440ae60f9f576dbab82f2e3ed91db3940e512 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cpu/qlinear.h @@ -0,0 +1,51 @@ +#pragma once +#include +#include + +namespace at::native { + +class QLinearOnednn final { + public: + C10_API static Tensor run_pointwise_tensor( + Tensor act, // int8 CPU tensor, not QTensor + Tensor act_scale, + Tensor act_zero_point, + Tensor onednn_weight, // int8 tensor from MkldnnCPU + Tensor weight_scales, + Tensor weight_zero_points, + std::optional bias, + double output_scale, + int64_t output_zero_point, + std::optional output_dtype, + std::string_view post_op_name, + c10::List> post_op_args, + std::string_view post_op_algorithm); + +C10_API static Tensor run_pointwise_binary_tensor( + Tensor act, // int8 CPU tensor, not QTensor + Tensor act_scale, + Tensor act_zero_point, + Tensor onednn_weight, // int8 tensor from MkldnnCPU + Tensor weight_scales, + Tensor weight_zero_points, + std::optional other, // extra input for binary post-op + std::optional bias, + double output_scale, + int64_t output_zero_point, + std::optional output_dtype, + double other_scale, + int64_t other_zero_point, + std::string_view binary_post_op, // e.g. "none", "sum", "add" + double binary_alpha, + std::string_view unary_post_op, // e.g. "none", "relu" + c10::List> unary_post_op_args, + std::string_view unary_post_op_algorithm); +}; + +C10_API Tensor _weight_int4pack_mm_cpu_tensor( + const Tensor& A, + const Tensor& B, + const Tensor& qGroupSize, + const Tensor& qScaleAndZeros); + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cudnn/utils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cudnn/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..e7476bf089db339f0c5a007450127fcbc7267027 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/cudnn/utils.h @@ -0,0 +1,318 @@ +#pragma once +/* +This file contains some of the auxiliary functions used by both Conv.cpp & Linear.cpp (introduced in a later PR) +*/ + +#ifdef USE_CUDA +#include // for the definition of AT_CUDNN_ENABLED + +#if AT_CUDNN_ENABLED() + +#include +#include +#include +#include +#include + +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wsuggest-override") +#include +C10_DIAGNOSTIC_POP() + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +struct PackedLinearWeightCudnn : public LinearPackedParamsBase { + PackedLinearWeightCudnn( + at::Tensor orig_weight, + std::optional bias, + c10::QScheme q_scheme) + : orig_weight(std::move(orig_weight)), + bias_(std::move(bias)), + q_scheme(std::move(q_scheme)) {} + + at::Tensor apply( + at::Tensor input, + double output_scale, + int64_t output_zero_point) override; + at::Tensor apply_relu( + at::Tensor input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_dynamic(at::Tensor input, bool reduce_range = false) override { + throw std::runtime_error( + "apply_dynamic is not implemented for this packed " + "parameter type"); + } + at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range = false) override { + throw std::runtime_error( + "apply_dynamic_relu is not implemented for this packed " + "parameter type"); + } + + std::tuple> unpack() override; + + std::optional bias() override { + return bias_; + } + + static c10::intrusive_ptr prepack( + at::Tensor weight, + std::optional bias); + + private: + at::Tensor orig_weight; + std::optional bias_; + c10::QScheme q_scheme; + + template + at::Tensor apply_impl( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point); + + template + void apply_impl_helper( + const at::Tensor& quantized_output, + const at::Tensor& input, + double output_scale); +}; + +template +struct PackedConvWeightCudnn : public ConvPackedParamsBase { + PackedConvWeightCudnn( + at::Tensor orig_weight, + std::optional bias, + torch::List stride, + torch::List padding, + torch::List output_padding, + torch::List dilation, + int64_t groups, + bool transpose, + c10::QScheme q_scheme, + int64_t output_channels) + : maybe_padded_weight_(std::move(orig_weight)), + bias_(std::move(bias)), + stride_(stride), + padding_(padding), + output_padding_(output_padding), + dilation_(dilation), + groups_(groups), + transpose_(transpose), + q_scheme_(q_scheme), + num_unpadded_output_channels_(output_channels) {} // output channels needs to be stored when we have to pad this dimension + + at::Tensor apply( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_relu( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_dynamic( + const at::Tensor& input, + bool reduce_range) override { + TORCH_CHECK(false, "apply_dynamic is currently not reported"); + } + + at::Tensor apply_dynamic_relu( + const at::Tensor& input, + bool reduce_range) { + TORCH_CHECK(false, "apply_dynamic_relu is currently not reported"); + } + + std::tuple> unpack() override; + + static c10::intrusive_ptr> prepack( + at::Tensor weight, + std::optional bias, + torch::List stride, + torch::List padding, + torch::List output_padding, + torch::List dilation, + int64_t groups, + bool transpose); + + const float* GetBiasData(at::Tensor* bias); + + torch::List stride() const override { + return stride_; + } + + torch::List padding() const override { + return padding_; + } + + torch::List output_padding() const override { + return output_padding_; + } + + torch::List dilation() const override { + return dilation_; + } + + int64_t groups() const override { + return groups_; + } + + bool transpose() const override { + return transpose_; + } + + private: + // cudnn v8.4.0 expects conv2d's int8 weight tensor's input and output channels to be a multiple of 4. if it is not + // we need to explicitly pad it to a multiple of 4 ourselves as cudnn does not currently support padding, hence the naming + // convention "maybe"_padded_weight. + // TODO: when and if cudnn enables padding in their operators, we can remove padding on our end and rename this to orig_weight_ + at::Tensor maybe_padded_weight_; + std::optional bias_; + torch::List stride_; + torch::List padding_; + torch::List output_padding_; + torch::List dilation_; + int64_t groups_; + bool transpose_; + c10::QScheme q_scheme_; + int64_t num_unpadded_output_channels_; + + template + at::Tensor apply_impl( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point); + + template + void apply_impl_helper( + const at::Tensor& quantized_output, + const at::Tensor& input, + double output_scale); +}; + +namespace cudnn_utils { + +// TODO: we can remove this function when cuDNN enables pass by value support for +// pointwise multiplication operations. the only reason why we need this right now is +// we use broadcasting scalar multiplication in conv, linear, and add ops, and cuDNN requires +// the scalar to be a scalar tensor with the same number of dimensions (num_dim) as the tensor we're multiplying to +inline at::Tensor getRequantMultiplierTensor(double requant_multiplier, uint8_t num_dim) { + at::SmallVector requantize_multiplier_tensor_size(num_dim, 1); + at::Tensor requantize_multiplier_tensor = at::empty(requantize_multiplier_tensor_size, at::device(at::kCUDA).dtype(at::kFloat)); + requantize_multiplier_tensor.fill_(requant_multiplier); + return requantize_multiplier_tensor; +} + +inline uint8_t getAlignment(const at::Tensor &t) { + // alignment are in bytes + uint8_t alignment = 1; + uintptr_t address = reinterpret_cast(t.data_ptr()); + for (; alignment < 16; alignment *= 2) { + if (address % (alignment * 2)) { + return alignment; + } + } + return alignment; +} + +// For the two getTensorDescriptor functions, there is a is_virtual parameter. This parameter is used to set the cudnn +// tensor as virtual or not. Setting the tensor as virtual is expected to have some performance benefits as the cudnn +// backend cudnn will no longer directly save to the tensor, allowing us to omit this tensor from the variant pack. +// See third_party/cudnn_frontend/samples/fusion_sample.cpp for other examples + +inline cudnn_frontend::Tensor getTensorDescriptor(const at::Tensor &t, int64_t id, uint8_t alignment, bool is_virtual = false) { + auto shape = t.sizes(); + auto strides = t.strides(); + if (is_virtual) { + return cudnn_frontend::TensorBuilder() + .setDim(shape.size(), shape.data()) + .setStrides(strides.size(), strides.data()) + .setId(id) + .setAlignment(alignment) + .setVirtual() + .setDataType(at::native::getCudnnDataType(t)) + .build(); + } + return cudnn_frontend::TensorBuilder() + .setDim(shape.size(), shape.data()) + .setStrides(strides.size(), strides.data()) + .setId(id) + .setAlignment(alignment) + .setDataType(at::native::getCudnnDataType(t)) + .build(); +} + +inline cudnn_frontend::Tensor getTensorDescriptor(const c10::IntArrayRef& shape, const c10::IntArrayRef& strides, cudnnDataType_t cudnn_dtype, int64_t id, uint8_t alignment, bool is_virtual = false) { + if (is_virtual) { + return cudnn_frontend::TensorBuilder() + .setDim(shape.size(), shape.data()) + .setStrides(strides.size(), strides.data()) + .setId(id) + .setAlignment(alignment) + .setVirtual() + .setDataType(cudnn_dtype) + .build(); + } + return cudnn_frontend::TensorBuilder() + .setDim(shape.size(), shape.data()) + .setStrides(strides.size(), strides.data()) + .setId(id) + .setAlignment(alignment) + .setDataType(cudnn_dtype) + .build(); +} + +// TODO: there is a table from input dtype to operator dtype, we can derive +// the operator dtype based on input dtype +inline cudnn_frontend::PointWiseDesc_v8 getPointWiseMulDescriptor(cudnnDataType_t dataType) { + return cudnn_frontend::PointWiseDescBuilder() + .setMode(cudnnPointwiseMode_t::CUDNN_POINTWISE_MUL) + .setMathPrecision(dataType) + .build(); +} + +// TODO: there is a table from input dtype to operator dtype, we can derive +// the operator dtype based on input dtype +inline cudnn_frontend::PointWiseDesc_v8 getPointWiseAddDescriptor(cudnnDataType_t dataType) { + return cudnn_frontend::PointWiseDescBuilder() + .setMode(cudnnPointwiseMode_t::CUDNN_POINTWISE_ADD) + .setMathPrecision(dataType) + .build(); +} + +// TODO: there is a table from input dtype to operator dtype, we can derive +// the operator dtype based on input dtype +inline cudnn_frontend::PointWiseDesc_v8 getPointWiseReluDescriptor(cudnnDataType_t dataType) { + return cudnn_frontend::PointWiseDescBuilder() + .setMode(cudnnPointwiseMode_t::CUDNN_POINTWISE_RELU_FWD) + .setMathPrecision(dataType) + .build(); +} + + +inline void filterEngineConfigs( + cudnn_frontend::EngineConfigList &from, + cudnn_frontend::EngineConfigList &to, + bool deterministic, bool allow_tf32, c10::ScalarType scalar_type) +{ + auto filter = [=](cudnnBackendDescriptor_t c) { + if (deterministic) { + if (cudnn_frontend::hasNumericalNote(c)) return true; + } + if (scalar_type == at::kFloat || scalar_type == at::kChar || !allow_tf32) { + if (cudnn_frontend::hasNumericalNote(c)) return true; + if (cudnn_frontend::hasNumericalNote(c)) return true; + } + return false; + }; + cudnn_frontend::filter(from, to, filter); +} + +} // cudnn_utils + +#endif // AT_CUDNN_ENABLED +#endif // USE_CUDA diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/library.h b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/library.h new file mode 100644 index 0000000000000000000000000000000000000000..5a6137d6b924afc8f72f760563b2083c920f2952 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/quantized/library.h @@ -0,0 +1,8 @@ +#pragma once + +#include + +TORCH_API int register_linear_params(); +int register_embedding_params(); + +template TORCH_API int register_conv_params(); diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/attention.h b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/attention.h new file mode 100644 index 0000000000000000000000000000000000000000..116927a2b32845674048febaef3b341a921c20ea --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/attention.h @@ -0,0 +1,70 @@ +#pragma once +#include +#include +#include +#include +#include + +namespace at::native { + +using fused_sdp_choice_fn = int64_t (*)(const Tensor& query_, const Tensor& key, const Tensor& value, + const std::optional& attn_mask_, double dropout_p, bool is_causal, std::optional scale, bool enable_gqa); + +DECLARE_DISPATCH(fused_sdp_choice_fn, _fused_sdp_choice_stub) + +TORCH_API Tensor bmm_nt(const Tensor& a, const Tensor& b); +TORCH_API Tensor masked_softmax( + Tensor& attn_scores, + std::optional attn_mask, + const Tensor& query, + std::optional mask_type = {}); + +using transform_bias_rescale_qkv_fn = void(*)( + at::ScalarType type, + void* _q_k_v, + const void* _qkv, + const void* _qkv_bias, + int64_t B, + int64_t T, + int64_t D, + int64_t num_head); + +DECLARE_DISPATCH(transform_bias_rescale_qkv_fn, transform_bias_rescale_qkv_stub) + +TORCH_API Tensor transform0213_gemm_nt_bias( + const Tensor& a, + const Tensor& b, + const Tensor& c, + const Tensor& query); + +TORCH_API Tensor bmm_nn(Tensor& out, const Tensor& a, const Tensor& b); + +TORCH_API void debug_assert_shape(int line, const Tensor& t, c10::IntArrayRef shape); + +TORCH_API Tensor qkv_projection( + const Tensor& query, + const Tensor& key, + const Tensor& value, + const int64_t embed_dim, + const Tensor& qkv_weight); + +using flash_attention_fn = void (*)( + const Tensor& output, const Tensor& logsumexp, + const Tensor& query, const Tensor& key, const Tensor& value, + double dropout_p, bool is_causal, + std::optional attn_mask, + std::optional scale); + +using flash_attention_backward_fn = void (*)( + const Tensor& grad_q, const Tensor& grad_k, + const Tensor& grad_v, const Tensor& grad_out, + const Tensor& query, const Tensor& key, + const Tensor& value, const Tensor& out, const Tensor& logsumexp, + double dropout_p, bool is_causal, + std::optional attn_mask, + std::optional scale); + +DECLARE_DISPATCH(flash_attention_fn, flash_attention_kernel) +DECLARE_DISPATCH(flash_attention_backward_fn, flash_attention_backward_kernel) + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/flash_attn/flash_api.h b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/flash_attn/flash_api.h new file mode 100644 index 0000000000000000000000000000000000000000..57e7eea39dfb87e86760ef6cbd19d421eaaa0ed4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/flash_attn/flash_api.h @@ -0,0 +1,96 @@ +#pragma once +#include + +#include +#include +#include + +namespace FLASH_NAMESPACE { + +TORCH_API +std::tuple +mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &out_, // batch_size x seqlen_q x num_heads x head_size + std::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool return_softmax, + std::optional gen_); + +std::tuple +mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. + std::optional &block_table_, // batch_size x max_num_blocks_per_seq + std::optional &alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool return_softmax, + std::optional gen_); + + +std::tuple +mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x seqlen_q + std::optional &dq_, // batch_size x seqlen_q x num_heads x head_size + std::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, // probability to drop + const float softmax_scale, + const bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset); + +std::tuple +mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size + const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &out, // total_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x s softmax logsumexp + std::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + std::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + std::optional &alibi_slopes_, // num_heads or b x num_heads + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset); + +} // namespace FLASH_NAMESPACE diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/flash_attn/static_switch.h b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/flash_attn/static_switch.h new file mode 100644 index 0000000000000000000000000000000000000000..251779a4fde40a4d5f3825f2e302ca4ecda9ad94 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/flash_attn/static_switch.h @@ -0,0 +1,107 @@ +// Inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` + +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +#ifdef FLASHATTENTION_DISABLE_DROPOUT + #define DROPOUT_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define DROPOUT_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_ALIBI + #define ALIBI_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define ALIBI_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_UNEVEN_K + #define EVENK_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + }() +#else + #define EVENK_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_LOCAL + #define LOCAL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define LOCAL_SWITCH BOOL_SWITCH +#endif + +#define FP16_SWITCH(COND, ...) \ + [&] { \ + if (COND) { \ + using elem_type = cutlass::half_t; \ + return __VA_ARGS__(); \ + } else { \ + using elem_type = cutlass::bfloat16_t; \ + return __VA_ARGS__(); \ + } \ + }() + +#define HEADDIM_SWITCH(HEADDIM, ...) \ + [&] { \ + if (HEADDIM <= 32) { \ + constexpr static int kHeadDim = 32; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 64) { \ + constexpr static int kHeadDim = 64; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 96) { \ + constexpr static int kHeadDim = 96; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 128) { \ + constexpr static int kHeadDim = 128; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 160) { \ + constexpr static int kHeadDim = 160; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 192) { \ + constexpr static int kHeadDim = 192; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 224) { \ + constexpr static int kHeadDim = 224; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 256) { \ + constexpr static int kHeadDim = 256; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/debug_utils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/debug_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..e8a751cd0f7e2b617a2c66bb513c9612c9b9720d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/debug_utils.h @@ -0,0 +1,210 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////// +// Debugging functions +//////////////////////////////////////////////////////////////////////////////// +// Nans & inf detection +#define NANCHECK(frag) \ + { \ + for (int _i = 0; _i < frag.size(); ++_i) { \ + assert(std::isfinite(float(frag[_i]))); \ + assert(!std::isnan(float(frag[_i]))); \ + } \ + } + +// Print on the first thread of the first block +#if 1 +#define PRINT_WARP_ID 0 +#define PRINT_LANE_ID 0 +#define PRINT_B0_T0(msg, ...) \ + if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && \ + threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \ + threadIdx.z == 0) { \ + printf(msg "\n", ##__VA_ARGS__); \ + } +#define PRINT_T0(msg, ...) \ + if (threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \ + threadIdx.z == 0) { \ + printf(msg "\n", ##__VA_ARGS__); \ + } +#define PRINT_TX_LX(msg, ...) \ + for (int bx = 0; bx < gridDim.x; ++bx) { \ + for (int by = 0; by < gridDim.y; ++by) { \ + for (int bz = 0; bz < gridDim.z; ++bz) { \ + for (int tx = 0; tx < blockDim.x; ++tx) { \ + for (int ty = 0; ty < blockDim.y; ++ty) { \ + for (int tz = 0; tz < blockDim.z; ++tz) { \ + __syncthreads(); \ + if (blockIdx.x == bx && blockIdx.y == by && blockIdx.z == bz && \ + threadIdx.x == tx && threadIdx.y == ty && \ + threadIdx.z == tz) { \ + printf( \ + "[%d,%d,%d][%d,%d,%d]" msg "\n", \ + bx, \ + by, \ + bz, \ + tx, \ + ty, \ + tz, \ + ##__VA_ARGS__); \ + } \ + } \ + } \ + } \ + } \ + } \ + } +#else +#define PRINT_B0_T0 +#define PRINT_TX_LX +#endif + +struct __string_view { + char const* data; + std::size_t size; +}; +#if __cplusplus >= 201402L +template +constexpr __string_view __get_type_name() { + char const* p = __PRETTY_FUNCTION__; + while (*p++ != '=') + ; + for (; *p == ' '; ++p) + ; + char const* p2 = p; + int count = 1; + for (;; ++p2) { + switch (*p2) { + case '[': + ++count; + break; + case ']': + --count; + if (!count) + return {p, std::size_t(p2 - p)}; + } + } + return {}; +} +#else +template +constexpr __string_view __get_type_name() { + return {"unsupported", 11}; +} +#endif + +// Print a given array +#define PRINT_ACCUM8_T0_L0_START(name, accum, start) \ + PRINT_B0_T0( \ + "%s[%d:%d] - {%f, %f, %f, %f, %f, %f, %f, %f}", \ + name, \ + int(start), \ + int(start + 8), \ + float(accum[start + 0]), \ + float(accum[start + 1]), \ + float(accum[start + 2]), \ + float(accum[start + 3]), \ + float(accum[start + 4]), \ + float(accum[start + 5]), \ + float(accum[start + 6]), \ + float(accum[start + 7])); +#define PRINT_ACCUM8_T0_L0(name, accum) PRINT_ACCUM8_T0_L0_START(name, accum, 0) +#define PRINT_FRAG_T0_L0(name, frag) \ + { \ + auto typeStr = __get_type_name(); \ + PRINT_B0_T0("printing %s (%s)", name, typeStr.data); \ + for (int _start = 0; _start < frag.size(); _start += 8) { \ + PRINT_ACCUM8_T0_L0_START(" ", frag, _start); \ + } \ + /*__syncthreads(); \ + NANCHECK(frag); */ \ + } +#define PRINT_ARRAY_T0_L0_INCR(name, array, length, incr) \ + { \ + PRINT_B0_T0("printing %s (len=%d)", name, int(length)); \ + for (int _start = 0; _start < length; _start += incr) { \ + PRINT_ACCUM8_T0_L0_START(" ", array, _start); \ + } \ + } +#define PRINT_ARRAY_T0_L0(name, array, length) \ + PRINT_ARRAY_T0_L0_INCR(name, array, length, 8) + +// Print a 4x4 matrix +#define PRINT_TENSOR4x4_T0_L0_START(name, ref, start_x, start_y) \ + PRINT_B0_T0( \ + "%s[%d:%d, %d:%d]:\n %f, %f, %f, %f\n %f, %f, %f, %f\n %f, %f, %f, %f\n %f, %f, %f, %f", \ + name, \ + int(start_x), \ + int(start_x + 4), \ + int(start_y), \ + int(start_y + 4), \ + float(ref.at({start_x + 0, start_y + 0})), \ + float(ref.at({start_x + 0, start_y + 1})), \ + float(ref.at({start_x + 0, start_y + 2})), \ + float(ref.at({start_x + 0, start_y + 3})), \ + float(ref.at({start_x + 1, start_y + 0})), \ + float(ref.at({start_x + 1, start_y + 1})), \ + float(ref.at({start_x + 1, start_y + 2})), \ + float(ref.at({start_x + 1, start_y + 3})), \ + float(ref.at({start_x + 2, start_y + 0})), \ + float(ref.at({start_x + 2, start_y + 1})), \ + float(ref.at({start_x + 2, start_y + 2})), \ + float(ref.at({start_x + 2, start_y + 3})), \ + float(ref.at({start_x + 3, start_y + 0})), \ + float(ref.at({start_x + 3, start_y + 1})), \ + float(ref.at({start_x + 3, start_y + 2})), \ + float(ref.at({start_x + 3, start_y + 3}))); +#define PRINT_TENSOR4x4_T0_L0(name, ref) \ + PRINT_TENSOR4x4_T0_L0_START(name, ref, 0, 0) + +#define PRINT_PROBLEM_SIZE(name, ps) \ + PRINT_B0_T0( \ + "%s.problem_size: {.m=%d, .n=%d, .k=%d}", \ + name, \ + int(ps.m()), \ + int(ps.n()), \ + int(ps.k())) + +template +CUTLASS_DEVICE void print_warp_accum( + AccumT accum, + LaneOffsetT lane_offset, + int32_t num_rows, + int32_t num_cols) { + bool is_main = blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && + threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0; + for (int row = 0; row < num_rows; ++row) { + for (int col = 0; col < num_cols; ++col) { + if (col % 32 == 0) { + if (is_main) { + printf("\nmat[%3d, %3d:%3d]", row, col, col + 32); + } + __syncthreads(); + } + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + if (row == accum_m && col == accum_n && + (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)) { + printf(" %6.1f", float(accum[idx])); + } + }, + [&](int accum_m) {}); + __syncthreads(); + } + if (is_main) { + printf("\n"); + } + } +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_pipelined.h b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_pipelined.h new file mode 100644 index 0000000000000000000000000000000000000000..ea1d74f25127b998fface16a21bbe5b00789c50b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_pipelined.h @@ -0,0 +1,631 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + File copied from + then modified to: + (1) load 2 source fragments at the same time (pipelining) + (2) support reading from a different dtype + (3) pass the row id to the OutputOp if it takes it + (see MemoryEfficientAttentionNormalize) + Note that in general the fragment passed to the OutputOp could + span multiple rows but it does not happen with the configurations we have +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +template +struct ApplyEpilogueOp { + static CUTLASS_DEVICE typename Op::FragmentOutput apply( + Op const& output_op, + int row_id, + typename Op::FragmentAccumulator const& accum, + typename Op::FragmentOutput const& source) { + return output_op(accum, source); + } + static CUTLASS_DEVICE typename Op::FragmentOutput apply( + Op const& output_op, + int row_id, + typename Op::FragmentAccumulator const& accum) { + return output_op(accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Epilogue operator +template < + typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) + typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: + ///< gemm::warp::MmaTensorOp) + int PartitionsK, ///< Number of partitions of the K dimension + typename OutputTileIterator_, ///< Tile iterator writing output tensors + typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting + ///< accumulators + typename WarpTileIterator_, ///< Warp-scoped tile iterator writing + ///< accumulators to SMEM + typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading + ///< from SMEM + typename OutputOp_, ///< Output operator + typename Padding_, ///< Padding added to SMEM allocation to avoid bank + ///< conflicts (concept: MatrixShape) + int FragmentsPerPartition = + 1, ///< Used to coarsten the epilogue granularity + int IterationsUnroll = ///< Used to reduce binary size when epilogue op is + ///< large + (!IsEpilogueFunctorHeavy::value), + typename OutputTileSourceIterator_ = + OutputTileIterator_ ///< Tile iterator reading tensors + > +class EpiloguePipelined : public EpilogueBase< + Shape_, + typename WarpMmaOperator_::Shape, + PartitionsK, + AccumulatorFragmentIterator_, + WarpTileIterator_, + Padding_, + FragmentsPerPartition> { + public: + using Base = EpilogueBase< + Shape_, + typename WarpMmaOperator_::Shape, + PartitionsK, + AccumulatorFragmentIterator_, + WarpTileIterator_, + Padding_, + FragmentsPerPartition>; + + using Shape = Shape_; + using WarpMmaOperator = WarpMmaOperator_; + static int const kPartitionsK = PartitionsK; + using OutputTileIterator = OutputTileIterator_; + using OutputTileSourceIterator = OutputTileSourceIterator_; + using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; + using WarpTileIterator = WarpTileIterator_; + using SharedLoadIterator = SharedLoadIterator_; + using OutputOp = OutputOp_; + using Padding = Padding_; + + using Layout = layout::RowMajor; + using LongIndex = typename Layout::LongIndex; + + /// The complete warp-level accumulator tile + using AccumulatorTile = typename Base::AccumulatorTile; + + /// Accumulator element + using ElementAccumulator = typename WarpTileIterator::Element; + + /// Output element + using ElementOutput = typename OutputTileIterator::Element; + using ElementSource = typename OutputTileSourceIterator::Element; + + /// Output access size + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + /// Tensor reference to destination tensor + using TensorRef = typename OutputTileIterator::TensorRef; + + /// Tensor reference to sync tensor + using SyncTensorRef = + typename cutlass::TensorRef; + + /// Const tensor reference to source tensor + using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; + + /// Array type used to output + using OutputAccessType = Array< + typename OutputTileIterator::Element, + OutputTileIterator::kElementsPerAccess>; + using SourceAccessType = Array< + typename OutputTileSourceIterator::Element, + OutputTileSourceIterator::kElementsPerAccess>; + + /// Array type used by output functor + using AccumulatorAccessType = Array< + typename WarpTileIterator::Element, + OutputTileIterator::kElementsPerAccess>; + + /// Number of warps + using WarpCount = typename Base::WarpCount; + + static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 + ? Base::kFragmentsPerIteration + : kPartitionsK; + static int constexpr kSmemPointerOffset = + Base::SharedStorage::StorageShape::kCount / kSmemTiles; + + public: + static_assert( + OutputTileSourceIterator::Fragment::kElements == + OutputTileIterator::Fragment::kElements, + "Mismatch between input tile and output tile iterator (kElements)"); + static_assert( + OutputTileSourceIterator::kIterations == OutputTileIterator::kIterations, + "Mismatch between input tile and output tile iterator (kIterations)"); + static_assert( + SharedLoadIterator::Fragment::kElements == + OutputTileIterator::Fragment::kElements, + "Mismatch between shared load iterator and output tile iterator."); + + static_assert( + OutputTileIterator::kElementsPerAccess, + "OutputTileIterator::kElementsPerAccess must not be zero."); + + static_assert( + !(OutputTileIterator::Fragment::kElements % + OutputTileIterator::kElementsPerAccess), + "Divisibility"); + + private: + /// Loads fragment from shared memory aligned with output tensor + SharedLoadIterator shared_load_iterator_; + + public: + /// Constructor + CUTLASS_DEVICE + EpiloguePipelined( + typename Base::SharedStorage& shared_storage, ///< Shared storage object + int thread_idx, ///< ID of a thread within the threadblock + int warp_idx, ///< ID of warp within threadblock + int lane_idx ///< Id of thread within warp + ) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + shared_load_iterator_(shared_storage.reference(), thread_idx) {} + + /// Streams the result to global memory + CUTLASS_DEVICE + void operator()( + OutputOp const& output_op, ///< Output operator + OutputTileIterator + destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& + accumulators, ///< Complete warp-level accumulator tile + OutputTileSourceIterator + source_iterator) { ///< Threadblock tile coordinate in GEMM (in units + ///< of threadblock tiles) + + if (!output_op.is_source_needed()) { + compute_source_not_needed_(output_op, destination_iterator, accumulators); + } else { + compute_source_needed_( + output_op, destination_iterator, accumulators, source_iterator); + } + } + CUTLASS_DEVICE + void operator()( + OutputOp const& output_op, ///< Output operator + OutputTileIterator + destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& + accumulators) { ///< Complete warp-level accumulator tile + compute_source_not_needed_(output_op, destination_iterator, accumulators); + } + + private: + template + struct acc2smem_source_not_needed; + + template + struct acc2smem_source_not_needed> { + template + CUTLASS_DEVICE static void helper( + AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator& warp_tile_iterator) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { + typename AccumulatorFragmentIterator::Fragment accum_fragment; + + accum_fragment_iterator.load(accum_fragment); + ++accum_fragment_iterator; + + warp_tile_iterator.store(accum_fragment); + if (p < Base::kFragmentsPerIteration - 1) { + warp_tile_iterator.add_pointer_offset(kSmemPointerOffset); + } + } + + if (Base::kFragmentsPerIteration > 1) { + warp_tile_iterator.add_pointer_offset( + kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); + } + } + + CUTLASS_DEVICE + static void push( + size_t pos, + AccumulatorFragmentIterator const& iterator_begin, + WarpTileIterator& warp_tile_iterator) { + int dummy[] = { + (pos == (Seq * Base::kFragmentsPerIteration)) && + (helper( + iterator_begin, warp_tile_iterator), + 0)...}; + + CUTLASS_UNUSED(dummy[0]); + } + }; + + static_assert( + kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, + "One of these must be exactly 1."); + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_not_needed_( + OutputOp const& output_op, ///< Output operator + OutputTileIterator + destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& + accumulators ///< Complete warp-level accumulator tile + ) { + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + + // + // Iterate over accumulator tile + // + +#pragma unroll( \ + IterationsUnroll \ + ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration \ + : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; + iter += Base::kFragmentsPerIteration) { + // + // Convert and store fragment + // + + __syncthreads(); + + acc2smem_source_not_needed>:: + push(iter, accum_fragment_iterator, this->warp_tile_iterator_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { + typename SharedLoadIterator::Fragment + aligned_accum_fragment[kPartitionsK]; + + shared_load_iterator_.load(aligned_accum_fragment[0]); + + if (p < Base::kFragmentsPerIteration - 1) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + } else if (kPartitionsK > 1) { + plus add_fragments; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator_.load(aligned_accum_fragment[i]); + aligned_accum_fragment[0] = add_fragments( + aligned_accum_fragment[0], aligned_accum_fragment[i]); + } + + shared_load_iterator_.add_pointer_offset( + (1 - kPartitionsK) * kSmemPointerOffset); + } + + // + // Compute the output result + // + + typename OutputTileIterator::Fragment output_fragment; + + apply_output_operator_source_not_needed_( + destination_iterator.thread_start_row(), + output_fragment, + output_op, + aligned_accum_fragment[0]); + + // + // Store the final result + // + + destination_iterator.store(output_fragment); + ++destination_iterator; + } + + if (Base::kFragmentsPerIteration > 1) { + shared_load_iterator_.add_pointer_offset( + kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); + } + } + } + + template + struct acc2smem_source_needed; + + template + struct acc2smem_source_needed> { + template + CUTLASS_DEVICE static void helper( + AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator& warp_tile_iterator) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + typename AccumulatorFragmentIterator::Fragment accum_fragment; + accum_fragment_iterator.load(accum_fragment); + warp_tile_iterator.store(accum_fragment); + } + + CUTLASS_DEVICE + static void push( + size_t pos, + AccumulatorFragmentIterator const& iterator_begin, + WarpTileIterator& warp_tile_iterator) { + int dummy[] = { + (pos == Seq) && + (helper(iterator_begin, warp_tile_iterator), 0)...}; + } + }; + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_needed_( + OutputOp const& output_op, ///< Output operator + OutputTileIterator + destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& + accumulators, ///< Complete warp-level accumulator tile + OutputTileSourceIterator + source_iterator ///< Threadblock tile coordinate in GEMM (in units of + ///< threadblock tiles) + ) { + typename OutputTileSourceIterator::Fragment source_fragment[2]; + + source_fragment[0].clear(); + source_iterator.load(source_fragment[0]); + ++source_iterator; + source_fragment[1].clear(); + + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + + // + // Iterate over accumulator tile + // + +#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { + if (iter > 0) { + __syncthreads(); + } + // + // Load the source for next iteration (pipelining) + // + + if (iter + 1 < OutputTileIterator::kIterations) { + source_iterator.load(source_fragment[(iter + 1) % 2]); + } + ++source_iterator; + acc2smem_source_needed< + cutlass::make_index_sequence>:: + push(iter, accum_fragment_iterator, this->warp_tile_iterator_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + typename SharedLoadIterator::Fragment + aligned_accum_fragment[kPartitionsK]; + + shared_load_iterator_.load(aligned_accum_fragment[0]); + + // If the number of k-slices is > 1 - perform a reduction amongst the + // k-slices + if (kPartitionsK > 1) { + plus add_fragments; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator_.load(aligned_accum_fragment[i]); + aligned_accum_fragment[0] = add_fragments( + aligned_accum_fragment[0], aligned_accum_fragment[i]); + } + + shared_load_iterator_.add_pointer_offset( + (1 - kPartitionsK) * kSmemPointerOffset); + } + + // + // Compute the output result + // + + typename OutputTileIterator::Fragment output_fragment; + + apply_output_operator_( + destination_iterator.thread_start_row(), + output_fragment, + output_op, + aligned_accum_fragment[0], + source_fragment[iter % 2]); + + // + // Store the final result + // + + destination_iterator.store(output_fragment); + ++destination_iterator; + } + } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_( + int begin_row, + typename OutputTileIterator::Fragment& output_fragment, + OutputOp const& output_op, ///< Output operator + typename SharedLoadIterator::Fragment const& aligned_accum_fragment, + typename OutputTileSourceIterator::Fragment const& source_fragment) { + OutputAccessType* output_frag_ptr = + reinterpret_cast(&output_fragment); + + AccumulatorAccessType const* compute_frag_ptr = + reinterpret_cast(&aligned_accum_fragment); + + SourceAccessType const* source_frag_ptr = + reinterpret_cast(&source_fragment); + + int const kOutputOpIterations = OutputTileIterator::Fragment::kElements / + OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + // Call the output operator + output_frag_ptr[i] = ApplyEpilogueOp::apply( + output_op, + begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess), + compute_frag_ptr[i], + source_frag_ptr[i]); + } + } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_source_not_needed_( + int begin_row, + typename OutputTileIterator::Fragment& output_fragment, + OutputOp const& output_op, ///< Output operator + typename SharedLoadIterator::Fragment const& aligned_accum_fragment) { + OutputAccessType* output_frag_ptr = + reinterpret_cast(&output_fragment); + + AccumulatorAccessType const* compute_frag_ptr = + reinterpret_cast(&aligned_accum_fragment); + + int const kOutputOpIterations = OutputTileIterator::Fragment::kElements / + OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + // Call the output operator + output_frag_ptr[i] = ApplyEpilogueOp::apply( + output_op, + begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess), + compute_frag_ptr[i]); + } + } + + constexpr int CUTLASS_HOST_DEVICE getRowOffset(int i) { + using ThreadMap = typename OutputTileIterator::ThreadMap; + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + int frag_row_idx = + (row + + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + int frag_idx = ThreadMap::kElementsPerAccess * + (frag_row_idx * ThreadMap::Iterations::kColumn + column); + if (i < frag_idx + ThreadMap::kElementsPerAccess) { + return row_offset; + } + } + } + } + } + return -1; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_rescale_output.h b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_rescale_output.h new file mode 100644 index 0000000000000000000000000000000000000000..1d2d4008af0717e8391cfa0b745150bb16d443c9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_rescale_output.h @@ -0,0 +1,238 @@ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory + to match canonical tensor layouts in global memory. Epilogues support + conversion and reduction operations. + + This is a copy of cutlass/epilogue/threadblock/epilogue.h that can + handle "row_id" as a first argument, as uses it to get the corresponding + `m_prime` / `s_prime` to rescale the output. +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies a linear combination operator to an array of elements. +// output <- alpha * accumulator + beta * source +// with: +// alpha = 1 / s_prime (to normalize when isLast=True, 1 otherwise) +// beta = alpha / m_prime (renormalize the output when the max changes) +// source is the current output +template < + typename ElementOutput_, ///< Data type used to store tensors + typename ElementSource_, //< Data type for source (usually matches + //`ElementOutput`) + int Count, ///< Number of elements computed per operation. + ///< Usually it is 128/sizeof_bits, + ///< but we use 64 or 32 sometimes when there are not enough data + ///< to store + typename ElementAccumulator_, ///< Accumulator data type + typename ElementCompute_, ///< Data type used to compute linear combination + bool isFirst, + bool isLast, + typename FragmentAlphaBeta_, + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest> +class MemoryEfficientAttentionNormalize { + public: + using ElementOutput = ElementOutput_; + using ElementSource = ElementSource_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + + static int const kCount = Count; + + using FragmentOutput = Array; + using FragmentSource = Array; + using FragmentAccumulator = Array; + using ComputeFragment = Array; + using FragmentAlphaBeta = FragmentAlphaBeta_; + + static FloatRoundStyle const kRound = Round; + + private: + // + // Data members + // + + FragmentAlphaBeta const& s_prime_; + FragmentAlphaBeta const& m_prime_; + + public: + /// Constructs the function object, possibly loading from pointers in host + /// memory + CUTLASS_HOST_DEVICE + MemoryEfficientAttentionNormalize( + FragmentAlphaBeta const& s_prime, + FragmentAlphaBeta const& m_prime) + : s_prime_(s_prime), m_prime_(m_prime) {} + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return !isFirst; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) {} + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + int row, + FragmentAccumulator const& accumulator, + FragmentSource const& source) const { + assert(!isFirst); + + // Convert source to interal compute numeric type + NumericArrayConverter + source_converter; + NumericArrayConverter + accumulator_converter; + + // Convert to destination numeric type + NumericArrayConverter + destination_converter; + + ComputeFragment converted_source = source_converter(source); + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + ComputeFragment intermediate; + + multiplies mul_add_source; + multiply_add mul_add_accumulator; + + // Row sums for full masked out rows are 0, we set them to 1 + // In order to avoid NaNs in the output and instead sem them to 0. + ElementCompute denom = s_prime_[row] == 0 ? 1 : s_prime_[row]; + ElementCompute alpha = isLast ? (1 / denom) : 1; + ElementCompute beta = alpha * m_prime_[row]; + + intermediate = mul_add_source(beta, converted_source); // X = beta * C + + intermediate = mul_add_accumulator( + alpha, converted_accumulator, intermediate); // D = alpha * Accum + X + + return destination_converter(intermediate); + } + + /// Computes linear scaling: D = alpha * accumulator + CUTLASS_HOST_DEVICE + FragmentOutput operator()(int row, FragmentAccumulator const& accumulator) + const { + assert(isFirst); + + // Convert source to interal compute numeric type + NumericArrayConverter + accumulator_converter; + + // Convert to destination numeric type + NumericArrayConverter + destination_converter; + + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + ComputeFragment intermediate; + multiplies mul_accumulator; + + // Row sums for full masked out rows are 0, we set them to 1 + // In order to avoid NaNs in the output and instead sem them to 0. + ElementCompute denom = s_prime_[row] == 0 ? 1 : s_prime_[row]; + ElementCompute alpha = isLast ? (1 / denom) : 1; + + intermediate = mul_accumulator( + alpha, converted_accumulator); // X = alpha * C + uniform + + return destination_converter(intermediate); + } +}; + +} // namespace thread + +namespace threadblock { +template < + typename EO, + typename ES, + int Count, + typename EA, + typename EC, + bool F, + bool L, + typename FAB, + FloatRoundStyle R> +struct ApplyEpilogueOp> { + using Op = thread:: + MemoryEfficientAttentionNormalize; + static CUTLASS_DEVICE typename Op::FragmentOutput apply( + Op const& output_op, + int row_id, + typename Op::FragmentAccumulator const& accum, + typename Op::FragmentSource const& source) { + return output_op(row_id, accum, source); + } + static CUTLASS_DEVICE typename Op::FragmentOutput apply( + Op const& output_op, + int row_id, + typename Op::FragmentAccumulator const& accum) { + return output_op(row_id, accum); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_thread_apply_logsumexp.h b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_thread_apply_logsumexp.h new file mode 100644 index 0000000000000000000000000000000000000000..f0e6365bf7bdaff2cf1186dd3ffd322d6e0c3426 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_thread_apply_logsumexp.h @@ -0,0 +1,175 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing linear combination operations used by epilogues. +*/ + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct ArrayExponential { + CUTLASS_HOST_DEVICE + Array operator()( + Array const& input) const { + Array result; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ElementsPerAccess; ++i) { + result[i] = expf(input[i]); + } + + return result; + } +}; + +template +struct ArrayExponential { + CUTLASS_DEVICE + Array operator()( + Array const& input) const { + Array result; + + int const kVectorCount = ElementsPerAccess / 2; + + __half2 const* input_ptr = + reinterpret_cast<__half2 const*>(input.raw_data()); + __half2* res_ptr = reinterpret_cast<__half2*>(result.raw_data()); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kVectorCount; ++i) { + res_ptr[i] = h2exp(input_ptr[i]); + } + + return result; + } +}; +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies: +/// output <- (input - lse).exp() +template < + typename ElementOutput_, // output + typename ElementLSE_, // accumulator from LSE + typename ElementAccumulator_, // accumulator from matmul + typename ElementCompute_, // intermediate compute (and exp calculation) + int ElementsPerAccess> +class ApplyLogSumExp { + public: + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + using ElementLSE = ElementLSE_; + + static int const kElementsPerAccess = ElementsPerAccess; + static int const kCount = kElementsPerAccess; + static const ScaleType::Kind kScale = + cutlass::epilogue::thread::ScaleType::NoBetaScaling; + + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using FragmentCompute = Array; + using FragmentLSE = Array; + using FragmentScaleBias = FragmentLSE; // Used by epilogue_smem_accumulator.h + + public: + // + // Methods + // + + CUTLASS_HOST_DEVICE + ApplyLogSumExp() {} + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return true; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) {} + + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const& AB, + FragmentLSE const& scale_unused, + // bias used as LSE + FragmentLSE const& bias) const { + FragmentCompute frag_AB = NumericArrayConverter< + ElementCompute, + ElementAccumulator, + kElementsPerAccess>()(AB); + FragmentCompute frag_lse_compute = + NumericArrayConverter()( + bias); + FragmentCompute frag_compute; + + minus minus_lse; + detail::ArrayExponential apply_exp; + frag_compute = minus_lse(frag_AB, frag_lse_compute); + frag_compute = apply_exp(frag_compute); + + return NumericArrayConverter< + ElementOutput, + ElementCompute, + kElementsPerAccess>()(frag_compute); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma.h b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma.h new file mode 100644 index 0000000000000000000000000000000000000000..9be54f025d343bcce28b21ac4957a7fcfbacc815 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma.h @@ -0,0 +1,100 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include + +#include +#include +template +struct MakeCustomMma; + +template < + typename Shape, + typename IteratorA, + typename SmemIteratorA, + cutlass::arch::CacheOperation::Kind CacheOpA, + typename IteratorB, + typename SmemIteratorB, + cutlass::arch::CacheOperation::Kind CacheOpB, + typename ElementC, + typename LayoutC, + typename Policy, + int Stages, + cutlass::gemm::SharedMemoryClearOption SharedMemoryClear, + int kMaxK> +struct MakeCustomMma< + cutlass::gemm::threadblock::MmaMultistage< + Shape, + IteratorA, + SmemIteratorA, + CacheOpA, + IteratorB, + SmemIteratorB, + CacheOpB, + ElementC, + LayoutC, + Policy, + Stages, + SharedMemoryClear>, + kMaxK> { + // Reduce the number of stages if we don't need that many + static int constexpr kStages = + kMaxK == cutlass::platform::numeric_limits::max() + ? Stages + : cutlass::const_min( + Stages, + (kMaxK + int(Shape::kK) - 1) / int(Shape::kK)); + using Mma = cutlass::gemm::threadblock::CustomMmaMultistage< + Shape, + IteratorA, + SmemIteratorA, + CacheOpA, + IteratorB, + SmemIteratorB, + CacheOpB, + ElementC, + LayoutC, + Policy, + kStages, + SharedMemoryClear, + kMaxK>; +}; + +template < + typename Shape, + typename IteratorA, + typename SmemIteratorA, + typename IteratorB, + typename SmemIteratorB, + typename ElementC, + typename LayoutC, + typename Policy, + int kMaxK> +struct MakeCustomMma< + cutlass::gemm::threadblock::MmaPipelined< + Shape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + Policy>, + kMaxK> { + using Mma = cutlass::gemm::threadblock::CustomMmaPipelined< + Shape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + Policy>; +}; diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_base.h b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_base.h new file mode 100644 index 0000000000000000000000000000000000000000..b833b27efd05a38034bdf35418a28dbcaa6a5998 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_base.h @@ -0,0 +1,183 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class CustomMmaBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape< + Shape::kM / WarpGemm::kM, + Shape::kN / WarpGemm::kN, + Shape::kK / WarpGemm::kK>; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + template + struct OperandSharedStorage { + AlignedBuffer buffer; + using TensorRef = TensorRef; + + CUTLASS_DEVICE + static OperandLayout Layout() { + return OperandLayout::packed({OperandShape::kRow, OperandShape::kColumn}); + } + + /// Returns a TensorRef to the operand + CUTLASS_HOST_DEVICE + TensorRef ref() { + return TensorRef{buffer.data(), Layout()}; + } + }; + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape< + Shape::kM + Policy::SmemPaddingA::kRow, + Shape::kK * kStages + Policy::SmemPaddingA::kColumn>; + + /// Shape of the B matrix operand in shared memory + using ShapeB = MatrixShape< + Shape::kK * kStages + Policy::SmemPaddingB::kRow, + Shape::kN + Policy::SmemPaddingB::kColumn>; + + using SharedStorageA = OperandSharedStorage< + typename Operator::ElementA, + ShapeA, + typename Operator::LayoutA>; + using SharedStorageB = OperandSharedStorage< + typename Operator::ElementB, + ShapeB, + typename Operator::LayoutB>; + using TensorRefA = typename SharedStorageA::TensorRef; + using TensorRefB = typename SharedStorageB::TensorRef; + + struct SharedStorage { + /// Buffer for A operand + SharedStorageA operand_A; + + /// Buffer for B operand + SharedStorageB operand_B; + }; + + protected: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + CustomMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorageA& shared_storageA, + SharedStorageB& shared_storageB, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_A_(shared_storageA.ref(), lane_idx), + warp_tile_iterator_B_(shared_storageB.ref(), lane_idx) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_multistage.h b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_multistage.h new file mode 100644 index 0000000000000000000000000000000000000000..925f9e153f484d52c0f0fd7c49030335e1c6bfb3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_multistage.h @@ -0,0 +1,768 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Upper boundon the K dimension + int kMaxK = cutlass::platform::numeric_limits::max(), + /// Used for partial specialization + typename Enable = bool> +class CustomMmaMultistage : public CustomMmaBase { + public: + ///< Base class + using Base = CustomMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + static_assert( + Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + }; + + static bool const kSmemContainsEntireMat = kMaxK <= Shape::kK * Stages; + static constexpr int kNumStagesConcurrentLoad = + kSmemContainsEntireMat ? Stages : Stages - 1; + + private: + using WarpLoadedFragmentA = typename Operator::FragmentA; + using WarpLoadedFragmentB = typename Operator::FragmentB; + using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + + private: + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + bool prologue_done_; + + // Set to `True` to ensure the accumulator will be zero outside the GEMM + // footprint + bool zero_outside_bounds_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + CustomMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorageA& shared_storageA, + typename Base::SharedStorageB& shared_storageB, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storageA.ref(), thread_idx), + smem_iterator_B_(shared_storageB.ref(), thread_idx), + prologue_done_(false), + zero_outside_bounds_(false) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + CUTLASS_DEVICE + CustomMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& st, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : CustomMmaMultistage( + st.operand_A, + st.operand_B, + thread_idx, + warp_idx, + lane_idx) {} + + CUTLASS_DEVICE + void set_prologue_done(bool value) { + prologue_done_ = value; + } + + CUTLASS_DEVICE + void set_zero_outside_bounds(bool value) { + zero_outside_bounds_ = value; + } + + template + CUTLASS_DEVICE static void prologue( + typename Base::SharedStorage& shared_storage, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + int thread_idx, + int problem_size_k) { + prologue( + shared_storage.operand_A, + shared_storage.operand_B, + iterator_A, + iterator_B, + thread_idx, + problem_size_k); + } + + template + CUTLASS_DEVICE static void prologue( + typename Base::SharedStorageA& shared_storageA, + typename Base::SharedStorageB& shared_storageB, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + int thread_idx, + int problem_size_k) { + SmemIteratorA smem_iterator_A(shared_storageA.ref(), thread_idx); + SmemIteratorB smem_iterator_B(shared_storageB.ref(), thread_idx); + int32_t iter = (problem_size_k + Base::Shape::kK - 1) / Base::Shape::kK; + _prologue( + iterator_A, iterator_B, iter, smem_iterator_A, smem_iterator_B); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance( + IteratorA& iterator_A, + IteratorB& iterator_B, + int group_start_A = 0, + int group_start_B = 0) { + iterator_A.set_iteration_index( + group_start_A * IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (zero_outside_bounds_ || + SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index( + group_start_B * IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + if (zero_outside_bounds_ || + SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + template + CUTLASS_DEVICE static void _prologue( + IteratorA& iterator_A, + IteratorB& iterator_B, + int32_t& gemm_k_iterations, + SmemIteratorA& smem_iterator_A_, + SmemIteratorB& smem_iterator_B_) { + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < kNumStagesConcurrentLoad; + ++stage, --gemm_k_iterations) { + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + if (kLoadA) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + } + + ++iterator_A; + } + + ++smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + if (kLoadB) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + } + + ++iterator_B; + } + + ++smem_iterator_B_; + } + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + smem_iterator_A_.add_tile_offset({0, 1}); + smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< initial value of accumulator + FragmentC const& src_accum) { + // + // Prologue + // + + if (!prologue_done_) { + _prologue( + iterator_A, + iterator_B, + gemm_k_iterations, + smem_iterator_A_, + smem_iterator_B_); + } else if (!kSmemContainsEntireMat) { + _prologue( + iterator_A, + iterator_B, + gemm_k_iterations, + smem_iterator_A_, + smem_iterator_B_); + } else { + gemm_k_iterations -= kNumStagesConcurrentLoad; + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // + // Clear the remaining tiles of SMEM. This is a functional requirement for + // some kernels so that all accumulator elements outside the GEMM footprint + // are zero. + // + + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + /// Iterator to write threadblock-scoped tile of A operand to shared + /// memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + + typename IteratorA::AccessType zero_A; + zero_A.clear(); + + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + /// Iterator to write threadblock-scoped tile of B operand to shared + /// memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; + + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_B.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B; + } + } + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA warp_loaded_frag_A[2]; + WarpLoadedFragmentB warp_loaded_frag_B[2]; + WarpTransformedFragmentA warp_transformed_frag_A[2]; + WarpTransformedFragmentB warp_transformed_frag_B[2]; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma.transform( + warp_transformed_frag_A[0], + warp_transformed_frag_B[0], + warp_loaded_frag_A[0], + warp_loaded_frag_B[0]); + + // tf32x3 kernels use staging accumulation. warp_mma uses a temporary + // accumulator and this temporary accumulator is added to the final + // accumulator once in every mainloop iteration. + plus plus_accum; + + FragmentC tmp_accum; + + if (platform::is_same< + typename Operator::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + tmp_accum.clear(); + } + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-kNumStagesConcurrentLoad);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + + // In case of a non-circular buffer ("kSmemContainsEntireMat") + // make sure we don't load out of bounds data. + if (!kSmemContainsEntireMat || + gemm_k_iterations > (-kNumStagesConcurrentLoad) || + warp_mma_k < Base::kWarpGemmIterations - 1) { + this->warp_tile_iterator_A_.load( + warp_loaded_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load( + warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + } + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k > 0) + warp_mma.transform( + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B[warp_mma_k % 2]); + + if (platform::is_same< + typename Operator::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + warp_mma( + tmp_accum, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + tmp_accum); + + if (warp_mma_k == 0) { + accum = plus_accum(accum, tmp_accum); + tmp_accum.clear(); + } + } else { + warp_mma( + accum, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + accum); + } + + // Issue global->shared copies for the this stage + if (!kSmemContainsEntireMat && + warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance( + iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + if (!kSmemContainsEntireMat) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance( + iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); + } + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (!kSmemContainsEntireMat && + smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations) + warp_mma.transform( + warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + } + } + + if (platform::is_same< + typename Operator::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + accum = plus_accum(accum, tmp_accum); + } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + // commit and drain all pending and predicated cp.async pnz from the GEMM + // mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_pipelined.h b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_pipelined.h new file mode 100644 index 0000000000000000000000000000000000000000..de6d2e6dee917009ff50c5fc0cf738c0a4eb7c70 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_pipelined.h @@ -0,0 +1,402 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include +#include +#include +#include + +#include +#include + +#include + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Transformation applied to A operand + typename TransformA_ = NumericArrayConverter< + typename SmemIteratorA_::Element, + typename IteratorA_::Element, + IteratorA_::Fragment::kElements>, + /// + /// Transformation applied to B operand + typename TransformB_ = NumericArrayConverter< + typename SmemIteratorB_::Element, + typename IteratorB_::Element, + IteratorB_::Fragment::kElements>, + /// Used for partial specialization + typename Enable = bool> +class CustomMmaPipelined : public CustomMmaBase { + public: + ///< Base class + using Base = CustomMmaBase; + + using Shape = + Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA = + IteratorA_; ///< Iterates over tiles of A operand in global memory + using IteratorB = + IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + using TransformA = TransformA_; + using TransformB = TransformB_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA = typename IteratorA::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // statically assert kStages for MmaPipelined is two (Double-buffered pipeline) + static_assert( + (Base::kStages == 2), + "MmaPipelined requires kStages set to value 2"); + + static bool const kSmemContainsEntireMat = false; + + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + + protected: + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + CustomMmaPipelined( + typename Base::SharedStorageA& shared_storageA, + typename Base::SharedStorageB& shared_storageB, + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ) + : Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storageA.ref(), thread_idx), + smem_iterator_B_(shared_storageB.ref(), thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + CUTLASS_DEVICE + CustomMmaPipelined( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& st, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : CustomMmaPipelined( + st.operand_A, + st.operand_B, + thread_idx, + warp_idx, + lane_idx) {} + + CUTLASS_DEVICE + void set_prologue_done(bool value) { + // NOT IMPLEMENTED FOR PIPELINED + } + + CUTLASS_DEVICE + void set_zero_outside_bounds(bool value) { + // NOT NEEDED FOR PIPELINED + // shared memory will always be zero-filled + } + + template + CUTLASS_DEVICE static void prologue( + typename Base::SharedStorage& shared_storage, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + int thread_idx, + int problem_size_k) { + prologue( + shared_storage.operand_A, + shared_storage.operand_B, + iterator_A, + iterator_B, + thread_idx, + problem_size_k); + } + + template + CUTLASS_DEVICE static void prologue( + typename Base::SharedStorageA& shared_storageA, + typename Base::SharedStorageB& shared_storageB, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + int thread_idx, + int problem_size_k) { + // NOT IMPLEMENTED FOR PIPELINED + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + FragmentC const& src_accum, ///< source accumulator tile + TransformA transform_A = + TransformA(), ///< transformation applied to A fragment + TransformB transform_B = + TransformB()) { ///< transformation applied to B fragment + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentA tb_frag_A; + FragmentB tb_frag_B; + + tb_frag_A.clear(); + tb_frag_B.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + this->smem_iterator_A_.store(transform_A(tb_frag_A)); + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_A.clear_mask(gemm_k_iterations <= 1); + iterator_B.clear_mask(gemm_k_iterations <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* + // issuing shared memory loads (which have the tighest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + // Write fragments to shared memory + this->smem_iterator_A_.store(transform_A(tb_frag_A)); + + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } else { + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k == 0) { + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_A.clear_mask(gemm_k_iterations <= 2); + iterator_B.clear_mask(gemm_k_iterations <= 2); + } + + warp_mma( + accum, + warp_frag_A[warp_mma_k % 2], + warp_frag_B[warp_mma_k % 2], + accum); + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/find_default_mma.h b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/find_default_mma.h new file mode 100644 index 0000000000000000000000000000000000000000..aec5abf085cf4fa055f94e892aa2011f16022a1d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/find_default_mma.h @@ -0,0 +1,167 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +/*! \file + \brief Cutlass provides helper template functions to figure out the right + data structures to instantiate to run a GEMM with various parameters (see + `cutlass/gemm/threadblock/default_mma.h`). However, due to template + instantiation priority rules, it will only create an MmaMultiStage with + kStages=3 (otherwise creates an MmePipelined - which is not compatible with + FastF32). kStages=3 uses too much shared memory and we want to use kStages=2, + so we just copy-pasted some code from `default_mma.h` and + `default_mma_core.h` files and wrapped this template to allow our use case. + + This is really only for the FastF32 case - aka using TensorCores with fp32. +*/ + +#pragma once + +#include +#include +#include +#include +#include + +namespace cutlass { +namespace gemm { +namespace threadblock { + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Layout type for C and D matrix operand + typename LayoutC, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + typename Enable_ = void> +struct FindDefaultMma { + static constexpr bool AccumulatorsInRowMajor = false; + static constexpr SharedMemoryClearOption SharedMemoryClear = + SharedMemoryClearOption::kNone; + using DefaultMma = cutlass::gemm::threadblock::DefaultMma< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementAccumulator, + LayoutC, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + Stages, + Operator, + AccumulatorsInRowMajor, + SharedMemoryClear>; +}; + +/// Specialization for sm80 / FastF32 / multistage with kStages=2 +template < + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + int kStages, + typename Operator> +struct FindDefaultMma< + ElementA_, + LayoutA_, + kAlignmentA, + ElementB_, + LayoutB_, + kAlignmentB, + ElementAccumulator, + layout::RowMajor, + arch::OpClassTensorOp, + arch::Sm80, + ThreadblockShape, + WarpShape, + InstructionShape, + kStages, + Operator, + typename cutlass::platform::enable_if<(kAlignmentA > 1)>::type> { + using LayoutC = layout::RowMajor; + using OperatorClass = arch::OpClassTensorOp; + using ArchTag = arch::Sm80; + + using DefaultMma_ = cutlass::gemm::threadblock::DefaultMma< + ElementA_, + LayoutA_, + kAlignmentA, + ElementB_, + LayoutB_, + kAlignmentB, + ElementAccumulator, + LayoutC, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + 3, + Operator>; + struct DefaultMma : DefaultMma_ { + using MmaCore_ = typename DefaultMma_::MmaCore; + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< + typename MmaCore_::Shape, + typename DefaultMma_::IteratorA, + typename MmaCore_::SmemIteratorA, + MmaCore_::kCacheOpA, + typename DefaultMma_::IteratorB, + typename MmaCore_::SmemIteratorB, + MmaCore_::kCacheOpB, + ElementAccumulator, + LayoutC, + typename MmaCore_::MmaPolicy, + kStages>; + }; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/mma_accum_lambda_iterator.h b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/mma_accum_lambda_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..dc0774cfafa63302d10eb006dc56fa35e624d7b9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/mma_accum_lambda_iterator.h @@ -0,0 +1,354 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include +#include +#include +#include + +/* +TensorCores have different accumulator layouts. +This file provides a class to easily map the accumulator +i-th element with the corresponding matrix row/col. +*/ + +template +struct AccumLambdaIteratorSm80 { + static_assert( + cutlass::platform:: + is_same::value, + "only RowMajor is supported"); + + using Policy = typename T::Policy; + using InstructionShape = typename T::InstructionShape; + using OpDelta = typename T::OpDelta; + using Shape = typename T::Shape; + static int const kElementsPerAccess = InstructionShape::kN / 4; + static int const kRowsPerTile = 8; + static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile; + + static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset( + int8_t lane_id, + int8_t warp_id, + typename T::TensorCoord const& tile_offset) { + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + return cutlass::MatrixCoord( + quad + tile_offset.row() * Shape::kRow, + lane_in_quad * kElementsPerAccess + + tile_offset.column() * Shape::kColumn); + } + + template + CUTLASS_DEVICE static void iterateRows( + cutlass::MatrixCoord& lane_offset, + FA beginRow, + FB op, + FC endRow) { + // See cutlass/gemm/warp/mma_tensor_op_tile_iterator.h + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < kAccumulatorRows; ++row) { + int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + + row * kRowsPerTile + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { + int mma_accum_start = kAccumulatorRows * kElementsPerAccess * + (mma_n * Policy::MmaIterations::kRow + mma_m); + CUTLASS_PRAGMA_UNROLL + for (int col = 0; col < kElementsPerAccess; ++col) { + int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + + col + lane_offset.column(); + int idx = mma_accum_start + row * kElementsPerAccess + col; + op(accum_m, accum_n, idx); + } + } + + endRow(accum_m); + } + } + } + + template + CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) { + // In each warp, 4 threads will work on the same row + // - the ones with the same `quad` + auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1); + myValue = fn(myValue, otherV); + otherV = __shfl_xor_sync(0xffffffff, myValue, 2); + myValue = fn(myValue, otherV); + int lane_in_quad = (lane_id & 3); + return lane_in_quad == 0; + } +}; + +template +struct AccumLambdaIteratorSm70 { + static_assert( + cutlass::platform:: + is_same::value, + "only RowMajor is supported"); + + using Policy = typename T::Policy; + using InstructionShape = typename T::InstructionShape; + using OpDelta = typename T::OpDelta; + using Shape = typename T::Shape; + using Element = accum_t; + + static int const kElementsPerPartial = 4; + using EleShapePerPatial = typename cutlass::platform::conditional< + cutlass::platform::is_same::value, + cutlass::MatrixShape<2, 2>, + cutlass::MatrixShape<1, 4>>::type; + static int const kElementsPerMma = 8; + static int const kAccumulatorPatials = 2; + using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>; + + static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset( + int8_t lane_id, + int8_t warp_id, + typename T::TensorCoord const& tile_offset) { + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + int accum_m, accum_n; + + if (cutlass::platform::is_same::value) { + // (quad[2],quad[0])+lane_in_quad[0] + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1); + // (quad[1])+lane_in_quad[1] + accum_n = + ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials + + (lane_in_quad & 2); + } else { + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + + lane_in_quad; // (quad[2],quad[0]) + accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials; + } + return cutlass::MatrixCoord( + accum_m + tile_offset.row() * Shape::kRow, + accum_n + tile_offset.column() * Shape::kColumn); + } + + template + CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) { + static_assert( + cutlass::platform::is_same::value, + "update to support non-float accum"); + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16 + // T0 & T2 share same line within a quad + auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 1); + myValue = fn(myValue, otherV); + // quad 0 and quad 2 are on the same lines + otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 3); + myValue = fn(myValue, otherV); + return (lane_id & ((1 << 1) | (1 << 3))) == 0; + } + + template + CUTLASS_DEVICE static void iterateRows( + cutlass::MatrixCoord& lane_offset, + FA beginRow, + FB op, + FC endRow) { + CUTLASS_PRAGMA_UNROLL + for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < EleShapePerPatial::kRow; ++m) { + int accum_m = tile_m * Policy::InterleavedTile::kRow + + mma_m * QuadShapePerPatialMma::kRow + m * 2 + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; + ++tile_n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; + ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < kAccumulatorPatials; ++p) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < EleShapePerPatial::kColumn; ++n) { + int mma_accum_start = + (((tile_n * Policy::TileIterations::kRow + tile_m) * + Policy::MmaIterations::kColumn + + mma_n) * + Policy::MmaIterations::kRow + + mma_m) * + kElementsPerMma; + int accum_n = tile_n * Policy::InterleavedTile::kColumn + + mma_n * QuadShapePerPatialMma::kColumn + + p * Policy::InterleavedTile::kColumn / 2 + n + + lane_offset.column(); + int idx = mma_accum_start + p * kElementsPerPartial + + m * EleShapePerPatial::kColumn + n; + op(accum_m, accum_n, idx); + } + } + } + } + endRow(accum_m); + } + } + } + } +}; + +template +struct AccumLambdaIteratorSimt { + using Policy = typename T::Policy; + using Iterations = typename T::Iterations; + using Element = typename T::Element; + using Delta = typename T::Delta; + using Shape = typename T::Shape; + static_assert( + cutlass::platform:: + is_same::value, + "only RowMajor is supported"); + + template + CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) { + CUTLASS_PRAGMA_UNROLL + for (int bit = 1; bit < Policy::WarpShape::kColumn; bit *= 2) { + auto otherV = __shfl_xor_sync(0xffffffff, myValue, bit); + myValue = fn(myValue, otherV); + } + return (lane_id & (Policy::WarpShape::kColumn - 1)) == 0; + } + + template + CUTLASS_DEVICE static void iterateRows( + cutlass::MatrixCoord& lane_offset, + FA beginRow, + FB op, + FC endRow) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) { + int accum_m = mma_m * Delta::kRow + m + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { + int accum_n = + mma_n * Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN + + lane_offset.column(); + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) { + int idx = n + + Policy::LaneMmaShape::kN * + (mma_n + + Iterations::kColumn * + (m + mma_m * Policy::LaneMmaShape::kM)); + op(accum_m, accum_n + n, idx); + } + } + endRow(accum_m); + } + } + } + + static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset( + int8_t lane_id, + int8_t warp_id, + typename T::TensorCoord const& tile_offset) { + static_assert( + cutlass::platform::is_same< + typename Policy::LaneLayout, + cutlass::layout::RowMajorInterleaved<1>>::value, + ""); + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + cutlass::MatrixCoord lane_offset = lane_layout.inverse(lane_id) * + cutlass::MatrixCoord(Policy::LaneMmaShape::kM, + Policy::LaneMmaShape::kN); + return lane_offset + + tile_offset * cutlass::MatrixCoord(Shape::kRow, Shape::kColumn); + } +}; + +template +struct DefaultMmaAccumLambdaIterator; + +// Simt +template +struct DefaultMmaAccumLambdaIterator< + cutlass::gemm::warp::MmaSimtTileIterator< + S, + cutlass::gemm::Operand::kC, + accum_t, + cutlass::layout::RowMajor, + P, + 1, + 1>, + accum_t, + kWarpSize> { + using WarpIterator = typename cutlass::gemm::warp::MmaSimtTileIterator< + S, + cutlass::gemm::Operand::kC, + accum_t, + cutlass::layout::RowMajor, + P, + 1, + 1>; + using Iterator = AccumLambdaIteratorSimt; +}; + +// TensorOp - Volta +template +struct DefaultMmaAccumLambdaIterator< + cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + cutlass::MatrixShape<1, 1>>, + accum_t, + kWarpSize> { + using WarpIterator = + typename cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + cutlass::MatrixShape<1, 1>>; + using Iterator = AccumLambdaIteratorSm70; +}; + +// TensorOp - Sm75+ +template < + typename S1, + typename S2, + typename S3, + typename accum_t, + int kWarpSize> +struct DefaultMmaAccumLambdaIterator< + cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + S3>, + accum_t, + kWarpSize> { + using WarpIterator = + typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + S3>; + using Iterator = AccumLambdaIteratorSm80; +}; diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/mma_from_smem.h b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/mma_from_smem.h new file mode 100644 index 0000000000000000000000000000000000000000..a5433d074ed86463d66e9d078c4b55aff24a0d40 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/mma_from_smem.h @@ -0,0 +1,1948 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +/// Shared storage object needed by accumulator +/// From 13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h +template < + typename Shape_, + typename Element_, + typename Layout_, + typename Padding_> +class AccumulatorSharedStorage { + public: + // + // Type definitions + // + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using Padding = Padding_; + + /// Tensor reference to the accumulator + using TensorRefAccum = cutlass::TensorRef; + + /// Shape of the accumulator matrix in shared memory + using ShapeAccum = cutlass:: + MatrixShape; + + public: + // + // Data members + // + + /// Buffer for accumulator + cutlass::AlignedBuffer accum; + + public: + // + // Methods + // + + /// Returns a layout object for the Accum matrix + CUTLASS_DEVICE + static Layout LayoutAccum() { + return Layout::packed({ShapeAccum::kRow, ShapeAccum::kColumn}); + } + + /// Returns a TensorRef to the Accumulator + CUTLASS_HOST_DEVICE + TensorRefAccum accum_ref() { + return TensorRefAccum{accum.data(), LayoutAccum()}; + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Taken from +// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + // Maximum K dimension - also the dimension of the shared-memory + // holding `OperandA` + int kMaxK_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Layout in shared-memory of operand A + typename SmemLayoutA, + /// Used for partial specialization + typename Enable = bool> +class MmaBaseFromSharedMemory { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + static constexpr int kMaxK = kMaxK_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape< + Shape::kM / WarpGemm::kM, + Shape::kN / WarpGemm::kN, + Shape::kK / WarpGemm::kK>; + using WarpCount1 = WarpCount; + + /// Number of warp-level GEMM operations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + static int const kWarpGemmIterations1 = kWarpGemmIterations; + + /// Number of stages + static int const kStages = Stages; + + /// If this is true, we fill the entire shmem buffer at start + /// and don't need to iterate through it in a circular fashion + static bool const kSmemContainsEntireB = kMaxK <= Shape::kK * kStages; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = + TensorRef; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the B matrix operand in shared memory + using ShapeB = MatrixShape< + Shape::kK * kStages + Policy::SmemPaddingB::kRow, + Shape::kN + Policy::SmemPaddingB::kColumn>; + + public: + // + // Data members + // + + /// Buffer for B operand + AlignedBuffer operand_B; + + public: + // + // Methods + // + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; + } + }; + + protected: + // + // Data members + // + + // /// Iterator to load a warp-scoped tile of A operand from shared memory + // typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + MmaBaseFromSharedMemory( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + TensorRefB& b_tile, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_B_(b_tile, lane_idx) {} +}; + +namespace { + +// has necessary trait compliance with WarpIteratorFromSmem but doesn't do +// anything, can be default initialized, and uses fragment that takes up +// (almost) no space. this warp iterator is selected at compile time when +// elementwise on-the-fly scaling for operand A is disabled, in which case +// operations related to loading scale factors for operand A get wiped out by +// the compiler. +template +class NoOpWarpIteratorScale { + public: + // in pipelined+multistage MMA implementations we keep an array of fragments. + // if we aren't using scaling we don't want to waste registers on fragments + // of scale elements, so ideally this would be sized 0. + // Since arrays of zero-sized objects are not allowed, using size as 1. + // The compiler will most likely wipe it out anyways. + using Fragment = cutlass::Array; + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale() {} + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale(TensorRef const&, int) {} + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale& add_tile_offset( + typename TensorRef::TensorCoord const&) { + return *this; + } + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale& operator++() { + return *this; + } + + CUTLASS_DEVICE + void load(Fragment&) const {} +}; + +// if scaling is enabled, performs fragment elementwise multiplication between +// fragment and its scaling factor. +template +class FragmentElementwiseScaler; + +// specialization for scaling being enabled. +template +class FragmentElementwiseScaler { + public: + // cast scale_frag to correct type then apply elementwise to fragment + CUTLASS_DEVICE + static Fragment apply(Fragment frag, FragmentScale const& scale_frag) { + Fragment converted_scale_frag = cutlass::NumericArrayConverter< + typename Fragment::Element, + typename FragmentScale::Element, + FragmentScale::kElements>()(scale_frag); + return cutlass::multiplies()(frag, converted_scale_frag); + } +}; + +// specialization for scaling being disabled. doesn't do anything and should +// just get wiped out by the compiler. +template +class FragmentElementwiseScaler { + public: + CUTLASS_DEVICE + static Fragment apply(Fragment frag, FragmentScale const&) { + return frag; + } +}; +} // namespace + +//////////////////////////////////////////////////////////////////////////////// +// Taken from +// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + // BEGIN smem + /// Iterates over the intermediate accumulator tile in shared memory + typename WarpIteratorA_, + /// whether or not to perform elementwise multiplication of A + // by another matrix (A_scale) that is also kept in shared memory prior + // to matmul A @ B + bool ScaleOperandA_, + /// Max GEMM problem size in K dimension + int MaxK, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Transformation applied to B operand + typename TransformB_ = NumericArrayConverter< + typename SmemIteratorB_::Element, + typename IteratorB_::Element, + IteratorB_::Fragment::kElements>, + /// Used for partial specialization + typename Enable = bool> +class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory< + Shape_, + MaxK, + Policy_, + 2, + typename WarpIteratorA_::Layout> { + public: + ///< Base class + using Base = MmaBaseFromSharedMemory< + Shape_, + MaxK, + Policy_, + 2, + typename WarpIteratorA_::Layout>; + + using Shape = + Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + static constexpr bool ScaleOperandA = ScaleOperandA_; + + using WarpIteratorA = WarpIteratorA_; + ///< loads fragments of A_scale from shared memory if operand A scaling is + ///< enabled. otherwise no-op. + using WarpIteratorAScale = typename cutlass::platform::conditional< + ScaleOperandA, + WarpIteratorA, + NoOpWarpIteratorScale>::type; + + using IteratorB = + IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using SmemIteratorB = SmemIteratorB_; + + using TransformB = TransformB_; + + // + // Dependent types + // + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // statically assert kStages for MmaPipelined is two (Double-buffered pipeline) + static_assert( + (Base::kStages == 2), + "MmaPipelined requires kStages set to value 2"); + + private: + using WarpFragmentA = typename Operator::FragmentA; + + /// fragment type of OperandA elementwise scaling matrix. (almost) empty + /// if operand A scaling is disabled. + using WarpFragmentAScale = typename WarpIteratorAScale::Fragment; + + using WarpFragmentB = typename Operator::FragmentB; + + /// applies scaling factor to operand A fragment if operand A scaling is + /// enabled. otherwise no-op. + using FragmentAScaler = FragmentElementwiseScaler< + WarpFragmentA, + WarpFragmentAScale, + ScaleOperandA>; + + protected: + // /// Iterator to write threadblock-scoped tile of A operand to shared memory + // SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to load a warp-scoped tile of A operand from intermediate + /// accumulator tile + WarpIteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of A_scale from intermediate + /// accumulator tile (only used if ScaleOperandA_ is true) + WarpIteratorAScale warp_tile_iterator_A_scale_; + + public: + /// constructor for MMA with operand A scaling enabled. + CUTLASS_DEVICE + MmaPipelinedFromSharedMemory( + typename Base::TensorRefA a, // Operand A in shared memory + typename Base::TensorRefA a_scale, // Operand A_scale in shared memory + typename Base::TensorRefB + b_staging, // staging memory for loading tiles of B + int thread_idx, + int warp_idx, + int lane_idx) + : Base(b_staging, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A_(a, lane_idx), + warp_tile_iterator_A_scale_(a_scale, lane_idx), + smem_iterator_B_(b_staging, thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_A_scale_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + /// Construct from tensor references + CUTLASS_DEVICE + MmaPipelinedFromSharedMemory( + typename Base::TensorRefA a, ///< Operand A in shared memory + typename Base::TensorRefB b_staging, ///< staging memory for loading B + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx) ///< ID of each thread within a warp + : Base(b_staging, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A_(a, lane_idx), + smem_iterator_B_(b_staging, thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + // For API compatibility with MmaMultistageFromSharedMemory + // but not supported as it worsens perf: older gpus < sm80 don't + // support async transfers and have to waste registers + CUTLASS_DEVICE + void set_prologue_done(bool value) {} + CUTLASS_DEVICE + static void prologue( + typename Base::SharedStorage& shared_storage, + IteratorB iterator_B1, + int thread_idx, + int problem_size_0_n) {} + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + // IteratorA iterator_A, ///< iterator over A + // operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + FragmentC const& src_accum, ///< source accumulator tile + // TransformA transform_A = TransformA(), ///< transformation + // applied to A fragment + TransformB transform_B = + TransformB()) { ///< transformation applied to B fragment + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentB tb_frag_B; + + tb_frag_B.clear(); + + // The last kblock is loaded in the prolog + iterator_B.set_residual_tile(gemm_k_iterations == 1); + iterator_B.load(tb_frag_B); + + ++iterator_B; + + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + ++this->smem_iterator_B_; + + __syncthreads(); + + // remember that WarpFragmentAScale and WarpIteratorAScale are empty/no-op + // if scaling is disabled. + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentAScale warp_frag_A_scale[2]; + WarpFragmentB warp_frag_B[2]; + warp_frag_A[0].clear(); + warp_frag_A_scale[0].clear(); + warp_frag_B[0].clear(); + + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_A_scale_.load(warp_frag_A_scale[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_A_scale_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_B.set_residual_tile(gemm_k_iterations == 2); + iterator_B.clear_mask(gemm_k_iterations <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* + // issuing shared memory loads (which have the tightest latency + // requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + bool hasNext = true; + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + if (gemm_k_iterations > 1) { + // Write fragments to shared memory + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + } + + __syncthreads(); + + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory SMEM: Don't reset iterator A, as + // we are continuing our iteration at this point + if (smem_write_stage_idx == 1) { + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } else { + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + } + + smem_write_stage_idx ^= 1; + hasNext = gemm_k_iterations > 1; + } + + // Only read the next if we need to + if (hasNext) { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_A_scale_.load( + warp_frag_A_scale[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_A_scale_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k == 0) { + iterator_B.load(tb_frag_B); + + ++iterator_B; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_B.set_residual_tile(gemm_k_iterations == 3); + iterator_B.clear_mask(gemm_k_iterations <= 2); + } + } + + warp_mma( + accum, + FragmentAScaler::apply( + warp_frag_A[warp_mma_k % 2], warp_frag_A_scale[warp_mma_k % 2]), + warp_frag_B[warp_mma_k % 2], + accum); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Taken from +// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape1_, + /// Iterates over the intermediate accumulator tile in shared memory + typename WarpIteratorA1_, + /// whether or not to perform elementwise multiplication of A + // by another matrix (A_scale) that is also kept in shared memory prior + // to matmul A @ B + bool ScaleOperandA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB1_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB1_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB1, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy1_, + /// Number of stages, + int Stages_, + int kMaxK_, + /// Used for partial specialization + typename Enable = bool> +class MmaMultistageFromSharedMemory : public MmaBaseFromSharedMemory< + Shape1_, + kMaxK_, + Policy1_, + Stages_, + typename WarpIteratorA1_::Layout> { + public: + ///< Base class + using Base = MmaBaseFromSharedMemory< + Shape1_, + kMaxK_, + Policy1_, + Stages_, + typename WarpIteratorA1_::Layout>; + + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape1 = Shape1_; + ///< Iterates over tiles of B operand in global memory + using IteratorB1 = IteratorB1_; + using IteratorB = IteratorB1; + ///< Policy describing tuning details + using Policy1 = Policy1_; + + using SmemIteratorB1 = SmemIteratorB1_; + using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate + ///< accumulator tile in shared memory + static constexpr bool ScaleOperandA = ScaleOperandA_; + + ///< warp level iterator over A_scale matrix tile kept in shared memory. + ///< if elementwise A scaling is disabled then everything this does is no-op. + using WarpIteratorAScale = typename cutlass::platform::conditional< + ScaleOperandA, + WarpIteratorA1, + NoOpWarpIteratorScale>::type; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1; + static constexpr bool kSmemContainsEntireB = Base::kSmemContainsEntireB; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC1 = typename Policy1::Operator::FragmentC; + using FragmentC = FragmentC1; + + /// Warp-level Mma + using Operator1 = typename Policy1::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on B operand + static ComplexTransform const kTransformB1 = Operator1::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + static_assert( + Base::kWarpGemmIterations1 > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand B + static int const TBLoadIterationsB1 = + IteratorB1::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB1 = + (TBLoadIterationsB1 + Base::kWarpGemmIterations1 - 1) / + Base::kWarpGemmIterations1; + }; + + static constexpr int kNumStagesConcurrentLoad = + kSmemContainsEntireB ? Base::kStages : Base::kStages - 1; + + private: + using WarpLoadedFragmentA1 = typename Operator1::FragmentA; + /// fragment of OperandA scale matrix. if operand A scaling is disabled this + /// is (almost) empty. + using WarpLoadedFragmentA1Scale = typename WarpIteratorAScale::Fragment; + using WarpLoadedFragmentB1 = typename Operator1::FragmentB; + using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA; + using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; + + /// applies elementwise scaling to fragment of A. if operand A scaling is + /// disabled this is a no-op. + using FragmentAScaler = FragmentElementwiseScaler< + WarpLoadedFragmentA1, + WarpLoadedFragmentA1Scale, + ScaleOperandA>; + + private: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A1 operand from intermediate + /// accumulator tile + WarpIteratorA1 warp_tile_iterator_A1_; + + /// Iterator to load a warp-scoped tile of A1_scale operand from shared memory + /// if operand A scaling is disabled everything this does is a no-op. + WarpIteratorAScale warp_tile_iterator_A1_scale_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB1 smem_iterator_B1_; + + bool prologue_done_; + + public: + /// constructor for MMA with operand A scaling enabled. + CUTLASS_DEVICE + MmaMultistageFromSharedMemory( + typename Base::TensorRefA a, + typename Base::TensorRefA a_scale, + typename Base::TensorRefB b_tile, + int thread_idx, + int warp_idx, + int lane_idx) + : Base(b_tile, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A1_(a, lane_idx), + warp_tile_iterator_A1_scale_(a_scale, lane_idx), + smem_iterator_B1_(b_tile, thread_idx), + prologue_done_(false) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + int warp_idx_mn_1 = + warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; + int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; + + // Add per-warp offsets in units of warp-level tiles + warp_tile_iterator_A1_.add_tile_offset( + {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); + warp_tile_iterator_A1_scale_.add_tile_offset( + {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1}); + } + + /// Construct from tensor references + CUTLASS_DEVICE + MmaMultistageFromSharedMemory( + typename Base::TensorRefA a, + typename Base::TensorRefB b_tile, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(b_tile, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A1_(a, lane_idx), + smem_iterator_B1_(b_tile, thread_idx), + prologue_done_(false) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn_1 = + warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); + + int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; + int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; + + // Add per-warp offsets in units of warp-level tiles + warp_tile_iterator_A1_.add_tile_offset( + {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1}); + } + + CUTLASS_DEVICE + void set_prologue_done(bool value) { + prologue_done_ = value; + } + + CUTLASS_DEVICE + static void prologue( + typename Base::SharedStorage& shared_storage, + IteratorB iterator_B1, + int thread_idx, + int problem_size_0_n) { + SmemIteratorB1 smem_iterator_B1(shared_storage.operand_B_ref(), thread_idx); + _prologue( + iterator_B1, + (problem_size_0_n + Base::Shape::kK - 1) / Base::Shape::kK, + smem_iterator_B1); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance_1( + IteratorB1& iterator_B1, + int group_start_B1 = 0) { + iterator_B1.set_iteration_index( + group_start_B1 * IteratorB1::kAccessesPerVector); + this->smem_iterator_B1_.set_iteration_index(group_start_B1); + + // Load for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) { + if (group_start_B1 + j < Detail::TBLoadIterationsB1) { + typename IteratorB1::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B1.get(); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B1.valid()); + + ++iterator_B1; + } + ++this->smem_iterator_B1_; + } + } + } + + CUTLASS_DEVICE + static void _prologue( + IteratorB& iterator_B1, + int32_t gemm_k_iterations_1, + SmemIteratorB1& smem_iterator_B1_) { + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < kNumStagesConcurrentLoad; + ++stage, --gemm_k_iterations_1) { + iterator_B1.set_residual_tile(gemm_k_iterations_1 == 1); + iterator_B1.clear_mask(gemm_k_iterations_1 == 0); + + iterator_B1.set_iteration_index(0); + smem_iterator_B1_.set_iteration_index(0); + + // Load for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLoadIterationsB1; ++j) { + typename IteratorB1::AccessType* dst_ptr = + reinterpret_cast( + smem_iterator_B1_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B1.get(), iterator_B1.valid()); + + ++iterator_B1; + } + + ++smem_iterator_B1_; + } + + // Move to the next stage + iterator_B1.add_tile_offset({1, 0}); + + smem_iterator_B1_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + iterator_B1.set_residual_tile(gemm_k_iterations_1 == 1); + iterator_B1.clear_mask(gemm_k_iterations_1 == 0); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations_1_, + ///< destination accumulator tile + FragmentC1& accum, + ///< iterator over B1 operand in global memory + IteratorB1 iterator_B1, + ///< initial value of accumulator + FragmentC1 const& src_accum) { + // 2nd Gemm + + // + // Prologue + // + // Perform accumulation in the 'd' output operand + accum = src_accum; + + if (!prologue_done_) { + _prologue(iterator_B1, gemm_k_iterations_1_, smem_iterator_B1_); + } else if (!kSmemContainsEntireB) { + // Restore the iterators increments + + int gemm_k_iterations_1 = gemm_k_iterations_1_; + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < kNumStagesConcurrentLoad; + ++stage, --gemm_k_iterations_1) { + iterator_B1.set_iteration_index(0); + this->smem_iterator_B1_.set_iteration_index(0); + + // Load for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLoadIterationsB1; ++j) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + ++iterator_B1; + } + ++this->smem_iterator_B1_; + } + iterator_B1.add_tile_offset({1, 0}); + this->smem_iterator_B1_.add_tile_offset({1, 0}); + } + iterator_B1.set_residual_tile(gemm_k_iterations_1 <= 1); + iterator_B1.clear_mask(gemm_k_iterations_1 <= 0); + } + + // DEPBAR+SYNC + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // remember that WarpFragmentAScale and WarpIteratorAScale are no-op/empty + // if scaling is disabled. + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA1 warp_loaded_frag_A1[2]; + WarpLoadedFragmentA1Scale warp_loaded_frag_A1_scale[2]; + WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; + WarpTransformedFragmentA1 warp_transformed_frag_A1[2]; + WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; + + Operator1 warp_mma1; + + warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0]); + ++warp_tile_iterator_A1_; + + warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[0]); + ++warp_tile_iterator_A1_scale_; + + this->warp_tile_iterator_B_.set_kgroup_index(0); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B1[0]); + ++this->warp_tile_iterator_B_; + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma1.transform( + warp_transformed_frag_A1[0], + warp_transformed_frag_B1[0], + FragmentAScaler::apply( + warp_loaded_frag_A1[0], warp_loaded_frag_A1_scale[0]), + warp_loaded_frag_B1[0]); + + // tf32x3 kernels use staging accumulation. warp_mma uses a temporary + // accumulator and this temporary accumulator is added to the final + // accumulator once in every mainloop iteration. + plus plus_accum; + + FragmentC1 tmp_accum; + + if (platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + tmp_accum.clear(); + } + + // + // Mainloop + // + + CUTLASS_PRAGMA_UNROLL + for (int gemm_k_iterations_1 = gemm_k_iterations_1_ - (Base::kStages - 1); + gemm_k_iterations_1 > (-Base::kStages + 1); + gemm_k_iterations_1--) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; + ++warp_mma_k) { + // Load warp-level tile from accumulator fragment (A) + // or shared memory (operand B) + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations1); + // skip warp tile loading for the last kgroup (we are out of the buf) + if (gemm_k_iterations_1 > (-Base::kStages + 2) || + warp_mma_k < Base::kWarpGemmIterations1 - 1) { + warp_tile_iterator_A1_.load( + warp_loaded_frag_A1[(warp_mma_k + 1) % 2]); + warp_tile_iterator_A1_scale_.load( + warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load( + warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + } + ++warp_tile_iterator_A1_; + ++warp_tile_iterator_A1_scale_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k > 0) + warp_mma1.transform( + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + FragmentAScaler::apply( + warp_loaded_frag_A1[warp_mma_k % 2], + warp_loaded_frag_A1_scale[warp_mma_k % 2]), + warp_loaded_frag_B1[warp_mma_k % 2]); + + if (platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + warp_mma1( + tmp_accum, + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + tmp_accum); + + if (warp_mma_k == 0) { + accum = plus_accum(accum, tmp_accum); + tmp_accum.clear(); + } + } else { + warp_mma1( + accum, + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + accum); + } + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations1 - 1) { + int group_start_iteration_B1; + + group_start_iteration_B1 = warp_mma_k * Detail::kAccessesPerGroupB1; + + if (!kSmemContainsEntireB) { + copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); + } + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations1) { + int group_start_iteration_B1; + group_start_iteration_B1 = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB1; + + if (!kSmemContainsEntireB) { + copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); + } + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_B1.add_tile_offset({1, 0}); + + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (!kSmemContainsEntireB) { + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy1::kPartitionsK * + Base::kWarpGemmIterations1, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + } + + iterator_B1.set_residual_tile(gemm_k_iterations_1 == 2); + iterator_B1.clear_mask(gemm_k_iterations_1 == 1); + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations1) + warp_mma1.transform( + warp_transformed_frag_A1[(warp_mma_k + 1) % 2], + warp_transformed_frag_B1[(warp_mma_k + 1) % 2], + FragmentAScaler::apply( + warp_loaded_frag_A1[(warp_mma_k + 1) % 2], + warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]), + warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + } + } + + if (platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + accum = plus_accum(accum, tmp_accum); + } + } +}; + +// Converts a "regular" Mma into their counterpart from shared memory +template < + typename Mma_, + int kMaxK, + typename WarpIteratorA_, + /// whether or not to apply elementwise multiplication of operand A by + /// another matrix in shared memory before usage in A @ B + bool kScaleOperandA, + bool kTransposeA = false> +struct DefaultMmaFromSharedMemory; + +// Mma pipelined +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + typename WarpIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Transformation applied to A operand + typename TransformA_, + /// Transformation applied to B operand + typename TransformB_, + // Max MMA problem size K + int kMaxK, + /// whether or not to apply elementwise multiplication of operand A by + /// another matrix in shared memory before usage in A @ B + bool kScaleOperandA, + bool kTransposeA> +struct DefaultMmaFromSharedMemory< + MmaPipelined< + Shape_, + IteratorA_, + SmemIteratorA_, + IteratorB_, + SmemIteratorB_, + ElementC_, + LayoutC_, + Policy_, + TransformA_, + TransformB_>, + kMaxK, + WarpIteratorA_, + kScaleOperandA, + kTransposeA> { + using RegularMma = MmaPipelined< + Shape_, + IteratorA_, + SmemIteratorA_, + IteratorB_, + SmemIteratorB_, + ElementC_, + LayoutC_, + Policy_, + TransformA_, + TransformB_>; + + using WarpShape = typename Policy_::Operator::Shape; + using InstructionShape = typename Policy_::Operator::InstructionShape; + using ArchMmaOperator = typename Policy_::Operator; + + static constexpr bool kIsTransposedA = false; + using WarpIteratorA = WarpIteratorA_; + using IteratorB = + typename cutlass::transform::threadblock::MakeIteratorResidualLast< + IteratorB_>::Iterator; + + using Mma = typename cutlass::gemm::threadblock::MmaPipelinedFromSharedMemory< + Shape_, + WarpIteratorA, + kScaleOperandA, + kMaxK, + IteratorB, + SmemIteratorB_, + ElementC_, + LayoutC_, + Policy_>; +}; + +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + typename WarpIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + int kMaxK, + /// whether or not to apply elementwise multiplication of operand A by + /// another matrix in shared memory before usage in A @ B + bool kScaleOperandA, + bool kTransposeA> +struct DefaultMmaFromSharedMemory< + MmaMultistage< + Shape_, + IteratorA_, + SmemIteratorA_, + CacheOpA, + IteratorB_, + SmemIteratorB_, + CacheOpB, + ElementC_, + LayoutC_, + Policy_, + Stages, + SharedMemoryClear>, + kMaxK, + WarpIteratorA_, + kScaleOperandA, + kTransposeA> { + using RegularMma = MmaMultistage< + Shape_, + IteratorA_, + SmemIteratorA_, + CacheOpA, + IteratorB_, + SmemIteratorB_, + CacheOpB, + ElementC_, + LayoutC_, + Policy_, + Stages, + SharedMemoryClear>; + + using WarpShape = typename Policy_::Operator::Shape; + using InstructionShape = typename Policy_::Operator::InstructionShape; + using WarpIteratorTranspose = TransposeWarpIterator; + static constexpr bool kIsTransposedA = + WarpIteratorTranspose::kSupportsTranspose && kTransposeA; + using WarpIteratorA = typename platform::conditional< + kIsTransposedA, + typename WarpIteratorTranspose::Iterator, + WarpIteratorA_>::type; + + // Reduce the number of stages if we don't need that many + static int constexpr kStagesMax = + (kMaxK + int(Shape_::kK) - 1) / int(Shape_::kK); + static int constexpr kStages = cutlass::const_min(Stages, kStagesMax); + + using IteratorB = + typename cutlass::transform::threadblock::MakeIteratorResidualLast< + IteratorB_>::Iterator; + using Mma = + typename cutlass::gemm::threadblock::MmaMultistageFromSharedMemory< + Shape_, + WarpIteratorA, + kScaleOperandA, + IteratorB, + SmemIteratorB_, + RegularMma::kCacheOpB, + ElementC_, + LayoutC_, + Policy_, + kStages, + kMaxK>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename IteratorC, + typename Operator, + typename scalar_t, + typename WarpShape_, + typename ThreadblockShape_> +struct B2bGemm; + +// Tensor Cores >= Sm75 specialization (Ampere ...) +template < /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Element type + typename Element_, + /// Layout of operand in memory + typename Layout_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_, + /// Interval between adjacent *MMA instructions (in units of MMA + /// instructions, concept: MatrixShape) + typename OpDelta_, + typename Operator, + typename scalar_t, + typename WarpShape_, + typename ThreadblockShape_> +struct B2bGemm< + cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< + Shape_, + Element_, + Layout_, + InstructionShape_, + OpDelta_>, + Operator, + scalar_t, + WarpShape_, + ThreadblockShape_> { + using IteratorC = + typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< + Shape_, + Element_, + Layout_, + InstructionShape_, + OpDelta_>; + using FragmentC = typename IteratorC::Fragment; + using InstructionShape = InstructionShape_; + using WarpShape = WarpShape_; + using ThreadblockShape = ThreadblockShape_; + using accum_t = Element_; + using lse_scalar_t = float; + + using SmemAccumulatorLayout = cutlass::layout::RowMajor; + + // Iterator to load accumulators (results of matmul in registers) + using FragmentIteratorAccumulator = + cutlass::epilogue::warp::FragmentIteratorTensorOp< + WarpShape, + InstructionShape, + accum_t, + typename Operator::Policy::Operator::FragmentC, + cutlass::layout::RowMajor>; + + // Iterator to store to shared-memory + using SmemIteratorD0 = typename cutlass::epilogue::warp::TileIteratorTensorOp< + WarpShape, + InstructionShape, + scalar_t, // accum_t, + SmemAccumulatorLayout>; + using AccumulatorSharedStorage = + cutlass::gemm::threadblock::AccumulatorSharedStorage< + ThreadblockShape, + typename SmemIteratorD0::Element, + typename SmemIteratorD0::TensorLayout, + typename SmemIteratorD0::Padding>; + // We need to provide an operation for the epilogue. Let's create an + // operation that does nothing (ScaleType::Nothing), just converts + // from accum_t (float) -> scalar_t (can be half) + using OutputOpNoOp = cutlass::epilogue::thread::LinearCombination< + typename SmemIteratorD0::Element, // ElementOutput + FragmentIteratorAccumulator::Fragment::kElements, + accum_t, // ElementAccumulator + typename SmemIteratorD0::Element, // ElementCompute + cutlass::epilogue::thread::ScaleType::Nothing>; + using Epilogue = cutlass::epilogue::threadblock::EpilogueSmemAccumulator< + SmemIteratorD0, + FragmentIteratorAccumulator, + SmemIteratorD0, // ScaleBiasIterator - not used + OutputOpNoOp>; + + // Epilogue 2: with LSE (for backwards pass) + static int const kElementsPerAccess = 2; // TODO: Why 2? + using IteratorAccumulatorLSE = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + // Shape + cutlass::MatrixShape, + // WarpShape + cutlass::MatrixShape, + lse_scalar_t, + cutlass::layout::RowMajor, + kElementsPerAccess>>; + using EpilogueOpApplyLSE = cutlass::epilogue::thread::ApplyLogSumExp< + scalar_t, // ElementOutput_ + lse_scalar_t, // ElementLSE_ + accum_t, // ElementAccumulator_ + accum_t, // ElementCompute_ + 128 / cutlass::sizeof_bits::value + // FragmentIteratorAccumulator::Fragment::kElements + // InstructionShape::kM * InstructionShape::kN / 32 + >; + using EpilogueWithLSE = + cutlass::epilogue::threadblock::EpilogueSmemAccumulator< + SmemIteratorD0, + FragmentIteratorAccumulator, + IteratorAccumulatorLSE, + EpilogueOpApplyLSE>; + + static void CUTLASS_DEVICE accumToSmem( + AccumulatorSharedStorage& shared_storage, + FragmentC const& accum, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + SmemIteratorD0 smem_iterator_attn(shared_storage.accum_ref(), lane_id); + smem_iterator_attn.add_tile_offset( + tile_coords * + cutlass::MatrixCoord{ + SmemIteratorD0::TileIterations::kRow, + SmemIteratorD0::TileIterations::kColumn}); + Epilogue epilogue; + epilogue(OutputOpNoOp({}), smem_iterator_attn, accum); + } + + static void CUTLASS_DEVICE accumApplyLSEToSmem( + AccumulatorSharedStorage& shared_storage, + FragmentC& accum, + lse_scalar_t const* lse, + int32_t lse_extents, + int thread_id, + int warp_id, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + constexpr int32_t kAlignLSE = 32; + IteratorAccumulatorLSE iterator_lse( + lse, + {(int32_t)0, (int32_t)ceil_div(lse_extents, kAlignLSE) * kAlignLSE}, + thread_id, + warp_id, + cutlass::MatrixCoord{0, 0} // offset + ); + + SmemIteratorD0 smem_iterator_attn(shared_storage.accum_ref(), lane_id); + smem_iterator_attn.add_tile_offset( + tile_coords * + cutlass::MatrixCoord{ + SmemIteratorD0::TileIterations::kRow, + SmemIteratorD0::TileIterations::kColumn}); + EpilogueWithLSE epilogue; + EpilogueOpApplyLSE minus_lse_exp({}); + epilogue( + minus_lse_exp, + smem_iterator_attn, + accum, + // scale - unused + iterator_lse, + // bias + iterator_lse); + } +}; + +// Volta Specialization +// only supported for f16 +template +struct B2bGemm< + cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + cutlass::MatrixShape<32, 32>, + float, + cutlass::layout::RowMajor, + cutlass::gemm::GemmShape<16, 16, 4>, + cutlass::MatrixShape<1, 1>>, + Operator, + cutlass::half_t, + WarpShape_, + ThreadblockShape_> { + using IteratorC = + cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + cutlass::MatrixShape<32, 32>, + float, + cutlass::layout::RowMajor, + cutlass::gemm::GemmShape<16, 16, 4>, + cutlass::MatrixShape<1, 1>>; + using scalar_t = cutlass::half_t; + using accum_t = IteratorC::Element; + using WarpShape = WarpShape_; + using ThreadblockShape = ThreadblockShape_; + using FragmentC = IteratorC::Fragment; + using lse_scalar_t = float; + + // Storage in shared-memory for Q.Kt + using SmemAccumulatorLayout = + cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>; + using AccumulatorSharedStorage = + cutlass::gemm::threadblock::AccumulatorSharedStorage< + ThreadblockShape, + scalar_t, + SmemAccumulatorLayout, + cutlass::MatrixShape<0, 0> // Padding + >; + using TensorRef = cutlass::TensorRef; + using Policy = typename IteratorC::Policy; + using Element = accum_t; + // Those are MmaVoltaTensorOpAccumulatorTileIterator private fields + // Let's copy their values + static int const kElementsPerPartial = 4; + using EleShapePerPatial = typename cutlass::platform::conditional< + cutlass::platform::is_same::value, + cutlass::MatrixShape<2, 2>, + cutlass::MatrixShape<1, 4>>::type; + static int const kElementsPerMma = 8; + static int const kAccumulatorPatials = 2; + using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>; + + static void CUTLASS_DEVICE accumToSmem( + AccumulatorSharedStorage& shared_storage, + FragmentC const& accum, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + // ctor - from MmaVoltaTensorOpAccumulatorTileIterator + TensorRef ref_(shared_storage.accum_ref()); + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + int accum_m, accum_n; + + if (cutlass::platform::is_same::value) { + // (quad[2],quad[0])+lane_in_quad[0] + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1); + // (quad[1])+lane_in_quad[1] + accum_n = + ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials + + (lane_in_quad & 2); + } else { + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + + lane_in_quad; // (quad[2],quad[0]) + accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials; + } + cutlass::MatrixCoord lane_offset(accum_m, accum_n); + + // Tile offset + ref_.add_coord_offset( + tile_coords * + cutlass::MatrixCoord( + {IteratorC::Shape::kRow, IteratorC::Shape::kColumn})); + + using AccessType = cutlass::Array; + + // store - from MmaVoltaTensorOpAccumulatorTileIterator + CUTLASS_PRAGMA_UNROLL + for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) { + CUTLASS_PRAGMA_UNROLL + for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + int mma_accum_start = + (((tile_n * Policy::TileIterations::kRow + tile_m) * + Policy::MmaIterations::kColumn + + mma_n) * + Policy::MmaIterations::kRow + + mma_m) * + kElementsPerMma; + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < kAccumulatorPatials; ++p) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < EleShapePerPatial::kRow; ++m) { + int accum_m = tile_m * Policy::InterleavedTile::kRow + + mma_m * QuadShapePerPatialMma::kRow + m * 2; + int accum_n = tile_n * Policy::InterleavedTile::kColumn + + mma_n * QuadShapePerPatialMma::kColumn + + p * Policy::InterleavedTile::kColumn / 2; + int r = (accum_m + lane_offset.row()); + AccessType to_store; + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < EleShapePerPatial::kColumn; ++n) { + int idx = mma_accum_start + p * kElementsPerPartial + + m * EleShapePerPatial::kColumn + n; + int c = (accum_n + n + lane_offset.column()); + to_store[n] = scalar_t(accum[idx]); + } + int c = (accum_n + lane_offset.column()); + assert(r < 32); + assert(c < 32); + *reinterpret_cast( + ref_.data() + ref_.offset({r, c})) = to_store; + } + } + } + } + } + } + } + + static void CUTLASS_DEVICE accumApplyLSEToSmem( + AccumulatorSharedStorage& shared_storage, + typename IteratorC::Fragment& accum, + lse_scalar_t const* lse, + int lse_extent, + int thread_id, + int warp_id, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + // Non-optimized way to apply LSE to registers + // NOTE: accum is attn.T + // TODO: Optimize for each architecture + static constexpr int WarpSize = 32; + using AccumLambdaIterator = + typename DefaultMmaAccumLambdaIterator:: + Iterator; + auto lane_offset = + AccumLambdaIterator::get_lane_offset(lane_id, warp_id, tile_coords); + + cutlass::Array lse_prefetched; + lse_prefetched.clear(); + int rowIdx = 0; + int colIdx = 0; + AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + ++rowIdx; + colIdx = 0; + }, + [&](int accum_m, int accum_n, int idx) { + if (rowIdx == 1) { + lse_prefetched[colIdx] = accum_n < lse_extent + ? lse[accum_n] + : platform::numeric_limits::infinity(); + } + accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]); + ++colIdx; + }, + [&](int accum_m) {}); + accumToSmem(shared_storage, accum, lane_id, tile_coords); + } +}; + +// Simt Specialization +// for f32 on Sm70-Sm75 and f16/f32 below + +template < + typename Operator, + typename OperatorPolicy, + typename scalar_t, + typename WarpShape_, + typename ThreadblockShape_> +struct B2bGemm< + cutlass::gemm::warp::MmaSimtTileIterator< + cutlass::MatrixShape<32, 32>, + cutlass::gemm::Operand::kC, + float, + cutlass::layout::RowMajor, + OperatorPolicy, + 1, + 1>, + Operator, + scalar_t, + WarpShape_, + ThreadblockShape_> { + using IteratorC = cutlass::gemm::warp::MmaSimtTileIterator< + cutlass::MatrixShape<32, 32>, + cutlass::gemm::Operand::kC, + float, + cutlass::layout::RowMajor, + OperatorPolicy, + 1, + 1>; + using accum_t = typename IteratorC::Element; + using WarpShape = WarpShape_; + using ThreadblockShape = ThreadblockShape_; + using FragmentC = typename IteratorC::Fragment; + using lse_scalar_t = float; + + // Storage in shared-memory for Q.Kt + using AccumulatorSharedStorage = + cutlass::gemm::threadblock::AccumulatorSharedStorage< + ThreadblockShape, + scalar_t, + cutlass::layout::ColumnMajor, + cutlass::MatrixShape<0, 0> // Padding + >; + + static void CUTLASS_DEVICE accumToSmem( + AccumulatorSharedStorage& shared_storage, + FragmentC const& accum, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + using Policy = typename IteratorC::Policy; + using Element = typename IteratorC::Element; + using Iterations = typename IteratorC::Iterations; + using Delta = typename IteratorC::Delta; + + auto ref_ = shared_storage.accum_ref(); + // ctor - MmaSimtTileIterator + // compute offset based on thread ID and lane layout + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + MatrixCoord lane_offset = lane_layout.inverse(lane_id) * + MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN); + + ref_.add_coord_offset(lane_offset); + + // Tile offset + ref_.add_coord_offset( + tile_coords * + cutlass::MatrixCoord( + {IteratorC::Shape::kRow, IteratorC::Shape::kColumn})); + + // store - MmaSimtTileIterator + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) { + int r = + Policy::LaneMmaShape::kM * (mma_m * Policy::WarpShape::kRow) + + m; + int c = mma_n * Delta::kColumn + n; + int idx = n + + Policy::LaneMmaShape::kN * + (mma_n + + Iterations::kColumn * + (m + mma_m * Policy::LaneMmaShape::kM)); + ref_.at({r, c}) = scalar_t(accum[idx]); + } + } + } + } + } + + static void CUTLASS_DEVICE accumApplyLSEToSmem( + AccumulatorSharedStorage& shared_storage, + typename IteratorC::Fragment& accum, + lse_scalar_t const* lse, + int lse_extent, + int thread_id, + int warp_id, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + // Non-optimized way to apply LSE to registers + // NOTE: accum is attn.T + // TODO: Optimize for each architecture + static constexpr int WarpSize = 32; + using AccumLambdaIterator = + typename DefaultMmaAccumLambdaIterator:: + Iterator; + auto lane_offset = + AccumLambdaIterator::get_lane_offset(lane_id, warp_id, tile_coords); + + cutlass::Array lse_prefetched; + lse_prefetched.clear(); + int rowIdx = 0; + int colIdx = 0; + AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + ++rowIdx; + colIdx = 0; + }, + [&](int accum_m, int accum_n, int idx) { + if (rowIdx == 1) { + lse_prefetched[colIdx] = accum_n < lse_extent + ? lse[accum_n] + : platform::numeric_limits::infinity(); + } + accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]); + ++colIdx; + }, + [&](int accum_m) {}); + accumToSmem(shared_storage, accum, lane_id, tile_coords); + } +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm_kernel_utils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm_kernel_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..21f1bd0eed808b868918d3983e738e547e7e0470 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm_kernel_utils.h @@ -0,0 +1,209 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include + +//////////////////////////////////////////////////////////////////////////////// +// Some helper functions +//////////////////////////////////////////////////////////////////////////////// +#define DISPATCH_TYPES(tensor, func) \ + { \ + if (query.scalar_type() == at::ScalarType::Float) { \ + using scalar_t = float; \ + func(); \ + } else if (query.scalar_type() == at::ScalarType::Half) { \ + using scalar_t = cutlass::half_t; \ + func(); \ + } else if (query.scalar_type() == at::ScalarType::BFloat16) { \ + using scalar_t = cutlass::bfloat16_t; \ + func(); \ + } else { \ + TORCH_CHECK(false, "Only fp32, half & bf16 supported at the moment"); \ + } \ + } + +#define DISPATCH_BOOL(BOOL_V, BOOL_NAME, F) \ + { \ + if (BOOL_V) { \ + constexpr bool BOOL_NAME = true; \ + F(); \ + } else { \ + constexpr bool BOOL_NAME = false; \ + F(); \ + } \ + } +#define DISPATCH_ARCHTAG(CC, func) \ + { \ + if (CC >= 80) { \ + using ArchTag = cutlass::arch::Sm80; \ + func(); \ + } else if (CC >= 75) { \ + using ArchTag = cutlass::arch::Sm75; \ + func(); \ + } else if (CC >= 70) { \ + using ArchTag = cutlass::arch::Sm70; \ + func(); \ + } else if (CC >= 50) { \ + using ArchTag = cutlass::arch::Sm50; \ + func(); \ + } else { \ + TORCH_CHECK( \ + false, \ + "Your device is too old. We require compute capability >= 50"); \ + } \ + } + +#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \ + TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + TORCH_CHECK(TENSOR.is_contiguous()); + +#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \ + TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + TORCH_CHECK( \ + TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous"); + +#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ + TORCH_CHECK( \ + uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned") + +#define ASSIGN_CHECK_OVERFLOW(A, B) \ + { \ + A = B; \ + TORCH_CHECK( \ + B < std::numeric_limits::max(), #B " overflows"); \ + } + +namespace gemm_kernel_utils { + +template +constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m) { + return (n + m - 1) / m; +} + +template +constexpr CUTLASS_HOST_DEVICE integer align_up(integer n, integer m) { + return ((n + m - 1) / m) * m; +} + +//////////////////////////////////////////////////////////////////////////////// +// Determine the type of GEMM we do (TensorCores or not, Shapes ...) +// TODO: Maybe we could rely on Cutlass's DefaultGemm templates +//////////////////////////////////////////////////////////////////////////////// + +// Fallback to Simt (FMA on cuda cores) if not in a special case below +template +struct DefaultGemmType { + static constexpr int ThreadK = 8; + static constexpr int WarpK = 8; + static constexpr int kMinimumAlignment = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using OpClass = cutlass::arch::OpClassSimt; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Specialization for tensorcores with f32 +template +struct DefaultGemmType< + ArchTag, + float, + typename cutlass::platform::enable_if< + ArchTag::kMinComputeCapability >= 80>::type> { + static constexpr int ThreadK = 32; + static constexpr int WarpK = 32; + static constexpr int kMinimumAlignment = 4; + using OpClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Operator = cutlass::arch::OpMultiplyAddFastF32; +}; + +// Specialization for tensorcores with f16/bf16 - Sm75+ +template +struct DefaultGemmType< + ArchTag, + scalar_t, + typename cutlass::platform::enable_if< + ArchTag::kMinComputeCapability >= 75 && + cutlass::sizeof_bits::value == 16>::type> { + static constexpr int ThreadK = 32; + static constexpr int WarpK = 32; + static constexpr int kMinimumAlignment = 4; + using OpClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Specialization for tensorcores with f16 - Volta +template <> +struct DefaultGemmType { + static constexpr int ThreadK = 32; + static constexpr int WarpK = 32; + static constexpr int kMinimumAlignment = 2; + using OpClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Enables to do +// `auto x = kCondition ? fa(arg) : fb(arg)` +// when `fa` and `fb` have different types +template +struct call_conditional; + +template +struct call_conditional { + template + static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg) + -> decltype(ta(arg)) { + return ta(arg); + } +}; + +template +struct call_conditional { + template + static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg) + -> decltype(tb(arg)) { + return tb(arg); + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Mark a variable as warp-uniform - enables some compiler optimizations +// The cheapest way to do it is just to broadcast it from lane 0 +//////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE T warp_uniform(T value) { + struct { + union { + T value; + uint32_t asInt; + }; + } p; + p.value = value; + p.asInt = __shfl_sync(0xffffffff, (unsigned)p.asInt, 0); + return p.value; +} + +template +CUTLASS_DEVICE T* warp_uniform(T* ptr) { + struct { + union { + T* ptr; + uint32_t asInt[2]; + }; + } p; + p.ptr = ptr; + p.asInt[0] = warp_uniform(p.asInt[0]); + p.asInt[1] = warp_uniform(p.asInt[1]); + return p.ptr; +} +} // namespace gemm_kernel_utils diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/default_warp_iterator_from_smem.h b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/default_warp_iterator_from_smem.h new file mode 100644 index 0000000000000000000000000000000000000000..699f62311fba3ae841ad99dd5bc8b2e114f7f1c3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/default_warp_iterator_from_smem.h @@ -0,0 +1,143 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Instantiates the right WarpIterator to read from shared memory + The class `DefaultWarpIteratorAFromSharedMemory` is useful when reading + data dumped with `B2bGemm::accumToSmem`. +*/ + +#pragma once + +#include +#include +#include + +#include + +namespace cutlass { +namespace gemm { +namespace threadblock { + +template < + typename WarpShape, + typename InstructionShape, + typename RegularWarpIterator, + typename Policy, + typename Enable = void> +struct DefaultWarpIteratorAFromSharedMemory {}; + +// TensorOp - Ampere half +template +struct DefaultWarpIteratorAFromSharedMemory< + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, kInstrK>, + RegularWarpIterator, + Policy, + typename platform::enable_if<( + sizeof_bits::value == 16 && + Policy::Operator::Policy::OpDelta::kRow == 1)>::type> { + using OpDelta = typename Policy::Operator::Policy::OpDelta; + using WarpShape = cutlass::MatrixShape<32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, kInstrK>; + + using WarpIterator = cutlass::gemm::warp::WarpIteratorFromSmem< + cutlass::gemm::Operand::kA, + typename RegularWarpIterator::Element, + cutlass::MatrixShape>; +}; + +// TensorOp - Ampere f32 +template +struct DefaultWarpIteratorAFromSharedMemory< + WarpShape, + cutlass::gemm::GemmShape<16, 8, 8>, + RegularWarpIterator, + Policy, + typename platform::enable_if<( + sizeof_bits::value != 16 || + Policy::Operator::Policy::OpDelta::kRow != 1)>::type> { + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + static constexpr auto kWarpSize = 32; + using OpDelta = typename Policy::Operator::Policy::OpDelta; + + using WarpIterator = + cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< + cutlass::MatrixShape, + cutlass::gemm::Operand::kA, + typename RegularWarpIterator::Element, + cutlass::layout::RowMajor, + cutlass::MatrixShape, + OpDelta::kRow, + kWarpSize>; +}; + +// TensorOp - Volta +template +struct DefaultWarpIteratorAFromSharedMemory< + WarpShape, + cutlass::gemm::GemmShape<16, 16, 4>, + RegularWarpIterator, + Policy> { + using InstructionShape = cutlass::gemm::GemmShape<16, 16, 4>; + static constexpr auto kWarpSize = 32; + using OpDelta = typename Policy::Operator::Policy::OpDelta; + + using WarpIterator = + cutlass::gemm::warp::MmaVoltaTensorOpMultiplicandTileIterator< + cutlass::MatrixShape<32, 32>, // MatrixShape, + cutlass::gemm::Operand::kA, + typename RegularWarpIterator::Element, + cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>, + cutlass::MatrixShape<16, 4>, + OpDelta::kRow, + kWarpSize>; +}; + +// Simt +template +struct DefaultWarpIteratorAFromSharedMemory< + WarpShape, + cutlass::gemm::GemmShape<1, 1, 1>, + RegularWarpIterator, + Policy> { + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + static constexpr auto kWarpSize = 32; + + // We just use the same iterator, as we reproduced the same shared-memory + // schema. Just modify it to handle non-complete tiles. + using WarpIterator = RegularWarpIterator; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/epilogue_predicated_tile_iterator.h b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/epilogue_predicated_tile_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..7e87205f0f1b439d8d671523bf1aa70d1900f6e0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/epilogue_predicated_tile_iterator.h @@ -0,0 +1,752 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue iterator that supports prefetching + + Mostly copied from +*/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////// + +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load and store output tile from global memory in +/// epilogue. +/// +/// Satisfies: ReadableTileIterator | PredicatedTileIterator | +/// ForwardTileIterator +/// +template < + typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) + typename Element_, ///< Element data type + bool ScatterD = false, ///< Scatter D operand or not + bool UseCUDAStore = false> +class PredicatedTileIteratorPrefetch { + public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; + + static_assert( + ThreadMap::Iterations::kRow > 0, + "ThreadMap::Iterations::kRow must be > 0"); + static_assert( + ThreadMap::Iterations::kGroup > 0, + "ThreadMap::Iterations::kGroup must be > 0"); + static_assert( + ThreadMap::Iterations::kCluster > 0, + "ThreadMap::Iterations::kCluster must be > 0"); + static_assert( + ThreadMap::Iterations::kColumn > 0, + "ThreadMap::Iterations::kColumn must be > 0"); + + /// Fragment object + using Fragment = Array< + Element, + ThreadMap::Iterations::kColumn * ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster * + ThreadMap::kElementsPerAccess>; + + /// Memory access size + using AccessType = AlignedArray; + + // + // Parameters struct + // + + /// Uses a non-template class + struct Params : PredicatedTileIteratorParams { + using Base = PredicatedTileIteratorParams; + + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : PredicatedTileIteratorParams( + layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc()) {} + + CUTLASS_HOST_DEVICE + Params(Base const& base) : Base(base) {} + }; + + /// Mask object + struct Mask { + static int const kCount = ThreadMap::Iterations::kColumn; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { + enable(); + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = false; + } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = true; + } + } + }; + + private: + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + PredicatedTileIteratorParams params_; + + /// Byte-level pointer + uint8_t* byte_pointer_; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index extent_row_; + + /// Extent of the matrix tile in rows + Index extent_column_; + + /// A thread's starting row position (assuming steady-state predicates have + /// been computed) + Index thread_start_row_; + + /// A thread's starting column + Index thread_start_column_; + + /// Internal state counter + int state_[3]; + + /// Scatter indices + int const* indices_; + + // + // Static asserts about internal strides + // + + static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + static_assert( + sizeof(PredicatedTileIteratorParams::stride) == 8, + "Expected 64b strides"); + + private: + // + // Methods + // + + public: + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + PredicatedTileIteratorPrefetch( + PredicatedTileIteratorParams const& params, + Element* pointer, + TensorCoord extent, + int thread_idx, + TensorCoord threadblock_offset = TensorCoord(), + int const* indices = nullptr) + : params_(params), indices_(indices) { + TensorCoord thread_offset = + ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + extent_row_ = extent.row(); + extent_column_ = extent.column(); + + thread_start_row_ = thread_offset.row(); + thread_start_column_ = thread_offset.column(); + + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { + mask_.predicates[c] = + ((thread_offset.column() + ThreadMap::Delta::kColumn * c) < + extent.column()); + } + + // Null pointer performs no accesses + if (!pointer) { + mask_.clear(); + } + + if (ScatterD && !indices) { + mask_.clear(); + } + + // Initialize pointer + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.row()) * LongIndex(params_.stride) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / + kElementsPerAccess; + + if (ScatterD) { + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / + kElementsPerAccess; + } + + // Initialize internal state counter + state_[0] = state_[1] = state_[2] = 0; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_DEVICE + void prefetch_all() { + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < kIterations; ++iter) { + prefetch(); + ++(*this); + } + } + + CUTLASS_DEVICE + void prefetch() { + uint8_t* byte_pointer = byte_pointer_; + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + // on windows using unsigned long here gives the error + // error: asm operand type size(4) does not match + // type/size implied by constraint 'l' + uint64_t addr = (uint64_t)((void*)&memory_pointer + [column * ThreadMap::Delta::kColumn / + kElementsPerAccess]); + asm volatile("prefetch.global.L1 [ %1 ];" : "=l"(addr) : "l"(addr)); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { + byte_pointer += params_.increment_row; + } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, int64_t byte_offset) const { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast( + byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * + LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr + [frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer + [column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { + byte_pointer += params_.increment_row; + } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) const { + load_with_byte_offset(frag, 0); + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, int64_t byte_offset) const { + uint8_t* byte_pointer = byte_pointer_; + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast( + byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * + LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + bool guard = row_guard && mask_.predicates[column]; + + if (UseCUDAStore) { + if (guard) { + memory_pointer + [column * ThreadMap::Delta::kColumn / kElementsPerAccess] = + frag_ptr + [frag_row_idx * ThreadMap::Iterations::kColumn + + column]; + } + } else { + cutlass::arch::global_store( + frag_ptr + [frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer + [column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { + byte_pointer += params_.increment_row; + } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) const { + store_with_byte_offset(frag, 0); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void downsample_load_with_byte_offset( + Fragment& frag, + int64_t byte_offset, + int convolution_P, + int convolution_Q, + int add_P, + int add_Q, + int problem_N) const { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + int output_row = row_offset + thread_start_row_; + int output_N = output_row / (convolution_P * convolution_Q); + int output_PQ = output_row % (convolution_P * convolution_Q); + int output_P = output_PQ / convolution_Q; + int output_Q = output_PQ % convolution_Q; + + int input_row = output_N * 2 * convolution_P * 2 * convolution_Q + + (2 * output_P + add_P) * 2 * convolution_Q + 2 * output_Q + add_Q; + + int64_t byte_offset = + (input_row - output_row) * problem_N * sizeof(float); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr + [frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer + [column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + byte_pointer += params_.increment_row; + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void upsample_load_with_byte_offset( + Fragment& frag, + int64_t byte_offset, + int convolution_P, + int convolution_Q, + int add_P, + int add_Q, + int problem_N) const { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + int output_row = row_offset + thread_start_row_; + int output_N = output_row / (convolution_P * convolution_Q); + int output_PQ = output_row % (convolution_P * convolution_Q); + int output_P = output_PQ / convolution_Q; + int output_Q = output_PQ % convolution_Q; + int row_add_P = add_P; + int row_add_Q = add_Q; + if (output_P > convolution_P - 2) + row_add_P = 0; + if (output_Q > convolution_Q - 2) + row_add_Q = 0; + + int input_row = output_N * (convolution_P / 2) * (convolution_Q / 2) + + ((output_P + row_add_P) / 2) * (convolution_Q / 2) + + (output_Q + row_add_Q) / 2; + + int64_t byte_offset = + (input_row - output_row) * problem_N * sizeof(float); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr + [frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer + [column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + byte_pointer += params_.increment_row; + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + CUTLASS_DEVICE + MatrixCoord thread_start() const { + return MatrixCoord(thread_start_row_, thread_start_column_); + } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_row() const { + return thread_start_row_; + } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_column() const { + return thread_start_column_; + } + + /// Extent of the matrix in rows + CUTLASS_DEVICE + Index extent_row() const { + return extent_row_; + } + + /// Extent of the matrix in columns + CUTLASS_DEVICE + Index extent_column() const { + return extent_column_; + } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorPrefetch& operator++() { + ++state_[0]; + + if (!ScatterD) { + byte_pointer_ += params_.advance_row; + } + + thread_start_row_ += ThreadMap::Shape::kRow; + + if (state_[0] == ThreadMap::Count::kRow) { + state_[0] = 0; + ++state_[1]; + byte_pointer_ += params_.advance_group; + + thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * + ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + state_[1] = 0; + ++state_[2]; + byte_pointer_ += params_.advance_cluster; + + thread_start_row_ += ThreadMap::Count::kGroup * + ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * + ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { + state_[2] = 0; + byte_pointer_ += params_.advance_tile; + } + } + } + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { + mask_.clear(); + } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { + mask_.enable(); + } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask& mask) const { + mask = mask_; + } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const& mask) { + mask_ = mask; + } +}; + +template +struct MakePrefetchableIterator { + using Iterator = PredicatedTileIteratorPrefetch< + typename IT::ThreadMap, + typename IT::Element>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/make_residual_last.h b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/make_residual_last.h new file mode 100644 index 0000000000000000000000000000000000000000..f2655a096643c73bb6b804056c4885708456f24f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/make_residual_last.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include + + +namespace cutlass { +namespace transform { +namespace threadblock { + +template +struct MakeIteratorResidualLast; + +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + int AccessSize, + bool Gather> +struct MakeIteratorResidualLast> { + using Iterator = PredicatedTileIteratorResidualLast< + Shape, + Element, + Layout, + AdvanceRank, + ThreadMap, + AccessSize, + Gather>; +}; + +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + typename AccessType, + bool Gather> +struct MakeIteratorResidualLast> { + using Iterator = PredicatedTileAccessIteratorResidualLast< + Shape, + Element, + Layout, + AdvanceRank, + ThreadMap, + AccessType, + Gather>; +}; +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/predicated_tile_access_iterator_residual_last.h b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/predicated_tile_access_iterator_residual_last.h new file mode 100644 index 0000000000000000000000000000000000000000..db10d4e4a829078b37a1c57ca94c615e32c38dc5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/predicated_tile_access_iterator_residual_last.h @@ -0,0 +1,2115 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates calculating the address and predicates to the load of tiles + from pitch-linear rank=2 tensors. + + This iterator uses masks to guard out-of-bounds accesses. The first tile + this iterator visits maybe partial, then the remaining tiles are complete. + So, we only need to compute the predicates twice, once before the first tile + and once for the remaining full tiles which can share the same predicates. + + A precomputed "Params" object minimizes the amount of state that must be + stored in registers, and integer addition is used to advance the pointer + through memory. +*/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedTileAccessIteratorResidualLast +/// +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + typename AccessType, + bool Gather = false> +class PredicatedTileAccessIteratorResidualLast; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for pitch-linear +/// data. +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_, + bool Gather> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::PitchLinear, + AdvanceRank, + ThreadMap_, + AccessType_, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates< + Shape, + Element, + Layout, + AdvanceRank, + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = + ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert( + !(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + using Mask = typename UnderlyingPredicates::Mask; + + /// Uses a non-template class + struct Params : PredicatedTileAccessIteratorParams { + using Base = PredicatedTileAccessIteratorParams; + + // Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : Base( + layout.stride(0), + MakePredicatedTileAccessIteratorDesc< + Shape, + Element, + Layout, + kAdvanceRank, + ThreadMap>()()) {} + + CUTLASS_HOST_DEVICE + Params(Base const& base) : Base(base) {} + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + private: + // + // Data members + // + + UnderlyingPredicates the_predicates; + Mask residual_tile_mask; + + /// Parameters object with precomputed internal state + Params params_; + + /// Internal pointer to first access of tile + BytePointer pointer_; + + /// Below is used when Gather is turned on. We need to record strided_offset + /// and contiguous_offset separated to compute the offset by using + /// + /// offset = contiguous_offset + indices[strided_offset] + /// + + /// Gather indices + int const* indices_; + + Index gather_offset_strided; + + private: + /// Computes predicates based on internally tracked per-thread offset. + CUTLASS_DEVICE + void compute_predicates_( + /// Extent of the matrix window + TensorCoord extent, + /// optionally, simplify predicate calculation during 'steady state' phase + bool is_steady_state = false) { + the_predicates.compute_predicates_(extent, is_steady_state); + } + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + /// Gather indices + int const* indices = nullptr) + : params_(params), + pointer_(reinterpret_cast( + const_cast(pointer))), + the_predicates(extent), + indices_(indices) { + the_predicates.set_predicates(thread_id, threadblock_offset); + the_predicates.get_mask(residual_tile_mask); + + // Working around a weird compiler bug happening on P100 for the backward. + // I've seen together: the_predicates.predicates_[0] = 14 (instead of 15) + // residual_tile_mask[0] = 15 (correct) + // + // Adding prints when the value is calculated (in `compute_predicates_`) + // sometimes removes the bug. The consequence is that we skip some + // element of a tensor, leading to wrong results + // Setting `compute_predicates_`'s second argument (`is_steady_state`) to + // true also seems to get rid of the bug - at the cost of twice as many + // comparisons. +#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 700) + constexpr bool kWorkAroundCompilerBug = false; +#else + constexpr bool kWorkAroundCompilerBug = true; +#endif + the_predicates.compute_predicates_(extent, true && !kWorkAroundCompilerBug); + + // update internal pointers + Layout layout(params_.stride_); + + if (!Gather) { + add_pointer_offset(layout(the_predicates.thread_offset_)); + } else { + gather_offset_strided = the_predicates.thread_offset_.strided(); + add_pointer_offset( + layout(make_Coord(the_predicates.thread_offset_.contiguous(), 0))); + } + } + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + the_predicates.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool is_residual_tile) { + if (is_residual_tile) { + the_predicates.set_mask(residual_tile_mask); + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += sizeof_bits::value * pointer_offset / 8; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + if (!Gather) { + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided()); + pointer_ += Shape::kContiguous * tile_offset.contiguous(); + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous()); + pointer_ += Shape::kStrided * tile_offset.strided(); + } + } else { + add_pointer_offset(Shape::kContiguous * tile_offset.contiguous()); + gather_offset_strided += Shape::kStrided * tile_offset.strided(); + } + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + if (Gather) { + assert(indices_); + + if (!valid()) { + return nullptr; + } + + LongIndex contiguous_offset = the_predicates.iteration_contiguous_ * + (ThreadMap::Delta::kContiguous * sizeof_bits::value / + 8) + + the_predicates.iteration_vector_; + int strided_index = gather_offset_strided + + the_predicates.iteration_strided_ * ThreadMap::Delta::kStrided; + + LongIndex strided_offset = indices_[strided_index] * + LongIndex(params_.stride_) * sizeof_bits::value / 8; + + return reinterpret_cast( + pointer_ + contiguous_offset + strided_offset); + } + + return reinterpret_cast( + pointer_ + + the_predicates.iteration_contiguous_ * + (ThreadMap::Delta::kContiguous * + sizeof_bits::value) / + 8) + + the_predicates.iteration_vector_; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + the_predicates.operator++(); + + ++the_predicates.iteration_vector_; + if (the_predicates.iteration_vector_ < kAccessesPerVector) { + return *this; + } + + the_predicates.iteration_vector_ = 0; + ++the_predicates.iteration_contiguous_; + + if (the_predicates.iteration_contiguous_ < + ThreadMap::Iterations::kContiguous) { + return *this; + } + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + the_predicates.iteration_contiguous_ = 0; + ++the_predicates.iteration_strided_; + + if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) { + if (!Gather) { + pointer_ += params_.inc_strided_; + } + + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + the_predicates.iteration_strided_ = 0; + + if (!Gather) { + // advance to next tile + pointer_ += params_.inc_next_; + + // now return to start tile - if the iterator is subsequently advanced, + // this subtraction as well as the subsequent integer addition are both + // elided by the compiler. + pointer_ -= params_.inc_advance_; + } + + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + the_predicates.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + the_predicates.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + the_predicates.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + the_predicates.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() const { + return the_predicates.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for column-major +/// data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_, + bool Gather> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::ColumnMajor, + AdvanceRank, + ThreadMap_, + AccessType_, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessType, + Gather>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))){}; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row(), + threadblock_offset.column()), + indices) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for row-major +/// data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_, + bool Gather> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::RowMajor, + AdvanceRank, + ThreadMap_, + AccessType_, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessType, + Gather>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))){}; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + /// Gather indices + int const* indices = nullptr) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column(), + threadblock_offset.row()), + indices) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank 2 +/// data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::AffineRankN<2>, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRankN<2>; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates< + Shape, + Element, + layout::PitchLinear, + AdvanceRank, + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = + ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert( + !(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingPredicates::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + friend PredicatedTileAccessIteratorResidualLast; + + private: + /// stride of pitch-linear layout (units of Element) + Coord stride_; + /// amount (in byte) to increment pointer to move to next access along + /// contiguous dimension + LongIndex inc_contiguous_; + /// amount (in byte) to increment pointer from first access of current + /// contiguous dimension to first access of next one. + LongIndex inc_strided_; + /// amount (in byte) to increment pointer from last access of current + /// contiguous dimension to first access of next one. + LongIndex inc_next_strided_; + /// amount (in byte) to increment pointer from last access to first access + /// of next tile + LongIndex inc_next_; + /// amount (in byte) to increment pointer from first access of current tile + /// to first access of next tile + LongIndex inc_advance_; + + public: + // Default ctor + CUTLASS_HOST_DEVICE + Params() + : stride_(0), + inc_contiguous_(0), + inc_strided_(0), + inc_next_(0), + inc_advance_(0) {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : stride_({layout.stride(0), layout.stride(1)}) { + inc_contiguous_ = + (LongIndex(stride_[0]) * ThreadMap::Delta::kContiguous) * + sizeof_bits::value / 8; + + inc_strided_ = (LongIndex(stride_[1]) * ThreadMap::Delta::kStrided) * + sizeof_bits::value / 8; + + inc_next_strided_ = inc_strided_ - + LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_; + + if (kAdvanceRank) { + // advance along strided dimension + inc_advance_ = Shape::kStrided * LongIndex(stride_[1]) * + sizeof_bits::value / 8; + } else { + // advance along contiguous dimension + inc_advance_ = + Shape::kContiguous * stride_[0] * sizeof_bits::value / 8; + } + + inc_next_ = inc_advance_ - + LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_ - + LongIndex(ThreadMap::Iterations::kStrided - 1) * inc_strided_; + }; + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + // + // Data members + // + + /// Parameters object with precomputed internal state + Params params_; + + /// Internal pointer to first access of tile + BytePointer pointer_; + + UnderlyingPredicates the_predicates; + Mask residual_tile_mask; + + private: + /// Computes predicates based on internally tracked per-thread offset. + CUTLASS_DEVICE + void compute_predicates_( + /// Extent of the matrix window + TensorCoord extent, + /// optionally, simplify predicate calculation during 'steady state' phase + bool is_steady_state = false) { + the_predicates.compute_predicates_(extent, is_steady_state); + } + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : params_(params), + pointer_(reinterpret_cast( + const_cast(pointer))), + the_predicates(extent) { + the_predicates.set_predicates(thread_id, threadblock_offset); + + // update internal pointers + Layout layout(params_.stride_); + add_pointer_offset(layout(the_predicates.thread_offset_)); + } + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + the_predicates.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool is_residual_tile) { + if (is_residual_tile) { + the_predicates.set_mask(residual_tile_mask); + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += sizeof_bits::value * pointer_offset / 8; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset[1]); + pointer_ += Shape::kContiguous * tile_offset[0]; + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset[0]); + pointer_ += Shape::kStrided * tile_offset[1]; + } + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(pointer_) + + the_predicates.iteration_vector_; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + the_predicates.operator++(); + ++the_predicates.iteration_vector_; + if (the_predicates.iteration_vector_ < kAccessesPerVector) { + return *this; + } + + the_predicates.iteration_vector_ = 0; + ++the_predicates.iteration_contiguous_; + + if (the_predicates.iteration_contiguous_ < + ThreadMap::Iterations::kContiguous) { + pointer_ += params_.inc_contiguous_; + return *this; + } + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + the_predicates.iteration_contiguous_ = 0; + ++the_predicates.iteration_strided_; + + if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) { + pointer_ += params_.inc_next_strided_; + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + the_predicates.iteration_strided_ = 0; + + // advance to next tile + pointer_ += params_.inc_next_; + + // now return to start tile - if the iterator is subsequently advanced, this + // subtraction as well as the subsequent integer addition are both elided by + // the compiler. + pointer_ -= params_.inc_advance_; + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + the_predicates.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + the_predicates.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + the_predicates.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + the_predicates.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return the_predicates.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank 2 +/// column-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::AffineRank2ColumnMajor, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))){}; + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row(), + threadblock_offset.column())) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset( + make_Coord(tile_offset.row(), tile_offset.column())); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank-2 +/// row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::AffineRank2RowMajor, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))){}; + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset( + make_Coord(tile_offset.column(), tile_offset.row())); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for column-major +/// interleaved data. It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// + +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_, + int InterleavedK> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::ColumnMajorInterleaved, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::ColumnMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape< + Shape::kRow * kInterleavedK, + Shape::kColumn / kInterleavedK>, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord( + extent.row() * kInterleavedK, + extent.column() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row() * kInterleavedK, + threadblock_offset.column() / kInterleavedK)) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for row-major +/// interleaved data. +// It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_, + int InterleavedK> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::RowMajorInterleaved, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::RowMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape< + Shape::kColumn * kInterleavedK, + Shape::kRow / kInterleavedK>, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord( + extent.column() * kInterleavedK, + extent.row() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column() * kInterleavedK, + threadblock_offset.row() / kInterleavedK)) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/predicated_tile_iterator_residual_last.h b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/predicated_tile_iterator_residual_last.h new file mode 100644 index 0000000000000000000000000000000000000000..7ccced32b3cc04ef4ae4b4d1c4011ac830a46f67 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/predicated_tile_iterator_residual_last.h @@ -0,0 +1,2120 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of tiles from pitch-linear rank=2 + tensors. + + This iterator uses masks to guard out-of-bounds accesses. The first tile + this iterator visits maybe partial, then the remaining tiles are complete. + So, we only need to compute the predicates twice, once before the first tile + and once for the remaining full tiles which can share the same predicates. + + A precomputed "Params" object minimizes the amount of state that must be + stored in registers, and integer addition is used to advance the pointer + through memory. +*/ + +#pragma once + +#include +#include + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedTileIteratorResidualLast +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +/// Regular tile iterator using a precomputed control structure to minimize +/// register liveness and integer arithmetic. +/// +/// Layout is assumed to be invariant at the time the precomputed "Params" +/// object is constructed. +/// +/// Base pointer and tensor extents may be specified at the time the iterator is +/// constructed. Subsequently, they are assumed to be immutable. +/// +/// Adding a logical coordinate offset may be performed at the time the iterator +/// is constructed. Subsequent additions to logical coordinate offset may be +/// performed but are relatively expensive. +/// +/// Visitation order is intended to first visit a "residual" tile that may be +/// partially full in both the advance dimension and the steady-state dimension. +/// This is assumed to be the last tile in the iteration sequence. Advancing an +/// iterator that has just been constructed moves to the first tile that is full +/// in the advance dimension and recomputes predicates. Subsequent accesses may +/// be performed without updating internal predicates and are efficient in terms +/// of live register state and pointer arithmetic instructions. +/// +/// To be efficient, this assumes the iterator will be dereferenced and advanced +/// at least once outside any looping structure to minimize integer arithmetic. +/// +/// Access out of bounds are safe so long as `clear_mask()` is called prior to +/// dereferencing the iterator. +/// +/// +/// Example: +/// +/// An efficient pipeline structure may be constructed as follows: +/// +// template +// __global__ void kernel( +// typename Iterator::Params params, +// typename Iterator::Element *ptr, +// TensorCoord extent) { +// +// typename Iterator::Fragment fragment; +// +// TensorCoord threadblock_offset(0, 0); +// +// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets); +// +// +// fragment = *iter; // load "residue" tile first +// ++iter; // advance to first "steady state" tile and update +// internal masks +// +// +// #pragma unroll +// for (int i = Remaining - 1; i >= 0; --i) { +// +// f(fragment); +// +// if (!i) { +// iter.clear_mask(); // light-weight operation to clear masks - +// subsequent loads become NO-OPs. +// } +// +// fragment = *iter; // load tile during "steady state" phase +// ++iter; // advance to next tile - lightweight due to +// steady-state masks +// } +// } +// +// void host(TensorView view) { +// +// using Iterator = +// transform::threadblock::PredicatedTileIteratorResidualLast; +// +// typename Iterator::Params params(view.layout()); +// +// kernel(params, view.data()); +// } +/// +/// +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + int AccessSize = ThreadMap::kElementsPerAccess, + bool Gather = false> +class PredicatedTileIteratorResidualLast; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + bool Gather> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::PitchLinear, + AdvanceRank, + ThreadMap_, + AccessSize, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + /// Type used for internal memory accesses + using AccessType = AlignedArray< + Element, + AccessSize, + (AccessSize * sizeof_bits::value / 8)>; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = PredicatedTileAccessIteratorResidualLast< + Shape, + Element, + Layout, + kAdvanceRank, + ThreadMap, + AccessType, + Gather>; + + static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename TileAccessIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + using Base = typename TileAccessIterator::Params::Base; + + friend PredicatedTileIteratorResidualLast; + + private: + /// Parameters object + typename TileAccessIterator::Params params_; + + public: + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : params_(layout) {} + + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Base const& base) : params_(base) {} + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + /// Gather indices + int const* indices = nullptr) + : address_iterator_( + params.params_, + pointer, + extent, + thread_id, + threadblock_offset, + indices) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + if (kAdvanceRank) + address_iterator_.add_tile_offset({0, 1}); + else + address_iterator_.add_tile_offset({1, 0}); + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + address_iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + address_iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + address_iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + address_iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + address_iterator_.get_mask(mask); + } + + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + load_with_byte_offset( + frag, pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + address_iterator_.set_iteration_index(idx); + char const* byte_ptr = + reinterpret_cast(address_iterator_.get()) + + byte_offset; + + AccessType const* access_ptr = + reinterpret_cast(byte_ptr); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, address_iterator_.valid()); + + ++address_iterator_; + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_byte_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + store_with_byte_offset( + frag, pointer_offset * sizeof_bits::value / 8); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + address_iterator_.set_iteration_index(0); + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + char* byte_ptr = + reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType* access_ptr = reinterpret_cast(byte_ptr); + + if (address_iterator_.valid()) { + *access_ptr = frag_ptr[idx]; + } + ++address_iterator_; + } + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_byte_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + bool Gather> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::ColumnMajor, + AdvanceRank, + ThreadMap_, + AccessSize, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize, + Gather>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row(), + threadblock_offset.column()), + indices) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + bool Gather> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::RowMajor, + AdvanceRank, + ThreadMap_, + AccessSize, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize, + Gather>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = nullptr ///< Gather indices + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column(), + threadblock_offset.row()), + indices) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for affine rank-2 data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::AffineRankN<2>, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRankN<2>; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + /// Type used for internal memory accesses + using AccessType = AlignedArray< + Element, + AccessSize, + (AccessSize * sizeof_bits::value / 8)>; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = PredicatedTileAccessIteratorResidualLast< + Shape, + Element, + Layout, + kAdvanceRank, + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename TileAccessIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + friend PredicatedTileIteratorResidualLast; + + private: + /// Parameters object + typename TileAccessIterator::Params params_; + + public: + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : params_(layout) {} + + CUTLASS_HOST_DEVICE + Params() {} + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : address_iterator_( + params.params_, + pointer, + extent, + thread_id, + threadblock_offset) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + if (kAdvanceRank) + address_iterator_.add_tile_offset(make_Coord(0, 1)); + else + address_iterator_.add_tile_offset(make_Coord(1, 0)); + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + address_iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + address_iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + address_iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + address_iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + address_iterator_.get_mask(mask); + } + + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + load_with_byte_offset( + frag, pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + address_iterator_.set_iteration_index(idx); + char const* byte_ptr = + reinterpret_cast(address_iterator_.get()) + + byte_offset; + + AccessType const* access_ptr = + reinterpret_cast(byte_ptr); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, address_iterator_.valid()); + + ++address_iterator_; + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_byte_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + store_with_byte_offset( + frag, pointer_offset * sizeof_bits::value / 8); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + address_iterator_.set_iteration_index(0); + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + char* byte_ptr = + reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType* access_ptr = reinterpret_cast(byte_ptr); + + if (address_iterator_.valid()) { + *access_ptr = frag_ptr[idx]; + } + ++address_iterator_; + } + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_byte_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for affine rank 2 +/// column-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::AffineRank2ColumnMajor, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))) {} + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row(), + threadblock_offset.column())) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for affine rank 2 +/// row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::AffineRank2RowMajor, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))) {} + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for interleaved data. +/// It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// + +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + int InterleavedK> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::ColumnMajorInterleaved, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::ColumnMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape< + Shape::kRow * kInterleavedK, + Shape::kColumn / kInterleavedK>, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord( + extent.row() * kInterleavedK, + extent.column() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row() * kInterleavedK, + threadblock_offset.column() / kInterleavedK)) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for interleaved-32 +/// data. It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + int InterleavedK> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::RowMajorInterleaved, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::RowMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape< + Shape::kColumn * kInterleavedK, + Shape::kRow / kInterleavedK>, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord( + extent.column() * kInterleavedK, + extent.row() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column() * kInterleavedK, + threadblock_offset.row() / kInterleavedK)) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/transpose_warp_iterator.h b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/transpose_warp_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..1242eee7af4e8e564f4f045f91c1ccbdca712888 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/transpose_warp_iterator.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include + +template +struct TransposeWarpIterator { + using Iterator = char; + static bool constexpr kSupportsTranspose = false; +}; + +template < + /// Operand identity + cutlass::gemm::Operand Operand, + /// Data type of A elements + typename Element, + typename InstructionShape, + bool kTranspose> +struct TransposeWarpIterator< + cutlass::gemm::warp:: + WarpIteratorFromSmem> { + using Iterator = cutlass::gemm::warp:: + WarpIteratorFromSmem; + static bool constexpr kSupportsTranspose = true; +}; diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/warp_iterator_from_smem.h b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/warp_iterator_from_smem.h new file mode 100644 index 0000000000000000000000000000000000000000..baa0ef23493cba177ce8829177b288bddf406cbb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/warp_iterator_from_smem.h @@ -0,0 +1,284 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Inspired from + "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" Loads tiles of GEMM + operands from a RowMajor shared-memory layout into registers to use by A100 + TensorCores. + + The difference with "mma_tensor_op_tile_access_iterator.h" is that: + (1) We use "ldmatrix" to load tiles, rather than manual loads (slightly + faster) (2) We support to transpose the operand (eg read `A.transpose()` when + the shared memory holds `A`) + + This is only implemented for the specific shapes. +*/ +#pragma once + +#include + +//////////////////////////////////////////////////////////////////////////////// +namespace cutlass { +namespace gemm { +namespace warp { + +template < + /// Operand identity + Operand Operand_, + /// Data type of A elements + typename Element_, + typename InstructionShape_, + bool kTranspose = false> +class WarpIteratorFromSmem { + public: + /// Shape of tile to load (concept: MatrixShape) + using Shape = cutlass::MatrixShape<32, 32>; + + /// Operand tag + static Operand const kOperand = Operand_; + static_assert( + kOperand == Operand::kA, + "No support for OperandB at the moment"); + + /// Basic check + static_assert( + kOperand == Operand::kA || kOperand == Operand::kB, + "WarpIteratorFromSmem may only be instantiated for A or B operands to warp-level Mma."); + + /// Element type + using Element = Element_; + static_assert(sizeof_bits::value == 16, "Only supported for half"); + + /// Layout of source tile + using Layout = cutlass::layout::RowMajor; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = InstructionShape_; + static_assert(InstructionShape::kRow == 16, "Only supports 16x8x8 / 16x8x16"); + static_assert( + InstructionShape::kColumn == 8 || InstructionShape::kColumn == 16, + "Only supports 16x8x8 / 16x8x16"); + + /// Delta between *MMA operations (in units of *MMA operations, concept: + /// MatrixShape) + static int const kOpDelta = 1; + + /// Number of participating threads + static int const kThreads = 32; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Number of elements accessed per Shared Memory load + static int const kElementsPerAccess = + (sizeof_bits::value >= 32 ? 1 + : 32 / sizeof_bits::value); + + using InstructionCount = MatrixShape< + Shape::kRow / InstructionShape::kRow, + Shape::kColumn / InstructionShape::kColumn>; + + static int const kIterations = (kOperand == Operand::kA) + ? InstructionCount::kColumn + : InstructionCount::kRow; + + public: + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + using Fragment = Array< + Element, + (kOperand == Operand::kA) + ? (Shape::kRow* InstructionShape::kColumn / kThreads) + : (Shape::kColumn* InstructionShape::kRow / kThreads)>; + + /// Memory access type + // using AccessType = AlignedArray; + using AccessType = Array; + + static int constexpr kWarpShapeDivisibleInner = + (kOperand == Operand::kA ? InstructionShape::kColumn + : InstructionShape::kRow); + static int constexpr kAccessesInner = + (kWarpShapeDivisibleInner / kElementsPerAccess) / 4; + // Number of 32bits tiles to load per `ldmatrix` + static int const kTilesPerInstruction = InstructionShape::kRow / 8; + static_assert(kTilesPerInstruction == 2, "Only supports 16x8x16 and 16x8x8"); + + private: + /// Underlying tensor reference + TensorRef ref_; + + /// Origin + MatrixCoord origin_; + + /// Iterations in a tile + int iterations_; + + public: + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem(TensorRef const& ref, int lane_id) + : WarpIteratorFromSmem(ref, {Shape::kRow, Shape::kColumn}, lane_id) {} + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem(TensorRef const& ref, TensorCoord extent, int lane_id) + : ref_(ref), iterations_(0) { + // See also: + // https://docs.nvidia.com/cuda/archive/11.7.1/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-1688 + // 16x8x8: kAccessesInner = 1 (1 ldmatrix.x4) + // 16x8x16: kAccessesInner = 2 (2 ldmatrix.x4) + int ldsm_vec_num = (lane_id >> 3); + if (kOperand == Operand::kA) { + origin_ = MatrixCoord(lane_id % 8, 0); + static_assert( + InstructionCount::kRow * kTilesPerInstruction == 4, + "can't use ldmatrix.x4"); + int access_m_idx = ldsm_vec_num % kTilesPerInstruction; + int inner_idx = (ldsm_vec_num / kTilesPerInstruction) % kAccessesInner; + int inst_m_idx = ldsm_vec_num / (kTilesPerInstruction * kAccessesInner); + MatrixCoord offset( + access_m_idx * 8 + inst_m_idx * InstructionShape::kRow, + inner_idx * 4 * kElementsPerAccess); + if (kTranspose) { + offset = MatrixCoord(offset.column(), offset.row()); + } + origin_ += offset; + } else { + // XXX: This is not tested or used + origin_ = MatrixCoord(0, lane_id % 8); + static_assert(InstructionCount::kColumn * kAccessesInner == 4, ""); + CUTLASS_PRAGMA_UNROLL + for (int inst_n_idx = 0; inst_n_idx < InstructionCount::kColumn; + ++inst_n_idx) { + CUTLASS_PRAGMA_UNROLL + for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) { + int access_idx = inner_idx + kAccessesInner * inst_n_idx; + + MatrixCoord offset( + inner_idx * 4 * kElementsPerAccess, inst_n_idx * 8); + + if (access_idx == ldsm_vec_num) { + if (kTranspose) { + offset = MatrixCoord(offset.column(), offset.row()); + } + origin_ += offset; + } + } + } + } + + ref_.add_coord_offset(origin_); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem& add_tile_offset(TensorCoord const& tile_offset) { + TensorCoord coord_offset( + tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); + if (kTranspose) { + coord_offset = TensorCoord{coord_offset.column(), coord_offset.row()}; + } + origin_ += coord_offset; + + ref_.add_coord_offset(coord_offset); + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_DEVICE + void advance() { + if (kOperand == Operand::kA) { + add_tile_offset({0, 1}); + } else { + add_tile_offset({1, 0}); + } + + iterations_ = 0; + } + + /// increase iterations in a tile + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem& operator++() { + iterations_++; + + if (iterations_ >= kIterations) + advance(); + + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_DEVICE + void load(Fragment& frag) const { + AccessType* access_ptr = reinterpret_cast(&frag); + using LoadLayout = typename platform:: + conditional::type; + + CUTLASS_PRAGMA_UNROLL + for (int access_m_idx = 0; access_m_idx < + (InstructionCount::kRow * kTilesPerInstruction * kAccessesInner) / 4; + ++access_m_idx) { + MatrixCoord offset; + if (kOperand == Operand::kA) { + offset = MatrixCoord( + access_m_idx * 16, iterations_ * InstructionShape::kColumn); + } else { + offset = MatrixCoord(iterations_ * InstructionShape::kRow, 0); + } + if (kTranspose) { + offset = MatrixCoord(offset.column(), offset.row()); + } + cutlass::arch::ldsm( + access_ptr[access_m_idx], ref_.data() + ref_.offset(offset)); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass +//////////////////////////////////////////////////////////////////////////////// diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h new file mode 100644 index 0000000000000000000000000000000000000000..95419b41c84b835a377cfd69236d97f7658aa718 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h @@ -0,0 +1,2614 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +#include +#include +#include + +#include +#include + +using namespace gemm_kernel_utils; + +namespace PyTorchMemEffAttention { +namespace { + +template +struct GmemTile { + /* + Helper functions to efficient store/load RF to gmem + + GEMM accumulators have a particular format on A100, and + it takes some compute/shared-memory to rearrange them to + a RowMajor or ColumnMajor format in global memory through + an Epilogue. The same complexity goes for loading into RF. + + This class loads/stores RF as they are, and can be used for + efficient accumulation across gemms for instance: + + ``` + GmemTile tile; + for (int i = 0; i < N; ++i) { + // ... + + Fragment accum; + if (i == 0) { + accum.clear(); + } else { + tile.load(accum); + } + mma(accum, ...); + if (i < N-1) { + // Store for next GEMM + tile.store(accum); + } else { + // Store in tensor (eg RowMajor) + epilogue(accum); + } + + // ... + } + ``` + */ + + // 128bits per thread + using AccessType = cutlass::Array; + static constexpr int32_t kBytes = sizeof(AccessType); + static constexpr int32_t kStride = kNumThreads * AccessType::kElements; + static constexpr int32_t kNumIters = + FragmentType::kElements / AccessType::kElements; + static constexpr int32_t kElementsStored = + kNumThreads * FragmentType::kElements; + static_assert( + FragmentType::kElements % AccessType::kElements == 0, + "fragment not aligned on 128 bits"); + + float* ptr; + + CUTLASS_DEVICE void load(FragmentType& fragment, int thread_id) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNumIters; ++i) { + AccessType* __restrict__ gmem_ptr = reinterpret_cast( + ptr + thread_id * AccessType::kElements + i * kStride); + AccessType sub_fragment; + cutlass::arch::global_load( + sub_fragment, gmem_ptr, true); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < AccessType::kElements; ++j) { + fragment[i * AccessType::kElements + j] = sub_fragment[j]; + } + } + } + + CUTLASS_DEVICE void store(FragmentType const& fragment, int thread_id) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNumIters; ++i) { + AccessType* __restrict__ gmem_ptr = reinterpret_cast( + ptr + thread_id * AccessType::kElements + i * kStride); + AccessType sub_fragment; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < AccessType::kElements; ++j) { + sub_fragment[j] = fragment[i * AccessType::kElements + j]; + } + cutlass::arch::global_store( + sub_fragment, gmem_ptr, true); + } + } + + CUTLASS_DEVICE void storeAtomicAdd( + FragmentType const& fragment, + int thread_id) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNumIters; ++i) { + float* gmem_ptr = ptr + thread_id * AccessType::kElements + i * kStride; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < AccessType::kElements; ++j) { + float val = fragment[i * AccessType::kElements + j]; + float* ptr = gmem_ptr + j; + atomicAdd(ptr, val); + } + } + } +}; + +struct AtomicLock { + CUTLASS_DEVICE static void acquire( + int32_t* lock, + int set_val, + int thread_id) { + if (thread_id == 0) { + while (atomicCAS(lock, 0 /*cmp*/, set_val /*setval*/) != set_val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + __nanosleep(40); +#endif + } + } + __syncthreads(); + } + CUTLASS_DEVICE static void release(int32_t* lock, int thread_id) { + if (thread_id == 0) { + int status = 0; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + asm volatile("st.global.release.gpu.b32 [%0], %1;\n" + : + : "l"(lock), "r"(status)); +#else + asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status)); +#endif + } + } +}; + +template +constexpr int getWarpsPerSmBw() { + bool is_half = !cutlass::platform::is_same::value; + if (Arch::kMinComputeCapability >= 80) { + return is_half ? 12 : 8; + } + return 8; +} +} // namespace + +template < + // which arch we target (eg `cutlass::arch::Sm80`) + typename ArchTag_, + // input/output type + typename scalar_t_, + // run optimized kernel because memory accesses will be aligned + bool kIsAligned_, + // use dropout if enabled + bool kApplyDropout_, + // when doing a GEMM, preload the next one (uses more shmem) + bool kPreload_, + // block dimensions + int kBlockSizeI_, + int kBlockSizeJ_, + // upperbound on `max(value.shape[-1], query.shape[-1])` + int kMaxK_ = (int)cutlass::platform::numeric_limits::max(), + // assumes that `cu_seqlen` is None, and + // (1) `num_queries % kBlockSizeI == 0` + // (2) `num_keys % kBlockSizeJ == 0` + bool kKeysQueriesAlignedToBlockSize_ = false> +struct AttentionBackwardKernel { + enum CustomMaskType { + NoCustomMask = 0, + CausalFromTopLeft = 1, + CausalFromBottomRight = 2, + NumCustomMaskTypes, + }; + using scalar_t = scalar_t_; + using output_t = scalar_t; + using output_accum_t = float; + using lse_scalar_t = float; + using accum_t = float; + using ArchTag = ArchTag_; + static constexpr bool kIsAligned = kIsAligned_; + static constexpr bool kApplyDropout = kApplyDropout_; + static constexpr bool kPreload = kPreload_; + static constexpr int kBlockSizeI = kBlockSizeI_; + static constexpr int kBlockSizeJ = kBlockSizeJ_; + static constexpr int kMaxK = kMaxK_; + static constexpr bool kKeysQueriesAlignedToBlockSize = + kKeysQueriesAlignedToBlockSize_; + + static constexpr int64_t kWarpSize = 32; + + // If this is true, we store and accumulate dK/dV in RF + // rather than going back to gmem everytime + static constexpr bool kIsHalf = cutlass::sizeof_bits::value <= 16; + static constexpr bool kOutputInRF = kIsHalf && kMaxK <= kBlockSizeI; + static_assert( + !kPreload || + (kIsHalf && ArchTag::kMinComputeCapability >= 80 && kOutputInRF), + "preload MMA not supported"); + static constexpr bool kPrologueQK = kPreload; + static constexpr bool kPrologueGV = kPreload; + static constexpr bool kPrologueDOV = kPreload; + static constexpr bool kPrologueGQ = kPreload; + static constexpr bool kPrologueGK = kPreload; + + static constexpr int64_t kNumWarpsPerBlock = + (kBlockSizeI * kBlockSizeJ) / (32 * 32); + + // Compute delta for the f16 kernels + // TODO: Figure out why it's slower on the f32 kernels + // (something due to RF pressure?) + // TODO: Remove condition on `kOutputInRF` - this is needed to work + // around a compiler bug on V100, not exactly sure why but I spent + // too much time on this already. Reproducible with + // (B, Mq, Mkv, K) = (1, 1, 1, 136) for instance + static constexpr bool kKernelComputesDelta = + kIsHalf && (kOutputInRF || ArchTag::kMinComputeCapability != 70); + + // Launch bounds + static constexpr int64_t kNumThreads = kWarpSize * kNumWarpsPerBlock; + static constexpr int64_t kMinBlocksPerSm = + getWarpsPerSmBw() / kNumWarpsPerBlock; + + using GemmType = DefaultGemmType; + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration< + typename GemmType::OpClass, + ArchTag, + scalar_t, + scalar_t, + scalar_t, // ElementC + accum_t // ElementAccumulator + >; + static constexpr auto kOptimalAlignement = cutlass::platform::max( + DefaultConfig::kAlignmentA, + DefaultConfig::kAlignmentB); + static constexpr auto kMinimumAlignment = GemmType::kMinimumAlignment; + + struct MatmulQK { + /* + attn_T = k_j @ q_i.transpose(-2, -1) # matmul + attn_T = (attn_T - logsumexp[i_start:i_end].unsqueeze(1).transpose(-2, + -1)).exp() # epilogue + + with attn_T.shape = (kBlockSizeJ, kBlockSizeI) + */ + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using DefaultMma = typename cutlass::gemm::threadblock::DefaultMma< + scalar_t, // ElementA + cutlass::layout::RowMajor, // LayoutA + kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment, + scalar_t, // ElementB + cutlass::layout::ColumnMajor, // LayoutB + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, + accum_t, // ElementC + cutlass::layout::RowMajor, // LayoutC + typename GemmType::OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + DefaultConfig::kStages, + typename GemmType::Operator, + false, // AccumulatorsInRowMajor = false, + cutlass::gemm::SharedMemoryClearOption::kNone>; + using MmaCore = typename DefaultMma::MmaCore; + using Mma = + typename MakeCustomMma::Mma; + + // used for efficient load of bias tile (Bij) from global memory to shared + // memory + using BiasLoader = TileSmemLoader< + scalar_t, + // Bij is applied to transposed attn matrix tile (Pij.T). Bij is loaded + // row-major but needs to have transposed shape so we get the same + // elements. + cutlass::MatrixShape, + MmaCore::kThreads, + // input restriction: kv_len has to be a multiple of this value + 128 / cutlass::sizeof_bits::value>; + + // Epilogue to store to shared-memory in a format that we can use later for + // the second matmul + using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< + typename Mma::Operator::IteratorC, + typename Mma::Operator, + scalar_t, + WarpShape, + ThreadblockShape>; + using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator< + typename Mma::Operator::IteratorC, + accum_t, + kWarpSize>::Iterator; + using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; + }; + + struct MatmulGradV { + /* + grad_v[j_start:j_end] += attn_T @ do_i # matmul + + Dimensions: (kBlockSizeJ * kNumWarpsPerBlock, kBlockSizeI, K) + (we might need to iterate multiple times on K) + */ + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< + scalar_t, // ElementA, + cutlass::layout::RowMajor, // LayoutA, + DefaultConfig::kAlignmentA, + scalar_t, // ElementB, + cutlass::layout::RowMajor, // LayoutB, + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, + output_t, + cutlass::layout::RowMajor, // LayoutC, + accum_t, + typename GemmType::OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + typename DefaultConfig::EpilogueOutputOp, + void, // ThreadblockSwizzle - not used + DefaultConfig::kStages, + false, // SplitKSerial + typename GemmType::Operator>; + + // if dropout: + // for computing dVj += (Pij.T * Zij) @ dOi + // Pij_dropped.T = Pij.T * Zij is computed on the fly as fragments of + // Pij.T are loaded in. The reason we do it this way is because Pij.T and + // Zij are reused in later steps, while Pij_dropped.T is only needed in + // this step. computing Pij_dropped.T on the fly allows us to avoid + // keeping all 3 of Pij_dropped.T, Pij.T, and Zij in shared memory at the + // same time. + // if no dropout: + // for computing dVj += Pij.T @ dOi + using WarpIteratorA = typename cutlass::gemm::threadblock:: + DefaultWarpIteratorAFromSharedMemory< + typename DefaultGemm::Mma::Operator::Shape, // WarpShape + typename DefaultGemm::Mma::Operator:: + InstructionShape, // InstructionShape + typename DefaultGemm::Mma::Operator:: + IteratorA, // RegularWarpIterator + typename DefaultGemm::Mma::Policy // Policy + >::WarpIterator; + using DefaultMmaFromSmem = + typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + MatmulQK::AccumulatorSharedStorage::Shape::kN, + WarpIteratorA, + kApplyDropout>; // kScaleOperandA + + using Mma = typename DefaultMmaFromSmem::Mma; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + + // Epilogue + using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp; + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = + typename cutlass::epilogue::threadblock::MakePrefetchableIterator< + typename DefaultEpilogue::OutputTileIterator>::Iterator; + using AccumTileGmem = GmemTile; + }; + + struct MatmulDOIVJ { + /* + doi_t_vj = do_i @ v_j.transpose(-2, -1) # matmul + tmp = (doi_t_vj - Di.unsqueeze(1)) * attn # inplace / epilogue? + */ + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + + using ElementC = output_t; + using ElementAccum = accum_t; + + // no-op output op - epilogue just stores result to global memory + using BiasGradEpilogueOutputOp = + typename cutlass::epilogue::thread::LinearCombination< + ElementC, + DefaultConfig::EpilogueOutputOp::kCount, + typename DefaultConfig::EpilogueOutputOp::ElementAccumulator, + typename DefaultConfig::EpilogueOutputOp::ElementCompute, + cutlass::epilogue::thread::ScaleType::Nothing>; + + using DefaultGemm = typename cutlass::gemm::kernel::DefaultGemm< + scalar_t, // ElementA + cutlass::layout::RowMajor, // LayoutA + kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment, + scalar_t, // ElementB + cutlass::layout::ColumnMajor, // LayoutB + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, + ElementC, // ElementC + cutlass::layout::RowMajor, // LayoutC + ElementAccum, // ElementAccumulator + typename GemmType::OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + BiasGradEpilogueOutputOp, // EpilogueOutputOp + void, // ThreadblockSwizzle (not used) + // multiple preloads, dropout Zij tile, and 3 stages push us over shared + // memory capacity on A100. set a ceiling on number of stages to save + // shared memory if dropout is in use. + kPreload && kApplyDropout && (kBlockSizeI * kBlockSizeJ > 64 * 64) + ? cutlass::const_min(2, DefaultConfig::kStages) + : DefaultConfig::kStages, // Stages + false, // SplitKSerial + typename GemmType::Operator, + cutlass::gemm::SharedMemoryClearOption::kNone>; + using Mma = typename MakeCustomMma::Mma; + using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator< + typename Mma::Operator::IteratorC, + ElementAccum, + kWarpSize>::Iterator; + + // epilogue used to write bias gradient, which is just the output of this + // matmul with some operations applied to the fragment + using BiasGradEpilogue = typename DefaultGemm::Epilogue; + + // Epilogue to store to shared-memory in a format that we can use later for + // the second matmul + using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< + typename DefaultGemm::Mma::Operator::IteratorC, + typename DefaultGemm::Mma::Operator, + scalar_t, + WarpShape, + ThreadblockShape>; + using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; + }; + + struct MatmulGradQ { + // grad_q <- tmp @ k_j + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< + scalar_t, // ElementA, + cutlass::layout::RowMajor, // LayoutA, + DefaultConfig::kAlignmentA, + scalar_t, // ElementB, + cutlass::layout::RowMajor, // LayoutB, + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, + output_t, + cutlass::layout::RowMajor, // LayoutC, + accum_t, + typename GemmType::OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + typename DefaultConfig::EpilogueOutputOp, + void, // ThreadblockSwizzle - not used + DefaultConfig::kStages, + false, // SplitKSerial + typename GemmType::Operator>; + + using WarpIteratorA = typename cutlass::gemm::threadblock:: + DefaultWarpIteratorAFromSharedMemory< + typename DefaultGemm::Mma::Operator::Shape, + typename DefaultGemm::Mma::Operator::InstructionShape, + typename DefaultGemm::Mma::Operator::IteratorA, + typename DefaultGemm::Mma::Policy>::WarpIterator; + using DefaultMmaFromSmem = + typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + MatmulDOIVJ::AccumulatorSharedStorage::Shape::kN, + WarpIteratorA, + false>; // kScaleOperandA + using Mma = typename DefaultMmaFromSmem::Mma; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + + // Epilogue + using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp; + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = + typename cutlass::epilogue::threadblock::MakePrefetchableIterator< + typename DefaultEpilogue::OutputTileIterator>::Iterator; + using AccumTileGmem = GmemTile; + }; + struct MatmulGradK { + // grad_k <- tmp.transpose(-2, -1) @ q_i + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< + scalar_t, // ElementA, + cutlass::layout::RowMajor, // LayoutA, + DefaultConfig::kAlignmentA, + scalar_t, // ElementB, + cutlass::layout::RowMajor, // LayoutB, + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, + output_t, + cutlass::layout::RowMajor, // LayoutC, + accum_t, + typename GemmType::OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + typename DefaultConfig::EpilogueOutputOp, + void, // ThreadblockSwizzle - not used + DefaultConfig::kStages, + false, // SplitKSerial + typename GemmType::Operator>; + + using WarpIteratorA = typename cutlass::gemm::threadblock:: + DefaultWarpIteratorAFromSharedMemory< + typename DefaultGemm::Mma::Operator::Shape, + typename DefaultGemm::Mma::Operator::InstructionShape, + typename DefaultGemm::Mma::Operator::IteratorA, + typename DefaultGemm::Mma::Policy>::WarpIterator; + using DefaultMmaFromSmemN = + typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + MatmulQK::AccumulatorSharedStorage::Shape::kN, // kMaxK + WarpIteratorA, + false>; // kScaleOperandA + using DefaultMmaFromSmemT = + typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + MatmulDOIVJ::AccumulatorSharedStorage::Shape::kM, // kMaxK + WarpIteratorA, + false, // kScaleOperandA + kPreload>; // kTransposeA + using DefaultMmaFromSmem = typename cutlass::platform::conditional< + DefaultMmaFromSmemT::kIsTransposedA, + DefaultMmaFromSmemT, + DefaultMmaFromSmemN>::type; + using Mma = typename DefaultMmaFromSmem::Mma; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + + // Epilogue + using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp; + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = + typename cutlass::epilogue::threadblock::MakePrefetchableIterator< + typename DefaultEpilogue::OutputTileIterator>::Iterator; + using AccumTileGmem = GmemTile; + }; + + // NOTE: nvcc 12.4 has correctness errors with this on M60 (sm52) + // when there is an attention bias. Let's just disable it for now. + static constexpr auto kMinSm = ArchTag::kMinComputeCapability; + static constexpr bool kEnableSplitKeys = kMinSm >= 70; + + static constexpr bool kNeedsAccumGradQ = kEnableSplitKeys || + !cutlass::platform::is_same::value; + static constexpr bool kNeedsAccumGradK = !kOutputInRF && + !cutlass::platform::is_same::value; + static constexpr bool kNeedsAccumGradV = !kOutputInRF && + !cutlass::platform::is_same::value; + + struct GradQTempStorage { + int32_t lock; + int32_t counter; + int32_t pad[2]; // pad to 128bits + output_accum_t buffer[MatmulGradQ::AccumTileGmem::kElementsStored]; + }; + + struct Params { + // Input tensors + const scalar_t* query_ptr = nullptr; // [Mq, nH, K] + const scalar_t* key_ptr = nullptr; // [Mk, nH, K] + const scalar_t* value_ptr = nullptr; // [Mk, nH, Kv] + const scalar_t* bias_ptr = nullptr; + const lse_scalar_t* logsumexp_ptr = nullptr; // [nH, Mq] + const scalar_t* output_ptr = nullptr; // [Mq, nH, Kv] + const scalar_t* grad_output_ptr = nullptr; // [Mq, nH, Kv] + accum_t* delta_ptr = nullptr; // [nH, Mq] + const int32_t* cu_seqlens_q_ptr = nullptr; + const int32_t* cu_seqlens_k_ptr = nullptr; + + // Output tensors + output_t* grad_query_ptr = nullptr; // [Mq, nH, K] + output_t* grad_key_ptr = nullptr; // [Mk, nH, K] + output_t* grad_value_ptr = nullptr; // [Mk, nH, Kv] + output_t* grad_bias_ptr = nullptr; + + // Accumulators + output_accum_t* workspace = nullptr; // [Mq, Kq] + [Mkv, Kq] + [Mkv, Kv] + output_accum_t* workspace_gv = + nullptr; // (will be calculated by the kernel) + GradQTempStorage* workspace_gq = + nullptr; // (will be calculated by the kernel) + + // Sliding window. ignored if == 0 + int32_t window_size = 0; + + // Scale + accum_t scale = 1.0f; + + // Dimensions/strides + int32_t head_dim = -1; + int32_t head_dim_value = -1; + int32_t num_queries = -1; + int32_t num_keys = -1; + int32_t num_heads = -1; + uint8_t custom_mask_type = NoCustomMask; + + int64_t q_strideM = -1; + int64_t k_strideM = -1; + int64_t v_strideM = -1; + int64_t bias_strideM = 0; + int64_t gO_strideM = -1; + int64_t gB_strideM = -1; + int8_t gQKV_strideM_multiplier = 1; // 3 for packed, 1 otherwise + + at::PhiloxCudaState rng_engine_inputs = {0, 0}; + + // RNG sequence offset based on batch_id and head_id + unsigned long long dropout_batch_head_rng_offset = 0; + float dropout_prob = 0.0f; + + CUTLASS_HOST_DEVICE int64_t o_strideM() const { + return head_dim_value * num_heads; + } + CUTLASS_HOST_DEVICE int64_t gQ_strideM() const { + return gQKV_strideM_multiplier * num_heads * head_dim; + } + CUTLASS_HOST_DEVICE int64_t gK_strideM() const { + return gQKV_strideM_multiplier * num_heads * head_dim; + } + CUTLASS_HOST_DEVICE int64_t gV_strideM() const { + return gQKV_strideM_multiplier * num_heads * head_dim_value; + } + + // Everything below is only used in `advance_to_block` + // and shouldn't use registers + int64_t o_strideH = -1; + int32_t q_strideH = -1; + int32_t k_strideH = -1; + int32_t v_strideH = -1; + int64_t bias_strideH = 0; + int64_t o_strideB = -1; + int64_t q_strideB = -1; + int64_t k_strideB = -1; + int64_t v_strideB = -1; + int64_t bias_strideB = 0; + int64_t lse_strideB = -1; + int64_t lse_strideH = -1; + int64_t delta_strideB = -1; + int64_t delta_strideH = -1; + int32_t num_batches = -1; + int16_t num_splits_key = 1; // We use `gridDim.x` inside kernel + + int64_t gO_strideB = 0; + int64_t gQ_strideB = 0; + int64_t gK_strideB = 0; + int64_t gV_strideB = 0; + int64_t gB_strideB = 0; + int64_t gO_strideH = 0; + int64_t gQ_strideH = 0; + int64_t gK_strideH = 0; + int64_t gV_strideH = 0; + int64_t gB_strideH = 0; + + CUTLASS_HOST_DEVICE int16_t num_splits_key_device() const { +#ifdef __CUDA_ARCH__ + return kEnableSplitKeys ? gridDim.x : 1; +#else + return num_splits_key; // for host-side tests +#endif + } + CUTLASS_HOST_DEVICE int16_t split_key_device() const { +#ifdef __CUDA_ARCH__ + return kEnableSplitKeys ? blockIdx.x : 0; +#else + return 0; // for host-side tests +#endif + } + + CUTLASS_DEVICE bool advance_to_block() { + int64_t batch_id = blockIdx.z; + int32_t head_id = blockIdx.y; + + if (kNeedsAccumGradQ || kNeedsAccumGradK || kNeedsAccumGradV) { + assert(workspace_size() == 0 || workspace != nullptr); + + workspace += (batch_id * num_heads + head_id) * workspace_strideBH(); + workspace = warp_uniform(workspace); + workspace_gv = workspace + workspace_elements_gk(); + workspace_gq = + (GradQTempStorage*)(workspace_gv + workspace_elements_gv()); + if (kEnableSplitKeys) { + workspace_gv += workspace_elements_gv() * split_key_device() / + num_splits_key_device(); + workspace += workspace_elements_gk() * split_key_device() / + num_splits_key_device(); + } + } else { + workspace = nullptr; + } + + // Advance pointers that depend on the total concatenated + // number of queries, as `num_queries` is modified in the block + // below + dropout_batch_head_rng_offset = + batch_id * (num_heads * num_queries * num_keys) + + head_id * (num_queries * num_keys); + logsumexp_ptr += batch_id * lse_strideB + head_id * lse_strideH; + + if (cu_seqlens_q_ptr != nullptr) { + assert(cu_seqlens_k_ptr != nullptr); + cu_seqlens_q_ptr += batch_id; + cu_seqlens_k_ptr += batch_id; + int32_t q_start = cu_seqlens_q_ptr[0]; + int32_t k_start = cu_seqlens_k_ptr[0]; + int64_t q_next_start = cu_seqlens_q_ptr[1]; + int64_t k_next_start = cu_seqlens_k_ptr[1]; + assert(q_next_start - q_start <= num_queries); + assert(k_next_start - k_start <= num_keys); + num_queries = q_next_start - q_start; + num_keys = k_next_start - k_start; + + // Jump manually + batch_id = 0; + + query_ptr += q_start * q_strideM; + key_ptr += k_start * k_strideM; + value_ptr += k_start * v_strideM; + assert(bias_ptr == nullptr); + assert(grad_bias_ptr == nullptr); + output_ptr += q_start * o_strideM(); + grad_output_ptr += q_start * gO_strideM; + delta_ptr += q_start; + + grad_query_ptr += q_start * gQ_strideM(); + grad_key_ptr += k_start * gK_strideM(); + grad_value_ptr += k_start * gV_strideM(); + } + + query_ptr += batch_id * q_strideB + head_id * q_strideH; + key_ptr += batch_id * k_strideB + head_id * k_strideH; + value_ptr += batch_id * v_strideB + head_id * v_strideH; + if (bias_ptr != nullptr) { + bias_ptr += batch_id * bias_strideB + head_id * bias_strideH; + } + output_ptr += batch_id * o_strideB + head_id * o_strideH; + grad_output_ptr += batch_id * gO_strideB + head_id * gO_strideH; + delta_ptr += batch_id * delta_strideB + head_id * delta_strideH; + + grad_query_ptr += batch_id * gQ_strideB + head_id * gQ_strideH; + grad_key_ptr += batch_id * gK_strideB + head_id * gK_strideH; + grad_value_ptr += batch_id * gV_strideB + head_id * gV_strideH; + if (grad_bias_ptr != nullptr) { + grad_bias_ptr += batch_id * gB_strideB + head_id * gB_strideH; + } + + // Some values are modified above + // Signal to the compiler that they are the same in all threads + // and can be stored in warp-uniform registers (Sm75+) + num_queries = warp_uniform(num_queries); + num_keys = warp_uniform(num_keys); + custom_mask_type = warp_uniform(custom_mask_type); + + query_ptr = warp_uniform(query_ptr); + key_ptr = warp_uniform(key_ptr); + value_ptr = warp_uniform(value_ptr); + bias_ptr = warp_uniform(bias_ptr); + logsumexp_ptr = warp_uniform(logsumexp_ptr); + output_ptr = warp_uniform(output_ptr); + grad_output_ptr = warp_uniform(grad_output_ptr); + delta_ptr = warp_uniform(delta_ptr); + + grad_query_ptr = warp_uniform(grad_query_ptr); + grad_key_ptr = warp_uniform(grad_key_ptr); + grad_value_ptr = warp_uniform(grad_value_ptr); + grad_bias_ptr = warp_uniform(grad_bias_ptr); + +#if 0 + PRINT_T0("[b:%d h:%d] dp[0]:%f Q:%f K:%f V:%f LSE:%f", + int(blockIdx.z), int(blockIdx.y), + float(delta_ptr[0]), + float(query_ptr[0]), float(key_ptr[0]), float(value_ptr[0]), + float(logsumexp_ptr[0]) + ) +#endif + return true; + } + + __host__ dim3 getBlocksGrid() const { + return dim3(num_splits_key, num_heads, num_batches); + } + __host__ dim3 getThreadsGrid() const { + return dim3(kWarpSize * kNumWarpsPerBlock, 1, 1); + } + CUTLASS_HOST_DEVICE int64_t workspace_elements_gk() const { + if (!kNeedsAccumGradK) { + return 0; + } + return num_splits_key * kBlockSizeJ * + align_up(head_dim, kBlockSizeI); + } + CUTLASS_HOST_DEVICE int64_t workspace_elements_gv() const { + if (!kNeedsAccumGradV) { + return 0; + } + return num_splits_key * kBlockSizeJ * + align_up(head_dim_value, kBlockSizeI); + } + CUTLASS_HOST_DEVICE int64_t workspace_elements_gq() const { + if (!kNeedsAccumGradQ) { + return 0; + } + int num_blocks = ceil_div(num_queries, kBlockSizeI); + int num_cols = ceil_div(head_dim, MatmulGradQ::ThreadblockShape::kN); + return num_blocks * num_cols * sizeof(GradQTempStorage) / + sizeof(output_accum_t); + } + CUTLASS_HOST_DEVICE int64_t workspace_strideBH() const { + // Aligned on 128bits + return align_up( + workspace_elements_gk() + workspace_elements_gv() + + workspace_elements_gq(), + int64_t(4)); + } + CUTLASS_HOST_DEVICE int64_t workspace_size() const { + // Returns size of buffer we need to run this kernel + return num_batches * num_heads * workspace_strideBH() * sizeof(float); + } + CUTLASS_HOST_DEVICE bool should_zero_workspace() const { + return num_splits_key > 1 || window_size > 0; + } + }; + + // shared storage for keeping Zij matrix. not needed if we aren't using + // dropout, in which case we use an empty array to save shared memory + using ZijSharedStorage = typename cutlass::platform::conditional< + kApplyDropout, + typename MatmulQK::AccumulatorSharedStorage, + // dummy shared storage object that takes up no space. + typename cutlass::gemm::threadblock::AccumulatorSharedStorage< +#ifdef _WIN32 + // windows builds throw the error: + // "type containing an unknown-size array is not allowed" + // if we try to make Zij shared storage zero-sized. + // To get around this just make it sized 1 on windows. + typename cutlass::gemm::GemmShape<1, 1, 0>, +#else + typename cutlass::gemm::GemmShape<0, 0, 0>, +#endif + typename MatmulQK::AccumulatorSharedStorage::Element, + typename MatmulQK::AccumulatorSharedStorage::Layout, + typename cutlass::MatrixShape<0, 0>>>::type; + + struct SharedStoragePrologue { + struct { + cutlass::Array di; // (do_i * o_i).sum(-1) + typename MatmulQK::Mma::SharedStorageA mm_qk_k; + } persistent; + union { + struct { + // part1 - after Q.K / dV / dO.V + union { + // 1. efficient load of bias tile Bij, which is then applied to Pij + typename MatmulQK::BiasLoader::SmemTile bias; + // 4. store Pij. it is needed: + // - in dVj += (Pij.T * Zij) @ dOi + // - in dSij = Pij * (dPij - Di) + // 6. dVj += (Pij.T * Zij) @ dOi + // 10. write to fragment + typename MatmulQK::AccumulatorSharedStorage attn_shared_storage; + }; + // 5. store Zij. it is needed in dVj += (Pij.T * Zij) @ dOi + ZijSharedStorage zij; + + union { + // 2. prologue for dVj + // 6. workspace for dVj += (Pij.T * Zij) @ dOi + typename MatmulGradV::Mma::SharedStorage mm_gradV; + // 7. dVj epilogue + typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue; + }; + + // 3. prologue for dPij_dropped + // 8. used in dPij_dropped = dOi @ Vj.T + typename MatmulDOIVJ::Mma::SharedStorage mm_doivj; + } part1; + + struct { + // part2 - dQ + union { + typename MatmulQK::AccumulatorSharedStorage + tmpT_shared_storage; // (from part1) + typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage; + }; + typename MatmulGradK::Mma::SharedStorage mm_gradK; // (preload) + typename MatmulGradQ::Mma::SharedStorage mm_gradQ; // (preload) + union { + // store dB = dSij to global memory + typename MatmulDOIVJ::BiasGradEpilogue::SharedStorage gradB_epilogue; + typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue; + }; + + } part2; + + struct { + // part3 - after last iteration on dQ's epilogue / dK + union { + typename MatmulQK::AccumulatorSharedStorage + tmpT_shared_storage; // (from part1) + typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage; + }; + typename MatmulGradK::Mma::SharedStorage mm_gradK; // (preload) + typename MatmulGradQ::DefaultEpilogue::SharedStorage + gradQ_epilogue_lastIter; + + typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue; + } part3; + + struct { + // part4 - after last iteration on dK's epilogue / preload next K.Q_t + typename MatmulQK::Mma::SharedStorageB mm_qk_q; + + // If we reach end of current key, dump RF->gmem with "final" epilogues + typename MatmulGradK::DefaultEpilogue::SharedStorage + gradK_epilogue_final; + typename MatmulGradV::DefaultEpilogue::SharedStorage + gradV_epilogue_final; + } part4; + }; + static void print_size() { + // Field size +#define FSZ(f) int((sizeof(((SharedStoragePrologue*)0)->f))) + + printf("Total smem: %d bytes\n", int(sizeof(SharedStoragePrologue))); + printf(" persistent: %db\n", FSZ(persistent)); + printf(" mm_qk_k: %db\n", FSZ(persistent.mm_qk_k)); + printf(" part1: %db\n", FSZ(part1)); + printf(" bias: %db\n", FSZ(part1.bias)); + printf(" attn_shared_storage: %db\n", FSZ(part1.attn_shared_storage)); + printf(" zij: %db\n", FSZ(part1.zij)); + printf(" mm_gradV: %db\n", FSZ(part1.mm_gradV)); + printf(" gradV_epilogue: %db\n", FSZ(part1.gradV_epilogue)); + printf(" mm_doivj: %db\n", FSZ(part1.mm_doivj)); + printf(" part2: %db\n", FSZ(part2)); + printf(" tmpT_shared_storage: %db\n", FSZ(part2.tmpT_shared_storage)); + printf(" tmp_shared_storage: %db\n", FSZ(part2.tmp_shared_storage)); + printf(" mm_gradK: %db\n", FSZ(part2.mm_gradK)); + printf(" mm_gradQ: %db\n", FSZ(part2.mm_gradQ)); + printf(" gradB_epilogue: %db\n", FSZ(part2.gradB_epilogue)); + printf(" gradQ_epilogue: %db\n", FSZ(part2.gradQ_epilogue)); + printf(" part3: %db\n", FSZ(part3)); + printf(" tmpT_shared_storage: %db\n", FSZ(part3.tmpT_shared_storage)); + printf(" part4: %db\n", FSZ(part4)); + printf(" mm_qk_q: %db\n", FSZ(part4.mm_qk_q)); + printf( + " gradK_epilogue_final: %db\n", FSZ(part4.gradK_epilogue_final)); + printf( + " gradV_epilogue_final: %db\n", FSZ(part4.gradV_epilogue_final)); + } +// =========================================== +#define FIELD(INSIDE_STRUCT, FIELDNAME) \ + CUTLASS_DEVICE auto& FIELDNAME() { \ + return INSIDE_STRUCT.FIELDNAME; \ + } + + FIELD(persistent, di) + FIELD(persistent, mm_qk_k) + FIELD(part1, bias) + FIELD(part1, attn_shared_storage) + FIELD(part1, zij) + FIELD(part1, mm_gradV) + FIELD(part1, gradV_epilogue) + FIELD(part1, mm_doivj) + FIELD(part2, mm_gradK) + FIELD(part2, mm_gradQ) + FIELD(part2, gradB_epilogue) + FIELD(part2, gradQ_epilogue) + FIELD(part2, tmp_shared_storage) + FIELD(part3, tmpT_shared_storage) + FIELD(part3, gradQ_epilogue_lastIter) + FIELD(part3, gradK_epilogue) + FIELD(part4, mm_qk_q) + FIELD(part4, gradK_epilogue_final) + FIELD(part4, gradV_epilogue_final) + }; + + struct SharedStorageNoPrologue { + struct { + cutlass::Array di; // (do_i * o_i).sum(-1) + } persistent; + union { + struct { + // part1 - Q.K matmul + typename MatmulQK::Mma::SharedStorageA mm_qk_k; + typename MatmulQK::Mma::SharedStorageB mm_qk_q; + } part1; + + struct { + // part2 - compute gradV + union { + // 1. efficient load of bias tile Bij, which is then applied to Pij + typename MatmulQK::BiasLoader::SmemTile bias; + // 2. store Pij to shared memory. it is needed: + // - in this step, where it is used in dVj += (Pij.T * Zij) @ dOi + // - in next step where it is used in dSij = Pij * (dPij - Di) + typename MatmulQK::AccumulatorSharedStorage attn_shared_storage; + }; + // 3. store Zij. it is needed in this step, where it is used + // to compute Pij_dropped = Pij * Zij on the fly as fragments of Pij are + // loaded for the computation of dVj. + ZijSharedStorage zij; + + union { + typename MatmulGradV::Mma::SharedStorage mm_gradV; + typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue; + }; + } part2; + + struct { + // part3 - DO.V matmul + union { + // first compute dPij = (dOi @ Vj.T) * Zij + // and dSij = Pij * (dPij - Di) + struct { + // (from part2) - Pij for computing dSij = Pij * (dPij - Di) + typename MatmulQK::AccumulatorSharedStorage attn_shared_storage; + // matmul to compute dOiVj + typename MatmulDOIVJ::Mma::SharedStorage mm_doivj; + }; + // then store dB = dSij to global memory + typename MatmulDOIVJ::BiasGradEpilogue::SharedStorage gradB_epilogue; + }; + } part3; + + struct { + // part4 - compute gradQ + typename MatmulQK::AccumulatorSharedStorage + tmpT_shared_storage; // (from part2) + typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage; + union { + typename MatmulGradQ::Mma::SharedStorage mm_gradQ; + typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue; + typename MatmulGradQ::DefaultEpilogue::SharedStorage + gradQ_epilogue_lastIter; + }; + } part4; + + struct { + // part5 - compute gradK + typename MatmulQK::AccumulatorSharedStorage + tmpT_shared_storage; // (from part2) + typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage; + union { + typename MatmulGradK::Mma::SharedStorage mm_gradK; + typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue; + }; + } part5; + + struct { + // part6 - store RF accumulated into gmem + typename MatmulGradK::DefaultEpilogue::SharedStorage + gradK_epilogue_final; + typename MatmulGradV::DefaultEpilogue::SharedStorage + gradV_epilogue_final; + } part6; + }; + static void print_size() { +#define FIELD_SIZEOF(f) int((sizeof(((SharedStorageNoPrologue*)0)->f))) + printf("Total smem: %d bytes\n", int(sizeof(SharedStorageNoPrologue))); + printf(" persistent: %db\n", FIELD_SIZEOF(persistent)); + printf(" part1: %db\n", FIELD_SIZEOF(part1)); + printf(" part2: %db\n", FIELD_SIZEOF(part2)); + printf(" part3: %db\n", FIELD_SIZEOF(part3)); + printf(" part4: %db\n", FIELD_SIZEOF(part4)); + printf(" part5: %db\n", FIELD_SIZEOF(part5)); + printf(" part6: %db\n", FIELD_SIZEOF(part6)); + } +// =========================================== +#define FIELD(INSIDE_STRUCT, FIELDNAME) \ + CUTLASS_DEVICE auto& FIELDNAME() { \ + return INSIDE_STRUCT.FIELDNAME; \ + } + + FIELD(persistent, di) + FIELD(part1, mm_qk_k) + FIELD(part1, mm_qk_q) + FIELD(part2, bias) + FIELD(part2, attn_shared_storage) + FIELD(part2, zij) + FIELD(part2, mm_gradV) + FIELD(part2, gradV_epilogue) + FIELD(part3, mm_doivj) + FIELD(part3, gradB_epilogue) + FIELD(part4, tmpT_shared_storage) + FIELD(part4, tmp_shared_storage) + FIELD(part4, mm_gradQ) + FIELD(part4, gradQ_epilogue) + FIELD(part4, gradQ_epilogue_lastIter) + FIELD(part5, mm_gradK) + FIELD(part5, gradK_epilogue) + FIELD(part6, gradK_epilogue_final) + FIELD(part6, gradV_epilogue_final) + }; + + using SharedStorage = typename cutlass::platform::conditional< + kPreload, + SharedStoragePrologue, + SharedStorageNoPrologue>::type; + + struct OutputFragments { + typename MatmulGradV::Mma::FragmentC gradV; + typename MatmulGradK::Mma::FragmentC gradK; + + CUTLASS_DEVICE void clear() { + gradV.clear(); + gradK.clear(); + } + }; + + static bool __host__ check_supported(Params const& p) { + CHECK_ALIGNED_PTR(p.query_ptr, kMinimumAlignment); + CHECK_ALIGNED_PTR(p.key_ptr, kMinimumAlignment); + CHECK_ALIGNED_PTR(p.value_ptr, kMinimumAlignment); + CHECK_ALIGNED_PTR(p.output_ptr, kMinimumAlignment); + CHECK_ALIGNED_PTR(p.grad_output_ptr, kMinimumAlignment); + CHECK_ALIGNED_PTR(p.bias_ptr, kMinimumAlignment); + TORCH_CHECK( + p.num_heads <= 1 || p.lse_strideH % 8 == 0, + "LSE is not correctly aligned (strideH)"); + TORCH_CHECK( + p.num_batches <= 1 || p.lse_strideB % 8 == 0, + "LSE is not correctly aligned (strideB)"); + TORCH_CHECK( + p.num_heads <= 1 || p.q_strideH % kMinimumAlignment == 0, + "query is not correctly aligned (strideH)"); + TORCH_CHECK( + p.num_heads <= 1 || p.k_strideH % kMinimumAlignment == 0, + "key is not correctly aligned (strideH)"); + TORCH_CHECK( + p.num_heads <= 1 || p.v_strideH % kMinimumAlignment == 0, + "value is not correctly aligned (strideH)"); + TORCH_CHECK( + p.num_batches <= 1 || p.q_strideB % kMinimumAlignment == 0, + "query is not correctly aligned (strideB)"); + TORCH_CHECK( + p.num_batches <= 1 || p.k_strideB % kMinimumAlignment == 0, + "key is not correctly aligned (strideB)"); + TORCH_CHECK( + p.num_batches <= 1 || p.v_strideB % kMinimumAlignment == 0, + "value is not correctly aligned (strideB)"); + TORCH_CHECK( + p.q_strideM % kMinimumAlignment == 0, + "query is not correctly aligned (strideM)"); + TORCH_CHECK( + p.k_strideM % kMinimumAlignment == 0, + "key is not correctly aligned (strideM)"); + TORCH_CHECK( + p.v_strideM % kMinimumAlignment == 0, + "value is not correctly aligned (strideM)"); + if (p.bias_ptr) { + TORCH_CHECK( + p.num_batches <= 1 || p.bias_strideB % kMinimumAlignment == 0, + "attn_bias is not correctly aligned (strideB). ", + "attn_bias.stride(0) = ", p.bias_strideB, ", and should be a " + "multiple of ", kMinimumAlignment, "."); + TORCH_CHECK( + p.num_heads <= 1 || p.bias_strideH % kMinimumAlignment == 0, + "attn_bias is not correctly aligned (strideH) ." + "attn_bias.stride(1) = ", p.bias_strideH, ", and should be a " + "multiple of ", kMinimumAlignment, "."); + TORCH_CHECK( + p.num_queries <= 1 || p.bias_strideM % kMinimumAlignment == 0, + "attn_bias is not correctly aligned (strideM). " + "attn_bias.stride(2) = ", p.bias_strideM, ", and should be a ", + "multiple of ", kMinimumAlignment, "."); + } + if (p.grad_bias_ptr) { + TORCH_CHECK( + p.num_batches <= 1 || p.gB_strideB % kMinimumAlignment == 0, + "attn_bias.grad is not correctly aligned (strideB)"); + TORCH_CHECK( + p.num_heads <= 1 || p.gB_strideH % kMinimumAlignment == 0, + "attn_bias.grad is not correctly aligned (strideH)"); + TORCH_CHECK( + p.gB_strideM % kMinimumAlignment == 0, + "attn_bias.grad is not correctly aligned (strideM)"); + } + TORCH_CHECK( + !(p.cu_seqlens_q_ptr && p.bias_ptr), + "CuSeqlen + bias not implemented yet"); + TORCH_CHECK( + p.custom_mask_type < NumCustomMaskTypes, + "Invalid value for `custom_mask_type`"); + TORCH_CHECK( + p.dropout_prob <= 1.0f && p.dropout_prob >= 0.0f, + "Invalid value for `dropout_prob`"); + TORCH_CHECK( + kApplyDropout || p.dropout_prob == 0.0f, + "Set `kApplyDropout`=True to support `dropout_prob > 0`"); + TORCH_CHECK(p.head_dim > 0, "Invalid value for `head_dim`"); + TORCH_CHECK(p.head_dim_value > 0, "Invalid value for `head_dim_value`"); + TORCH_CHECK(p.num_queries > 0, "Invalid value for `num_queries`"); + TORCH_CHECK(p.num_keys > 0, "Invalid value for `num_keys`"); + TORCH_CHECK(p.num_heads > 0, "Invalid value for `num_heads`"); + TORCH_CHECK(p.num_batches > 0, "Invalid value for `num_batches`"); + TORCH_CHECK(p.head_dim <= kMaxK, "kMaxK: Expected `head_dim < kMaxK`"); + TORCH_CHECK( + p.head_dim_value <= kMaxK, "kMaxK: Expected `head_dim_value < kMaxK`"); + if (kKeysQueriesAlignedToBlockSize) { + TORCH_CHECK( + p.cu_seqlens_k_ptr == nullptr, + "This kernel does not support cu_seqlen"); + TORCH_CHECK( + p.cu_seqlens_q_ptr == nullptr, + "This kernel does not support cu_seqlen"); + TORCH_CHECK( + p.num_queries % kBlockSizeI == 0, + "kKeysQueriesAlignedToBlockSize condition not respected"); + TORCH_CHECK( + p.num_keys % kBlockSizeJ == 0, + "kKeysQueriesAlignedToBlockSize condition not respected"); + } + TORCH_CHECK( + kEnableSplitKeys || p.num_splits_key == 1, "SplitKeys is disabled"); + TORCH_CHECK( + p.num_splits_key > 0, "Invalid `num_splits_key` (expected >0)"); + TORCH_CHECK( + p.num_splits_key <= cutlass::ceil_div(p.num_keys, kBlockSizeJ), + "Invalid `num_splits_key` (", + p.num_splits_key, + ") - too large for `num_keys` = ", + p.num_keys); + if (p.window_size != 0) { + TORCH_CHECK( + p.custom_mask_type != NoCustomMask, + "LocalAttention only supported in causal mode"); + } + return true; + } + + static CUTLASS_DEVICE void attention_kernel(Params p) { + extern __shared__ char smem_buffer[]; + SharedStorage& shared_storage = *((SharedStorage*)smem_buffer); + + uint16_t thread_id = threadIdx.x; + uint8_t warp_id = warp_uniform(thread_id / 32); + uint8_t lane_id = thread_id % 32; + + int64_t key_start = p.split_key_device() * kBlockSizeJ; + if (key_start >= p.num_keys) { + return; + } + if (kPrologueQK) { + int64_t query_start = getQueryStart(p, key_start); + prologueQkNextIteration( + shared_storage, p, query_start, key_start, warp_id, lane_id); + } + + // Computes (dO*out).sum(-1) and writes it to `p.delta_ptr` + if (kKernelComputesDelta) { + constexpr int kOptimalElements = + 128 / cutlass::sizeof_bits::value; + if (p.head_dim_value % kOptimalElements == 0) { + for (int query_start = 0; query_start < p.num_queries; + query_start += kBlockSizeI) { + computeDelta(p, query_start, warp_id, lane_id); + } + } else { + for (int query_start = 0; query_start < p.num_queries; + query_start += kBlockSizeI) { + computeDelta<1>(p, query_start, warp_id, lane_id); + } + } + __syncthreads(); + } + + OutputFragments output_frags; + + curandStatePhilox4_32_10_t rng_state_init; + + if (kApplyDropout) { + // See Note [Seed and Offset Device] + auto const [seed, offset] = at::cuda::philox::unpack(p.rng_engine_inputs); + // each element of the attention matrix P with shape + // (batch_sz, n_heads, n_queries, n_keys) is associated with a single + // offset in RNG sequence. we initialize the RNG state with offset that + // starts at the beginning of a (n_queries, n_keys) matrix for this + // block's batch_id and head_id + // initializing rng state is very expensive, so we run once per kernel, + // rather than once per iteration. each iteration takes a copy of the + // initialized RNG state and offsets it as needed. + curand_init( + seed, + 0, + offset + p.dropout_batch_head_rng_offset, + &rng_state_init); + } + + CUTLASS_PRAGMA_UNROLL + for (; key_start < p.num_keys; + key_start += p.num_splits_key_device() * kBlockSizeJ) { + output_frags.clear(); + + int64_t next_key = key_start; + int64_t query_start = getQueryStart(p, key_start); + while (next_key == key_start && query_start < p.num_queries) { + // This line here + // vvvvvvvvvvvvvv + warp_id = warp_uniform(warp_id); + // ^^^^^^^^^^^^^^ + // ... makes everything use less RF and be 10% faster. Why? + // I don't know. My theory is that it forces `nvcc` to + // re-compute indices, offsets etc... and not keep them + // from the previous iteration, which prevents MASSIVE + // register spilling. + + processBlockIJ( + shared_storage, + output_frags, + p, + query_start, + key_start, + rng_state_init, + warp_id, + lane_id); + + int64_t next_query; + incrIteration(p, query_start, key_start, next_query, next_key); + query_start = next_query; + } + if (kOutputInRF) { + writeFragsToGmem( + shared_storage, output_frags, p, key_start, warp_id, lane_id); + } else if (getQueryStart(p, key_start) >= p.num_queries) { + zfillGradKV( + p, key_start, warp_id, lane_id); + } + __syncthreads(); + } + } + + template + static CUTLASS_DEVICE void zfillGradKV( + Params const& p, + int32_t key_start, + uint8_t warp_id, + uint8_t lane_id) { + constexpr int kThreadsPerKey = 8; + constexpr int kParallelKeys = kNumThreads / kThreadsPerKey; + static_assert(kBlockSizeJ % kParallelKeys == 0, ""); + // This function is not really optimized, but should rarely be used + // It's only used when some keys are "useless" and don't attend to + // any query, due to causal masking + + int thread_id = 32 * warp_id + lane_id; + int k_shift = lane_id % kThreadsPerKey; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kBlockSizeJ; j += kParallelKeys) { + int key = key_start + j + (thread_id / kThreadsPerKey); + if (!skipBoundsChecks && key >= p.num_keys) { + continue; + } + auto gv_ptr = p.grad_value_ptr + key * p.gV_strideM(); + auto gk_ptr = p.grad_key_ptr + key * p.gK_strideM(); + + for (int k = k_shift; k < p.head_dim_value; k += kThreadsPerKey) { + gv_ptr[k] = scalar_t(0); + } + for (int k = k_shift; k < p.head_dim; k += kThreadsPerKey) { + gk_ptr[k] = scalar_t(0); + } + } + } + + template + static CUTLASS_DEVICE void processBlockIJ( + SharedStorage& shared_storage, + OutputFragments& output_frags, + Params& p, + int64_t query_start, + int64_t key_start, + const curandStatePhilox4_32_10_t& curand_state_init, + uint8_t warp_id, + uint8_t lane_id) { + cutlass::Array + dropout_keep_mask_doivj; + dropout_keep_mask_doivj.fill(cutlass::uint1b_t{1}); + const float dropout_scale = + kApplyDropout ? 1.0 / (1.0 - p.dropout_prob) : 1.0f; + + cutlass::MatrixCoord no_offset{0, 0}; + accum_t scale = p.scale; + int16_t thread_id = 32 * warp_id + lane_id; + + auto rematerializeThreadIds = [&]() { + // Prevents `nvcc` from keeping values deduced from + // `thread_id`, `warp_id`, ... in RF - to reduce register pressure + warp_id = warp_uniform(thread_id / 32); + lane_id = thread_id % 32; + thread_id = 32 * warp_id + lane_id; + }; + + bool isFirstQuery = (query_start == getQueryStart(p, key_start)); + int64_t next_query, next_key; + incrIteration(p, query_start, key_start, next_query, next_key); + bool isLastQuery = next_key != key_start; + + accum_t di_rf = accum_t(0); + if (thread_id < kBlockSizeI) { + if (query_start + thread_id < p.num_queries) { + di_rf = p.delta_ptr[query_start + thread_id]; + } + shared_storage.di()[thread_id] = di_rf; + } + + int32_t num_queries_in_block = skipBoundsChecks + ? MatmulQK::Mma::Shape::kN + : warp_uniform(cutlass::fast_min( + MatmulQK::Mma::Shape::kN, (int32_t)(p.num_queries - query_start))); + int32_t num_keys_in_block = skipBoundsChecks + ? MatmulQK::Mma::Shape::kM + : warp_uniform(cutlass::fast_min( + MatmulQK::Mma::Shape::kM, (int32_t)(p.num_keys - key_start))); + + auto prologueGradV = [&](int64_t col) { + typename MatmulGradV::Mma::IteratorB iterator_dO( + {int32_t(p.gO_strideM)}, + const_cast(p.grad_output_ptr + query_start * p.gO_strideM + col), + {num_queries_in_block, (int32_t)(p.head_dim_value - col)}, + thread_id, + no_offset); + MatmulGradV::Mma::prologue( + shared_storage.mm_gradV(), + iterator_dO, + thread_id, + num_queries_in_block); + }; + auto prologueGradQ = [&](int col) { + typename MatmulGradQ::Mma::IteratorB iterator_K( + {int32_t(p.k_strideM)}, + const_cast(p.key_ptr + key_start * p.k_strideM + col), + {num_keys_in_block, p.head_dim - col}, + thread_id, + no_offset); + MatmulGradQ::Mma::prologue( + shared_storage.mm_gradQ(), iterator_K, thread_id, num_keys_in_block); + }; + auto prologueGradK = [&](int col) { + typename MatmulGradK::Mma::IteratorB iterator_Q( + {int32_t(p.q_strideM)}, + const_cast(p.query_ptr + query_start * p.q_strideM + col), + {num_queries_in_block, p.head_dim - col}, + thread_id, + no_offset); + MatmulGradK::Mma::prologue( + shared_storage.mm_gradK(), + iterator_Q, + thread_id, + num_queries_in_block); + }; + auto prologueDOV = [&]() { + typename MatmulDOIVJ::Mma::IteratorA iterator_A( + {int32_t(p.gO_strideM)}, + const_cast(p.grad_output_ptr + query_start * p.gO_strideM), + {num_queries_in_block, p.head_dim_value}, + thread_id, + no_offset); + typename MatmulDOIVJ::Mma::IteratorB iterator_B( + {int32_t(p.v_strideM)}, + const_cast(p.value_ptr + key_start * p.v_strideM), + {p.head_dim_value, num_keys_in_block}, + thread_id, + no_offset); + MatmulDOIVJ::Mma::prologue( + shared_storage.mm_doivj(), + iterator_A, + iterator_B, + thread_id, + p.head_dim_value); + }; + + ///////////////////////////////////////////////////////////////////////////////////////////////// + // MatmulQK + ///////////////////////////////////////////////////////////////////////////////////////////////// + { + using Mma = typename MatmulQK::Mma; + + cutlass::gemm::GemmCoord problem_size( + num_keys_in_block, + num_queries_in_block, + p.head_dim // k + ); + + // k_j + typename Mma::IteratorA iterator_A( + {int32_t(p.k_strideM)}, + const_cast(p.key_ptr + key_start * p.k_strideM), + {problem_size.m(), problem_size.k()}, + thread_id, + no_offset); + + // q_i.transpose(-2, -1) + typename Mma::IteratorB iterator_B( + {int32_t(p.q_strideM)}, + const_cast(p.query_ptr + query_start * p.q_strideM), + {problem_size.k(), problem_size.n()}, + thread_id, + no_offset); + + Mma mma( + shared_storage.mm_qk_k(), + shared_storage.mm_qk_q(), + thread_id, + warp_id, + lane_id); + + typename Mma::FragmentC accum; + + accum.clear(); + + auto gemm_k_iterations = + (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma.set_prologue_done(kPrologueQK); + mma.set_zero_outside_bounds(!skipBoundsChecks); + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + accum = cutlass::multiplies()(scale, accum); + + // Epilogue: add LSE + exp and store that to our shared memory buffer + // shmem <- (matmul_result - + // logsumexp[i_start:i_end].unsqueeze(1)).exp() + int warp_idx_mn_0 = + warp_id % (Mma::Base::WarpCount::kM * Mma::Base::WarpCount::kN); + auto output_tile_coords = cutlass::MatrixCoord{ + warp_idx_mn_0 % Mma::Base::WarpCount::kM, + warp_idx_mn_0 / Mma::Base::WarpCount::kM}; + + // apply bias if applicable + if (p.bias_ptr != nullptr) { + // load bias tile Bij into shared memory + typename MatmulQK::BiasLoader::GmemTileIterator bias_iter( + {cutlass::layout::RowMajor(p.bias_strideM)}, + const_cast(p.bias_ptr + query_start * p.bias_strideM + key_start), + {num_queries_in_block, num_keys_in_block}, + thread_id); + cutlass::TensorRef bias_tensor_ref( + shared_storage.bias().data(), + cutlass::layout::RowMajor(MatmulQK::ThreadblockShape::kM)); + typename MatmulQK::BiasLoader::SmemTileIterator smem_tile_iter( + bias_tensor_ref, thread_id); + MatmulQK::BiasLoader::load(bias_iter, smem_tile_iter); + + // Pij += Bij, where Pij is in register fragment and Bij is in shmem + auto lane_offset = MatmulQK::AccumLambdaIterator::get_lane_offset( + lane_id, warp_id, output_tile_coords); + MatmulQK::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_n) {}, + [&](int accum_m, int accum_n, int idx) { + // remember we are transposed + accum[idx] += bias_tensor_ref.at({accum_n, accum_m}); + }, + [&](int accum_n) {}); + } + + // Apply mask + if (p.custom_mask_type == CausalFromTopLeft || + p.custom_mask_type == CausalFromBottomRight) { + auto lane_offset = MatmulQK::AccumLambdaIterator::get_lane_offset( + lane_id, warp_id, output_tile_coords); + int shift = query_start - key_start; + if (p.custom_mask_type == CausalFromBottomRight) { + shift += p.num_keys - p.num_queries; + } + // current_key = key_start + accum_m + // current_query = query_start + accum_n + // mask if: `current_key > current_query` + MatmulQK::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + if (accum_m > accum_n + shift) { + accum[idx] = + -cutlass::platform::numeric_limits::infinity(); + } + }, + [&](int accum_m) {}); + } + if (p.window_size > 0) { + auto lane_offset = MatmulQK::AccumLambdaIterator::get_lane_offset( + lane_id, warp_id, output_tile_coords); + int shift = query_start - key_start - p.window_size; + // current_key = key_start + accum_m + // current_query = query_start + accum_n + // mask if: `current_key < current_query - window_size` + // if accum_m < accum_n + query_start - window_size - key_start + + MatmulQK::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + if (accum_m <= accum_n + shift) { + accum[idx] = + -cutlass::platform::numeric_limits::infinity(); + } + }, + [&](int accum_m) {}); + } + __syncthreads(); + if (kPrologueGV) { + prologueGradV(0); + } + if (kPrologueDOV) { + prologueDOV(); + } + + MatmulQK::B2bGemm::accumApplyLSEToSmem( + shared_storage.attn_shared_storage(), + accum, + p.logsumexp_ptr + query_start, + problem_size.n(), + thread_id, + warp_id, + lane_id, + output_tile_coords); +#if 0 + auto accum_ref_attnT = shared_storage.attn_shared_storage().accum_ref(); + PRINT_TENSOR4x4_T0_L0("attn_T", accum_ref_attnT); +#endif + + // if we are using dropout, compute Zij, writing it to shared memory. + // each element of Zij is: + // - 0 with probability dropout_p + // - 1 / (1 - dropout_p) with probability 1 - dropout_p + if (kApplyDropout) { + auto zij = shared_storage.zij().accum_ref(); + // each thread generates a contiguous sequence of elements in Zij, all + // in the same row. the reason they have to come from the same row is + // that sampling random numbers from a contiguous random number sequence + // is much more efficient than jumping around, and the linear offset of + // each element of Z (the global matrix) maps to an offset in a random + // number sequence. for Z, the end of a row and the beginning of the + // next have adjacent offsets, but for Zij (tile of global matrix), this + // is not necessarily the case. + // We must fill the entire `zij` shmem with values (even out of bounds + // on the K-dimension) otherwise we can get NaNs during the GEMM + const int kQueriesPerBlock = kBlockSizeI; + const int threads_per_row = cutlass::fast_min( + kNumThreads / kQueriesPerBlock, (int64_t)num_keys_in_block); + const int elts_per_thread = cutlass::round_nearest( + cutlass::ceil_div(num_keys_in_block, threads_per_row), 4); + + const int thread_i = thread_id / threads_per_row; + const int thread_start_j = + (thread_id % threads_per_row) * elts_per_thread; + + if (thread_i < kQueriesPerBlock && thread_start_j < num_keys_in_block) { + curandStatePhilox4_32_10_t curand_state = curand_state_init; + skipahead( + (query_start + thread_i) * p.num_keys + + (key_start + thread_start_j), + &curand_state); + + // generate elements of Zij, 4 elements at a time + for (int zij_start_col_idx = thread_start_j; zij_start_col_idx < + cutlass::fast_min(thread_start_j + elts_per_thread, + num_keys_in_block); + zij_start_col_idx += 4) { + const float4 rand_uniform_quad = curand_uniform4(&curand_state); + + CUTLASS_PRAGMA_UNROLL + for (int quad_idx = 0; quad_idx < 4; ++quad_idx) { + // we'll write Zij transposed since attention is also transposed + // during the matmul to compute dV. + zij.at({zij_start_col_idx + quad_idx /*k*/, thread_i /*q*/}) = + (&rand_uniform_quad.x)[quad_idx] > p.dropout_prob + ? scalar_t(dropout_scale) + : scalar_t(0); + } + } + } + __syncthreads(); +#if 0 + PRINT_TENSOR4x4_T0_L0("zij", zij); + PRINT_TENSOR4x4_T0_L0_START("zij", zij, kBlockSizeJ - 4, kBlockSizeI - 4); +#endif + + // Save mask for later DOIVJ matmul + + int warp_idx_mn_0 = warp_id % + (MatmulDOIVJ::Mma::Base::WarpCount::kM * + MatmulDOIVJ::Mma::Base::WarpCount::kN); + auto output_tile_coords_doivj = cutlass::MatrixCoord{ + warp_idx_mn_0 % MatmulDOIVJ::Mma::Base::WarpCount::kM, + warp_idx_mn_0 / MatmulDOIVJ::Mma::Base::WarpCount::kM}; + auto lane_offset = MatmulDOIVJ::AccumLambdaIterator::get_lane_offset( + lane_id, warp_id, output_tile_coords_doivj); + MatmulDOIVJ::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m /*q*/, int accum_n /*k*/, int idx) { + if (zij.at({accum_n, accum_m}) == scalar_t(0)) { + dropout_keep_mask_doivj[idx] = cutlass::uint1b_t{0}; + } + }, + [&](int accum_m) {}); + } + __syncthreads(); + } + rematerializeThreadIds(); + + ///////////////////////////////////////////////////////////////////////////////////////////////// + // GradV matmul + // + // grad_v[j_start:j_end] += attn_T @ do_i + ///////////////////////////////////////////////////////////////////////////////////////////////// + constexpr bool kSingleIterationGradV = + kMaxK <= MatmulGradV::ThreadblockShape::kN; + for (int32_t col = 0; col < (kSingleIterationGradV ? 1 : p.head_dim_value); + col += MatmulGradV::ThreadblockShape::kN) { + using Mma = typename MatmulGradV::Mma; + using AccumTileGmem = typename MatmulGradQ::AccumTileGmem; + + cutlass::gemm::GemmCoord problem_size( + num_keys_in_block, p.head_dim_value - col, num_queries_in_block); + auto createEpilogueIter = [&]() { + return typename MatmulGradV::OutputTileIterator( + typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()}, + p.grad_value_ptr + key_start * p.gV_strideM() + col, + {num_keys_in_block, p.head_dim_value - col}, + thread_id); + }; + typename Mma::IteratorB iterator_B( + {int32_t(p.gO_strideM)}, + const_cast(p.grad_output_ptr + query_start * p.gO_strideM + col), + {num_queries_in_block, p.head_dim_value - col}, + thread_id, + no_offset); + + // if dropout: dVj += (Pij.T * Zij) @ dOi + // otherwise: dVj += Pij.T @ dOi + Mma mma( + // operand A: Pij.T + shared_storage.attn_shared_storage().accum_ref(), + // operand A_scale Zij.T: + // if we're using dropout, operand A is Pij_dropped.T = Pij.T * Zij.T + // which is computed on the fly as fragments of Pij.T are loaded in + shared_storage.zij().accum_ref(), + // operand B: dOi - which was loaded into shared memory previously + // when we computed dVj + shared_storage.mm_gradV().operand_B_ref(), + thread_id, + warp_id, + lane_id); + + int storage_id = col / MatmulGradV::ThreadblockShape::kN; + AccumTileGmem gmem_tile{ + p.workspace_gv + storage_id * AccumTileGmem::kElementsStored}; + if (!kOutputInRF) { + if (isFirstQuery || !kNeedsAccumGradV) { + output_frags.gradV.clear(); + } else { + gmem_tile.load(output_frags.gradV, thread_id); + } + } + mma.set_prologue_done(kPrologueGV); + + auto gemm_k_iterations = + (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + __syncthreads(); + + mma(gemm_k_iterations, + output_frags.gradV, + iterator_B, + output_frags.gradV); + __syncthreads(); + if (kPrologueGV && !kSingleIterationGradV && + col + MatmulGradV::ThreadblockShape::kN < p.head_dim_value) { + prologueGradV(col + MatmulGradV::ThreadblockShape::kN); + } + + if (!kOutputInRF) { + if (kNeedsAccumGradV && !isLastQuery) { + gmem_tile.store(output_frags.gradV, thread_id); + } else { + accumulateInGmem( + shared_storage.gradV_epilogue(), + output_frags.gradV, + createEpilogueIter(), + isFirstQuery || kNeedsAccumGradV, + warp_id, + lane_id); + } + } + } + __syncthreads(); + + ///////////////////////////////////////////////////////////////////////////////////////////////// + // MatmulDOIVJ + ///////////////////////////////////////////////////////////////////////////////////////////////// + { + using Mma = typename MatmulDOIVJ::Mma; + // do_i + typename Mma::IteratorA iterator_A( + {int32_t(p.gO_strideM)}, + const_cast(p.grad_output_ptr + query_start * p.gO_strideM), + {num_queries_in_block, p.head_dim_value}, + thread_id, + no_offset); + + // v_j.transpose(-2, -1) + typename Mma::IteratorB iterator_B( + {int32_t(p.v_strideM)}, + const_cast(p.value_ptr + key_start * p.v_strideM), + {p.head_dim_value, num_keys_in_block}, + thread_id, + no_offset); + + Mma mma(shared_storage.mm_doivj(), thread_id, warp_id, lane_id); + mma.set_prologue_done(kPrologueDOV); + mma.set_zero_outside_bounds(!skipBoundsChecks); + + typename Mma::FragmentC accum; + + accum.clear(); + + auto gemm_k_iterations = + (p.head_dim_value + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + __syncthreads(); + if (kPrologueGQ) { + prologueGradQ(0); + } + if (kPrologueGK) { + prologueGradK(0); + } + + int warp_idx_mn_0 = + warp_id % (Mma::Base::WarpCount::kM * Mma::Base::WarpCount::kN); + auto output_tile_coords = cutlass::MatrixCoord{ + warp_idx_mn_0 % Mma::Base::WarpCount::kM, + warp_idx_mn_0 / Mma::Base::WarpCount::kM}; + // TODO: This must be terribly inefficient. There must be a better way + // tmp [RF] <- (accum [RF] - Di [smem] ) * attn_T.T [smem] + // attn_shared_storage [smem] <- tmp.T + // tmp_shared_storage [smem] <- tmp + { + using LambdaIterator = typename MatmulDOIVJ::AccumLambdaIterator; + auto lane_offset = LambdaIterator::get_lane_offset( + lane_id, warp_id, output_tile_coords); + // if dropout was used, compute dPij = dPij_dropped * Zij + if (kApplyDropout) { + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + if (dropout_keep_mask_doivj[idx].get()) { + accum[idx] *= dropout_scale; + } else { + accum[idx] = 0; + } + }, + [&](int accum_m) {}); + } + + auto attn_T = shared_storage.attn_shared_storage().accum_ref(); +#if 0 + PRINT_B0_T0("doivj_dropped"); + print_warp_accum(accum, lane_offset, 4, 4); + PRINT_TENSOR4x4_T0_L0("attn_T", attn_T) +#endif + accum_t current_di; + // dSij = (dPij - Di) * Pij + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { current_di = shared_storage.di()[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { + // TODO: Otherwise we can get nans as we + // might have infs here (only seen on f16 tho) + if (skipBoundsChecks || + (accum_m < num_queries_in_block && + accum_n < num_keys_in_block)) { + accum_t attn = attn_T.at({accum_n, accum_m}); + accum[idx] = (accum[idx] - current_di) * attn; + } else { + accum[idx] = 0; + } + }, + [&](int accum_m) { + + }); + + // store bias gradient tile dBij to global memory, + // where dBij = dSij = Pij * (dPij - Di) + if (p.grad_bias_ptr != nullptr) { + typename MatmulDOIVJ::BiasGradEpilogue::OutputTileIterator + output_iter( + typename MatmulDOIVJ::BiasGradEpilogue::OutputTileIterator:: + Params{p.gB_strideM}, + // grad_bias_ptr is offset to point at beginning of + // matrix of shape (queries, keys) for a given + // (batch_id, head_id) the pointer arithmetic here produces + // a pointer to the start of the current tile within that + // matrix + p.grad_bias_ptr + query_start * p.gB_strideM + key_start, + {num_queries_in_block, num_keys_in_block}, + thread_id); + + // no-op epilogue operator - just casting and storing contents of + // accum to global memory + typename MatmulDOIVJ::BiasGradEpilogue::OutputOp output_op({1, 1}); + typename MatmulDOIVJ::BiasGradEpilogue epilogue( + shared_storage.gradB_epilogue(), thread_id, warp_id, lane_id); + epilogue(output_op, output_iter, accum, output_iter); + } + + accum = accum * scale; + +#if 0 + PRINT_B0_T0("(doivj - di) * attn * scale"); + print_warp_accum(accum, lane_offset, 4, 4); +#endif + + __syncthreads(); + if (!MatmulGradK::DefaultMmaFromSmem::kIsTransposedA) { + auto tmpT = shared_storage.tmpT_shared_storage().accum_ref(); + // attn <- attn_T.T + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + tmpT.at({accum_n, accum_m}) = scalar_t(accum[idx]); + }, + [&](int accum_m) {}); + } + } + + MatmulDOIVJ::B2bGemm::accumToSmem( + shared_storage.tmp_shared_storage(), + accum, + lane_id, + output_tile_coords); + __syncthreads(); + } + // Force `nvcc` to recompute values that depend on the variables just below + // to use less RF and prevent some spilling + p.head_dim = warp_uniform(p.head_dim); + p.k_strideM = warp_uniform(p.k_strideM); + rematerializeThreadIds(); + + ///////////////////////////////////////////////////////////////////////////////////////////////// + // GradQ matmul + // + // grad_q[i_start:i_end] += tmp @ k_j + ///////////////////////////////////////////////////////////////////////////////////////////////// + // Skip the loop & associated branches if we know at compile time the number + // of iterations + constexpr bool kSingleIterationGradQ = + kMaxK <= MatmulGradQ::ThreadblockShape::kN; + for (int col = 0; col < (kSingleIterationGradQ ? 1 : p.head_dim); + col += MatmulGradQ::ThreadblockShape::kN) { + using Mma = typename MatmulGradQ::Mma; + using AccumTileGmem = typename MatmulGradQ::AccumTileGmem; + + cutlass::gemm::GemmCoord problem_size( + num_queries_in_block, + false ? MatmulGradQ::ThreadblockShape::kN : p.head_dim - col, + num_keys_in_block); + + // k_j + typename Mma::IteratorB iterator_B( + {int32_t(p.k_strideM)}, + const_cast(p.key_ptr + key_start * p.k_strideM + col), + {problem_size.k(), problem_size.n()}, + thread_id, + no_offset); + + auto a = shared_storage.tmp_shared_storage().accum_ref(); + Mma mma( + // operand A: dSij + shared_storage.tmp_shared_storage().accum_ref(), + // operand B: Kj + shared_storage.mm_gradQ().operand_B_ref(), + thread_id, + warp_id, + lane_id); + + typename Mma::FragmentC accum; + + int col_id = col / MatmulGradQ::ThreadblockShape::kN; + int num_cols = kSingleIterationGradQ + ? 1 + : ceil_div(p.head_dim, MatmulGradQ::ThreadblockShape::kN); + int storage_id = (col_id + query_start / kBlockSizeI * num_cols); + + if (p.num_splits_key_device() > 1) { + AtomicLock::acquire( + &p.workspace_gq[storage_id].lock, + p.split_key_device() + 1, + thread_id); + // Make sure we can see other block's output + __threadfence(); + } + + AccumTileGmem gmem_tile{&p.workspace_gq[storage_id].buffer[0]}; + if (!kNeedsAccumGradQ || + (p.num_splits_key_device() == 1 && key_start == 0)) { + // if we know we are the first to access it, we know it's only zeros. + // Avoids a load from gmem (and gmem init as well) + accum.clear(); + } else { + gmem_tile.load(accum, thread_id); + } + + auto gemm_k_iterations = + (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + __syncthreads(); + mma.set_prologue_done(kPrologueGQ); + mma(gemm_k_iterations, accum, iterator_B, accum); + __syncthreads(); + bool isLastColumn = kSingleIterationGradQ || + (col + MatmulGradQ::ThreadblockShape::kN >= p.head_dim); + if (kPrologueGQ && !isLastColumn) { + prologueGradQ(col + MatmulGradQ::ThreadblockShape::kN); + } + + bool isLast = [&]() { + int32_t next_key = key_start + p.num_splits_key_device() * kBlockSizeJ; + if (p.num_keys <= next_key) { + return true; + } + if (query_start < getSmallestQueryForKey(p, next_key)) { + return true; + } + return false; + }(); + // Output results + if (p.num_splits_key_device() > 1) { + int32_t numAddsSoFar = -1; + if (isLast && thread_id == 0) { + numAddsSoFar = atomicAdd(&p.workspace_gq[storage_id].counter, 1) + + 1; // `atomicAdd` returns the old value + } + isLast = __syncthreads_or( + numAddsSoFar == getNumParallelBlocksForQuery(p, query_start)); + assert(numAddsSoFar <= getNumParallelBlocksForQuery(p, query_start)); + } + if (kNeedsAccumGradQ && !isLast) { + gmem_tile.store(accum, thread_id); + if (p.num_splits_key_device() > 1) { + // Make sure everyone wrote before we release the lock + __threadfence(); + __syncthreads(); + AtomicLock::release(&p.workspace_gq[storage_id].lock, thread_id); + } + } else { + // NOTE: We're not releasing the lock because no one is expected + // to come after us (we're the last one to write) + typename MatmulGradQ::OutputTileIterator output_it( + typename MatmulGradQ::OutputTileIterator::Params{p.gQ_strideM()}, + p.grad_query_ptr + query_start * p.gQ_strideM() + col, + {problem_size.m(), problem_size.n()}, + thread_id); + // if `direct_store` is True, we store to gmem (`*gmem = accum`) + // otherwise, we accumulate in gmem (`*gmem = *gmem + accum`) + // If we know ahead of time when we will write for the first time + // we can: + // (1) Avoid an additional memory read + // (2) Avoid the cost of initializing memory to 0 + bool direct_store = kNeedsAccumGradQ || key_start == 0 || + (p.num_splits_key_device() > 1); + accumulateInGmem( + isLastColumn ? shared_storage.gradQ_epilogue_lastIter() + : shared_storage.gradQ_epilogue(), + accum, + output_it, + direct_store, + warp_id, + lane_id); + } + } + ///////////////////////////////////////////////////////////////////////////////////////////////// + // GradK matmul + // + // grad_k[i_start:i_end] += tmp.transpose(-2, -1) @ q_i + ///////////////////////////////////////////////////////////////////////////////////////////////// + rematerializeThreadIds(); + + constexpr bool kSingleIterationGradK = + kMaxK <= MatmulGradK::ThreadblockShape::kN; + for (int col = 0; col < (kSingleIterationGradK ? 1 : p.head_dim); + col += MatmulGradK::ThreadblockShape::kN) { + using Mma = typename MatmulGradK::Mma; + using AccumTileGmem = typename MatmulGradQ::AccumTileGmem; + + cutlass::gemm::GemmCoord problem_size( + num_keys_in_block, + false ? MatmulGradK::ThreadblockShape::kN : p.head_dim - col, + num_queries_in_block); + auto createEpilogueIter = [&]() { + return typename MatmulGradK::OutputTileIterator( + typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()}, + p.grad_key_ptr + key_start * p.gK_strideM() + col, + {num_keys_in_block, + false ? MatmulGradK::ThreadblockShape::kN : p.head_dim - col}, + thread_id); + }; + + // q_i + typename Mma::IteratorB iterator_B( + {int32_t(p.q_strideM)}, + const_cast(p.query_ptr + query_start * p.q_strideM + col), + {problem_size.k(), problem_size.n()}, + thread_id, + no_offset); + + auto getTmp = [&](int) { return &shared_storage.tmp_shared_storage(); }; + auto getTmpT = [&](int) { return &shared_storage.tmpT_shared_storage(); }; + // this is basically: + // opA = kIsTransposedA ? getTmp() : getTmpT(); + bool constexpr kIsTransposedA = + MatmulGradK::DefaultMmaFromSmem::kIsTransposedA; + auto& opA = *call_conditional< + kIsTransposedA, + decltype(getTmp), + decltype(getTmpT)>::apply(getTmp, getTmpT, 0); + Mma mma( + // operand A: dSij.T + opA.accum_ref(), + // operand B: Qi + shared_storage.mm_gradK().operand_B_ref(), + thread_id, + warp_id, + lane_id); + + int storage_id = col / MatmulGradK::ThreadblockShape::kN; + AccumTileGmem gmem_tile{ + p.workspace + storage_id * AccumTileGmem::kElementsStored}; + if (!kOutputInRF) { + if (isFirstQuery || !kNeedsAccumGradK) { + output_frags.gradK.clear(); + } else { + gmem_tile.load(output_frags.gradK, thread_id); + } + } + mma.set_prologue_done(kPrologueGK); + + auto gemm_k_iterations = + (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + __syncthreads(); + + mma(gemm_k_iterations, + output_frags.gradK, + iterator_B, + output_frags.gradK); + __syncthreads(); + bool isLastColumn = kSingleIterationGradK || + col + MatmulGradK::ThreadblockShape::kN >= p.head_dim; + if (kPrologueGK && !isLastColumn) { + prologueGradK(col + MatmulGradK::ThreadblockShape::kN); + } + + if (kPrologueQK && isLastColumn) { + int64_t next_query, next_key; + incrIteration(p, query_start, key_start, next_query, next_key); + DISPATCH_BOOL( + next_key != key_start, kForceReloadK, ([&]() { + prologueQkNextIteration( + shared_storage, p, next_query, next_key, warp_id, lane_id); + })); + } + + // Output results + if (!kOutputInRF) { + if (kNeedsAccumGradK && !isLastQuery) { + gmem_tile.store(output_frags.gradK, thread_id); + } else { + accumulateInGmem( + isLastColumn ? shared_storage.gradK_epilogue_final() + : shared_storage.gradK_epilogue(), + output_frags.gradK, + createEpilogueIter(), + isFirstQuery || kNeedsAccumGradK, + warp_id, + lane_id); + __syncthreads(); + } + } + } + } + + static CUTLASS_HOST_DEVICE int64_t getQueryStartShift(Params const& p) { + if (p.custom_mask_type == NoCustomMask && p.num_splits_key_device() > 1) { + return (p.split_key_device() * kBlockSizeI) % getQueryEnd(p); + } + return 0; + } + + // Iteration order logic + static CUTLASS_HOST_DEVICE int64_t + getQueryStart(Params const& p, int64_t key_start) { + return getSmallestQueryForKey(p, key_start) + getQueryStartShift(p); + }; + static CUTLASS_HOST_DEVICE int64_t getQueryEnd(Params const& p) { + return align_up(p.num_queries, kBlockSizeI); + }; + + static CUTLASS_HOST_DEVICE int64_t + getSmallestQueryForKey(Params const& p, int64_t key_start) { + if (p.custom_mask_type == NoCustomMask) { + return 0; + } + int64_t shift = p.custom_mask_type == CausalFromBottomRight + ? p.num_keys - p.num_queries + : 0; + int64_t window_size = + p.window_size == 0 ? p.num_queries + p.num_keys : p.window_size; + + auto last_key_for_block = + cutlass::fast_min(key_start + kBlockSizeJ, (int64_t)p.num_keys) - 1; + int first_query = key_start - shift; + int last_query = last_key_for_block - shift + window_size - 1; + if (last_query < 0 || first_query >= p.num_queries) { + return getQueryEnd(p); // nothing to compute in this column + } + first_query = cutlass::fast_max(0, first_query); + return (first_query / kBlockSizeI) * kBlockSizeI; + }; + + // Returns how many kernel blocks will write to a given block in `grad_query` + // This is usually equal to the number of key splits, but can be different + // for instance in the causal case, or varying seqlen + static CUTLASS_HOST_DEVICE int32_t + getNumParallelBlocksForQuery(Params const& p, int32_t query_start) { + int16_t num_key_blocks = ceil_div(p.num_keys, kBlockSizeJ); + if (p.custom_mask_type != NoCustomMask) { + int32_t shift = p.custom_mask_type == CausalFromBottomRight + ? p.num_keys - p.num_queries + : 0; + int32_t last_query_for_block = + cutlass::fast_min(query_start + kBlockSizeI, p.num_queries) - 1; + int32_t last_key_for_block = + cutlass::fast_min(last_query_for_block + shift, p.num_keys - 1); + int32_t first_key_for_block = p.window_size == 0 + ? 0 + : cutlass::fast_max(query_start - p.window_size + 1 + shift, 0); + + if (p.window_size == 0) { + num_key_blocks = last_key_for_block / kBlockSizeJ + 1; + } else { + num_key_blocks = (last_key_for_block / kBlockSizeJ) - + (first_key_for_block / kBlockSizeJ) + 1; + } + + if (last_key_for_block < 0 || first_key_for_block >= p.num_keys) { + num_key_blocks = 0; + } + } + return cutlass::fast_min(p.num_splits_key_device(), num_key_blocks); + }; + + // Returns the next block to process + static CUTLASS_HOST_DEVICE void incrIteration( + Params const& p, + int64_t query_start, + int64_t key_start, + int64_t& next_query, + int64_t& next_key) { + next_query = query_start + kBlockSizeI; + next_key = key_start; + auto query_shift = getQueryStartShift(p); + // Wrap around + if (query_shift) { + if (next_query >= p.num_queries) { + next_query = getSmallestQueryForKey(p, key_start); + return; + } else if (query_start < query_shift && query_shift <= next_query) { + // jump to next key + } else { + return; + } + } else { + if (p.window_size > 0) { + int32_t shift = p.custom_mask_type == CausalFromBottomRight + ? p.num_keys - p.num_queries + : 0; + // last key that is not masked out + int last_key_for_block = + cutlass::fast_min(key_start + kBlockSizeJ, (int64_t)p.num_keys) - 1; + int last_query = last_key_for_block - shift + p.window_size - 1; + if (next_query <= last_query && next_query < p.num_queries) { + return; + } + } else if (next_query < p.num_queries) { + return; + } + // jump to next key + } + // Next key + next_key = key_start + p.num_splits_key_device() * (int64_t)kBlockSizeJ; + next_query = getQueryStart(p, next_key); + } + + template + static CUTLASS_DEVICE void prologueQkNextIteration( + SharedStorage& shared_storage, + Params const& p, + int32_t query_start, + int32_t key_start, + uint8_t warp_id, + uint8_t lane_id) { + if (query_start >= p.num_queries || key_start >= p.num_keys) { + return; + } + + static constexpr bool kReloadK = + kForceReloadK || !MatmulQK::Mma::kSmemContainsEntireMat; + int thread_id = 32 * warp_id + lane_id; + typename MatmulQK::Mma::IteratorA iterator_A( + {int32_t(p.k_strideM)}, + const_cast(p.key_ptr + key_start * p.k_strideM), + {p.num_keys - key_start, p.head_dim}, + thread_id, + cutlass::MatrixCoord{0, 0}); + + typename MatmulQK::Mma::IteratorB iterator_B( + {int32_t(p.q_strideM)}, + const_cast(p.query_ptr + query_start * p.q_strideM), + {p.head_dim, p.num_queries - query_start}, + thread_id, + cutlass::MatrixCoord{0, 0}); + + MatmulQK::Mma::prologue( + shared_storage.mm_qk_k(), + shared_storage.mm_qk_q(), + iterator_A, + iterator_B, + thread_id, + p.head_dim); + } + + template + static CUTLASS_DEVICE void writeFragsToGmem( + SharedStorage& shared_storage, + OutputFragments& output_frags, + Params const& p, + int32_t key_start, + uint8_t warp_id, + uint8_t lane_id) { + uint16_t thread_id = 32 * warp_id + lane_id; + int32_t num_keys_in_block = skipBoundsChecks + ? MatmulQK::Mma::Shape::kM + : cutlass::fast_min( + MatmulQK::Mma::Shape::kM, p.num_keys - key_start); + typename MatmulGradV::OutputTileIterator outputV_it( + typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()}, + p.grad_value_ptr + key_start * p.gV_strideM(), + {num_keys_in_block, p.head_dim_value}, + thread_id); + accumulateInGmem( + shared_storage.gradV_epilogue_final(), + output_frags.gradV, + outputV_it, + true, + warp_id, + lane_id); + + typename MatmulGradK::OutputTileIterator outputK_it( + typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()}, + p.grad_key_ptr + key_start * p.gK_strideM(), + {num_keys_in_block, + false ? MatmulGradK::ThreadblockShape::kN : p.head_dim}, + thread_id); + accumulateInGmem( + shared_storage.gradK_epilogue_final(), + output_frags.gradK, + outputK_it, + true, + warp_id, + lane_id); + } + + template + static CUTLASS_DEVICE void accumulateInGmem( + typename MatmulT::DefaultEpilogue::SharedStorage& epilogue_smem, + typename MatmulT::Mma::FragmentC const& accum, + typename MatmulT::OutputTileIterator output_it, + bool first, + uint8_t warp_id, + uint8_t lane_id) { + using DefaultEpilogue = typename MatmulT::DefaultEpilogue; + using DefaultOutputOp = typename MatmulT::DefaultOutputOp; + using Mma = typename MatmulT::Mma; + int thread_id = 32 * warp_id + lane_id; + DISPATCH_BOOL( + first, kIsFirst, ([&]() { + static constexpr auto ScaleType = kIsFirst + ? cutlass::epilogue::thread::ScaleType::Nothing + : cutlass::epilogue::thread::ScaleType::NoBetaScaling; + using EpilogueOutputOp = + typename cutlass::epilogue::thread::LinearCombination< + typename DefaultOutputOp::ElementOutput, + DefaultOutputOp::kCount, + typename DefaultOutputOp::ElementAccumulator, + typename DefaultOutputOp::ElementCompute, + ScaleType>; + using Epilogue = + typename cutlass::epilogue::threadblock::EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename MatmulT::OutputTileIterator, + typename DefaultEpilogue::AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true // IterationsUnroll + >; + EpilogueOutputOp rescale({1, 1}); + Epilogue epilogue(epilogue_smem, thread_id, warp_id, lane_id); + epilogue(rescale, output_it, accum, output_it); + })); + } + + template + static CUTLASS_DEVICE void computeDelta( + Params const& p, + int32_t query_start, + uint8_t warp_id, + uint8_t lane_id) { + // Each thread computes one value for Delta + // Depending on warp configuration, we might have multiple + // threads of the same warp working on the same row + using AccessType = cutlass::Array; + static_assert(kNumThreads >= kBlockSizeI, ""); + static constexpr int kNumThreadsPerLine = kNumThreads / kBlockSizeI; + int16_t thread_id = 32 * warp_id + lane_id; + + int16_t laneFirstCol = kElementsPerAccess * (lane_id % kNumThreadsPerLine); + int16_t laneRow = thread_id / kNumThreadsPerLine; + bool rowPred = (query_start + laneRow) < p.num_queries; + bool pred = rowPred; + + // on windows, previous syntax __restrict__ AccessType* + // resulted in error: "restrict" is not allowed + const AccessType* __restrict__ grad_output_ptr = + reinterpret_cast( + p.grad_output_ptr + (query_start + laneRow) * p.gO_strideM + + laneFirstCol); + const AccessType* __restrict__ output_ptr = + reinterpret_cast( + p.output_ptr + (query_start + laneRow) * p.o_strideM() + + laneFirstCol); + + static constexpr int64_t kMaxIters = + kMaxK / (kElementsPerAccess * kNumThreadsPerLine); + constexpr int kPipelineStages = 2; + accum_t delta_value = accum_t(0); + using GlobalLoad = + cutlass::arch::global_load; + AccessType frag_grad_output[kPipelineStages]; + AccessType frag_output[kPipelineStages]; + + auto loadAndIncrement = [&](int ld_pos, bool is_valid) { + frag_grad_output[ld_pos].clear(); + frag_output[ld_pos].clear(); + GlobalLoad(frag_grad_output[ld_pos], grad_output_ptr, is_valid); + GlobalLoad(frag_output[ld_pos], output_ptr, is_valid); + grad_output_ptr += kNumThreadsPerLine; + output_ptr += kNumThreadsPerLine; + }; + + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < kPipelineStages - 1; ++iter) { + int ld_pos = iter % kPipelineStages; + pred = pred && + (laneFirstCol + iter * kElementsPerAccess * kNumThreadsPerLine) < + p.head_dim_value; + loadAndIncrement(ld_pos, pred); + } + auto columnIteration = [&](int iter) { + // Load for next iter + int ld_pos = (iter + kPipelineStages - 1) % kPipelineStages; + pred = pred && + (laneFirstCol + + (iter + kPipelineStages - 1) * kElementsPerAccess * + kNumThreadsPerLine) < p.head_dim_value; + loadAndIncrement(ld_pos, pred); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < AccessType::kElements; ++i) { + delta_value += accum_t(frag_output[iter % kPipelineStages][i]) * + accum_t(frag_grad_output[iter % kPipelineStages][i]); + } + }; + + // If we have a small lower-bound for K, we can unroll the loop + if (kMaxK <= 256) { + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < kMaxIters; ++iter) { + columnIteration(iter); + } + } else { + int num_iters = + ceil_div(p.head_dim_value, kElementsPerAccess * kNumThreadsPerLine) * + (kElementsPerAccess * kNumThreadsPerLine); + for (int iter = 0; iter < num_iters; ++iter) { + columnIteration(iter); + } + } + + // Reduce between workers + static_assert( + kNumThreadsPerLine == 1 || kNumThreadsPerLine == 2 || + kNumThreadsPerLine == 4, + ""); + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kNumThreadsPerLine; i *= 2) { + delta_value = delta_value + __shfl_xor_sync(0xffffffff, delta_value, i); + } + + // Store in gmem + if (rowPred) { + p.delta_ptr[query_start + laneRow] = delta_value; + } + } +}; + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_backward_batched_impl(typename AK::Params p) { + if (!p.advance_to_block()) { + return; + } + AK::attention_kernel(p); +} + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_backward_batched(typename AK::Params params); + +} // namespace PyTorchMemEffAttention diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h new file mode 100644 index 0000000000000000000000000000000000000000..b841c9d1496ea01c388ca275f50232b1477cefd6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h @@ -0,0 +1,1353 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include + +using namespace gemm_kernel_utils; + +namespace PyTorchMemEffAttention { +namespace { +template +constexpr int getWarpsPerSmFw() { + return ( + Arch::kMinComputeCapability >= 80 && + !cutlass::platform::is_same::value + ? 16 + : 12); +} +static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) { + // source: https://stackoverflow.com/a/51549250 + return !signbit(value) + ? __int_as_float(atomicMax((int *)addr, __float_as_int(value))) + : __uint_as_float( + atomicMin((unsigned int *)addr, __float_as_uint(value))); +} +} // namespace + +template < + // The datatype of Q/K/V + typename scalar_t_, + // Architecture we are targeting (eg `cutlass::arch::Sm80`) + typename ArchTag, + // If Q/K/V are correctly aligned in memory and we can run a fast kernel + bool isAligned_, + int kQueriesPerBlock_, + int kKeysPerBlock_, + // upperbound on `max(value.shape[-1], query.shape[-1])` + int kMaxK_ = (int)cutlass::platform::numeric_limits::max(), + // This is quite slower on V100 for some reason + // Set to false if you know at compile-time you will never need dropout + bool kSupportsDropout_ = true, + bool kSupportsBias_ = true> +struct AttentionKernel { + enum CustomMaskType { + NoCustomMask = 0, + CausalFromTopLeft = 1, + CausalFromBottomRight = 2, + NumCustomMaskTypes, + }; + + using scalar_t = scalar_t_; + using accum_t = float; + using lse_scalar_t = float; + using output_t = scalar_t; + // Accumulator between 2 iterations + // Using `accum_t` improves perf on f16 at the cost of + // numerical errors + using output_accum_t = accum_t; + static constexpr bool kSupportsDropout = kSupportsDropout_; + static constexpr bool kSupportsBias = kSupportsBias_; + static constexpr int kKeysPerBlock = kKeysPerBlock_; + static constexpr int kQueriesPerBlock = kQueriesPerBlock_; + static constexpr int kMaxK = kMaxK_; + static constexpr bool kIsAligned = isAligned_; + static constexpr bool kSingleValueIteration = kMaxK <= kKeysPerBlock; + static constexpr int32_t kAlignLSE = 32; // block size of backward + static constexpr bool kIsHalf = cutlass::sizeof_bits::value == 16; + static constexpr bool kPreloadV = + ArchTag::kMinComputeCapability >= 80 && kIsHalf; + static constexpr bool kKeepOutputInRF = kSingleValueIteration; + static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF && + !cutlass::platform::is_same::value; + + static_assert(kQueriesPerBlock % 32 == 0, ""); + static_assert(kKeysPerBlock % 32 == 0, ""); + static constexpr int kNumWarpsPerBlock = + kQueriesPerBlock * kKeysPerBlock / (32 * 32); + static constexpr int kWarpSize = 32; + + // Launch bounds + static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock; + static constexpr int kMinBlocksPerSm = + getWarpsPerSmFw() / kNumWarpsPerBlock; + + struct Params { + // Input tensors + const scalar_t* query_ptr = nullptr; // [num_queries, num_heads, head_dim] + const scalar_t* key_ptr = nullptr; // [num_keys, num_heads, head_dim] + const scalar_t* value_ptr = nullptr; // [num_keys, num_heads, head_dim_value] + const scalar_t* attn_bias_ptr = nullptr; // [num_heads, num_queries, num_keys] + const int32_t* seqstart_q_ptr = nullptr; + const int32_t* seqstart_k_ptr = nullptr; + + const int32_t* seqlen_k_ptr = nullptr; + uint32_t causal_diagonal_offset = 0; + + // Output tensors + output_t* output_ptr = nullptr; // [num_queries, num_heads, head_dim_value] + // [num_queries, num_heads, head_dim_value] + output_accum_t* output_accum_ptr = nullptr; + // [num_heads, num_queries] - can be null + lse_scalar_t* logsumexp_ptr = nullptr; + + // Sliding window. ignored if == 0 + int32_t window_size = 0; + + // Scale + accum_t scale = 0.0; + + // Dimensions/strides + int32_t head_dim = 0; + int32_t head_dim_value = 0; + int32_t num_queries = 0; + int32_t num_keys = 0; + int32_t num_keys_absolute = 0; + + uint8_t custom_mask_type = NoCustomMask; + + int32_t q_strideM = 0; + int32_t k_strideM = 0; + int32_t v_strideM = 0; + int32_t bias_strideM = 0; + + int32_t o_strideM = 0; + + // Everything below is only used in `advance_to_block` + // and shouldn't use registers + int32_t q_strideH = 0; + int32_t k_strideH = 0; + int32_t v_strideH = 0; + int64_t bias_strideH = 0; + + int64_t q_strideB = 0; + int64_t k_strideB = 0; + int64_t v_strideB = 0; + int64_t bias_strideB = 0; + + int32_t num_batches = 0; + int32_t num_heads = 0; + + // dropout + bool use_dropout = false; + unsigned long long dropout_batch_head_rng_offset = 0; + float dropout_prob = 0.0f; + at::PhiloxCudaState rng_engine_inputs = at::PhiloxCudaState(0, 0); + int64_t* extragraph_offset = nullptr; + int64_t* seed = nullptr; + + // Moves pointers to what we should process + // Returns "false" if there is no work to do + CUTLASS_DEVICE bool advance_to_block() { + auto batch_id = blockIdx.z; + auto head_id = blockIdx.y; + auto query_start = blockIdx.x * kQueriesPerBlock; + + auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE; + + if (kSupportsDropout) { + dropout_batch_head_rng_offset = + batch_id * num_heads * num_queries * num_keys + + head_id * num_queries * num_keys; + } + + int64_t q_start = 0, k_start = 0; + // Advance to current batch - in case of different sequence lengths + if (seqstart_q_ptr != nullptr) { + assert(seqstart_k_ptr != nullptr); + seqstart_q_ptr += batch_id; + + q_start = seqstart_q_ptr[0]; + int64_t q_next_start = seqstart_q_ptr[1]; + int64_t k_end; + seqstart_k_ptr += batch_id; + + if (seqlen_k_ptr) { + k_start = seqstart_k_ptr[0]; + k_end = k_start + seqlen_k_ptr[batch_id]; + } else { + k_start = seqstart_k_ptr[0]; + k_end = seqstart_k_ptr[1]; + } + + num_queries = q_next_start - q_start; + num_keys = k_end - k_start; + + if (query_start >= num_queries) { + return false; + } + } else { + query_ptr += batch_id * q_strideB; + key_ptr += batch_id * k_strideB; + value_ptr += batch_id * v_strideB; + output_ptr += int64_t(batch_id * num_queries) * o_strideM; + if (output_accum_ptr != nullptr) { + output_accum_ptr += + int64_t(batch_id * num_queries) * (head_dim_value * num_heads); + } + q_start = 0; + k_start = 0; + } + + // Advance to the current batch / head / query_start + query_ptr += (q_start + query_start) * q_strideM + head_id * q_strideH; + key_ptr += k_start * k_strideM + head_id * k_strideH; + + value_ptr += k_start * v_strideM + head_id * v_strideH; + output_ptr += + int64_t(q_start + query_start) * o_strideM + head_id * head_dim_value; + + if (kSupportsBias && attn_bias_ptr != nullptr) { + attn_bias_ptr += (batch_id * bias_strideB) + (head_id * bias_strideH); + } + if (output_accum_ptr != nullptr) { + output_accum_ptr += + int64_t(q_start + query_start) * (head_dim_value * num_heads) + + head_id * head_dim_value; + } else { + // Accumulate directly in the destination buffer (eg for f32) + output_accum_ptr = (accum_t*)output_ptr; + } + + if (logsumexp_ptr != nullptr) { + // lse[batch_id, head_id, query_start] + logsumexp_ptr += + batch_id * lse_dim * num_heads + head_id * lse_dim + query_start; + } + + // Custom masking + if (custom_mask_type == CausalFromBottomRight) { + causal_diagonal_offset = num_keys - num_queries; + } + // We use num_keys_absolute to index into the rng_state + // We need this index to match between forward and backwards + num_keys_absolute = num_keys; + if (custom_mask_type == CausalFromTopLeft || + custom_mask_type == CausalFromBottomRight) { + // the bottom row of the current block is query_start + kQueriesPerBlock + // the last active key is then query_start + causal_diagonal_offset + + // kQueriesPerBlock so num_keys is the min between actual num_keys and + // this to avoid extra computations + num_keys = cutlass::fast_min( + int32_t(query_start + causal_diagonal_offset + kQueriesPerBlock), + num_keys); + } + + num_queries -= query_start; + num_batches = 0; // no longer used after + + // If num_queries == 1, and there is only one key head we're wasting + // 15/16th of tensor core compute In that case : + // - we only launch kernels for head_id % kQueriesPerBlock == 0 + // - we iterate over heads instead of queries (strideM = strideH) + if (num_queries == 1 && k_strideH == 0 && v_strideH == 0 && + logsumexp_ptr == nullptr && window_size == 0) { + if (head_id % kQueriesPerBlock != 0) { + return false; + } + q_strideM = q_strideH; + bias_strideM = bias_strideH; + num_queries = num_heads; + num_heads = 1; // unused but here for intent + // remove causal since n_query = 1 + // otherwise, offset would change with head ! + custom_mask_type = NoCustomMask; + o_strideM = head_dim_value; + } + + // Make sure the compiler knows these variables are the same on all + // the threads of the warp. + // Only worth doing if they could have been modified above. + query_ptr = warp_uniform(query_ptr); + key_ptr = warp_uniform(key_ptr); + value_ptr = warp_uniform(value_ptr); + if (kSupportsBias) { + attn_bias_ptr = warp_uniform(attn_bias_ptr); + } + output_ptr = warp_uniform(output_ptr); + output_accum_ptr = warp_uniform(output_accum_ptr); + logsumexp_ptr = warp_uniform(logsumexp_ptr); + num_queries = warp_uniform(num_queries); + num_keys = warp_uniform(num_keys); + num_heads = warp_uniform(num_heads); + o_strideM = warp_uniform(o_strideM); + custom_mask_type = warp_uniform(custom_mask_type); + return true; + } + + __host__ dim3 getBlocksGrid() const { + return dim3( + ceil_div(num_queries, (int32_t)kQueriesPerBlock), + num_heads, + num_batches); + } + + __host__ dim3 getThreadsGrid() const { + return dim3(kWarpSize, kNumWarpsPerBlock, 1); + } + }; + + struct MM0 { + /* + In this first matmul, we compute a block of `Q @ K.T`. + While the calculation result is still hot in registers, we update + `mi`, `m_prime`, `s_prime` in shared-memory, and then store this value + into a shared-memory ("AccumulatorSharedStorage") that is used later as + operand A for the second matmul (see MM1) + */ + using GemmType = DefaultGemmType; + + using OpClass = typename GemmType::OpClass; + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration< + OpClass, + ArchTag, + scalar_t, + scalar_t, + scalar_t, // ElementC + accum_t // ElementAccumulator + >; + static constexpr int kAlignmentA = + kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment; + static constexpr int kAlignmentB = + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; + using ThreadblockShape = cutlass::gemm:: + GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma< + scalar_t, // ElementA, + cutlass::layout::RowMajor, // LayoutA, + kAlignmentA, + scalar_t, // ElementB, + cutlass::layout::ColumnMajor, // LayoutB, + kAlignmentB, + accum_t, + cutlass::layout::RowMajor, // LayoutC, + OpClass, + ArchTag, // ArchTag + ThreadblockShape, // ThreadblockShape + WarpShape, // WarpShape + typename GemmType::InstructionShape, // InstructionShape + ArchTag::kMinComputeCapability >= 80 && kIsHalf + ? 4 + : DefaultConfig::kStages, + typename GemmType::Operator // Operator + >::DefaultMma; + using MmaCore = typename DefaultMma::MmaCore; + using IteratorA = typename DefaultMma::IteratorA; + using IteratorB = typename DefaultMma::IteratorB; + using DefaultThreadblockMma = typename DefaultMma::ThreadblockMma; + using Mma = typename cutlass::platform::conditional< + kSingleValueIteration, + typename MakeCustomMma::Mma, + DefaultThreadblockMma>::type; + using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator< + typename Mma::Operator::IteratorC, + accum_t, + kWarpSize>::Iterator; + static_assert( + MmaCore::WarpCount::kM * MmaCore::WarpCount::kN * + MmaCore::WarpCount::kK == + kNumWarpsPerBlock, + ""); + + // used for efficient load of bias tile Bij from global to shared memory + using BiasLoader = TileSmemLoader< + scalar_t, + cutlass::MatrixShape, + MmaCore::kThreads, + // input restriction: kv_len has to be a multiple of this value + 128 / cutlass::sizeof_bits::value>; + + // Epilogue to store to shared-memory in a format that we can use later for + // the second matmul + using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< + typename Mma::Operator::IteratorC, + typename Mma::Operator, + scalar_t, + WarpShape, + ThreadblockShape>; + using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; + }; + + struct MM1 { + /** + Second matmul: perform `attn @ V` where `attn` is the attention (not + normalized) and stored in shared memory + */ + using GemmType = DefaultGemmType; + + using OpClass = typename GemmType::OpClass; + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration< + OpClass, + ArchTag, + scalar_t, + scalar_t, + output_accum_t, // ElementC + accum_t // ElementAccumulator + >; + static constexpr int kAlignmentA = DefaultConfig::kAlignmentA; // from smem + static constexpr int kAlignmentB = + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; + using ThreadblockShape = cutlass::gemm:: + GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + using LayoutB = cutlass::layout::RowMajor; + using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< + scalar_t, // ElementA, + cutlass::layout::RowMajor, // LayoutA, + kAlignmentA, + scalar_t, // ElementB, + LayoutB, // LayoutB, + kAlignmentB, + output_accum_t, + cutlass::layout::RowMajor, // LayoutC, + accum_t, + OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + typename DefaultConfig::EpilogueOutputOp, + void, // ThreadblockSwizzle - not used + ArchTag::kMinComputeCapability >= 80 && kIsHalf + ? 4 + : DefaultConfig::kStages, + false, // SplitKSerial + typename GemmType::Operator>; + + using WarpIteratorA = typename cutlass::gemm::threadblock:: + DefaultWarpIteratorAFromSharedMemory< + typename DefaultGemm::Mma::Policy::Operator::Shape, // WarpShape + typename DefaultGemm::Mma::Policy::Operator::InstructionShape, + typename DefaultGemm::Mma::Policy::Operator::IteratorA, + typename DefaultGemm::Mma::Policy>::WarpIterator; + using DefaultMmaFromSmem = + typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + MM0::AccumulatorSharedStorage::Shape::kN, // kMaxK + WarpIteratorA, + false>; // kScaleOperandA + using Mma = typename DefaultMmaFromSmem::Mma; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + static_assert( + WarpCount::kM * WarpCount::kN * WarpCount::kK == kNumWarpsPerBlock, + ""); + + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = + typename cutlass::epilogue::threadblock::PredicatedTileIterator< + typename DefaultEpilogue::OutputTileIterator::ThreadMap, + output_t>; + using OutputTileIteratorAccum = + typename cutlass::epilogue::threadblock::PredicatedTileIterator< + typename DefaultEpilogue::OutputTileIterator::ThreadMap, + output_accum_t>; + }; + + static constexpr int64_t kAlignmentQ = MM0::kAlignmentA; + static constexpr int64_t kAlignmentK = MM0::kAlignmentB; + static constexpr int64_t kAlignmentV = 1; + + // Shared storage - depends on kernel params + struct ScalingCoefs { + cutlass::Array m_prime; + cutlass::Array s_prime; + cutlass::Array mi; + cutlass::Array out_rescale; + cutlass::Array + addition_storage; + }; + + struct SharedStorageEpilogueAtEnd : ScalingCoefs { + struct SharedStorageAfterMM0 { + // Everything here might be overwritten during MM0 + union { + typename MM0::BiasLoader::SmemTile bias; + typename MM0::AccumulatorSharedStorage si; + }; + typename MM1::Mma::SharedStorage mm1; + }; + + union { + typename MM0::Mma::SharedStorage mm0; + SharedStorageAfterMM0 after_mm0; + typename MM1::DefaultEpilogue::SharedStorage epilogue; + }; + + CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& + epilogue_shared_storage() { + return epilogue; + } + }; + + struct SharedStorageEpilogueInLoop : ScalingCoefs { + struct SharedStorageAfterMM0 { + // Everything here might be overwritten during MM0 + union { + typename MM0::BiasLoader::SmemTile bias; + typename MM0::AccumulatorSharedStorage si; + }; + typename MM1::Mma::SharedStorage mm1; + typename MM1::DefaultEpilogue::SharedStorage epilogue; + }; + + union { + typename MM0::Mma::SharedStorage mm0; + SharedStorageAfterMM0 after_mm0; + }; + + CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& + epilogue_shared_storage() { + return after_mm0.epilogue; + } + }; + + using SharedStorage = typename cutlass::platform::conditional< + kSingleValueIteration || kKeepOutputInRF, + SharedStorageEpilogueAtEnd, + SharedStorageEpilogueInLoop>::type; + + static bool __host__ check_supported(Params const& p) { + CHECK_ALIGNED_PTR(p.query_ptr, kAlignmentQ); + CHECK_ALIGNED_PTR(p.key_ptr, kAlignmentK); + CHECK_ALIGNED_PTR(p.value_ptr, kAlignmentV); + if (kSupportsBias) { + CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ); + TORCH_CHECK( + p.num_batches <= 1 || p.bias_strideB % kAlignmentQ == 0, + "attn_bias is not correctly aligned (strideB). ", + "attn_bias.stride( 0) = ", p.bias_strideB, ", and should be a " + "multiple of ", kAlignmentQ, "."); + TORCH_CHECK( + p.num_heads <= 1 || p.bias_strideH % kAlignmentQ == 0, + "attn_bias is not correctly aligned (strideH). " + "attn_bias.stride(1) = ", p.bias_strideH, ", and should be a " + "multiple of ", kAlignmentQ, "."); + TORCH_CHECK( + p.num_queries <= 1 || p.bias_strideM % kAlignmentQ == 0, + "attn_bias is not correctly aligned (strideM). " + "attn_bias.stride(2) = ", p.bias_strideM, ", and should be a " + "multiple of ", kAlignmentQ, "."); + } + TORCH_CHECK( + p.q_strideM % kAlignmentQ == 0, + "query is not correctly aligned (strideM)"); + TORCH_CHECK( + p.k_strideM % kAlignmentK == 0, + "key is not correctly aligned (strideM)"); + TORCH_CHECK( + p.v_strideM % kAlignmentV == 0, + "value is not correctly aligned (strideM)"); + TORCH_CHECK( + p.num_heads <= 1 || p.q_strideH % kAlignmentQ == 0, + "query is not correctly aligned (strideH)"); + TORCH_CHECK( + p.num_heads <= 1 || p.k_strideH % kAlignmentK == 0, + "key is not correctly aligned (strideH)"); + TORCH_CHECK( + p.num_heads <= 1 || p.v_strideH % kAlignmentV == 0, + "value is not correctly aligned (strideH)"); + TORCH_CHECK( + p.custom_mask_type < NumCustomMaskTypes, + "invalid value for `custom_mask_type`"); + if (p.window_size > 0) { + TORCH_CHECK( + p.custom_mask_type == CausalFromTopLeft || + p.custom_mask_type == CausalFromBottomRight, + "custom_mask_type not supported"); + } + return true; + } + + static void CUTLASS_DEVICE attention_kernel(Params& p) { + // In this block, we will only ever: + // - read query[query_start:query_end, :] + // - write to output[query_start:query_end, :] + + extern __shared__ char smem_buffer[]; + SharedStorage& shared_storage = *((SharedStorage*)smem_buffer); + auto& m_prime = shared_storage.m_prime; + auto& s_prime = shared_storage.s_prime; + auto& mi = shared_storage.mi; + auto& out_rescale = shared_storage.out_rescale; + const uint32_t query_start = blockIdx.x * kQueriesPerBlock; + + static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); + if (thread_id() < kQueriesPerBlock) { + s_prime[thread_id()] = accum_t(0); + out_rescale[thread_id()] = accum_t(1.0); + m_prime[thread_id()] = + -cutlass::platform::numeric_limits::infinity(); + mi[thread_id()] = -cutlass::platform::numeric_limits::infinity(); + } + typename MM1::Mma::FragmentC accum_o; + accum_o.clear(); + + auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator { + using OutputTileIterator = typename MM1::OutputTileIterator; + return OutputTileIterator( + typename OutputTileIterator::Params{(int32_t)p.o_strideM}, + p.output_ptr, + typename OutputTileIterator::TensorCoord{ + p.num_queries, p.head_dim_value}, + thread_id(), + {0, col}); + }; + + auto createOutputAccumIter = [&](int col) -> + typename MM1::OutputTileIteratorAccum { + using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum; + return OutputTileIteratorAccum( + typename OutputTileIteratorAccum::Params{ + (int32_t)(p.head_dim_value * p.num_heads)}, + p.output_accum_ptr, + typename OutputTileIteratorAccum::TensorCoord{ + p.num_queries, p.head_dim_value}, + thread_id(), + {0, col}); + }; + + curandStatePhilox4_32_10_t curand_state_init; + if (kSupportsDropout && p.use_dropout) { + const auto [seed, offset] = at::cuda::philox::unpack(p.rng_engine_inputs); + if (p.rng_engine_inputs.captured_) { + // See Note [Seed and Offset Device] + // When we are in cuda graph capture mode the seed and offset are stored + // on device We pass in int64_t* seed, and int64_t* offset to act as + // scratch space for storing the rng state during the forward pass and + // saving for backwards. + *p.seed = seed; + *p.extragraph_offset = offset; + } + // each element of the attention matrix P with shape + // (batch_sz, n_heads, n_queries, n_keys) is associated with a single + // offset in RNG sequence. we initialize the RNG state with offset that + // starts at the beginning of a (n_queries, n_keys) matrix for this + // block's batch_id and head_id + // initializing rng state is very expensive, so we run once per kernel, + // rather than once per iteration. each iteration takes a copy of the + // initialized RNG state and offsets it as needed. + curand_init( + seed, + 0, + offset + p.dropout_batch_head_rng_offset, + &curand_state_init); + } + + // Iterate through keys + for (int32_t iter_key_start = 0; iter_key_start < p.num_keys; + iter_key_start += kKeysPerBlock) { + if (p.window_size > 0) { + // don't compute anything if below attention band + if (iter_key_start + kKeysPerBlock < + int32_t(query_start + p.causal_diagonal_offset) - p.window_size) { + continue; + } + } + int32_t problem_size_0_m = + cutlass::fast_min((int32_t)kQueriesPerBlock, p.num_queries); + int32_t problem_size_0_n = cutlass::fast_min( + int32_t(kKeysPerBlock), p.num_keys - iter_key_start); + int32_t const& problem_size_0_k = p.head_dim; + int32_t const& problem_size_1_n = p.head_dim_value; + int32_t const& problem_size_1_k = problem_size_0_n; + + auto prologueV = [&](int blockN) { + typename MM1::Mma::IteratorB iterator_V( + typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)}, + const_cast(p.value_ptr + iter_key_start * p.v_strideM), + {problem_size_1_k, problem_size_1_n}, + thread_id(), + cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); + MM1::Mma::prologue( + shared_storage.after_mm0.mm1, + iterator_V, + thread_id(), + problem_size_1_k); + }; + + __syncthreads(); // Need to have shared memory initialized, and `m_prime` + // updated from end of prev iter + // + // MATMUL: Q.K_t + // + // Computes the block-matrix product of: + // (a) query[query_start:query_end, :] + // with + // (b) key[iter_key_start:iter_key_start + kKeysPerBlock] + // and stores that into `shared_storage.si` + // + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {0, 0, 0}; + + cutlass::MatrixCoord tb_offset_A{ + tb_tile_offset.m() * MM0::Mma::Shape::kM, tb_tile_offset.k()}; + + cutlass::MatrixCoord tb_offset_B{ + tb_tile_offset.k(), tb_tile_offset.n() * MM0::Mma::Shape::kN}; + + // Construct iterators to A and B operands + typename MM0::IteratorA iterator_A( + typename MM0::IteratorA::Params( + typename MM0::MmaCore::LayoutA(p.q_strideM)), + const_cast(p.query_ptr), + {problem_size_0_m, problem_size_0_k}, + thread_id(), + tb_offset_A); + + typename MM0::IteratorB iterator_B( + typename MM0::IteratorB::Params( + typename MM0::MmaCore::LayoutB(p.k_strideM)), + const_cast(p.key_ptr + iter_key_start * p.k_strideM), + {problem_size_0_k, problem_size_0_n}, + thread_id(), + tb_offset_B); + + auto my_warp_id = warp_uniform(warp_id()); + auto my_lane_id = lane_id(); + + // Construct thread-scoped matrix multiply + typename MM0::Mma mma( + shared_storage.mm0, thread_id(), my_warp_id, my_lane_id); + + typename MM0::Mma::FragmentC accum; + + accum.clear(); + + auto gemm_k_iterations = + (problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + __syncthreads(); + + if (kPreloadV) { + prologueV(0); + } + + typename MM0::Mma::Operator::IteratorC::TensorCoord + iteratorC_tile_offset = { + (tb_tile_offset.m() * MM0::Mma::WarpCount::kM) + + (my_warp_id % MM0::Mma::WarpCount::kM), + (tb_tile_offset.n() * MM0::Mma::WarpCount::kN) + + (my_warp_id / MM0::Mma::WarpCount::kM)}; + + // multiply by scaling factor + if (kSupportsBias) { + accum = + cutlass::multiplies()(p.scale, accum); + } + + // apply attention bias if applicable + if (kSupportsBias && p.attn_bias_ptr != nullptr) { + // load bias tile Bij into shared memory + typename MM0::BiasLoader::GmemTileIterator bias_iter( + {cutlass::layout::RowMajor(p.bias_strideM)}, + // attn_bias_pointer points to matrix of size (n_queries, n_keys) + // for the relevant batch_id and head_id + const_cast(p.attn_bias_ptr + query_start * p.bias_strideM + iter_key_start), + {problem_size_0_m, problem_size_0_n}, + thread_id()); + cutlass::TensorRef bias_tensor_ref( + shared_storage.after_mm0.bias.data(), + cutlass::layout::RowMajor(MM0::ThreadblockShape::kN)); + typename MM0::BiasLoader::SmemTileIterator smem_tile_iter( + bias_tensor_ref, thread_id()); + MM0::BiasLoader::load(bias_iter, smem_tile_iter); + + // Pij += Bij, Pij is in register fragment and Bij is in shared memory + auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset( + my_lane_id, my_warp_id, iteratorC_tile_offset); + MM0::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + if (accum_m < problem_size_0_m && accum_n < problem_size_0_n) { + accum[idx] += bias_tensor_ref.at({accum_m, accum_n}); + } + }, + [&](int accum_m) {}); + } + + // Mask out last if causal + // This is only needed if upper-right corner of current query / key block + // intersects the mask Coordinates of upper-right corner of current block + // is y=query_start x=min(iter_key_start + kKeysPerBlock, num_keys)) The + // first masked element is x = y + offset -> query_start + offset There is + // intersection (and we need to mask) if min(iter_key_start + + // kKeysPerBlock, num_keys)) >= query_start + offset + if (p.custom_mask_type && + cutlass::fast_min(iter_key_start + kKeysPerBlock, p.num_keys) >= + (query_start + p.causal_diagonal_offset)) { + auto query_start = blockIdx.x * kQueriesPerBlock; + auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset( + my_lane_id, my_warp_id, iteratorC_tile_offset); + int32_t last_col; + MM0::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + // last absolute col is (last absolute query + offset) + // last local col is (last absolute query + offset - + // iter_key_start) + last_col = query_start + accum_m + p.causal_diagonal_offset - + iter_key_start; + }, + [&](int accum_m, int accum_n, int idx) { + if (accum_n > last_col) { + accum[idx] = + -cutlass::platform::numeric_limits::infinity(); + } + }, + [&](int accum_m) {}); + } + + // Mask out lower left corner of block if window_size > 0 + // only required if current block intersects with the lower left corner + // block starts at x_lowerleft = iter_key_start // y = query_start + + // kQueriesPerBlock first non masked value at this y is : x_first = + // query_start + kQueriesPerBlock - window_size mask if x_fist > + // x_lowerleft + + if (p.window_size > 0 && + (query_start + p.causal_diagonal_offset + + cutlass::fast_min( + int32_t(kQueriesPerBlock), int32_t(p.num_queries)) - + p.window_size >= + iter_key_start)) { + auto query_start = blockIdx.x * kQueriesPerBlock; + auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset( + my_lane_id, my_warp_id, iteratorC_tile_offset); + int32_t first_col; + const int32_t offset = query_start + p.causal_diagonal_offset - + p.window_size - iter_key_start; + MM0::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { first_col = accum_m + offset; }, + [&](int accum_m, int accum_n, int idx) { + if (accum_n <= first_col) { + accum[idx] = + -cutlass::platform::numeric_limits::infinity(); + } + }, + [&](int accum_m) {}); + // print_warp_accum(accum, lane_offset, 12, + // 12); + } + + // Update `mi` from accum stored in registers + // Also does accum[i] <- exp(accum[i] - mi) + iterative_softmax( + accum_o, + accum, + mi, + m_prime, + s_prime, + out_rescale, + shared_storage.addition_storage, + my_lane_id, + thread_id(), + my_warp_id, + p.num_keys - iter_key_start, + iter_key_start == 0, + iteratorC_tile_offset, + kSupportsBias ? 1.0f : p.scale); + + // Output results to shared-memory + int warp_idx_mn_0 = my_warp_id % + (MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN); + auto output_tile_coords = cutlass::MatrixCoord{ + warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM, + warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM}; + + MM0::B2bGemm::accumToSmem( + shared_storage.after_mm0.si, accum, my_lane_id, output_tile_coords); + + __syncthreads(); + + // apply dropout (if applicable) after we've written Pij to smem. + // dropout is applied by multiplying each element of Pij by: + // - 0 with probability dropout_p + // - 1 / (1 - dropout_p) with probability 1 - dropout_p + // + // for backward purposes we want to be able to map each element of the + // attention matrix to the same random uniform number as the one we used + // in forward, without needing to use the same iteration order or having + // to store the dropout matrix. its possible to do this in registers but + // it ends up being very slow because each thread having noncontiguous + // strips of the Pij tile means we have to skip around a lot, and also + // have to generate a single random number at a time + if (kSupportsDropout && p.use_dropout) { + auto si = shared_storage.after_mm0.si.accum_ref(); + // each thread handles a contiguous sequence of elements from Sij, all + // coming from the same row. the reason they have to come from the same + // row is that the sampling random numbers from a contiguous random + // number sequence is much more efficient than jumping around, and the + // linear offset of each element of S (the global matrix) maps to an + // offset in a random number sequence. for S, the end of a row and the + // beginning of the next have adjacent offsets, but for Sij, this is not + // necessarily the case. + const int num_threads = blockDim.x * blockDim.y * blockDim.z; + const int threads_per_row = + cutlass::fast_min(num_threads / problem_size_0_m, problem_size_0_n); + const int elts_per_thread = cutlass::round_nearest( + cutlass::ceil_div(problem_size_0_n, threads_per_row), 4); + + const int thread_i = thread_id() / threads_per_row; + const int thread_start_j = + (thread_id() % threads_per_row) * elts_per_thread; + + if (thread_i < problem_size_0_m && thread_start_j < problem_size_0_n) { + curandStatePhilox4_32_10_t curand_state = curand_state_init; + skipahead( + static_cast( + (query_start + thread_i) * p.num_keys_absolute + + (iter_key_start + thread_start_j)), + &curand_state); + const float dropout_scale = 1.0 / (1.0 - p.dropout_prob); + + // apply dropout scaling to elements this thread is responsible for, + // in chunks of 4 + for (int sij_start_col_idx = thread_start_j; sij_start_col_idx < + cutlass::fast_min(thread_start_j + elts_per_thread, + problem_size_0_n); + sij_start_col_idx += 4) { + const float4 rand_uniform_quad = curand_uniform4(&curand_state); + + CUTLASS_PRAGMA_UNROLL + for (int quad_idx = 0; quad_idx < 4; ++quad_idx) { + si.at({thread_i, sij_start_col_idx + quad_idx}) *= + static_cast( + dropout_scale * + ((&rand_uniform_quad.x)[quad_idx] > p.dropout_prob)); + } + } + } + __syncthreads(); // p.use_dropout should have same value kernel-wide + } + + // + // MATMUL: Attn . V + // Run the matmul `attn @ V` for a block of attn and V. + // `attn` is read from shared memory (in `shared_storage_si`) + // `V` is read from global memory (with iterator_B) + // + + const int64_t nBlockN = kSingleValueIteration + ? 1 + : ceil_div( + (int64_t)problem_size_1_n, int64_t(MM1::ThreadblockShape::kN)); + for (int blockN = 0; blockN < nBlockN; ++blockN) { + int gemm_k_iterations = + (problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add and store it in accum + // (in registers) + if (!kPreloadV) { + __syncthreads(); // we share shmem between mma and epilogue + } + + typename MM1::Mma::IteratorB iterator_V( + typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)}, + const_cast(p.value_ptr + iter_key_start * p.v_strideM), + {problem_size_1_k, problem_size_1_n}, + thread_id(), + cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); + typename MM1::Mma mma_pv( + // operand A: Pij_dropped in shared memory + shared_storage.after_mm0.si.accum_ref(), + // operand B: shared memory staging area for Vj, which is loaded + // from global memory + shared_storage.after_mm0.mm1.operand_B_ref(), + (int)thread_id(), + (int)my_warp_id, + (int)my_lane_id); + mma_pv.set_prologue_done(kPreloadV); + if (!kKeepOutputInRF) { + accum_o.clear(); + } + mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o); + __syncthreads(); + + if (kPreloadV && !kSingleValueIteration && blockN + 1 < nBlockN) { + prologueV(blockN + 1); + } + + if (!kKeepOutputInRF) { + int first_key = 0; + if (p.window_size > 0) { + first_key = (cutlass::fast_max( + int(query_start + p.causal_diagonal_offset) - + p.window_size + 1, + 0) / + kKeysPerBlock) * + kKeysPerBlock; + } + + // int first_key_block = 0; + // MM1::Mma::drain_cp_asyncs(); # TODO figure out if this is needed for correctness + DISPATCH_BOOL( + iter_key_start == first_key, kIsFirst, ([&] { + DISPATCH_BOOL( + (iter_key_start + kKeysPerBlock) >= p.num_keys, + kIsLast, + ([&] { + using DefaultEpilogue = typename MM1::DefaultEpilogue; + using DefaultOp = + typename MM1::DefaultConfig::EpilogueOutputOp; + using ElementCompute = typename DefaultOp::ElementCompute; + using EpilogueOutputOp = typename cutlass::epilogue:: + thread::MemoryEfficientAttentionNormalize< + typename cutlass::platform::conditional< + kIsLast, + output_t, + output_accum_t>::type, + output_accum_t, + DefaultOp::kCount, + typename DefaultOp::ElementAccumulator, + ElementCompute, + kIsFirst, + kIsLast, + cutlass::Array>; + using Epilogue = typename cutlass::epilogue::threadblock:: + EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename MM1::Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename cutlass::platform::conditional< + kIsLast, + typename MM1::OutputTileIterator, + typename MM1::OutputTileIteratorAccum>::type, + typename DefaultEpilogue:: + AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true, // IterationsUnroll + typename MM1::OutputTileIteratorAccum // Read + // iterator + >; + + int col = blockN * MM1::Mma::Shape::kN; + auto source_iter = createOutputAccumIter(col); + auto dest_iter = call_conditional< + kIsLast, + decltype(createOutputIter), + decltype(createOutputAccumIter)>:: + apply(createOutputIter, createOutputAccumIter, col); + EpilogueOutputOp rescale(s_prime, out_rescale); + Epilogue epilogue( + shared_storage.epilogue_shared_storage(), + thread_id(), + my_warp_id, + my_lane_id); + epilogue(rescale, dest_iter, accum_o, source_iter); + })); + })); + if (!kSingleValueIteration) { + __syncthreads(); + } + } + } + __syncthreads(); // we modify `m_prime` after + } + + if (kKeepOutputInRF) { + constexpr bool kIsFirst = true; + constexpr bool kIsLast = true; + using DefaultEpilogue = typename MM1::DefaultEpilogue; + using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp; + using ElementCompute = typename DefaultOp::ElementCompute; + using EpilogueOutputOp = + typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize< + output_t, // output + output_accum_t, // source + DefaultOp::kCount, + typename DefaultOp::ElementAccumulator, // accum + output_accum_t, // compute + kIsFirst, + kIsLast, + cutlass::Array>; + using Epilogue = + typename cutlass::epilogue::threadblock::EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename MM1::Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename MM1::OutputTileIterator, // destination + typename DefaultEpilogue::AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true, // IterationsUnroll + typename MM1::OutputTileIteratorAccum // source tile + >; + auto dest_iter = createOutputIter(0); + EpilogueOutputOp rescale(s_prime, out_rescale); + Epilogue epilogue( + shared_storage.epilogue_shared_storage(), + thread_id(), + warp_id(), + lane_id()); + epilogue(rescale, dest_iter, accum_o); + } + + // 7. Calculate logsumexp + // To make the backward easier, we pad logsumexp with `inf` + // this avoids a few bound checks, and is not more expensive during fwd + static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); + if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) { + auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE; + constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E + if (thread_id() < p.num_queries) { + // We set fully masked out rows to 0, the sumexp for masked out rows will be 0 + // We update it to be 1 prior to calling log so that log(1) = 0 + s_prime[thread_id()] = (s_prime[thread_id()] == 0) ? 1: s_prime[thread_id()]; + mi[thread_id()] = (mi[thread_id()] == -cutlass::platform::numeric_limits::infinity()) ? 0: mi[thread_id()]; + p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()] / kLog2e) + + cutlass::fast_log(accum_t(s_prime[thread_id()])); + } else if (thread_id() < lse_dim) { + p.logsumexp_ptr[thread_id()] = + cutlass::platform::numeric_limits::infinity(); + } + } + } + + template + CUTLASS_DEVICE static void iterative_softmax( + typename WarpIteratorC::Fragment& frag_o, // output so far + typename WarpIteratorC::Fragment& frag, + cutlass::Array& mi, + cutlass::Array& m_prime, + cutlass::Array& s_prime, + cutlass::Array& out_rescale, + cutlass::Array& + addition_storage, + int8_t lane_id, + int8_t thread_id, + int8_t warp_id, + int max_col, + bool is_first, + typename WarpIteratorC::TensorCoord const& tile_offset, + float scaling) { + /* Iterates on the accumulator and corresponding position on result matrix + + (1) Update `mi[r]` to the max value of the row `r` + (2) In a second iteration do the following: + (a) accum <- exp(accum - mi) + (b) m_prime <- exp(m_prime - mi) + (c) s_prime <- s_prime * m_prime + sum(accum) + + All of this is done on registers, before we store all of this + on shared memory for the next matmul with Value. + */ + using Fragment = typename WarpIteratorC::Fragment; + using LambdaIterator = typename DefaultMmaAccumLambdaIterator< + WarpIteratorC, + accum_t, + kWarpSize>::Iterator; + // Convert to `accum_t` (rather than double) + constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E + + static_assert(kQueriesPerBlock % kNumWarpsPerBlock == 0, ""); + static constexpr int kLinesPerWarp = kQueriesPerBlock / kNumWarpsPerBlock; + + frag = cutlass::multiplies()(scaling * kLog2e, frag); + + auto lane_offset = + LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset); + + // First update `mi` to the max per-row + { + accum_t max; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + max = -cutlass::platform::numeric_limits::infinity(); + }, + [&](int accum_m, int accum_n, int idx) { + if (accum_n < max_col) { + max = cutlass::fast_max(max, frag[idx]); + } + }, + [&](int accum_m) { + // Having 4x atomicMax seems faster than reduce within warp + // first... + atomicMaxFloat(&mi[accum_m], max); + }); + } + + // Make sure we all share the update values for `mi` + __syncthreads(); + + // Doing this `exp` is quite expensive. Let's + // split it across the warps + bool restore_mi_to_minus_inf = false; + if (lane_id < kLinesPerWarp) { + int id = warp_id * kLinesPerWarp + lane_id; + auto m_prime_id = m_prime[id]; + auto mi_id = mi[id]; + bool changed = m_prime_id < mi_id; // `false` if both are -inf + if (changed) { + auto m_prime_exp = exp2f(m_prime_id - mi_id); + out_rescale[id] = m_prime_exp; + s_prime[id] *= m_prime_exp; + } else { + // Only when bias is enabled, it's possible that all the first values + // of attention are masked to `-inf`. In that case we want to avoid + // `nan = exp2f(-inf - (-inf))` so we temporarily set `mi` to 0 + if (kSupportsBias && + mi_id == -cutlass::platform::numeric_limits::infinity()) { + restore_mi_to_minus_inf = true; + mi[id] = 0.0f; + } + out_rescale[id] = 1.0f; + } + } + __syncthreads(); // Update output fragments + if (kKeepOutputInRF && !is_first) { + accum_t line_rescale; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { line_rescale = out_rescale[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { + frag_o[idx] = frag_o[idx] * line_rescale; + }, + [&](int accum_m) {}); + } + // Update accum_m, accum_n, ... + { + accum_t mi_row, total_row; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { mi_row = mi[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { + frag[idx] = + (accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0); + }, + [&](int accum_m) {}); + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { total_row = 0.0; }, + [&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; }, + [&](int accum_m) { + if (LambdaIterator::reduceSameRow( + lane_id, total_row, [](accum_t a, accum_t b) { + return a + b; + })) { + // NOTE: we could atomically add `total_row` to `s_prime`, but + // it's faster (and deterministic) to avoid atomics here + addition_storage + [accum_m + kQueriesPerBlock * tile_offset.column()] = + total_row; + } + }); + } + __syncthreads(); + if (lane_id < kLinesPerWarp) { + int id = warp_id * kLinesPerWarp + lane_id; + accum_t total_row = s_prime[id]; + if (restore_mi_to_minus_inf) { + // Restore `mi`, see above when we set `restore_mi_to_minus_inf=true` + mi[id] = -cutlass::platform::numeric_limits::infinity(); + } else { + m_prime[id] = mi[id]; + } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) { + total_row += addition_storage[id + kQueriesPerBlock * i]; + } + s_prime[id] = total_row; + } + } + + static CUTLASS_DEVICE int8_t lane_id() { + return threadIdx.x; + } + static CUTLASS_DEVICE int8_t warp_id() { + return threadIdx.y; + } + static CUTLASS_DEVICE int16_t thread_id() { + return threadIdx.x + threadIdx.y * blockDim.x; + } +}; + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_batched_impl(typename AK::Params p) { + if (!p.advance_to_block()) { + return; + } + AK::attention_kernel(p); +} + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_batched(typename AK::Params params); + +} // namespace PyTorchMemEffAttention diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB.h b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB.h new file mode 100644 index 0000000000000000000000000000000000000000..33c54ce41e7ed2a3a237cdc8bdeb95e4cbb8ae9e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB.h @@ -0,0 +1,914 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +// This file is auto-generated. See "generate_kernels.py" +#pragma once +#include +using namespace PyTorchMemEffAttention; +// ======== f16 / sm70 ======== +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k32_seqaligned_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k32_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k64_seqaligned_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k64_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k128_seqaligned_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k128_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k128_seqaligned_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k128_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k65536_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k65536_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k32_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k64_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k128_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k128_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k65536_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k65536_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k32_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k64_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_128x64_k128_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k128_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_128x64_k65536_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k65536_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k32_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k64_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_128x64_k128_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k128_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_128x64_k65536_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k65536_dropout_sm70(typename AttentionBackwardKernel::Params p); + +template void dispatch_cutlassB_f16_sm70(T cb, int cc) { + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k32_seqaligned_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k32_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k64_seqaligned_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k64_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k128_seqaligned_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k128_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k128_seqaligned_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k128_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k65536_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k65536_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k32_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k64_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k128_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k128_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k65536_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k65536_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k32_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k64_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_128x64_k128_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k128_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_128x64_k65536_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k65536_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k32_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k64_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_128x64_k128_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k128_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_128x64_k65536_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k65536_dropout_sm70); +} + +// ======== bf16 / sm80 ======== +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_64x64_k32_seqaligned_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_64x64_k32_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_64x64_k64_seqaligned_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_64x64_k64_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_128x64_k96_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_128x128_k128_seqaligned_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_128x128_k128_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_64x64_k128_seqaligned_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_64x64_k128_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_128x64_k65536_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_64x64_k65536_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_64x64_k32_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_64x64_k64_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_128x128_k128_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_64x64_k128_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_128x64_k65536_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_64x64_k65536_dropout_sm80(typename AttentionBackwardKernel::Params p); + +template void dispatch_cutlassB_bf16_sm80(T cb, int cc) { + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_64x64_k32_seqaligned_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_64x64_k32_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_64x64_k64_seqaligned_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_64x64_k64_sm80); + if (cc == 86 || cc == 89) cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_128x64_k96_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_128x128_k128_seqaligned_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_128x128_k128_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_64x64_k128_seqaligned_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_64x64_k128_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_128x64_k65536_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_64x64_k65536_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_64x64_k32_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_64x64_k64_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_128x128_k128_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_64x64_k128_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_128x64_k65536_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_64x64_k65536_dropout_sm80); +} + +// ======== f16 / sm80 ======== +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k32_seqaligned_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k32_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k64_seqaligned_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k64_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k96_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x128_k128_seqaligned_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x128_k128_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k128_seqaligned_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k128_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k65536_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k65536_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k32_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k64_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x128_k128_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k128_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k65536_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k65536_dropout_sm80(typename AttentionBackwardKernel::Params p); + +template void dispatch_cutlassB_f16_sm80(T cb, int cc) { + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k32_seqaligned_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k32_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k64_seqaligned_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k64_sm80); + if (cc == 86 || cc == 89) cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k96_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x128_k128_seqaligned_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x128_k128_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k128_seqaligned_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k128_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k65536_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k65536_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k32_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k64_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x128_k128_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k128_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k65536_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k65536_dropout_sm80); +} + +// ======== f16 / sm50 ======== +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k32_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k64_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k128_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k65536_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k32_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k64_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k128_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k65536_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k32_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k64_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k128_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k65536_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k32_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k64_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k128_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k65536_dropout_sm50(typename AttentionBackwardKernel::Params p); + +template void dispatch_cutlassB_f16_sm50(T cb, int cc) { + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k32_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k64_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k128_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k65536_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k32_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k64_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k128_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k65536_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k32_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k64_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k128_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k65536_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k32_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k64_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k128_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k65536_dropout_sm50); +} + +// ======== f32 / sm50 ======== +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k32_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k64_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k128_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k65536_sm50(typename AttentionBackwardKernel::Params p); +#if defined(CUDA_VERSION) && CUDA_VERSION == 12040 && !defined(USE_ROCM) +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_32x32_k32_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_32x32_k64_dropout_sm50(typename AttentionBackwardKernel::Params p); +#else +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k32_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k64_dropout_sm50(typename AttentionBackwardKernel::Params p); +#endif +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k128_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k65536_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k32_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k64_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k128_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k65536_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k32_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k64_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k128_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k65536_dropout_sm50(typename AttentionBackwardKernel::Params p); + +template void dispatch_cutlassB_f32_sm50(T cb, int cc) { + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k32_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k64_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k128_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k65536_sm50); +#if defined(CUDA_VERSION) && CUDA_VERSION == 12040 && !defined(USE_ROCM) + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_32x32_k32_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_32x32_k64_dropout_sm50); +#else + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k32_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k64_dropout_sm50); +#endif + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k128_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k65536_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k32_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k64_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k128_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k65536_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k32_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k64_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k128_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k65536_dropout_sm50); +} + +// ======== f32 / sm70 ======== +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k32_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k64_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k128_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k65536_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k32_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k64_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k128_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k65536_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k32_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k64_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k128_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k65536_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k32_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k64_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k128_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k65536_dropout_sm70(typename AttentionBackwardKernel::Params p); + +template void dispatch_cutlassB_f32_sm70(T cb, int cc) { + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k32_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k64_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k128_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k65536_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k32_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k64_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k128_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k65536_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k32_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k64_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k128_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k65536_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k32_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k64_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k128_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k65536_dropout_sm70); +} + +// ======== f16 / sm75 ======== +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k32_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k64_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k128_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k128_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k65536_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k65536_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k32_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k64_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k128_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k128_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k65536_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k65536_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k32_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k64_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_128x64_k128_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k128_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_128x64_k65536_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k65536_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k32_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k64_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_128x64_k128_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k128_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_128x64_k65536_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k65536_dropout_sm75(typename AttentionBackwardKernel::Params p); + +template void dispatch_cutlassB_f16_sm75(T cb, int cc) { + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k32_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k64_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k128_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k128_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k65536_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k65536_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k32_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k64_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k128_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k128_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k65536_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k65536_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k32_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k64_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_128x64_k128_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k128_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_128x64_k65536_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k65536_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k32_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k64_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_128x64_k128_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k128_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_128x64_k65536_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k65536_dropout_sm75); +} + +// ======== f32 / sm75 ======== +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k32_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k64_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k128_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k65536_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k32_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k64_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k128_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k65536_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k32_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k64_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k128_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k65536_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k32_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k64_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k128_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k65536_dropout_sm75(typename AttentionBackwardKernel::Params p); + +template void dispatch_cutlassB_f32_sm75(T cb, int cc) { + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k32_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k64_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k128_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k65536_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k32_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k64_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k128_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k65536_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k32_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k64_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k128_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k65536_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k32_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k64_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k128_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k65536_dropout_sm75); +} + +// ======== f32 / sm80 ======== +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k32_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k64_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_128x64_k128_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k128_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_128x64_k65536_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k65536_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k32_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k64_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_128x64_k128_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k128_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_128x64_k65536_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k65536_dropout_sm80(typename AttentionBackwardKernel::Params p); + +template void dispatch_cutlassB_f32_sm80(T cb, int cc) { + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k32_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k64_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_128x64_k128_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k128_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_128x64_k65536_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k65536_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k32_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k64_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_128x64_k128_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k128_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_128x64_k65536_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k65536_dropout_sm80); +} + + +template +void dispatch_cutlassB(T cb, int cc = 0) { + + if (std::is_same_v && 70 <= cc && cc < 75) { + dispatch_cutlassB_f16_sm70(cb, cc); + } + if (std::is_same_v && 80 <= cc && cc <= 120) { + dispatch_cutlassB_bf16_sm80(cb, cc); + } + if (std::is_same_v && 80 <= cc && cc <= 120) { + dispatch_cutlassB_f16_sm80(cb, cc); + } + if (std::is_same_v && 50 <= cc && cc < 70) { + dispatch_cutlassB_f16_sm50(cb, cc); + } + if (std::is_same_v && 50 <= cc && cc < 70) { + dispatch_cutlassB_f32_sm50(cb, cc); + } + if (std::is_same_v && 70 <= cc && cc < 75) { + dispatch_cutlassB_f32_sm70(cb, cc); + } + if (std::is_same_v && 75 <= cc && cc < 80) { + dispatch_cutlassB_f16_sm75(cb, cc); + } + if (std::is_same_v && 75 <= cc && cc < 80) { + dispatch_cutlassB_f32_sm75(cb, cc); + } + if (std::is_same_v && 80 <= cc && cc <= 120) { + dispatch_cutlassB_f32_sm80(cb, cc); + } +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF.h b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF.h new file mode 100644 index 0000000000000000000000000000000000000000..32f624c1c8db439966c714ee1340ea5f751a641e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF.h @@ -0,0 +1,313 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +// This file is auto-generated. See "generate_kernels.py" +#pragma once +#include +using namespace PyTorchMemEffAttention; +// ======== bf16 / sm80 ======== +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_bf16_aligned_64x64_rf_sm80(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_bf16_aligned_64x128_rf_sm80(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_bf16_aligned_32x128_gmem_sm80(typename AttentionKernel::Params p); + +template void dispatch_cutlassF_bf16_sm80(T cb, int cc) { + cb(AttentionKernel(), fmha_cutlassF_bf16_aligned_64x64_rf_sm80); + cb(AttentionKernel(), fmha_cutlassF_bf16_aligned_64x128_rf_sm80); + cb(AttentionKernel(), fmha_cutlassF_bf16_aligned_32x128_gmem_sm80); +} + +// ======== f16 / sm50 ======== +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_aligned_64x64_rf_sm50(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_aligned_32x128_rf_sm50(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_aligned_32x128_gmem_sm50(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_notaligned_64x64_rf_sm50(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_notaligned_32x128_rf_sm50(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_notaligned_32x128_gmem_sm50(typename AttentionKernel::Params p); + +template void dispatch_cutlassF_f16_sm50(T cb, int cc) { + cb(AttentionKernel(), fmha_cutlassF_f16_aligned_64x64_rf_sm50); + cb(AttentionKernel(), fmha_cutlassF_f16_aligned_32x128_rf_sm50); + cb(AttentionKernel(), fmha_cutlassF_f16_aligned_32x128_gmem_sm50); + cb(AttentionKernel(), fmha_cutlassF_f16_notaligned_64x64_rf_sm50); + cb(AttentionKernel(), fmha_cutlassF_f16_notaligned_32x128_rf_sm50); + cb(AttentionKernel(), fmha_cutlassF_f16_notaligned_32x128_gmem_sm50); +} + +// ======== f16 / sm70 ======== +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_aligned_64x64_rf_sm70(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_aligned_32x128_rf_sm70(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_aligned_32x128_gmem_sm70(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_notaligned_64x64_rf_sm70(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_notaligned_32x128_rf_sm70(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_notaligned_32x128_gmem_sm70(typename AttentionKernel::Params p); + +template void dispatch_cutlassF_f16_sm70(T cb, int cc) { + cb(AttentionKernel(), fmha_cutlassF_f16_aligned_64x64_rf_sm70); + cb(AttentionKernel(), fmha_cutlassF_f16_aligned_32x128_rf_sm70); + cb(AttentionKernel(), fmha_cutlassF_f16_aligned_32x128_gmem_sm70); + cb(AttentionKernel(), fmha_cutlassF_f16_notaligned_64x64_rf_sm70); + cb(AttentionKernel(), fmha_cutlassF_f16_notaligned_32x128_rf_sm70); + cb(AttentionKernel(), fmha_cutlassF_f16_notaligned_32x128_gmem_sm70); +} + +// ======== f16 / sm75 ======== +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_aligned_64x64_rf_sm75(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_aligned_32x128_rf_sm75(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_aligned_32x128_gmem_sm75(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_notaligned_64x64_rf_sm75(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_notaligned_32x128_rf_sm75(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_notaligned_32x128_gmem_sm75(typename AttentionKernel::Params p); + +template void dispatch_cutlassF_f16_sm75(T cb, int cc) { + cb(AttentionKernel(), fmha_cutlassF_f16_aligned_64x64_rf_sm75); + cb(AttentionKernel(), fmha_cutlassF_f16_aligned_32x128_rf_sm75); + cb(AttentionKernel(), fmha_cutlassF_f16_aligned_32x128_gmem_sm75); + cb(AttentionKernel(), fmha_cutlassF_f16_notaligned_64x64_rf_sm75); + cb(AttentionKernel(), fmha_cutlassF_f16_notaligned_32x128_rf_sm75); + cb(AttentionKernel(), fmha_cutlassF_f16_notaligned_32x128_gmem_sm75); +} + +// ======== f16 / sm80 ======== +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_aligned_64x64_rf_sm80(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_aligned_64x128_rf_sm80(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_aligned_32x128_gmem_sm80(typename AttentionKernel::Params p); + +template void dispatch_cutlassF_f16_sm80(T cb, int cc) { + cb(AttentionKernel(), fmha_cutlassF_f16_aligned_64x64_rf_sm80); + cb(AttentionKernel(), fmha_cutlassF_f16_aligned_64x128_rf_sm80); + cb(AttentionKernel(), fmha_cutlassF_f16_aligned_32x128_gmem_sm80); +} + +// ======== f32 / sm50 ======== +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_aligned_64x64_rf_sm50(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_aligned_32x128_rf_sm50(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_aligned_32x128_gmem_sm50(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_notaligned_64x64_rf_sm50(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_notaligned_32x128_rf_sm50(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_notaligned_32x128_gmem_sm50(typename AttentionKernel::Params p); + +template void dispatch_cutlassF_f32_sm50(T cb, int cc) { + cb(AttentionKernel(), fmha_cutlassF_f32_aligned_64x64_rf_sm50); + cb(AttentionKernel(), fmha_cutlassF_f32_aligned_32x128_rf_sm50); + cb(AttentionKernel(), fmha_cutlassF_f32_aligned_32x128_gmem_sm50); + cb(AttentionKernel(), fmha_cutlassF_f32_notaligned_64x64_rf_sm50); + cb(AttentionKernel(), fmha_cutlassF_f32_notaligned_32x128_rf_sm50); + cb(AttentionKernel(), fmha_cutlassF_f32_notaligned_32x128_gmem_sm50); +} + +// ======== f32 / sm70 ======== +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_aligned_64x64_rf_sm70(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_aligned_32x128_rf_sm70(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_aligned_32x128_gmem_sm70(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_notaligned_64x64_rf_sm70(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_notaligned_32x128_rf_sm70(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_notaligned_32x128_gmem_sm70(typename AttentionKernel::Params p); + +template void dispatch_cutlassF_f32_sm70(T cb, int cc) { + cb(AttentionKernel(), fmha_cutlassF_f32_aligned_64x64_rf_sm70); + cb(AttentionKernel(), fmha_cutlassF_f32_aligned_32x128_rf_sm70); + cb(AttentionKernel(), fmha_cutlassF_f32_aligned_32x128_gmem_sm70); + cb(AttentionKernel(), fmha_cutlassF_f32_notaligned_64x64_rf_sm70); + cb(AttentionKernel(), fmha_cutlassF_f32_notaligned_32x128_rf_sm70); + cb(AttentionKernel(), fmha_cutlassF_f32_notaligned_32x128_gmem_sm70); +} + +// ======== f32 / sm75 ======== +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_aligned_64x64_rf_sm75(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_aligned_32x128_rf_sm75(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_aligned_32x128_gmem_sm75(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_notaligned_64x64_rf_sm75(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_notaligned_32x128_rf_sm75(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_notaligned_32x128_gmem_sm75(typename AttentionKernel::Params p); + +template void dispatch_cutlassF_f32_sm75(T cb, int cc) { + cb(AttentionKernel(), fmha_cutlassF_f32_aligned_64x64_rf_sm75); + cb(AttentionKernel(), fmha_cutlassF_f32_aligned_32x128_rf_sm75); + cb(AttentionKernel(), fmha_cutlassF_f32_aligned_32x128_gmem_sm75); + cb(AttentionKernel(), fmha_cutlassF_f32_notaligned_64x64_rf_sm75); + cb(AttentionKernel(), fmha_cutlassF_f32_notaligned_32x128_rf_sm75); + cb(AttentionKernel(), fmha_cutlassF_f32_notaligned_32x128_gmem_sm75); +} + +// ======== f32 / sm80 ======== +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_aligned_64x64_rf_sm80(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_aligned_64x128_rf_sm80(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_aligned_32x128_gmem_sm80(typename AttentionKernel::Params p); + +template void dispatch_cutlassF_f32_sm80(T cb, int cc) { + cb(AttentionKernel(), fmha_cutlassF_f32_aligned_64x64_rf_sm80); + cb(AttentionKernel(), fmha_cutlassF_f32_aligned_64x128_rf_sm80); + cb(AttentionKernel(), fmha_cutlassF_f32_aligned_32x128_gmem_sm80); +} + + +template +void dispatch_cutlassF(T cb, int cc = 0) { + + if (std::is_same_v && 80 <= cc && cc <= 120) { + dispatch_cutlassF_bf16_sm80(cb, cc); + } + if (std::is_same_v && 50 <= cc && cc < 70) { + dispatch_cutlassF_f16_sm50(cb, cc); + } + if (std::is_same_v && 70 <= cc && cc < 75) { + dispatch_cutlassF_f16_sm70(cb, cc); + } + if (std::is_same_v && 75 <= cc && cc < 80) { + dispatch_cutlassF_f16_sm75(cb, cc); + } + if (std::is_same_v && 80 <= cc && cc <= 120) { + dispatch_cutlassF_f16_sm80(cb, cc); + } + if (std::is_same_v && 50 <= cc && cc < 70) { + dispatch_cutlassF_f32_sm50(cb, cc); + } + if (std::is_same_v && 70 <= cc && cc < 75) { + dispatch_cutlassF_f32_sm70(cb, cc); + } + if (std::is_same_v && 75 <= cc && cc < 80) { + dispatch_cutlassF_f32_sm75(cb, cc); + } + if (std::is_same_v && 80 <= cc && cc <= 120) { + dispatch_cutlassF_f32_sm80(cb, cc); + } +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/pytorch_utils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/pytorch_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..2726dd3ff3e6f393d679dc6bf216190f1d4b0eca --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/pytorch_utils.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include + +#include +#include + + +template +struct CutlassToAtenDtype; + +template <> +struct CutlassToAtenDtype { + using scalar_t = cutlass::half_t; + + static constexpr __host__ at::ScalarType atScalarType() { + return at::ScalarType::Half; + } +}; + +template <> +struct CutlassToAtenDtype { + using scalar_t = cutlass::bfloat16_t; + + static constexpr __host__ at::ScalarType atScalarType() { + return at::ScalarType::BFloat16; + } +}; + +template <> +struct CutlassToAtenDtype { + using scalar_t = float; + + static constexpr __host__ at::ScalarType atScalarType() { + return at::ScalarType::Float; + } +}; diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/transform/tile_smem_loader.h b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/transform/tile_smem_loader.h new file mode 100644 index 0000000000000000000000000000000000000000..0471506058ffbaf3db37733f638aa1ab4f180ed2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/transform/tile_smem_loader.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +template < + typename scalar_t, // scalar type + typename ThreadblockTileShape, // size of tile to load + int Threads, // number of participating threads + int ElementsPerAccess> // thread access width in elements +class TileSmemLoader { + public: + using SmemTile = + cutlass::AlignedBuffer; + + using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< + cutlass::layout::PitchLinearShape< + ThreadblockTileShape::kColumn, // contiguous + ThreadblockTileShape::kRow>, // strided + Threads, // Threads + ElementsPerAccess>; // ElementsPerAccess + + using GmemTileIterator = + cutlass::transform::threadblock::PredicatedTileIterator< + ThreadblockTileShape, // Shape + scalar_t, // Element + cutlass::layout::RowMajor, // Layout + 0, // AdvanceRank + ThreadMap>; // ThreadMap + + using SmemTileIterator = cutlass::transform::threadblock::RegularTileIterator< + ThreadblockTileShape, // Shape + scalar_t, // Element + cutlass::layout::RowMajor, // Layout + 0, // AdvanceRank + ThreadMap>; // ThreadMap + + using Fragment = typename GmemTileIterator::Fragment; + + /// load a tile from global memory into shared memory + CUTLASS_DEVICE + static void load( + GmemTileIterator tile_load_iter, + SmemTileIterator tile_store_iter) { + Fragment tb_frag; + tb_frag.clear(); + tile_load_iter.load(tb_frag); + tile_store_iter.store(tb_frag); + + __syncthreads(); + } +}; diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/sdp_utils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/sdp_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..02177f99fb86ed7c15f0d21e7cc7f39c19f64e94 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/cuda/sdp_utils.h @@ -0,0 +1,17 @@ +#pragma once + +#include +#include +#include +#include + +namespace sdp { + +bool check_for_seq_len_1_nested_tensor(sdp_params const& params, bool debug); +SDPBackend select_sdp_backend(sdp_params const& kernel_params); +C10_EXPORT bool is_flash_attention_available(); +C10_EXPORT bool can_use_flash_attention(sdp_params const& params, bool debug); +C10_EXPORT bool can_use_mem_efficient_attention(sdp_params const& params, bool debug); +C10_EXPORT bool can_use_cudnn_attention(sdp_params const& params, bool debug); + +} // namespace sdp diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/hip/aotriton_adapter.h b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/hip/aotriton_adapter.h new file mode 100644 index 0000000000000000000000000000000000000000..894eac7ec90362fdee72df09ea9ebc1ae8d337b0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/hip/aotriton_adapter.h @@ -0,0 +1,126 @@ +#pragma once + +#ifdef USE_ROCM + +#include +#include + +//////////////////////////////////////////////////////////////////////////////// +// Common macros copied from cuda/mem_eff_attention/gemm_kernel_utils.h +//////////////////////////////////////////////////////////////////////////////// + +namespace sdp { + +namespace aotriton_adapter { + +inline aotriton::DType cast_dtype(caffe2::TypeMeta t_dtype) +{ +#define CAST_TYPE(aname, dtname) if (t_dtype == at::aname) return aotriton::DType::dtname + CAST_TYPE(kByte, kUInt8); + CAST_TYPE(kUInt16, kUInt16); + CAST_TYPE(kUInt32, kUInt32); + CAST_TYPE(kUInt64, kUInt64); + CAST_TYPE(kChar, kInt8); + CAST_TYPE(kShort, kInt16); + CAST_TYPE(kInt, kInt32); + CAST_TYPE(kLong, kInt64); + CAST_TYPE(kHalf, kFloat16); + CAST_TYPE(kFloat, kFloat32); + CAST_TYPE(kBFloat16, kBFloat16); + return aotriton::DType::kUnknown; +#undef CAST_TYPE +} + +template +struct IntArrayRefCaster { + // std::array cast(IntArrayRef); +}; + +template +struct IntArrayRefCaster { + static auto cast(at::IntArrayRef ref) { + return std::array{{ static_cast(ref.at(0)) }}; + } +}; + +template +struct IntArrayRefCaster { + static auto cast(at::IntArrayRef ref) { + return std::array{{ + static_cast(ref.at(0)), + static_cast(ref.at(1)) + }}; + } +}; + +template +struct IntArrayRefCaster { + static auto cast(at::IntArrayRef ref) { + return std::array{{ + static_cast(ref.at(0)), + static_cast(ref.at(1)), + static_cast(ref.at(2)) + }}; + } +}; + +template +struct IntArrayRefCaster { + static auto cast(at::IntArrayRef ref) { + return std::array{{ + static_cast(ref.at(0)), + static_cast(ref.at(1)), + static_cast(ref.at(2)), + static_cast(ref.at(3)) + }}; + } +}; + + +template +aotriton::TensorView mk_aotensor(const at::Tensor& q, std::string_view tensor_name) +{ + const auto strides = q.strides(); + int real_rank = strides.size(); + if (real_rank != Rank) { // Lazy convertion of tensor_name + TORCH_CHECK(false, + std::string(tensor_name) + "'s rank should be " + std::to_string(Rank) + + " but is " + std::to_string(real_rank)); + } + return aotriton::TensorView(reinterpret_cast(q.data_ptr()), + IntArrayRefCaster::cast(q.sizes()), + IntArrayRefCaster::cast(strides), + cast_dtype(q.dtype())); +} + +inline aotriton::TensorView<0> mk_aoscalartensor(const at::Tensor& q) +{ + return aotriton::TensorView<0>(reinterpret_cast(q.data_ptr()), + cast_dtype(q.dtype())); +} + +inline aotriton::TensorView<0> mk_philoxtensor(const int64_t* ptr) +{ + return aotriton::TensorView<0>(reinterpret_cast(ptr), + aotriton::DType::kUInt64); // AOTriton accepts unsigned int64 +} + +inline aotriton::TensorView<0> mk_atomictensor(const int32_t* ptr) +{ + return aotriton::TensorView<0>(reinterpret_cast(ptr), + aotriton::DType::kInt32); +} + +} // namespace aotriton_adapter + +} // namespace sdp + +namespace at::native { + +inline int64_t ceil_div(int64_t numerator, int64_t denominator) { + return (numerator + (denominator - 1)) / denominator; +} + +} + +#endif // USE_ROCM diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/hip/flash_attn/ck/me_ck_api.h b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/hip/flash_attn/ck/me_ck_api.h new file mode 100644 index 0000000000000000000000000000000000000000..408eb9180ac55d6c4a338ca469c00e7c391084e1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/hip/flash_attn/ck/me_ck_api.h @@ -0,0 +1,67 @@ +#pragma once +#include + +#include + +#if defined(USE_CK_FLASH_ATTENTION) +namespace pytorch_flash { + +std::tuple< + at::Tensor, // output + at::Tensor, // q + at::Tensor, // k + at::Tensor, // v + at::Tensor, // lse + at::Tensor, // seed + at::Tensor, // offset + at::Tensor> // dropout randval +mem_eff_forward_ck( + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + float p_dropout, + bool return_dropout_randval, + std::optional is_causal, + std::optional scale, + const std::optional& attn_bias_, + std::optional& out_, + const std::optional& cu_seqlens_q, + const std::optional& cu_seqlens_k, + const std::optional& seqstart_q, + const std::optional& seqstart_k, + std::optional gen_, + std::optional& seqused_k_ +); + +std::tuple< + at::Tensor, // dQ + at::Tensor, // dK + at::Tensor, // dV + at::Tensor> // dBias +mem_eff_backward_ck( + const at::Tensor &dout, + const at::Tensor &q, + const at::Tensor &k, + const at::Tensor &v, + const at::Tensor &out, + const at::Tensor &softmax_lse, + const at::Tensor &dq_, + const at::Tensor &dk_, + const at::Tensor &dv_, + std::optional &attn_bias, + bool bias_requires_grad, + std::optional &grad_bias, + std::optional &cu_seqlens_q, + std::optional &cu_seqlens_k, + int max_seqlen_q, + int max_seqlen_k, + float p_dropout, + float scale, + bool is_causal, + bool deterministic, + bool zero_tensors, + const at::Tensor philox_seed, + const at::Tensor philox_offset); + +} // namespace pytorch_flash +#endif // USE_CK_FLASH_ATTENTION diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/hip/flash_attn/flash_api.h b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/hip/flash_attn/flash_api.h new file mode 100644 index 0000000000000000000000000000000000000000..4fecda71587ba089e1c9a4cdea2fafe35abd09db --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/hip/flash_attn/flash_api.h @@ -0,0 +1,599 @@ +#pragma once +#include + +#include +#include +#include + +#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \ + TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + TORCH_CHECK(TENSOR.is_contiguous()); + +#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \ + TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + TORCH_CHECK( \ + TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous"); + +#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ + TORCH_CHECK( \ + uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned") + +#define ASSIGN_CHECK_OVERFLOW(A, B) \ + { \ + A = B; \ + TORCH_CHECK( \ + B < std::numeric_limits::max(), #B " overflows"); \ + } + +namespace pytorch_flash { + +// AOTriton Implementation +TORCH_API +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +mha_fwd_aot( + const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + out_, // batch_size x seqlen_q x num_heads x head_size + std::optional& + alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + std::optional window_size_left, + std::optional window_size_right, + const bool return_softmax, + const std::optional& gen_); + +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +mha_varlen_fwd_aot( + const at::Tensor& + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor& + k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& + v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional& + out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + std::optional& + seqused_k, // b. If given, only this many elements of each batch + // element's keys are used. + std::optional& block_table_, + std::optional& alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + std::optional window_size_left, + std::optional window_size_right, + const bool return_softmax, + const std::optional& gen_); + +std::tuple mha_bwd_aot( + const at::Tensor& dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& softmax_lse, // b x h x seqlen_q + std::optional& + dq_, // batch_size x seqlen_q x num_heads x head_size + std::optional& + dk_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + dv_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, // probability to drop + const float softmax_scale, + const bool is_causal, + std::optional window_size_left, + std::optional window_size_right, + const bool deterministic, + const at::Tensor& philox_seed, + const at::Tensor& philox_offset); + +std::tuple mha_varlen_bwd_aot( + const at::Tensor& dout, // total_q x num_heads, x head_size + const at::Tensor& + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor& + k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& + v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& out, // total_q x num_heads x head_size + const at::Tensor& softmax_lse, // b x h x s softmax logsumexp + std::optional& + dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + std::optional& + dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional& + dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + std::optional& alibi_slopes_, // num_heads or b x num_heads + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + std::optional window_size_left, + std::optional window_size_right, + const bool deterministic, + const at::Tensor& philox_seed, + const at::Tensor& philox_offset); + +#if defined(USE_CK_FLASH_ATTENTION) +// CK implementation +TORCH_API +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +mha_fwd_ck( + const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + out_, // batch_size x seqlen_q x num_heads x head_size + const float p_dropout, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_softmax, + std::optional gen_, + const std::optional& attn_bias_); // batch_size x nheads x seqlen_q x seqlen_k + +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +mha_varlen_fwd_ck( + const at::Tensor& + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor& + k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& + v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional& + out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + std::optional& + seqused_k, // b. If given, only this many elements of each batch + // element's keys are used. + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_softmax, + std::optional gen_, + const std::optional& attn_bias_); + +std::tuple mha_bwd_ck( + const at::Tensor& dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& softmax_lse, // b x h x seqlen_q + std::optional& + dq_, // batch_size x seqlen_q x num_heads x head_size + std::optional& + dk_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + dv_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + attn_bias_, // batch_size x num_heads x seqlen_q x seqlen_k + bool bias_requires_grad, + std::optional& grad_bias, + const float p_dropout, // probability to drop + const float softmax_scale, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset); + +std::tuple mha_varlen_bwd_ck( + const at::Tensor& dout, // total_q x num_heads, x head_size + const at::Tensor& + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor& + k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& + v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& out, // total_q x num_heads x head_size + const at::Tensor& softmax_lse, // b x h x s softmax logsumexp + std::optional& + dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + std::optional& + dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional& + dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + std::optional& attn_bias_, // num_heads or b x num_heads + bool bias_requires_grad, + std::optional& grad_bias, + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset); +#endif + +TORCH_API +inline std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +mha_fwd( + const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + out_, // batch_size x seqlen_q x num_heads x head_size + std::optional& + alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + std::optional window_size_left, + std::optional window_size_right, + const float softcap, + const bool return_softmax, + std::optional gen_) { +#if defined(USE_CK_FLASH_ATTENTION) + if (at::globalContext().getROCmFAPreferredBackend() == + at::ROCmFABackend::Ck) { + const int non_null_window_left = window_size_left.value_or(-1); + const int non_null_window_right = window_size_right.value_or(-1); + std::optional dummy_attn_bias = std::nullopt; + return mha_fwd_ck( + q, + k, + v, + out_, + p_dropout, + softmax_scale, + is_causal, + non_null_window_left, + non_null_window_right, + return_softmax, + gen_, + dummy_attn_bias); // Not used in flash attention + } +#endif + return mha_fwd_aot( + q, + k, + v, + out_, + alibi_slopes_, + p_dropout, + softmax_scale, + is_causal, + window_size_left, + window_size_right, + return_softmax, + gen_); +} + +inline std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +mha_varlen_fwd( + const at::Tensor& + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor& + k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& + v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional& + out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + std::optional& + seqused_k, // b. If given, only this many elements of each batch + // element's keys are used. + std::optional& + block_table_, // Not used on ROCm. Keeping for parity with CUDA + std::optional& alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + std::optional window_size_left, + std::optional window_size_right, + const float softcap, + const bool return_softmax, + std::optional gen_) { +#if defined(USE_CK_FLASH_ATTENTION) + if (at::globalContext().getROCmFAPreferredBackend() == + at::ROCmFABackend::Ck) { + std::optional dummy_attn_bias = std::nullopt; + const int non_null_window_left = window_size_left.value_or(-1); + const int non_null_window_right = window_size_right.value_or(-1); + return mha_varlen_fwd_ck( + q, + k, + v, + out_, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + max_seqlen_q, + max_seqlen_k, + p_dropout, + softmax_scale, + zero_tensors, + is_causal, + non_null_window_left, + non_null_window_right, + return_softmax, + gen_, + dummy_attn_bias); // Not used in flash attention + } +#endif + return mha_varlen_fwd_aot( + q, + k, + v, + out_, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + block_table_, + alibi_slopes_, + max_seqlen_q, + max_seqlen_k, + p_dropout, + softmax_scale, + zero_tensors, + is_causal, + window_size_left, + window_size_right, + return_softmax, + gen_); +} + +inline std::tuple mha_bwd( + const at::Tensor& dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& softmax_lse, // b x h x seqlen_q + std::optional& + dq_, // batch_size x seqlen_q x num_heads x head_size + std::optional& + dk_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + dv_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, // probability to drop + const float softmax_scale, + const bool is_causal, + std::optional window_size_left, + std::optional window_size_right, + const float softcap, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset) { + if (at::globalContext().getROCmFAPreferredBackend() == + at::ROCmFABackend::Ck) { +#if defined(USE_CK_FLASH_ATTENTION) + std::optional non_null_dbias = std::nullopt; + const int non_null_window_left = window_size_left.value_or(-1); + const int non_null_window_right = window_size_right.value_or(-1); + auto[dQuery, + dKey, + dValue, + dSoftmax, + dBias] = mha_bwd_ck( + dout, + q, + k, + v, + out, + softmax_lse, + dq_, + dk_, + dv_, + alibi_slopes_, + false, // bias_requires_grad + non_null_dbias, + p_dropout, + softmax_scale, + is_causal, + non_null_window_left, + non_null_window_right, + deterministic, + philox_seed, + philox_offset); + // for FA return [dQ, dV, dK, dSoftmax] + return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax)); +#else + TORCH_WARN_ONCE("Warning! You have opted to use CK flash attention backend in a build that was not compiled using USE_CK_FLASH_ATTENTION=1. Please set this variable and try again. Defaulting to use aotriton backend..."); +#endif + } + return mha_bwd_aot( + dout, + q, + k, + v, + out, + softmax_lse, + dq_, + dk_, + dv_, + alibi_slopes_, + p_dropout, + softmax_scale, + is_causal, + window_size_left, + window_size_right, + deterministic, + philox_seed, + philox_offset); +} + +inline std::tuple mha_varlen_bwd( + const at::Tensor& dout, // total_q x num_heads, x head_size + const at::Tensor& + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor& + k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& + v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& out, // total_q x num_heads x head_size + const at::Tensor& softmax_lse, // b x h x s softmax logsumexp + std::optional& + dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + std::optional& + dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional& + dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + std::optional& alibi_slopes_, // num_heads or b x num_heads + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + std::optional window_size_left, + std::optional window_size_right, + const float softcap, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset) { +#if defined(USE_CK_FLASH_ATTENTION) + if (at::globalContext().getROCmFAPreferredBackend() == + at::ROCmFABackend::Ck) { + std::optional non_null_dbias = std::nullopt; + const int non_null_window_left = window_size_left.value_or(-1); + const int non_null_window_right = window_size_right.value_or(-1); + auto[dQuery, + dKey, + dValue, + dSoftmax, + dBias] = mha_varlen_bwd_ck( + dout, + q, + k, + v, + out, + softmax_lse, + dq_, + dk_, + dv_, + cu_seqlens_q, + cu_seqlens_k, + alibi_slopes_, + false, // bias_requires_grad + non_null_dbias, + max_seqlen_q, + max_seqlen_k, + p_dropout, + softmax_scale, + zero_tensors, + is_causal, + non_null_window_left, + non_null_window_right, + deterministic, + philox_seed, + philox_offset); + // for FA return [dQ, dV, dK, dSoftmax] + return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax)); + } +#endif + return mha_varlen_bwd_aot( + dout, + q, + k, + v, + out, + softmax_lse, + dq_, + dk_, + dv_, + cu_seqlens_q, + cu_seqlens_k, + alibi_slopes_, + max_seqlen_q, + max_seqlen_k, + p_dropout, + softmax_scale, + zero_tensors, + is_causal, + window_size_left, + window_size_right, + deterministic, + philox_seed, + philox_offset); +} + +} // namespace pytorch_flash diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/sdp_utils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/sdp_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..884531d0016496f00574af2b034cdec54915059f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/sdp_utils.h @@ -0,0 +1,88 @@ +#pragma once +#include +#include + +namespace at::native { + +void alloc_with_matching_layout( + const Tensor& q, + Tensor& output, + const std::vector& shape) { + TORCH_INTERNAL_ASSERT( + shape.size() == q.sizes().size(), + "SDPA alloc_with_matching_layout got requested shape ndim != q ndim"); + + if (std::equal(q.sizes().begin(), q.sizes().end(), shape.begin())) { + output = at::empty_like(q); + return; + } + + // get the "fill order," which is just an argsort on the strides + std::vector fill_order(shape.size()); + std::iota(fill_order.begin(), fill_order.end(), 0); + const auto q_strides = q.strides(); + std::stable_sort( + fill_order.begin(), fill_order.end(), [&q_strides](int idx1, int idx2) { + return q_strides[idx1] < q_strides[idx2]; + }); + std::vector ordered_strides(shape.size()); + int64_t current_stride = 1; + for (const int dim_idx : fill_order) { + ordered_strides[dim_idx] = current_stride; + current_stride *= shape[dim_idx]; + } + output = at::empty(at::IntArrayRef(shape), q.options()) + .as_strided( + at::IntArrayRef(shape), at::IntArrayRef(ordered_strides), 0); +} + +void permute_to_matching_layout(const Tensor& output, Tensor& grad_output) { + const int dims = output.sizes().size(); + std::vector outer_to_inner(dims); + std::iota(outer_to_inner.begin(), outer_to_inner.end(), 0); + const auto o_strides = output.strides(); + std::stable_sort( + outer_to_inner.begin(), + outer_to_inner.end(), + [&o_strides](int idx1, int idx2) { + return o_strides[idx1] > o_strides[idx2]; + }); + std::vector inverse(dims); + for (int d = 0; d < dims; d++) { + inverse[d] = std::find(outer_to_inner.begin(), outer_to_inner.end(), d) - + outer_to_inner.begin(); + } + grad_output = grad_output.permute(at::IntArrayRef(outer_to_inner)) + .contiguous() + .permute(at::IntArrayRef(inverse)); +} + +bool same_strides(const Tensor& t1, const Tensor& t2) { + std::vector t1_strides_no_ones; + std::vector t2_strides_no_ones; + const auto t1strides = t1.strides(); + const auto t2strides = t2.strides(); + const int dim = t1strides.size(); + if (dim != (int)t2strides.size()) { + return false; + } + const auto t1sizes = t1.sizes(); + const auto t2sizes = t2.sizes(); + + // we are going through strides backward here, but if both are backward it's + // comparable + for (int i = 0; i < dim; i++) { + if (t1sizes[i] > 1) { + t1_strides_no_ones.push_back(t1strides[i]); + } + if (t2sizes[i] > 1) { + t2_strides_no_ones.push_back(t2strides[i]); + } + } + return std::equal( + t1_strides_no_ones.begin(), + t1_strides_no_ones.end(), + t2_strides_no_ones.begin(), + t2_strides_no_ones.end()); +} +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/sdp_utils_cpp.h b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/sdp_utils_cpp.h new file mode 100644 index 0000000000000000000000000000000000000000..a78926f179d57d9dff0ab8b042ac93e35873f95a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/transformers/sdp_utils_cpp.h @@ -0,0 +1,560 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace sdp { + +constexpr int32_t num_backends = at::num_sdp_backends; +using SDPBackend = at::SDPBackend; + +// Note that if this changed make sure to update +// the templated enum in mem_eff/kernel_forward.h and mem_eff/kernel_backward.h +enum class CustomMaskType { + NoCustomMask = 0, + CausalFromTopLeft = 1, + CausalFromBottomRight = 2, + NumCustomMaskTypes, +}; + +struct sdp_params { + at::Tensor query; + at::Tensor key; + at::Tensor value; + std::optional attn_mask; + double dropout; + bool is_causal; + bool enable_gqa; +}; + +SDPBackend select_sdp_backend_cpp(sdp_params const& kernel_params); + +inline c10::SymFloat calculate_scale( + const at::Tensor& query, + std::optional scale) { + const auto softmax_scale = scale.has_value() + ? scale.value() + : (c10::SymFloat(1.0) / (c10::SymFloat(query.sym_size(-1)).sqrt())); + return c10::SymFloat(softmax_scale); +} + +inline bool input_requires_grad(sdp_params const& params) { + const bool any_inputs_require_grad = params.query.requires_grad() || + params.key.requires_grad() || params.value.requires_grad(); + const bool gradmode_enabled = at::GradMode::is_enabled(); + return any_inputs_require_grad && gradmode_enabled; +} + +inline bool has_for_nested_inputs(sdp_params const& params) { + return + (params.query.is_nested() && params.query.layout() == c10::kStrided) || + (params.key.is_nested() && params.key.layout() == c10::kStrided) || + (params.value.is_nested() && params.value.layout() == c10::kStrided); +} + +inline bool has_for_dense_inputs(sdp_params const& params) { + return !params.query.is_nested() || !params.key.is_nested() || !params.value.is_nested(); +} + +inline bool has_only_dense_inputs(sdp_params const& params) { + return !params.query.is_nested() && !params.key.is_nested() && !params.value.is_nested(); +} + +template +inline bool check_tensor_dtype( + sdp_params const& params, + dtype_vector allowed_dtypes, + bool debug) { + auto query_dtype = params.query.dtype(); + if (!(query_dtype == params.key.dtype() && + query_dtype == params.value.dtype() && + (std::find(allowed_dtypes.begin(), allowed_dtypes.end(), query_dtype) != + allowed_dtypes.end()))) { + if (debug) { + TORCH_WARN( + "Expected query, key and value to all be of dtype: {", + c10::Join(", ", allowed_dtypes), + "}. Got ", + "Query dtype: ", + params.query.dtype(), + ", Key dtype: ", + params.key.dtype(), + ", and Value dtype: ", + params.value.dtype(), + " instead."); + } + return false; + } + return true; +} + + +inline bool try_broadcast_param_size( + const c10::SymInt q_size, + const c10::SymInt k_size, + const c10::SymInt v_size, + std::string_view param_name, + bool debug) { + auto max_size = std::max({q_size, k_size, v_size}); + if ((q_size != max_size && q_size != 1) || + (k_size != max_size && k_size != 1) || + (v_size != max_size && v_size != 1)) { + if (debug) { + TORCH_WARN( + "Both fused kernels require query, key and value to have broadcastable ", + param_name, + "got Query ", + param_name, + q_size, + ", Key ", + param_name, + k_size, + ", Value ", + param_name, + v_size, + " instead."); + } + return false; + } + return true; +} + +inline bool check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper( + at::Tensor const& param, + std::string_view param_name, + bool debug) { + const auto nt_tensor_impl = at::native::get_nested_tensor_impl(param); + const at::Tensor& sizes = nt_tensor_impl->get_nested_sizes(); + auto num_head_dims = nt_tensor_impl->opt_size(1); + if (!num_head_dims.has_value()) { + // num_head_dims is ragged + if (debug) { + TORCH_WARN( + "Fused kernels do not support ragged num_head_dims, ", + param_name, + "has a ragged num_heads."); + } + return false; + } + + auto* sizes_ptr = sizes.data_ptr(); + const int64_t n_tensors = param.size(0); + const int64_t size_tensor_stride = sizes.stride(0); + + // This is being called inside sdp with shape [batch, heads, {seq_len}, dim] + for (const auto i : c10::irange(n_tensors)) { + if (sizes_ptr[(i * size_tensor_stride) + 1] == 0) { + if (debug) { + TORCH_WARN( + "Fused kernels do not support seq_len == 0, ", + param_name, + "has a seq len of 0."); + } + return false; + } + } + return true; +} + +inline bool check_for_seq_len_0_nested_tensor(sdp_params const& params, bool debug) { + // When this function is called we are assured that the nt is dim==4 + bool q_is_safe = params.query.is_nested() + ? check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper( + params.query, "query ", debug) + : true; + // short circuit if any is unsafe + if (!q_is_safe) { + return false; + } + + bool k_is_safe = params.key.is_nested() + ? check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper( + params.key, "key ", debug) + : true; + if (!k_is_safe) { + return false; + } + + bool v_is_safe = params.value.is_nested() + ? check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper( + params.value, "value ", debug) + : true; + if (!v_is_safe) { + return false; + } + + // We now know none of the inputs have ragged num_heads, so we can safely + // access .size(1) + auto q_num_heads = params.query.size(1); + auto k_num_heads = params.key.size(1); + auto v_num_heads = params.value.size(1); + bool same_num_heads = + q_num_heads == k_num_heads && q_num_heads == v_num_heads; + + if (!same_num_heads) { + if (input_requires_grad(params)){ + if (debug) { + TORCH_WARN( + "Both fused kernels do not support training with broadcasted NT inputs."); + } + return false; + } + return try_broadcast_param_size( + q_num_heads, k_num_heads, v_num_heads, "num heads ", debug); + } + + return true; +} + +inline bool check_nested_tensor(sdp_params const& params, bool debug) { + // Return false if have nested tensor + if (!has_only_dense_inputs(params)) { + if (debug) { + TORCH_WARN( + "Both fused kernels of cpp version currently do not support Nested Tensor inputs."); + } + return false; + } + return true; +} + +inline bool check_for_dropout(sdp_params const& params, bool debug) { + if (params.dropout > 0.0) { + if (debug) { + TORCH_WARN("Both fused kernels do not support non-zero dropout."); + } + return false; + } + return true; +} + +inline bool check_requires_grad_and_nested(sdp_params const& params, bool debug) { + if (input_requires_grad(params)) { + if (debug) { + TORCH_WARN( + "Memory efficient attention currently doesn't support training with NT inputs."); + } + return false; + } + return true; +} + +inline bool check_for_attn_mask(sdp_params const& params, bool debug) { + if (params.attn_mask.has_value()) { + if (debug) { + TORCH_WARN("Flash Attention does not support non-null attn_mask."); + } + return false; + } + return true; +} + +inline bool check_attn_mask_shape(sdp_params const& params, bool debug) { + auto attn_mask = params.attn_mask; + if (!attn_mask.has_value()) { + return true; + } + if (attn_mask.value().requires_grad()) { + return false; + } + auto batchSize = params.query.sym_size(0); + auto qSize = params.query.sym_size(2); + auto kvSize = params.key.sym_size(2); + auto num_head = params.query.sym_size(1); + if (attn_mask.value().sym_size(-2) != qSize && attn_mask.value().sym_size(-2) != 1) { + return false; + } + if (attn_mask.value().sym_size(-1) != kvSize && attn_mask.value().sym_size(-1) != 1) { + return false; + } + if (attn_mask.value().dim() == 2) { + return true; + } else if (attn_mask.value().dim() == 4) { + if ((attn_mask.value().sym_size(0) == 1 || attn_mask.value().sym_size(0) == batchSize) + && (attn_mask.value().sym_size(1) == 1 || attn_mask.value().sym_size(1) == num_head)) { + return true; + } + } + if (debug) { + TORCH_WARN("Please use the following attn mask shapes: ", + "2d - ({Q_seq_len, 1} x {KV_seq_len, 1}); ", + "4d - ({Batch, 1} x {Num_heads, 1} x {Q_seq_len, 1} x {KV_seq_len, 1})"); + } + return false; +} + +inline bool check_tensor_shapes(sdp_params const& params, bool debug) { + auto query_dim = params.query.dim(); + if (!(query_dim == params.key.dim() && query_dim == params.value.dim() && + (query_dim == 4))) { + if (debug) { + TORCH_WARN( + "All fused kernels requires query, key and value to be 4 dimensional, but got Query dim: ", + query_dim, + ", Key dim: ", + params.key.dim(), + ", Value dim: ", + params.value.dim(), + " instead."); + } + return false; + } + return true; +} + +inline bool check_safe_kv_broadcast(at::Tensor const& param, bool debug) { + const auto nt_tensor_impl = at::native::get_nested_tensor_impl(param); + auto seq_len = nt_tensor_impl->opt_size(2); + if (!seq_len.has_value()) { + if (debug) { + TORCH_WARN( + "For both fused kernels, if one of key/value batch_size requires " + "broadcasting and the other does not, then the other must have a ", + "consistent seq_len dim.") + } + return false; + } + return true; +} + +template +inline bool check_grouped_query_attention(sdp_params const& params, bool debug) { + const auto q_num_heads = params.query.sym_size(-3); + const auto k_num_heads = params.key.sym_size(-3); + const auto v_num_heads = params.value.sym_size(-3); + const bool same_kv_heads = k_num_heads == v_num_heads; + + if (requires_same_num_heads && !(same_kv_heads)){ + if (debug) { + TORCH_WARN( + "Both fused kernels require key and value to have the same num_heads and batch_size but got: ", + "Key sizes: ", + params.key.sizes(), + ", Value sizes: ", + params.value.sizes(), + ", Query sizes: ", + params.query.sizes(), + " instead."); + } + return false; + } + // Check if grouped query attention is supported and validate the number of + // heads + if (q_num_heads % k_num_heads != 0 || (!requires_same_num_heads && (q_num_heads % v_num_heads != 0))) { + if (debug) { + TORCH_WARN( + "The number of heads in key/value must divide number of heads in query.", + "Got input Key sizes(): ", + params.key.sym_size(-3), + ", Value sizes(): ", + params.value.sym_size(-3), + ", Query sizes(): ", + params.query.sym_size(-3), + " instead."); + } + return false; + } + return true; +} + +template +inline bool check_batch_size_and_num_heads_dense(sdp_params const& params, bool debug) { + // This is expected to be called after check_tensor_shapes ensuring that the + // size() calls won't error since the inputs are all 4 dimensional + + auto q_batch_size = params.query.sym_size(0); + auto k_batch_size = params.key.sym_size(0); + auto v_batch_size = params.value.sym_size(0); + + bool same_batch_size = + q_batch_size == k_batch_size && q_batch_size == v_batch_size; + + auto q_num_heads = params.query.sym_size(-3); + auto k_num_heads = params.key.sym_size(-3); + auto v_num_heads = params.value.sym_size(-3); + + bool same_num_heads = + q_num_heads == k_num_heads && q_num_heads == v_num_heads; + + if (!same_batch_size){ + if(debug) { + TORCH_WARN( + "For dense inputs, both fused kernels require query, key and value to have the same batch_size. ", + "Query.sizes(): ", + params.query.sizes(), + ", Key.sizes(): ", + params.key.sizes(), + ", Value.sizes(): ", + params.value.sizes(), + " instead. To broadcast dense inputs, try using unsqueeze and expand_to before passing them into the kernel."); + } + return false; + } + + if(params.enable_gqa && supports_gqa){ + return check_grouped_query_attention(params, debug); + } + + // same num heads condition for non-gqa case + if (!same_num_heads){ + if (debug) { + TORCH_WARN( + "For dense input, both fused kernels require query, key and value to have the same num_heads. ", + "Query.sizes(): ", + params.query.sizes(), + ", Key sizes(): ", + params.key.sizes(), + ", Value sizes(): ", + params.value.sizes(), + " instead. To broadcast dense inputs, try using unsqueeze and expand_to before passing them into the kernel."); + } + return false; + } + // If all checks pass, return true + return true; +} + +inline bool check_batch_size_nested(sdp_params const& params, bool debug) { + // This is expected to be called after check_tensor_shapes ensuring that the + // size() calls won't error since the inputs are all 4 dimensional + auto q_batch_size = params.query.sym_size(0); + auto k_batch_size = params.key.sym_size(0); + auto v_batch_size = params.value.sym_size(0); + + bool same_batch_size = + q_batch_size == k_batch_size && q_batch_size == v_batch_size; + + // num_heads logic for nested input is checked in + // check_for_seq_len_0_nested_tensor as there is handling there to make sure + // num_heads is not ragged + bool broadcastable_batch_size = true; + if (!same_batch_size) { + if (input_requires_grad(params)){ + if (debug) { + TORCH_WARN( + "Both fused kernels do not support training with broadcasted NT inputs."); + } + return false; + } + // try to broadcast batchsize + broadcastable_batch_size = try_broadcast_param_size( + q_batch_size, k_batch_size, v_batch_size, "batch size ", debug); + + // if only one of k or v require broadcasting of batch size, the other + // must have a consistent seq_len dim + if (broadcastable_batch_size) { + if (k_batch_size == 1 && v_batch_size != 1 && + !check_safe_kv_broadcast(params.value, debug)) { + return false; + } + if (v_batch_size == 1 && k_batch_size != 1 && + !check_safe_kv_broadcast(params.key, debug)) { + return false; + } + } + } + return broadcastable_batch_size; +} + +inline bool check_nonzero_sequence_lengths_dense(sdp_params const& params, bool debug) { + // In some cases people will pass in 0 sized tensors, this will + // cause the fused path to error with unaligned mask + bool zero_seq_len_q = params.query.sym_size(-2) == 0; + bool zero_seq_len_k = params.key.sym_size(-2) == 0; + if (zero_seq_len_q || zero_seq_len_k) { + if (debug) { + TORCH_WARN( + "All fused kernels do not support zero seq_len_q or seq_len_kv."); + } + return false; + } + return true; +} + +template +inline bool check_last_dim_stride_equals_1_dense(sdp_params const& params, bool debug) { + // The stride checking for NestedTensors is done within the kernel + // And .contiguous will be called if needed + + // This function checks that the last dimension of the inputs to + // fused_attention have stride 1 + bool qkv_strides_equal_1 = params.query.sym_stride(-1) == 1 && + params.key.sym_stride(-1) == 1 && params.value.sym_stride(-1) == 1; + + // https://github.com/pytorch/pytorch/issues/116333 + // If the head_dim is size 1 the stride won't matter, but we + // check this condition before padding the head_dim to 1 + if (ignore_singleton_dim){ + qkv_strides_equal_1 = qkv_strides_equal_1 || params.query.sym_size(-1) == 1; + } + bool is_cpu = params.query.device().type() == c10::DeviceType::CPU; + bool mask_stride_equal_1 = params.attn_mask.has_value() + ? params.attn_mask.value().sym_stride(-1) == 1 + : true; + bool mask_stride_valid = is_cpu ? true : mask_stride_equal_1; + if (!(qkv_strides_equal_1 && mask_stride_valid)) { + if (debug) { + std::ostringstream message; + message + << "All fused kernels require the last dimension of the input to have stride 1. "; + message << "Got Query.stride(-1): " << params.query.sym_stride(-1) + << ", Key.stride(-1): " << params.key.sym_stride(-1) + << ", Value.stride(-1): " << params.value.sym_stride(-1); + + if (params.attn_mask.has_value()) { + message + << ", Attn_mask.stride(-1): " + << params.attn_mask.value().sym_stride(-1) + << " (GPU backends require attn_mask's last dimension to have stride 1 while the CPU does not)."; + } + TORCH_WARN(message.str()); + } + + return false; + } + return true; +} + +inline bool check_runtime_disabled_flash(sdp_params const& params, bool debug) { + // We check the global context to see if user has explicitly turned of flash + // sdp kernels + if (!at::globalContext().userEnabledFlashSDP()) { + if (debug) { + TORCH_WARN("Flash attention has been runtime disabled."); + } + return false; + } + return true; +} + +inline bool check_runtime_disabled_mem_efficient(sdp_params const& params, bool debug) { + // We check the global context to see if user has explicitly turned of + // mem_efficient sdp kernels + if (!at::globalContext().userEnabledMemEfficientSDP()) { + if (debug) { + TORCH_WARN("Memory Efficient attention has been runtime disabled."); + } + return false; + } + return true; +} + + +} // namespace sdp diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/utils/Factory.h b/phivenv/Lib/site-packages/torch/include/ATen/native/utils/Factory.h new file mode 100644 index 0000000000000000000000000000000000000000..1ca543b6727ef645e762ff391e4fabe3c10bc85b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/utils/Factory.h @@ -0,0 +1,20 @@ +#pragma once + +#include + +namespace at::native::mobile { + +Tensor allocate_padded_contiguous_if_needed( + const Tensor& input, + c10::MemoryFormat memory_format); + +// TODO: Remove this function when at::native::empty() is modified to accept a +// custom memory allocator. + +at::Tensor empty_with_tail_padding( + IntArrayRef size, + const caffe2::TypeMeta dtype, + c10::MemoryFormat memory_format, + std::optional maybe_names); + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/utils/ParamUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/native/utils/ParamUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..1a8e6d10a2371cbde748b4c02c28a1c995eb7c87 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/utils/ParamUtils.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include + +namespace at { +namespace native { + +template +inline std::vector _expand_param_if_needed( + ArrayRef list_param, + const char* param_name, + int64_t expected_dim) { + if (list_param.size() == 1) { + return std::vector(expected_dim, list_param[0]); + } else if ((int64_t)list_param.size() != expected_dim) { + std::ostringstream ss; + ss << "expected " << param_name << " to be a single integer value or a " + << "list of " << expected_dim << " values to match the convolution " + << "dimensions, but got " << param_name << "=" << list_param; + TORCH_CHECK(false, ss.str()); + } else { + return list_param.vec(); + } +} + +inline std::vector expand_param_if_needed( + IntArrayRef list_param, + const char* param_name, + int64_t expected_dim) { + return _expand_param_if_needed(list_param, param_name, expected_dim); +} + +inline std::vector expand_param_if_needed( + SymIntArrayRef list_param, + const char* param_name, + int64_t expected_dim) { + return _expand_param_if_needed(list_param, param_name, expected_dim); +} + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/utils/ParamsHash.h b/phivenv/Lib/site-packages/torch/include/ATen/native/utils/ParamsHash.h new file mode 100644 index 0000000000000000000000000000000000000000..24c836f3308d145f8a3e56f0d240404f53df5f0d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/utils/ParamsHash.h @@ -0,0 +1,104 @@ +#pragma once + +#include +#include +#include + +namespace at::native { + +// Hashing machinery for Params +// Fowler–Noll–Vo hash function +// see +// https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function +template +struct ParamsHash { + // Params must be a POD because we read out its memory + // contents as char* when hashing + static_assert(std::is_standard_layout_v, "Params is not POD"); + + size_t operator()(const Params& params) const { + auto ptr = reinterpret_cast(¶ms); + uint32_t value = 0x811C9DC5; + for (const auto i : c10::irange(sizeof(Params))) { + value ^= ptr[i]; + value *= 0x01000193; + } + return (size_t)value; + } +}; + +template +struct ParamsEqual { + // Params must be a POD because we read out its memory + // contents as char* when comparing + static_assert(std::is_standard_layout_v, "Params is not POD"); + + bool operator()(const Params& a, const Params& b) const { + auto ptr1 = reinterpret_cast(&a); + auto ptr2 = reinterpret_cast(&b); + return memcmp(ptr1, ptr2, sizeof(Params)) == 0; + } +}; + +// Provide explicit byte-for-byte constructors to avoid uwittingly leaving +// padding bytes unitialized (e.g., when passing Params by value) +template +struct ParamsWrapper { + T pod; + static_assert( + std::is_standard_layout_v, + "ParamsWrapper cannot wrap non-POD data"); + + ParamsWrapper() { + memset(&(this->pod), 0, sizeof(this->pod)); + } + + ParamsWrapper(const ParamsWrapper& other) { + memcpy(&(this->pod), &(other.pod), sizeof(this->pod)); + } + + ParamsWrapper(ParamsWrapper&& other) noexcept { + memcpy(&(this->pod), &(other.pod), sizeof(this->pod)); + } + + ParamsWrapper& operator=(const ParamsWrapper& other) { + memcpy(&(this->pod), &(other.pod), sizeof(this->pod)); + return *this; + } + + ParamsWrapper& operator=(ParamsWrapper&& other) noexcept { + memcpy(&(this->pod), &(other.pod), sizeof(this->pod)); + return *this; + } + + inline friend bool operator==( + const ParamsWrapper& lhs, + const ParamsWrapper& rhs) noexcept { + auto ptr1 = reinterpret_cast(&(lhs.pod)); + auto ptr2 = reinterpret_cast(&(rhs.pod)); + return memcmp(ptr1, ptr2, sizeof(lhs.pod)) == 0; + } +}; + +// Wrapped version: this allows the outer struct to have custom copy and move +// constructors for additional safety +template +struct ParamsWrapperHash { + // Params must be a POD because we read out its memory + // contents as char* when hashing + static_assert( + std::is_standard_layout_v, + "ParamsWrapper cannot wrap non-POD data"); + + size_t operator()(const ParamsWrapper& params_wrapper) const { + auto ptr = reinterpret_cast(&(params_wrapper.pod)); + uint32_t value = 0x811C9DC5; + for (const auto i : c10::irange(sizeof(params_wrapper.pod))) { + value ^= ptr[i]; + value *= 0x01000193; + } + return (size_t)value; + } +}; + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/verbose_wrapper.h b/phivenv/Lib/site-packages/torch/include/ATen/native/verbose_wrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..fc16ad2c373177cb92d297b4b78da0efa9800225 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/verbose_wrapper.h @@ -0,0 +1,8 @@ +#pragma once + +#include + +namespace torch::verbose { +TORCH_API int _mkl_set_verbose(int enable); +TORCH_API int _mkldnn_set_verbose(int level); +} // namespace torch::verbose diff --git a/phivenv/Lib/site-packages/torch/include/ATen/native/vol2col.h b/phivenv/Lib/site-packages/torch/include/ATen/native/vol2col.h new file mode 100644 index 0000000000000000000000000000000000000000..4e756e356759ad663a0b387082c47d40cb38e139 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/native/vol2col.h @@ -0,0 +1,109 @@ +#pragma once + +#include + +namespace at::native { + +template +void vol2col( + const T* data_vol, + const int64_t channels, + const int64_t depth, + const int64_t height, + const int64_t width, + const int64_t depth_col, + const int64_t height_col, + const int64_t width_col, + const int64_t kT, + const int64_t kernel_height, + const int64_t kernel_width, + const int64_t pT, + const int64_t pH, + const int64_t pW, + const int64_t dT, + const int64_t dH, + const int64_t dW, + const int64_t dilationT, + const int64_t dilationH, + const int64_t dilationW, + T* data_col) { + int64_t c, t, h, w; + int64_t channels_col = channels * kT * kernel_height * kernel_width; + for (c = 0; c < channels_col; ++c) { + int64_t w_offset = c % kernel_width; + int64_t h_offset = (c / kernel_width) % kernel_height; + int64_t t_offset = (c / kernel_width / kernel_height) % kT; + int64_t c_vol = c / kT / kernel_height / kernel_width; + for (t = 0; t < depth_col; ++t) { + int64_t t_pad = t * dT - pT + t_offset * dilationT; + for (h = 0; h < height_col; ++h) { + int64_t h_pad = h * dH - pH + h_offset * dilationH; + for (w = 0; w < width_col; ++w) { + int64_t w_pad = w * dW - pW + w_offset * dilationW; + if (t_pad >= 0 && t_pad < depth && h_pad >= 0 && h_pad < height && + w_pad >= 0 && w_pad < width) + data_col[((c * depth_col + t) * height_col + h) * width_col + w] = + data_vol + [((c_vol * depth + t_pad) * height + h_pad) * width + + w_pad]; + else + data_col[((c * depth_col + t) * height_col + h) * width_col + w] = + 0; + } + } + } + } +} + +template +void col2vol( + const T* data_col, + const int64_t channels, + const int64_t depth, + const int64_t height, + const int64_t width, + const int64_t out_depth, + const int64_t out_height, + const int64_t out_width, + const int64_t kT, + const int64_t kernel_height, + const int64_t kernel_width, + const int64_t pT, + const int64_t pH, + const int64_t pW, + const int64_t dT, + const int64_t dH, + const int64_t dW, + const int64_t dilationT, + const int64_t dilationH, + const int64_t dilationW, + T* data_vol) { + memset(data_vol, 0, sizeof(T) * depth * height * width * channels); + int64_t depth_col = out_depth; + int64_t height_col = out_height; + int64_t width_col = out_width; + int64_t channels_col = channels * kT * kernel_height * kernel_width; + for (int64_t c = 0; c < channels_col; ++c) { + int64_t w_offset = c % kernel_width; + int64_t h_offset = (c / kernel_width) % kernel_height; + int64_t t_offset = (c / kernel_width / kernel_height) % kT; + int64_t c_vol = c / kT / kernel_height / kernel_width; + for (int64_t t = 0; t < depth_col; ++t) { + int64_t t_pad = t * dT - pT + t_offset * dilationT; + for (int64_t h = 0; h < height_col; ++h) { + int64_t h_pad = h * dH - pH + h_offset * dilationH; + for (int64_t w = 0; w < width_col; ++w) { + int64_t w_pad = w * dW - pW + w_offset * dilationW; + if (t_pad >= 0 && t_pad < depth && h_pad >= 0 && h_pad < height && + w_pad >= 0 && w_pad < width) + data_vol + [((c_vol * depth + t_pad) * height + h_pad) * width + w_pad] += + data_col + [((c * depth_col + t) * height_col + h) * width_col + w]; + } + } + } + } +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/abs.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/abs.h new file mode 100644 index 0000000000000000000000000000000000000000..44fec6993682e5bc6d8a10f69dac9821f69dea70 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/abs.h @@ -0,0 +1,45 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::abs(Tensor self) -> Tensor +inline at::Tensor abs(const at::Tensor & self) { + return at::_ops::abs::call(self); +} + +// aten::abs_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & abs_(at::Tensor & self) { + return at::_ops::abs_::call(self); +} + +// aten::abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & abs_out(at::Tensor & out, const at::Tensor & self) { + return at::_ops::abs_out::call(self, out); +} +// aten::abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & abs_outf(const at::Tensor & self, at::Tensor & out) { + return at::_ops::abs_out::call(self, out); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/abs_compositeexplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/abs_compositeexplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..6464bff68321e5e032d3caeab3c5607692887418 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/abs_compositeexplicitautograd_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautograd { + +TORCH_API at::Tensor abs(const at::Tensor & self); +TORCH_API at::Tensor & abs_(at::Tensor & self); + +} // namespace compositeexplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/abs_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/abs_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..3318cfc1b25c809c02e00de2a958fa2889d0038d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/abs_cpu_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor & abs_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & abs_outf(const at::Tensor & self, at::Tensor & out); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/abs_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/abs_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..0d19eeae3b4cbb14e13e3ae89b7d264a564fdf15 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/abs_cuda_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor & abs_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & abs_outf(const at::Tensor & self, at::Tensor & out); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/abs_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/abs_native.h new file mode 100644 index 0000000000000000000000000000000000000000..93e181678e7cd904de4d659edad294406a288011 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/abs_native.h @@ -0,0 +1,31 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor abs(const at::Tensor & self); +TORCH_API at::Tensor & abs_(at::Tensor & self); +TORCH_API at::Tensor & abs_out(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor NestedTensor_abs(const at::Tensor & self); +TORCH_API at::Tensor & NestedTensor_abs_(at::Tensor & self); +TORCH_API at::Tensor abs_sparse(const at::Tensor & self); +TORCH_API at::Tensor & abs_sparse_out(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & abs_sparse_(at::Tensor & self); +TORCH_API at::Tensor abs_sparse_csr(const at::Tensor & self); +TORCH_API at::Tensor & abs_sparse_csr_out(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & abs_sparse_csr_(at::Tensor & self); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/abs_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/abs_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..af16bbb35ef90ec2b82c6a378de997395a5b44a1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/abs_ops.h @@ -0,0 +1,51 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API abs { + using schema = at::Tensor (const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::abs"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "abs(Tensor self) -> Tensor"; + static at::Tensor call(const at::Tensor & self); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self); +}; + +struct TORCH_API abs_ { + using schema = at::Tensor & (at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::abs_"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "abs_(Tensor(a!) self) -> Tensor(a!)"; + static at::Tensor & call(at::Tensor & self); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self); +}; + +struct TORCH_API abs_out { + using schema = at::Tensor & (const at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::abs"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/absolute.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/absolute.h new file mode 100644 index 0000000000000000000000000000000000000000..072e7cf84491cdf6e380ac56b4f2e0d9a59a26c0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/absolute.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::absolute(Tensor self) -> Tensor +inline at::Tensor absolute(const at::Tensor & self) { + return at::_ops::absolute::call(self); +} + +// aten::absolute.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & absolute_out(at::Tensor & out, const at::Tensor & self) { + return at::_ops::absolute_out::call(self, out); +} +// aten::absolute.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & absolute_outf(const at::Tensor & self, at::Tensor & out) { + return at::_ops::absolute_out::call(self, out); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/absolute_compositeimplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/absolute_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..1638f19b5f208cf4c3cf1479f2095b8e45497c6c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/absolute_compositeimplicitautograd_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API at::Tensor absolute(const at::Tensor & self); +TORCH_API at::Tensor & absolute_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & absolute_outf(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & absolute_(at::Tensor & self); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/absolute_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/absolute_native.h new file mode 100644 index 0000000000000000000000000000000000000000..e58f39429ae83bf39988724a2a13a8158373e9e0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/absolute_native.h @@ -0,0 +1,23 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor absolute(const at::Tensor & self); +TORCH_API at::Tensor & absolute_out(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & absolute_(at::Tensor & self); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/absolute_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/absolute_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..e1d15a33da495a80d08cc395a44da36340cc7864 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/absolute_ops.h @@ -0,0 +1,51 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API absolute { + using schema = at::Tensor (const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::absolute"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "absolute(Tensor self) -> Tensor"; + static at::Tensor call(const at::Tensor & self); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self); +}; + +struct TORCH_API absolute_ { + using schema = at::Tensor & (at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::absolute_"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "absolute_(Tensor(a!) self) -> Tensor(a!)"; + static at::Tensor & call(at::Tensor & self); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self); +}; + +struct TORCH_API absolute_out { + using schema = at::Tensor & (const at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::absolute"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "absolute.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/acos.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/acos.h new file mode 100644 index 0000000000000000000000000000000000000000..a1d393e85ab0b74db9ffb1c96102aee571fe034c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/acos.h @@ -0,0 +1,45 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::acos(Tensor self) -> Tensor +inline at::Tensor acos(const at::Tensor & self) { + return at::_ops::acos::call(self); +} + +// aten::acos_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & acos_(at::Tensor & self) { + return at::_ops::acos_::call(self); +} + +// aten::acos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & acos_out(at::Tensor & out, const at::Tensor & self) { + return at::_ops::acos_out::call(self, out); +} +// aten::acos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & acos_outf(const at::Tensor & self, at::Tensor & out) { + return at::_ops::acos_out::call(self, out); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/acos_compositeexplicitautogradnonfunctional_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/acos_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..17e29662d852815777e082ad49320532f5d400b7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/acos_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API at::Tensor acos(const at::Tensor & self); +TORCH_API at::Tensor & acos_(at::Tensor & self); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/acos_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/acos_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..2396f7b6f9ddeeee84eb60f25d40d583e154ae14 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/acos_cpu_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor acos(const at::Tensor & self); +TORCH_API at::Tensor & acos_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & acos_outf(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & acos_(at::Tensor & self); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/acos_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/acos_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..31049f254f7e6138f568bcd22b976218f256c57b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/acos_cuda_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor acos(const at::Tensor & self); +TORCH_API at::Tensor & acos_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & acos_outf(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & acos_(at::Tensor & self); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/acos_meta.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/acos_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..941a1d09d4b42f6f09cd300fe73f0b54a1cc5cc0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/acos_meta.h @@ -0,0 +1,27 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +struct TORCH_API structured_acos : public TensorIteratorBase { + + + void meta(const at::Tensor & self); +}; + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/acos_meta_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/acos_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..67c00a184c2419173a7a3078a7409ef37588e2b6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/acos_meta_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor acos(const at::Tensor & self); +TORCH_API at::Tensor & acos_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & acos_outf(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & acos_(at::Tensor & self); + +} // namespace meta +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/acos_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/acos_native.h new file mode 100644 index 0000000000000000000000000000000000000000..b5be7cf21dd6e9c841b20e7787f5e6e93acd7356 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/acos_native.h @@ -0,0 +1,23 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +struct TORCH_API structured_acos_out : public at::meta::structured_acos { +void impl(const at::Tensor & self, const at::Tensor & out); +}; +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/acos_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/acos_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..3a53fae79bf2d7610f7b5b2b4b3acd41d1d5a7ea --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/acos_ops.h @@ -0,0 +1,51 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API acos { + using schema = at::Tensor (const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::acos"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "acos(Tensor self) -> Tensor"; + static at::Tensor call(const at::Tensor & self); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self); +}; + +struct TORCH_API acos_ { + using schema = at::Tensor & (at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::acos_"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "acos_(Tensor(a!) self) -> Tensor(a!)"; + static at::Tensor & call(at::Tensor & self); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self); +}; + +struct TORCH_API acos_out { + using schema = at::Tensor & (const at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::acos"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "acos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/acosh.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/acosh.h new file mode 100644 index 0000000000000000000000000000000000000000..b70bd90cc5a2bc4fa27e06777357deffeddd6af2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/acosh.h @@ -0,0 +1,45 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::acosh(Tensor self) -> Tensor +inline at::Tensor acosh(const at::Tensor & self) { + return at::_ops::acosh::call(self); +} + +// aten::acosh_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & acosh_(at::Tensor & self) { + return at::_ops::acosh_::call(self); +} + +// aten::acosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & acosh_out(at::Tensor & out, const at::Tensor & self) { + return at::_ops::acosh_out::call(self, out); +} +// aten::acosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & acosh_outf(const at::Tensor & self, at::Tensor & out) { + return at::_ops::acosh_out::call(self, out); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/acosh_compositeexplicitautogradnonfunctional_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/acosh_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..279bd1957fe1e17321b6d062b343d12459bfb487 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/acosh_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API at::Tensor acosh(const at::Tensor & self); +TORCH_API at::Tensor & acosh_(at::Tensor & self); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/acosh_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/acosh_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..026f535fdf09102ea68284aaf189f3de877f1ed9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/acosh_cpu_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor acosh(const at::Tensor & self); +TORCH_API at::Tensor & acosh_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & acosh_outf(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & acosh_(at::Tensor & self); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/acosh_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/acosh_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..a31a1c967c1185ee9b87bd763c281a1ee491d141 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/acosh_cuda_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor acosh(const at::Tensor & self); +TORCH_API at::Tensor & acosh_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & acosh_outf(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & acosh_(at::Tensor & self); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/acosh_meta.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/acosh_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..76d2d506176901bfd907c5cdff1c06b528eb44ea --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/acosh_meta.h @@ -0,0 +1,27 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +struct TORCH_API structured_acosh : public TensorIteratorBase { + + + void meta(const at::Tensor & self); +}; + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/acosh_meta_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/acosh_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..d6f867a40b3cc2dc93a7b780cf0142f929228d36 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/acosh_meta_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor acosh(const at::Tensor & self); +TORCH_API at::Tensor & acosh_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & acosh_outf(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & acosh_(at::Tensor & self); + +} // namespace meta +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/acosh_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/acosh_native.h new file mode 100644 index 0000000000000000000000000000000000000000..a5acdcb203af6f9def50f39f18c69035cf9b5c90 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/acosh_native.h @@ -0,0 +1,23 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +struct TORCH_API structured_acosh_out : public at::meta::structured_acosh { +void impl(const at::Tensor & self, const at::Tensor & out); +}; +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/acosh_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/acosh_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..d4e61a9d34617b77952c49acbe9aa292fc0423ab --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/acosh_ops.h @@ -0,0 +1,51 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API acosh { + using schema = at::Tensor (const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::acosh"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "acosh(Tensor self) -> Tensor"; + static at::Tensor call(const at::Tensor & self); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self); +}; + +struct TORCH_API acosh_ { + using schema = at::Tensor & (at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::acosh_"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "acosh_(Tensor(a!) self) -> Tensor(a!)"; + static at::Tensor & call(at::Tensor & self); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self); +}; + +struct TORCH_API acosh_out { + using schema = at::Tensor & (const at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::acosh"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "acosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool1d.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool1d.h new file mode 100644 index 0000000000000000000000000000000000000000..db1640008906523dabd2c876e77a56adc82eb905 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool1d.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::adaptive_avg_pool1d(Tensor self, int[1] output_size) -> Tensor +inline at::Tensor adaptive_avg_pool1d(const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_avg_pool1d::call(self, output_size); +} + +// aten::adaptive_avg_pool1d.out(Tensor self, int[1] output_size, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & adaptive_avg_pool1d_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_avg_pool1d_out::call(self, output_size, out); +} +// aten::adaptive_avg_pool1d.out(Tensor self, int[1] output_size, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & adaptive_avg_pool1d_outf(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out) { + return at::_ops::adaptive_avg_pool1d_out::call(self, output_size, out); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool1d_compositeexplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool1d_compositeexplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..4f392f57a346ebb742a9a0b26e58c7cc3bb310f6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool1d_compositeexplicitautograd_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautograd { + +TORCH_API at::Tensor & adaptive_avg_pool1d_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API at::Tensor & adaptive_avg_pool1d_outf(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out); + +} // namespace compositeexplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool1d_compositeimplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool1d_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..45bb93f27cacd0ded73f9227e7be915c3f3de42d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool1d_compositeimplicitautograd_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API at::Tensor adaptive_avg_pool1d(const at::Tensor & self, at::IntArrayRef output_size); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool1d_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool1d_native.h new file mode 100644 index 0000000000000000000000000000000000000000..c1672dd16a7f93027165a8bb1609c31d9859832d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool1d_native.h @@ -0,0 +1,22 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor adaptive_avg_pool1d(const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API at::Tensor & adaptive_avg_pool1d_out(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool1d_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool1d_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..a3295c1b98a6464f0a0465936321ad2cd101b7f5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool1d_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API adaptive_avg_pool1d { + using schema = at::Tensor (const at::Tensor &, at::IntArrayRef); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::adaptive_avg_pool1d"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "adaptive_avg_pool1d(Tensor self, int[1] output_size) -> Tensor"; + static at::Tensor call(const at::Tensor & self, at::IntArrayRef output_size); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size); +}; + +struct TORCH_API adaptive_avg_pool1d_out { + using schema = at::Tensor & (const at::Tensor &, at::IntArrayRef, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::adaptive_avg_pool1d"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "adaptive_avg_pool1d.out(Tensor self, int[1] output_size, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool2d.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool2d.h new file mode 100644 index 0000000000000000000000000000000000000000..77701febb973bcdbda417ece164fd806ea95a381 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool2d.h @@ -0,0 +1,92 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & adaptive_avg_pool2d_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_avg_pool2d_out::call(self, c10::fromIntArrayRefSlow(output_size), out); +} +namespace symint { + template >> + at::Tensor & adaptive_avg_pool2d_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_avg_pool2d_out::call(self, c10::fromIntArrayRefSlow(output_size), out); + } +} + +// aten::adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & adaptive_avg_pool2d_outf(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out) { + return at::_ops::adaptive_avg_pool2d_out::call(self, c10::fromIntArrayRefSlow(output_size), out); +} +namespace symint { + template >> + at::Tensor & adaptive_avg_pool2d_outf(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out) { + return at::_ops::adaptive_avg_pool2d_out::call(self, c10::fromIntArrayRefSlow(output_size), out); + } +} + +// aten::adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & adaptive_avg_pool2d_symint_out(at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size) { + return at::_ops::adaptive_avg_pool2d_out::call(self, output_size, out); +} +namespace symint { + template >> + at::Tensor & adaptive_avg_pool2d_out(at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size) { + return at::_ops::adaptive_avg_pool2d_out::call(self, output_size, out); + } +} + +// aten::adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & adaptive_avg_pool2d_symint_outf(const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out) { + return at::_ops::adaptive_avg_pool2d_out::call(self, output_size, out); +} +namespace symint { + template >> + at::Tensor & adaptive_avg_pool2d_outf(const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out) { + return at::_ops::adaptive_avg_pool2d_out::call(self, output_size, out); + } +} + +// aten::adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor +inline at::Tensor adaptive_avg_pool2d(const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_avg_pool2d::call(self, c10::fromIntArrayRefSlow(output_size)); +} +namespace symint { + template >> + at::Tensor adaptive_avg_pool2d(const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_avg_pool2d::call(self, c10::fromIntArrayRefSlow(output_size)); + } +} + +// aten::adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor +inline at::Tensor adaptive_avg_pool2d_symint(const at::Tensor & self, c10::SymIntArrayRef output_size) { + return at::_ops::adaptive_avg_pool2d::call(self, output_size); +} +namespace symint { + template >> + at::Tensor adaptive_avg_pool2d(const at::Tensor & self, c10::SymIntArrayRef output_size) { + return at::_ops::adaptive_avg_pool2d::call(self, output_size); + } +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool2d_compositeimplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool2d_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..b14e16398365dbed76601b8a1fecc31053869166 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool2d_compositeimplicitautograd_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API at::Tensor adaptive_avg_pool2d(const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API at::Tensor adaptive_avg_pool2d_symint(const at::Tensor & self, c10::SymIntArrayRef output_size); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool2d_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool2d_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..168960f65674dc06fb52a3331be4ff949a7ac84e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool2d_cpu_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor & adaptive_avg_pool2d_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API at::Tensor & adaptive_avg_pool2d_outf(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out); +TORCH_API at::Tensor & adaptive_avg_pool2d_symint_out(at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size); +TORCH_API at::Tensor & adaptive_avg_pool2d_symint_outf(const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool2d_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool2d_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..6363b3a2ce1d7c7e914ad1fd03301b649157d405 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool2d_cuda_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor & adaptive_avg_pool2d_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API at::Tensor & adaptive_avg_pool2d_outf(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out); +TORCH_API at::Tensor & adaptive_avg_pool2d_symint_out(at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size); +TORCH_API at::Tensor & adaptive_avg_pool2d_symint_outf(const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool2d_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool2d_native.h new file mode 100644 index 0000000000000000000000000000000000000000..bf23c06082a22925e8cc440bcaf9ef356e74629e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool2d_native.h @@ -0,0 +1,24 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor adaptive_avg_pool2d_symint(const at::Tensor & self, c10::SymIntArrayRef output_size); +TORCH_API at::Tensor & adaptive_avg_pool2d_out_cpu(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out); +TORCH_API at::Tensor & adaptive_avg_pool2d_out_cuda(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out); +TORCH_API at::Tensor & mkldnn_adaptive_avg_pool2d_out_stub(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool2d_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool2d_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..5b87d78e27c239877669735da92ec6b831b7aa10 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool2d_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API adaptive_avg_pool2d_out { + using schema = at::Tensor & (const at::Tensor &, c10::SymIntArrayRef, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::adaptive_avg_pool2d"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out); +}; + +struct TORCH_API adaptive_avg_pool2d { + using schema = at::Tensor (const at::Tensor &, c10::SymIntArrayRef); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::adaptive_avg_pool2d"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor"; + static at::Tensor call(const at::Tensor & self, c10::SymIntArrayRef output_size); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d.h new file mode 100644 index 0000000000000000000000000000000000000000..f5cb5aab988da9b66213d3babe6c7870042c9a23 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d.h @@ -0,0 +1,92 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & adaptive_avg_pool3d_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_avg_pool3d_out::call(self, c10::fromIntArrayRefSlow(output_size), out); +} +namespace symint { + template >> + at::Tensor & adaptive_avg_pool3d_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_avg_pool3d_out::call(self, c10::fromIntArrayRefSlow(output_size), out); + } +} + +// aten::adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & adaptive_avg_pool3d_outf(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out) { + return at::_ops::adaptive_avg_pool3d_out::call(self, c10::fromIntArrayRefSlow(output_size), out); +} +namespace symint { + template >> + at::Tensor & adaptive_avg_pool3d_outf(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out) { + return at::_ops::adaptive_avg_pool3d_out::call(self, c10::fromIntArrayRefSlow(output_size), out); + } +} + +// aten::adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & adaptive_avg_pool3d_symint_out(at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size) { + return at::_ops::adaptive_avg_pool3d_out::call(self, output_size, out); +} +namespace symint { + template >> + at::Tensor & adaptive_avg_pool3d_out(at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size) { + return at::_ops::adaptive_avg_pool3d_out::call(self, output_size, out); + } +} + +// aten::adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & adaptive_avg_pool3d_symint_outf(const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out) { + return at::_ops::adaptive_avg_pool3d_out::call(self, output_size, out); +} +namespace symint { + template >> + at::Tensor & adaptive_avg_pool3d_outf(const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out) { + return at::_ops::adaptive_avg_pool3d_out::call(self, output_size, out); + } +} + +// aten::adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor +inline at::Tensor adaptive_avg_pool3d(const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_avg_pool3d::call(self, c10::fromIntArrayRefSlow(output_size)); +} +namespace symint { + template >> + at::Tensor adaptive_avg_pool3d(const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_avg_pool3d::call(self, c10::fromIntArrayRefSlow(output_size)); + } +} + +// aten::adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor +inline at::Tensor adaptive_avg_pool3d_symint(const at::Tensor & self, c10::SymIntArrayRef output_size) { + return at::_ops::adaptive_avg_pool3d::call(self, output_size); +} +namespace symint { + template >> + at::Tensor adaptive_avg_pool3d(const at::Tensor & self, c10::SymIntArrayRef output_size) { + return at::_ops::adaptive_avg_pool3d::call(self, output_size); + } +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_backward.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_backward.h new file mode 100644 index 0000000000000000000000000000000000000000..ded91f4ae8d38071b84a440e7e26aab25b2fc8c6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_backward.h @@ -0,0 +1,35 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::adaptive_avg_pool3d_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!) +inline at::Tensor & adaptive_avg_pool3d_backward_out(at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self) { + return at::_ops::adaptive_avg_pool3d_backward_grad_input::call(grad_output, self, grad_input); +} +// aten::adaptive_avg_pool3d_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!) +inline at::Tensor & adaptive_avg_pool3d_backward_outf(const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & grad_input) { + return at::_ops::adaptive_avg_pool3d_backward_grad_input::call(grad_output, self, grad_input); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_backward_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_backward_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..7d6271488475837cc58198cabb052077194eef9f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_backward_cpu_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor & adaptive_avg_pool3d_backward_out(at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self); +TORCH_API at::Tensor & adaptive_avg_pool3d_backward_outf(const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & grad_input); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_backward_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_backward_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..e8433c1e6df18a07e0b1e6a229ae5357530ab4dd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_backward_cuda_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor & adaptive_avg_pool3d_backward_out(at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self); +TORCH_API at::Tensor & adaptive_avg_pool3d_backward_outf(const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & grad_input); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_backward_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_backward_native.h new file mode 100644 index 0000000000000000000000000000000000000000..6513e07b0ff545782457d68f1db40e74368f960e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_backward_native.h @@ -0,0 +1,22 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor & adaptive_avg_pool3d_backward_out_cpu(const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & grad_input); +TORCH_API at::Tensor & adaptive_avg_pool3d_backward_out_cuda(const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & grad_input); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_backward_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_backward_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..536a95a5b81278c4ddc6a97328686dcad0b93c41 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_backward_ops.h @@ -0,0 +1,29 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API adaptive_avg_pool3d_backward_grad_input { + using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::adaptive_avg_pool3d_backward"; + static constexpr const char* overload_name = "grad_input"; + static constexpr const char* schema_str = "adaptive_avg_pool3d_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & grad_input); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & grad_input); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_compositeimplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..46872d80ad9b344940c662aac8ea906a77a329ef --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_compositeimplicitautograd_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API at::Tensor adaptive_avg_pool3d(const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API at::Tensor adaptive_avg_pool3d_symint(const at::Tensor & self, c10::SymIntArrayRef output_size); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..c8d2528bcee029c651ee315d8ae8e3f997ce25b6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_cpu_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor & adaptive_avg_pool3d_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API at::Tensor & adaptive_avg_pool3d_outf(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out); +TORCH_API at::Tensor & adaptive_avg_pool3d_symint_out(at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size); +TORCH_API at::Tensor & adaptive_avg_pool3d_symint_outf(const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..52f8f7a70f29d0d5da89ca7a2a5f03755c70e579 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_cuda_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor & adaptive_avg_pool3d_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API at::Tensor & adaptive_avg_pool3d_outf(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out); +TORCH_API at::Tensor & adaptive_avg_pool3d_symint_out(at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size); +TORCH_API at::Tensor & adaptive_avg_pool3d_symint_outf(const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_native.h new file mode 100644 index 0000000000000000000000000000000000000000..83845f0ba84a558492d0121f3c3d8f0419b757a8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_native.h @@ -0,0 +1,24 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor adaptive_avg_pool3d_symint(const at::Tensor & self, c10::SymIntArrayRef output_size); +TORCH_API at::Tensor & adaptive_avg_pool3d_out_cpu(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out); +TORCH_API at::Tensor & adaptive_avg_pool3d_out_cuda(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out); +TORCH_API at::Tensor & adaptive_avg_pool3d_out_quantized_cpu(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..4072a6dff3a6bfdef1d34e344719443b2e020d99 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API adaptive_avg_pool3d_out { + using schema = at::Tensor & (const at::Tensor &, c10::SymIntArrayRef, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::adaptive_avg_pool3d"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out); +}; + +struct TORCH_API adaptive_avg_pool3d { + using schema = at::Tensor (const at::Tensor &, c10::SymIntArrayRef); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::adaptive_avg_pool3d"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor"; + static at::Tensor call(const at::Tensor & self, c10::SymIntArrayRef output_size); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool1d.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool1d.h new file mode 100644 index 0000000000000000000000000000000000000000..8b5d595e9bdb6acde75cfa1ced9dc8a805695be5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool1d.h @@ -0,0 +1,31 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::adaptive_max_pool1d(Tensor self, int[1] output_size) -> (Tensor, Tensor) +inline ::std::tuple adaptive_max_pool1d(const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_max_pool1d::call(self, output_size); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool1d_compositeimplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool1d_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..4546f00b62879ef7a6b3e241c3e7295bf36f61c9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool1d_compositeimplicitautograd_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API ::std::tuple adaptive_max_pool1d(const at::Tensor & self, at::IntArrayRef output_size); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool1d_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool1d_native.h new file mode 100644 index 0000000000000000000000000000000000000000..6e87f6513435d56dbbe2b4feb346949ae5abfe36 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool1d_native.h @@ -0,0 +1,21 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API ::std::tuple adaptive_max_pool1d(const at::Tensor & self, at::IntArrayRef output_size); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool1d_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool1d_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..844422c4f97e30a503ade33dd1a839d19843573e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool1d_ops.h @@ -0,0 +1,29 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API adaptive_max_pool1d { + using schema = ::std::tuple (const at::Tensor &, at::IntArrayRef); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::adaptive_max_pool1d"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "adaptive_max_pool1d(Tensor self, int[1] output_size) -> (Tensor, Tensor)"; + static ::std::tuple call(const at::Tensor & self, at::IntArrayRef output_size); + static ::std::tuple redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d.h new file mode 100644 index 0000000000000000000000000000000000000000..7e3c0741dec75e765ec6077ce59248799e064ae4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::adaptive_max_pool2d.out(Tensor self, int[2] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) +inline ::std::tuple adaptive_max_pool2d_out(at::Tensor & out, at::Tensor & indices, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_max_pool2d_out::call(self, output_size, out, indices); +} +// aten::adaptive_max_pool2d.out(Tensor self, int[2] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) +inline ::std::tuple adaptive_max_pool2d_outf(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out, at::Tensor & indices) { + return at::_ops::adaptive_max_pool2d_out::call(self, output_size, out, indices); +} + +// aten::adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor) +inline ::std::tuple adaptive_max_pool2d(const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_max_pool2d::call(self, output_size); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward.h new file mode 100644 index 0000000000000000000000000000000000000000..f5fdd711bf2eea95a71554ddd0d51842d910349e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::adaptive_max_pool2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) +inline at::Tensor & adaptive_max_pool2d_backward_out(at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices) { + return at::_ops::adaptive_max_pool2d_backward_grad_input::call(grad_output, self, indices, grad_input); +} +// aten::adaptive_max_pool2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) +inline at::Tensor & adaptive_max_pool2d_backward_outf(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, at::Tensor & grad_input) { + return at::_ops::adaptive_max_pool2d_backward_grad_input::call(grad_output, self, indices, grad_input); +} + +// aten::adaptive_max_pool2d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor +inline at::Tensor adaptive_max_pool2d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices) { + return at::_ops::adaptive_max_pool2d_backward::call(grad_output, self, indices); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_compositeexplicitautogradnonfunctional_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..fde38f5cc95032d0abcefdc3587cf86990b78226 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API at::Tensor adaptive_max_pool2d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..798c421dbac601d569b02fb7bda45d1c5145c3c5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_cpu_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor adaptive_max_pool2d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); +TORCH_API at::Tensor & adaptive_max_pool2d_backward_out(at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); +TORCH_API at::Tensor & adaptive_max_pool2d_backward_outf(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, at::Tensor & grad_input); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..e176d3228926b817dc8ea7b323f94a19a273e739 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_cuda_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor adaptive_max_pool2d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); +TORCH_API at::Tensor & adaptive_max_pool2d_backward_out(at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); +TORCH_API at::Tensor & adaptive_max_pool2d_backward_outf(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, at::Tensor & grad_input); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_meta.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..4f6ae4f748d769cd06a9d4e86f10e2b7ee812665 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_meta.h @@ -0,0 +1,27 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +struct TORCH_API structured_adaptive_max_pool2d_backward : public at::impl::MetaBase { + + + void meta(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); +}; + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_meta_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..89982f55f7f3227ae11c6f416446f4ef25221fa7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_meta_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor adaptive_max_pool2d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); +TORCH_API at::Tensor & adaptive_max_pool2d_backward_out(at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); +TORCH_API at::Tensor & adaptive_max_pool2d_backward_outf(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, at::Tensor & grad_input); + +} // namespace meta +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_native.h new file mode 100644 index 0000000000000000000000000000000000000000..957f8ae8cea37de139450fc874158a26a38fa912 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_native.h @@ -0,0 +1,26 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +struct TORCH_API structured_adaptive_max_pool2d_backward_out_cpu : public at::meta::structured_adaptive_max_pool2d_backward { +void impl(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, const at::Tensor & grad_input); +}; +struct TORCH_API structured_adaptive_max_pool2d_backward_out_cuda : public at::meta::structured_adaptive_max_pool2d_backward { +void impl(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, const at::Tensor & grad_input); +}; +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..7a80d728c3d2b0e0b542f8b544e9aed016110512 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API adaptive_max_pool2d_backward_grad_input { + using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, const at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::adaptive_max_pool2d_backward"; + static constexpr const char* overload_name = "grad_input"; + static constexpr const char* schema_str = "adaptive_max_pool2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, at::Tensor & grad_input); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, at::Tensor & grad_input); +}; + +struct TORCH_API adaptive_max_pool2d_backward { + using schema = at::Tensor (const at::Tensor &, const at::Tensor &, const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::adaptive_max_pool2d_backward"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "adaptive_max_pool2d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor"; + static at::Tensor call(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_compositeexplicitautogradnonfunctional_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..f6a6e5560eff20cd763f28e34c6932ddd48ded3a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API ::std::tuple adaptive_max_pool2d(const at::Tensor & self, at::IntArrayRef output_size); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..4d3ee55bcba794e2b8c47da8d94deabed98e1a00 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_cpu_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API ::std::tuple adaptive_max_pool2d(const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API ::std::tuple adaptive_max_pool2d_out(at::Tensor & out, at::Tensor & indices, const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API ::std::tuple adaptive_max_pool2d_outf(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out, at::Tensor & indices); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..e30f25019eefa8d5e58c0ba2a75594470ae62c24 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_cuda_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API ::std::tuple adaptive_max_pool2d(const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API ::std::tuple adaptive_max_pool2d_out(at::Tensor & out, at::Tensor & indices, const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API ::std::tuple adaptive_max_pool2d_outf(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out, at::Tensor & indices); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_meta.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..afa4cd69561992749c49b9146aadc8ba1105d816 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_meta.h @@ -0,0 +1,27 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +struct TORCH_API structured_adaptive_max_pool2d : public at::impl::MetaBase { + + + void meta(const at::Tensor & self, at::IntArrayRef output_size); +}; + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_meta_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..d9668c4ece2710352a58d9951f04dd1e5409951b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_meta_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API ::std::tuple adaptive_max_pool2d(const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API ::std::tuple adaptive_max_pool2d_out(at::Tensor & out, at::Tensor & indices, const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API ::std::tuple adaptive_max_pool2d_outf(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out, at::Tensor & indices); + +} // namespace meta +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_native.h new file mode 100644 index 0000000000000000000000000000000000000000..159e481cb93692c145e508ed115e5a3afbc8cad8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_native.h @@ -0,0 +1,26 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +struct TORCH_API structured_adaptive_max_pool2d_out_cpu : public at::meta::structured_adaptive_max_pool2d { +void impl(const at::Tensor & self, at::IntArrayRef output_size, const at::Tensor & out, const at::Tensor & indices); +}; +struct TORCH_API structured_adaptive_max_pool2d_out_cuda : public at::meta::structured_adaptive_max_pool2d { +void impl(const at::Tensor & self, at::IntArrayRef output_size, const at::Tensor & out, const at::Tensor & indices); +}; +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..f62339d931e4f5853d51e9b75e023f3be21a9b75 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API adaptive_max_pool2d_out { + using schema = ::std::tuple (const at::Tensor &, at::IntArrayRef, at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::adaptive_max_pool2d"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "adaptive_max_pool2d.out(Tensor self, int[2] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))"; + static ::std::tuple call(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out, at::Tensor & indices); + static ::std::tuple redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out, at::Tensor & indices); +}; + +struct TORCH_API adaptive_max_pool2d { + using schema = ::std::tuple (const at::Tensor &, at::IntArrayRef); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::adaptive_max_pool2d"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor)"; + static ::std::tuple call(const at::Tensor & self, at::IntArrayRef output_size); + static ::std::tuple redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d.h new file mode 100644 index 0000000000000000000000000000000000000000..48949e8e27e8260fdd5177569a299960b8d90375 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::adaptive_max_pool3d.out(Tensor self, int[3] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) +inline ::std::tuple adaptive_max_pool3d_out(at::Tensor & out, at::Tensor & indices, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_max_pool3d_out::call(self, output_size, out, indices); +} +// aten::adaptive_max_pool3d.out(Tensor self, int[3] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) +inline ::std::tuple adaptive_max_pool3d_outf(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out, at::Tensor & indices) { + return at::_ops::adaptive_max_pool3d_out::call(self, output_size, out, indices); +} + +// aten::adaptive_max_pool3d(Tensor self, int[3] output_size) -> (Tensor, Tensor) +inline ::std::tuple adaptive_max_pool3d(const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_max_pool3d::call(self, output_size); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward.h new file mode 100644 index 0000000000000000000000000000000000000000..a4c7790d008a96f50e75396788d7113b4c36b48f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::adaptive_max_pool3d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) +inline at::Tensor & adaptive_max_pool3d_backward_out(at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices) { + return at::_ops::adaptive_max_pool3d_backward_grad_input::call(grad_output, self, indices, grad_input); +} +// aten::adaptive_max_pool3d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) +inline at::Tensor & adaptive_max_pool3d_backward_outf(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, at::Tensor & grad_input) { + return at::_ops::adaptive_max_pool3d_backward_grad_input::call(grad_output, self, indices, grad_input); +} + +// aten::adaptive_max_pool3d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor +inline at::Tensor adaptive_max_pool3d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices) { + return at::_ops::adaptive_max_pool3d_backward::call(grad_output, self, indices); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_compositeexplicitautogradnonfunctional_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..39370d2febf72e354824f2de2da25812378534d6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API at::Tensor adaptive_max_pool3d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..f82cc66807a723d28089c883a1ef9091344b5d8c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_cpu_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor adaptive_max_pool3d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); +TORCH_API at::Tensor & adaptive_max_pool3d_backward_out(at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); +TORCH_API at::Tensor & adaptive_max_pool3d_backward_outf(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, at::Tensor & grad_input); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..c8e4a5c65488f5d479b869a4e49280c4d9428079 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_cuda_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor adaptive_max_pool3d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); +TORCH_API at::Tensor & adaptive_max_pool3d_backward_out(at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); +TORCH_API at::Tensor & adaptive_max_pool3d_backward_outf(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, at::Tensor & grad_input); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_meta.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..7b42a759492e05d5406f9765ba8f37878823caf7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_meta.h @@ -0,0 +1,27 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +struct TORCH_API structured_adaptive_max_pool3d_backward : public at::impl::MetaBase { + + + void meta(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); +}; + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_meta_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..ef8b885cbb1c5c98254c2eca5f8539f5e96f0310 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_meta_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor adaptive_max_pool3d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); +TORCH_API at::Tensor & adaptive_max_pool3d_backward_out(at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); +TORCH_API at::Tensor & adaptive_max_pool3d_backward_outf(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, at::Tensor & grad_input); + +} // namespace meta +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_native.h new file mode 100644 index 0000000000000000000000000000000000000000..56c8f5576d1e9ca311006cae244d457e863ee21e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_native.h @@ -0,0 +1,26 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +struct TORCH_API structured_adaptive_max_pool3d_backward_out_cpu : public at::meta::structured_adaptive_max_pool3d_backward { +void impl(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, const at::Tensor & grad_input); +}; +struct TORCH_API structured_adaptive_max_pool3d_backward_out_cuda : public at::meta::structured_adaptive_max_pool3d_backward { +void impl(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, const at::Tensor & grad_input); +}; +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..a0d0419cd1b5c6f70df64e7380b4ed34d6df5bad --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API adaptive_max_pool3d_backward_grad_input { + using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, const at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::adaptive_max_pool3d_backward"; + static constexpr const char* overload_name = "grad_input"; + static constexpr const char* schema_str = "adaptive_max_pool3d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, at::Tensor & grad_input); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, at::Tensor & grad_input); +}; + +struct TORCH_API adaptive_max_pool3d_backward { + using schema = at::Tensor (const at::Tensor &, const at::Tensor &, const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::adaptive_max_pool3d_backward"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "adaptive_max_pool3d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor"; + static at::Tensor call(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_compositeexplicitautogradnonfunctional_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..684cfdeec4c31441721d1e75c50e4b7797773db6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API ::std::tuple adaptive_max_pool3d(const at::Tensor & self, at::IntArrayRef output_size); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..41dd65be3684d5febb0eb37d655c37c9f71d17ed --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_cpu_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API ::std::tuple adaptive_max_pool3d(const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API ::std::tuple adaptive_max_pool3d_out(at::Tensor & out, at::Tensor & indices, const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API ::std::tuple adaptive_max_pool3d_outf(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out, at::Tensor & indices); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..55acbfee68654cf1c7578be8c3b1c9c992329fc2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_cuda_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API ::std::tuple adaptive_max_pool3d(const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API ::std::tuple adaptive_max_pool3d_out(at::Tensor & out, at::Tensor & indices, const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API ::std::tuple adaptive_max_pool3d_outf(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out, at::Tensor & indices); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_meta.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..ff879e5455503d61e430375b276e2e1dad31dca9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_meta.h @@ -0,0 +1,27 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +struct TORCH_API structured_adaptive_max_pool3d : public at::impl::MetaBase { + + + void meta(const at::Tensor & self, at::IntArrayRef output_size); +}; + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_meta_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..aaa79aaa2193efbd7f8d7dcde2865de883604108 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_meta_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API ::std::tuple adaptive_max_pool3d(const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API ::std::tuple adaptive_max_pool3d_out(at::Tensor & out, at::Tensor & indices, const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API ::std::tuple adaptive_max_pool3d_outf(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out, at::Tensor & indices); + +} // namespace meta +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_native.h new file mode 100644 index 0000000000000000000000000000000000000000..48fdac18b505ea03b1f8e9548784d45030a7bb2b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_native.h @@ -0,0 +1,26 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +struct TORCH_API structured_adaptive_max_pool3d_out_cpu : public at::meta::structured_adaptive_max_pool3d { +void impl(const at::Tensor & self, at::IntArrayRef output_size, const at::Tensor & out, const at::Tensor & indices); +}; +struct TORCH_API structured_adaptive_max_pool3d_out_cuda : public at::meta::structured_adaptive_max_pool3d { +void impl(const at::Tensor & self, at::IntArrayRef output_size, const at::Tensor & out, const at::Tensor & indices); +}; +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..916fbbf0091eb57e91c3cd1616158dcb92ee5314 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API adaptive_max_pool3d_out { + using schema = ::std::tuple (const at::Tensor &, at::IntArrayRef, at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::adaptive_max_pool3d"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "adaptive_max_pool3d.out(Tensor self, int[3] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))"; + static ::std::tuple call(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out, at::Tensor & indices); + static ::std::tuple redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out, at::Tensor & indices); +}; + +struct TORCH_API adaptive_max_pool3d { + using schema = ::std::tuple (const at::Tensor &, at::IntArrayRef); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::adaptive_max_pool3d"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "adaptive_max_pool3d(Tensor self, int[3] output_size) -> (Tensor, Tensor)"; + static ::std::tuple call(const at::Tensor & self, at::IntArrayRef output_size); + static ::std::tuple redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/add.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/add.h new file mode 100644 index 0000000000000000000000000000000000000000..32b35b4c18bf79c79fc8b65413df7118a9db4ab4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/add.h @@ -0,0 +1,54 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor +inline at::Tensor add(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::add_Tensor::call(self, other, alpha); +} + +// aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & add_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::add_out::call(self, other, alpha, out); +} +// aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & add_outf(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::add_out::call(self, other, alpha, out); +} + +// aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor +inline at::Tensor add(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) { + return at::_ops::add_Scalar::call(self, other, alpha); +} + +// aten::add.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & add_out(at::Tensor & out, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) { + return at::_ops::add_Scalar_out::call(self, other, alpha, out); +} +// aten::add.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & add_outf(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::add_Scalar_out::call(self, other, alpha, out); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/add_compositeexplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/add_compositeexplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..08bfeb4fa2e48c74311edbf3f62afdc9f2a39fe2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/add_compositeexplicitautograd_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautograd { + +TORCH_API at::Tensor add(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1); +TORCH_API at::Tensor & add_out(at::Tensor & out, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1); +TORCH_API at::Tensor & add_outf(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha, at::Tensor & out); +TORCH_API at::Tensor & add_(at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1); + +} // namespace compositeexplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/add_compositeexplicitautogradnonfunctional_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/add_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..27c0cd355a6022aa0a35d6dfdad96b54a3e43adf --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/add_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API at::Tensor add(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1); +TORCH_API at::Tensor & add_(at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/add_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/add_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..a973c0538bcdbab828f9dc6965c045a281c144dd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/add_cpu_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor add(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1); +TORCH_API at::Tensor & add_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1); +TORCH_API at::Tensor & add_outf(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out); +TORCH_API at::Tensor & add_(at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/add_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/add_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..92dda7012e4fed8bf10b8eb59db24d50261445e0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/add_cuda_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor add(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1); +TORCH_API at::Tensor & add_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1); +TORCH_API at::Tensor & add_outf(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out); +TORCH_API at::Tensor & add_(at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/add_meta.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/add_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..9939c09a0600604bd7ae29f562422299d57873bc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/add_meta.h @@ -0,0 +1,27 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +struct TORCH_API structured_add_Tensor : public TensorIteratorBase { + + + void meta(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha); +}; + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/add_meta_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/add_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..126bbf4815668ac2404432be68a6e55a03cdfac4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/add_meta_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor add(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1); +TORCH_API at::Tensor & add_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1); +TORCH_API at::Tensor & add_outf(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out); +TORCH_API at::Tensor & add_(at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1); + +} // namespace meta +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/add_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/add_native.h new file mode 100644 index 0000000000000000000000000000000000000000..e7ccb31d41769ee53e30febe31df0d98c8617a79 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/add_native.h @@ -0,0 +1,43 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +struct TORCH_API structured_ufunc_add_CPU : public at::meta::structured_add_Tensor { +void impl(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, const at::Tensor & out); +}; +struct TORCH_API structured_ufunc_add_CUDA : public at::meta::structured_add_Tensor { +void impl(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, const at::Tensor & out); +}; +TORCH_API at::Tensor NestedTensor_add_Tensor(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1); +TORCH_API at::Tensor & NestedTensor_add__Tensor(at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1); +TORCH_API at::Tensor add_sparse(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1); +TORCH_API at::Tensor & add_out_sparse_cpu(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out); +TORCH_API at::Tensor & add_sparse_(at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1); +TORCH_API at::Tensor & add_out_sparse_cuda(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out); +TORCH_API at::Tensor add_sparse_csr(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1); +TORCH_API at::Tensor & add_out_sparse_compressed_cpu(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out); +TORCH_API at::Tensor & add_sparse_csr_(at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1); +TORCH_API at::Tensor & add_out_sparse_compressed_cuda(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out); +TORCH_API at::Tensor mkldnn_add(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1); +TORCH_API at::Tensor & mkldnn_add_out(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out); +TORCH_API at::Tensor & mkldnn_add_(at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1); +TORCH_API at::Tensor add_zerotensor(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1); +TORCH_API at::Tensor add(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1); +TORCH_API at::Tensor & add_Scalar_out(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha, at::Tensor & out); +TORCH_API at::Tensor & add_(at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/add_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/add_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..7aa42fb6b900b736afb15ed44fd0b308088698e8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/add_ops.h @@ -0,0 +1,84 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API add_Tensor { + using schema = at::Tensor (const at::Tensor &, const at::Tensor &, const at::Scalar &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::add"; + static constexpr const char* overload_name = "Tensor"; + static constexpr const char* schema_str = "add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"; + static at::Tensor call(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha); +}; + +struct TORCH_API add__Tensor { + using schema = at::Tensor & (at::Tensor &, const at::Tensor &, const at::Scalar &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::add_"; + static constexpr const char* overload_name = "Tensor"; + static constexpr const char* schema_str = "add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)"; + static at::Tensor & call(at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha); +}; + +struct TORCH_API add_out { + using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, const at::Scalar &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::add"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out); +}; + +struct TORCH_API add_Scalar { + using schema = at::Tensor (const at::Tensor &, const at::Scalar &, const at::Scalar &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::add"; + static constexpr const char* overload_name = "Scalar"; + static constexpr const char* schema_str = "add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor"; + static at::Tensor call(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha); +}; + +struct TORCH_API add__Scalar { + using schema = at::Tensor & (at::Tensor &, const at::Scalar &, const at::Scalar &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::add_"; + static constexpr const char* overload_name = "Scalar"; + static constexpr const char* schema_str = "add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)"; + static at::Tensor & call(at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha); +}; + +struct TORCH_API add_Scalar_out { + using schema = at::Tensor & (const at::Tensor &, const at::Scalar &, const at::Scalar &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::add"; + static constexpr const char* overload_name = "Scalar_out"; + static constexpr const char* schema_str = "add.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addbmm.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addbmm.h new file mode 100644 index 0000000000000000000000000000000000000000..58c648e71d5785da22b6dc25d1fad647169f8d82 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addbmm.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::addbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & addbmm_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addbmm_out::call(self, batch1, batch2, beta, alpha, out); +} +// aten::addbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & addbmm_outf(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::addbmm_out::call(self, batch1, batch2, beta, alpha, out); +} + +// aten::addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor +inline at::Tensor addbmm(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addbmm::call(self, batch1, batch2, beta, alpha); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addbmm_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addbmm_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..4b79e23aeeac7d2fb9d957d8e9129715f6fdd1f0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addbmm_cpu_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor addbmm(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & addbmm_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & addbmm_outf(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); +TORCH_API at::Tensor & addbmm_(at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addbmm_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addbmm_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..6bc7a7c0680d6278569997b0cba0cb28a265e6de --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addbmm_cuda_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor addbmm(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & addbmm_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & addbmm_outf(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); +TORCH_API at::Tensor & addbmm_(at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addbmm_meta_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addbmm_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..6452de7679774bcddb9012c48260bb5e3beea5f7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addbmm_meta_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor & addbmm_(at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1); + +} // namespace meta +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addbmm_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addbmm_native.h new file mode 100644 index 0000000000000000000000000000000000000000..00e096100ae3594d3cec6fdf89d1832594a1d4f1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addbmm_native.h @@ -0,0 +1,23 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor addbmm(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & addbmm_out(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); +TORCH_API at::Tensor & addbmm_(at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addbmm_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addbmm_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..5053cb28112a48793c0b1fe6ee5649e23f851958 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addbmm_ops.h @@ -0,0 +1,51 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API addbmm_ { + using schema = at::Tensor & (at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Scalar &, const at::Scalar &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::addbmm_"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "addbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)"; + static at::Tensor & call(at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha); +}; + +struct TORCH_API addbmm_out { + using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Scalar &, const at::Scalar &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::addbmm"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "addbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); +}; + +struct TORCH_API addbmm { + using schema = at::Tensor (const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Scalar &, const at::Scalar &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::addbmm"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"; + static at::Tensor call(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addcdiv.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addcdiv.h new file mode 100644 index 0000000000000000000000000000000000000000..a33e1bb7c668df62bc1295065a6577d8eac34d1e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addcdiv.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::addcdiv.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & addcdiv_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) { + return at::_ops::addcdiv_out::call(self, tensor1, tensor2, value, out); +} +// aten::addcdiv.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & addcdiv_outf(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value, at::Tensor & out) { + return at::_ops::addcdiv_out::call(self, tensor1, tensor2, value, out); +} + +// aten::addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor +inline at::Tensor addcdiv(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) { + return at::_ops::addcdiv::call(self, tensor1, tensor2, value); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addcdiv_compositeexplicitautogradnonfunctional_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addcdiv_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..5c772d2c64dcd37674988ebfc110fbf642eb9f63 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addcdiv_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API at::Tensor addcdiv(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1); +TORCH_API at::Tensor & addcdiv_(at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addcdiv_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addcdiv_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..67600a3ac708466a3b8551bbc0254ae3ca7a4e6f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addcdiv_cpu_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor addcdiv(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1); +TORCH_API at::Tensor & addcdiv_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1); +TORCH_API at::Tensor & addcdiv_outf(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value, at::Tensor & out); +TORCH_API at::Tensor & addcdiv_(at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addcdiv_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addcdiv_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..85268d589375de645afc8e2cdb4f83b2180be8a6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addcdiv_cuda_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor addcdiv(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1); +TORCH_API at::Tensor & addcdiv_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1); +TORCH_API at::Tensor & addcdiv_outf(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value, at::Tensor & out); +TORCH_API at::Tensor & addcdiv_(at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addcdiv_meta.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addcdiv_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..43e8d94efc5006c644d01168cb006488c19c60e5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addcdiv_meta.h @@ -0,0 +1,27 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +struct TORCH_API structured_addcdiv : public TensorIteratorBase { + + + void meta(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value); +}; + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addcdiv_meta_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addcdiv_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..1b1dca2d129367c1303657d280f40b69719ac5f7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addcdiv_meta_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor addcdiv(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1); +TORCH_API at::Tensor & addcdiv_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1); +TORCH_API at::Tensor & addcdiv_outf(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value, at::Tensor & out); +TORCH_API at::Tensor & addcdiv_(at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1); + +} // namespace meta +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addcdiv_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addcdiv_native.h new file mode 100644 index 0000000000000000000000000000000000000000..4879a678a89e306980429c790c6bb793bf308639 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addcdiv_native.h @@ -0,0 +1,23 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +struct TORCH_API structured_addcdiv_out : public at::meta::structured_addcdiv { +void impl(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value, const at::Tensor & out); +}; +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addcdiv_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addcdiv_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..e40a605fd25d24a49c8f5387373e2cfa4e5729de --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addcdiv_ops.h @@ -0,0 +1,51 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API addcdiv_out { + using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Scalar &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::addcdiv"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "addcdiv.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value, at::Tensor & out); +}; + +struct TORCH_API addcdiv { + using schema = at::Tensor (const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Scalar &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::addcdiv"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor"; + static at::Tensor call(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value); +}; + +struct TORCH_API addcdiv_ { + using schema = at::Tensor & (at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Scalar &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::addcdiv_"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "addcdiv_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!)"; + static at::Tensor & call(at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addcmul.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addcmul.h new file mode 100644 index 0000000000000000000000000000000000000000..7eb88f50ae4159a76976b1928dff55d830087d18 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addcmul.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::addcmul.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & addcmul_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) { + return at::_ops::addcmul_out::call(self, tensor1, tensor2, value, out); +} +// aten::addcmul.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & addcmul_outf(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value, at::Tensor & out) { + return at::_ops::addcmul_out::call(self, tensor1, tensor2, value, out); +} + +// aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor +inline at::Tensor addcmul(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) { + return at::_ops::addcmul::call(self, tensor1, tensor2, value); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addcmul_compositeexplicitautogradnonfunctional_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addcmul_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..3c97b1dfdafa251d1f1ea3440bf0d62036b68bbf --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addcmul_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API at::Tensor addcmul(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1); +TORCH_API at::Tensor & addcmul_(at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addcmul_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addcmul_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..75a70786f370e14310946ec652b4228a68c1b543 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addcmul_cpu_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor addcmul(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1); +TORCH_API at::Tensor & addcmul_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1); +TORCH_API at::Tensor & addcmul_outf(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value, at::Tensor & out); +TORCH_API at::Tensor & addcmul_(at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addcmul_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addcmul_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..ad8f29b9c668ec3329370975f8ae1025ba8504cd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addcmul_cuda_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor addcmul(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1); +TORCH_API at::Tensor & addcmul_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1); +TORCH_API at::Tensor & addcmul_outf(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value, at::Tensor & out); +TORCH_API at::Tensor & addcmul_(at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addcmul_meta.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addcmul_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..785fe7017f5f0c2c523ae523f919abec2174143b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addcmul_meta.h @@ -0,0 +1,27 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +struct TORCH_API structured_addcmul : public TensorIteratorBase { + + + void meta(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value); +}; + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addcmul_meta_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addcmul_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..4f9a950bf2253b2a2d8dc809c4302937e22c37ed --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addcmul_meta_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor addcmul(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1); +TORCH_API at::Tensor & addcmul_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1); +TORCH_API at::Tensor & addcmul_outf(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value, at::Tensor & out); +TORCH_API at::Tensor & addcmul_(at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1); + +} // namespace meta +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addcmul_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addcmul_native.h new file mode 100644 index 0000000000000000000000000000000000000000..72532cbf042f5084d11fb08e0d5ca6cac699c488 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addcmul_native.h @@ -0,0 +1,23 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +struct TORCH_API structured_addcmul_out : public at::meta::structured_addcmul { +void impl(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value, const at::Tensor & out); +}; +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addcmul_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addcmul_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..cf43ec9b60375ed870b66c8419ea536ee5803948 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addcmul_ops.h @@ -0,0 +1,51 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API addcmul_out { + using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Scalar &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::addcmul"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "addcmul.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value, at::Tensor & out); +}; + +struct TORCH_API addcmul { + using schema = at::Tensor (const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Scalar &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::addcmul"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor"; + static at::Tensor call(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value); +}; + +struct TORCH_API addcmul_ { + using schema = at::Tensor & (at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Scalar &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::addcmul_"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "addcmul_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!)"; + static at::Tensor & call(at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addmm.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addmm.h new file mode 100644 index 0000000000000000000000000000000000000000..cc8202ce1f873fadee046733a370636564f67863 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addmm.h @@ -0,0 +1,54 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & addmm_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addmm_out::call(self, mat1, mat2, beta, alpha, out); +} +// aten::addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & addmm_outf(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::addmm_out::call(self, mat1, mat2, beta, alpha, out); +} + +// aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor +inline at::Tensor addmm(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addmm::call(self, mat1, mat2, beta, alpha); +} + +// aten::addmm.dtype(Tensor self, Tensor mat1, Tensor mat2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1) -> Tensor +inline at::Tensor addmm(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, at::ScalarType out_dtype, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addmm_dtype::call(self, mat1, mat2, out_dtype, beta, alpha); +} + +// aten::addmm.dtype_out(Tensor self, Tensor mat1, Tensor mat2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & addmm_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, at::ScalarType out_dtype, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addmm_dtype_out::call(self, mat1, mat2, out_dtype, beta, alpha, out); +} +// aten::addmm.dtype_out(Tensor self, Tensor mat1, Tensor mat2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & addmm_outf(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, at::ScalarType out_dtype, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::addmm_dtype_out::call(self, mat1, mat2, out_dtype, beta, alpha, out); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addmm_compositeexplicitautogradnonfunctional_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addmm_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..feba8a394fecf46aeb4f7c50daaeff88554ae4dc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addmm_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API at::Tensor addmm(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & addmm_(at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addmm_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addmm_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..985ad3f56d0f3586712431659d8f818fde11301e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addmm_cpu_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor addmm(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & addmm_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & addmm_outf(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); +TORCH_API at::Tensor & addmm_(at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addmm_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addmm_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..b0e73735bb6f4c7eddc4d27d195bb3f254946344 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addmm_cuda_dispatch.h @@ -0,0 +1,29 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor addmm(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & addmm_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & addmm_outf(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); +TORCH_API at::Tensor & addmm_(at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor addmm(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, at::ScalarType out_dtype, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & addmm_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, at::ScalarType out_dtype, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & addmm_outf(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, at::ScalarType out_dtype, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addmm_meta.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addmm_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..36f1ffec6a4af5b404dfd3dc3a81033d49d7abcc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addmm_meta.h @@ -0,0 +1,27 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +struct TORCH_API structured_addmm : public at::impl::MetaBase { + + + void meta(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha); +}; + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addmm_meta_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addmm_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..04fa482acc86693c26cbccc0f653983a43bca6c6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addmm_meta_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor addmm(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & addmm_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & addmm_outf(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); +TORCH_API at::Tensor & addmm_(at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1); + +} // namespace meta +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addmm_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addmm_native.h new file mode 100644 index 0000000000000000000000000000000000000000..cfb6cec4ab308c765d7ce9b93c67c7b66a204487 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addmm_native.h @@ -0,0 +1,37 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +struct TORCH_API structured_addmm_out_cpu : public at::meta::structured_addmm { +void impl(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, const at::Tensor & out); +}; +struct TORCH_API structured_addmm_out_cuda : public at::meta::structured_addmm { +void impl(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, const at::Tensor & out); +}; +TORCH_API at::Tensor addmm_sparse_dense_cpu(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & addmm_out_sparse_dense_cpu(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); +TORCH_API at::Tensor & s_addmm_sparse_dense_cpu_(at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor addmm_sparse_dense_cuda(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & addmm_out_sparse_dense_cuda(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); +TORCH_API at::Tensor & s_addmm_sparse_dense_cuda_(at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor addmm_sparse_compressed_dense(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & addmm_out_sparse_compressed_cpu(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); +TORCH_API at::Tensor & addmm_out_sparse_compressed_cuda(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); +TORCH_API at::Tensor _addmm_dtype_cuda(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, at::ScalarType out_dtype, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & _addmm_dtype_out_cuda(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, at::ScalarType out_dtype, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addmm_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addmm_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..5967f165550f17afa270a7b07a660c32ef60badd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addmm_ops.h @@ -0,0 +1,73 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API addmm_out { + using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Scalar &, const at::Scalar &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::addmm"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); +}; + +struct TORCH_API addmm { + using schema = at::Tensor (const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Scalar &, const at::Scalar &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::addmm"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"; + static at::Tensor call(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha); +}; + +struct TORCH_API addmm_dtype { + using schema = at::Tensor (const at::Tensor &, const at::Tensor &, const at::Tensor &, at::ScalarType, const at::Scalar &, const at::Scalar &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::addmm"; + static constexpr const char* overload_name = "dtype"; + static constexpr const char* schema_str = "addmm.dtype(Tensor self, Tensor mat1, Tensor mat2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1) -> Tensor"; + static at::Tensor call(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, at::ScalarType out_dtype, const at::Scalar & beta, const at::Scalar & alpha); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, at::ScalarType out_dtype, const at::Scalar & beta, const at::Scalar & alpha); +}; + +struct TORCH_API addmm_dtype_out { + using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, const at::Tensor &, at::ScalarType, const at::Scalar &, const at::Scalar &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::addmm"; + static constexpr const char* overload_name = "dtype_out"; + static constexpr const char* schema_str = "addmm.dtype_out(Tensor self, Tensor mat1, Tensor mat2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, at::ScalarType out_dtype, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, at::ScalarType out_dtype, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); +}; + +struct TORCH_API addmm_ { + using schema = at::Tensor & (at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Scalar &, const at::Scalar &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::addmm_"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)"; + static at::Tensor & call(at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addmv.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addmv.h new file mode 100644 index 0000000000000000000000000000000000000000..3eda3878a468a548ca30e97398d0caf54ae6beed --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addmv.h @@ -0,0 +1,45 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor +inline at::Tensor addmv(const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addmv::call(self, mat, vec, beta, alpha); +} + +// aten::addmv_(Tensor(a!) self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) +inline at::Tensor & addmv_(at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addmv_::call(self, mat, vec, beta, alpha); +} + +// aten::addmv.out(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & addmv_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addmv_out::call(self, mat, vec, beta, alpha, out); +} +// aten::addmv.out(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & addmv_outf(const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::addmv_out::call(self, mat, vec, beta, alpha, out); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addmv_compositeexplicitautogradnonfunctional_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addmv_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..df767d5de1e310245ceac36dd607f7e07859b323 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addmv_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API at::Tensor addmv(const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & addmv_(at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta=1, const at::Scalar & alpha=1); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addmv_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addmv_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..9cad9726894ca51e335aef7cb6256d1574b354f6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addmv_cpu_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor addmv(const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & addmv_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & addmv_outf(const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); +TORCH_API at::Tensor & addmv_(at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta=1, const at::Scalar & alpha=1); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addmv_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addmv_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..d921c70bb6bb1b80595dde08ea1f911f2b45863e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addmv_cuda_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor addmv(const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & addmv_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & addmv_outf(const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); +TORCH_API at::Tensor & addmv_(at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta=1, const at::Scalar & alpha=1); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addmv_meta.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addmv_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..2ca80da3f0124c33a8dde8e767d69c0cf751811b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addmv_meta.h @@ -0,0 +1,27 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +struct TORCH_API structured_addmv : public at::impl::MetaBase { + + + void meta(const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta, const at::Scalar & alpha); +}; + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addmv_meta_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addmv_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..1cb5405e21617c327127a4a0d5a0380ac2c67223 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addmv_meta_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor addmv(const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & addmv_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & addmv_outf(const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); +TORCH_API at::Tensor & addmv_(at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta=1, const at::Scalar & alpha=1); + +} // namespace meta +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addmv_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addmv_native.h new file mode 100644 index 0000000000000000000000000000000000000000..356a8b9f31cc2824f722bcad7f05fb3d9cc0d80e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addmv_native.h @@ -0,0 +1,28 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +struct TORCH_API structured_addmv_out_cpu : public at::meta::structured_addmv { +void impl(const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta, const at::Scalar & alpha, const at::Tensor & out); +}; +struct TORCH_API structured_addmv_out_cuda : public at::meta::structured_addmv { +void impl(const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta, const at::Scalar & alpha, const at::Tensor & out); +}; +TORCH_API at::Tensor & addmv_out_sparse_compressed(const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); +TORCH_API at::Tensor & addmv_out_sparse_compressed_cuda(const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addmv_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addmv_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..4c9a2b2e2aa1b07e35895baa6fb0aba8abad1472 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addmv_ops.h @@ -0,0 +1,51 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API addmv { + using schema = at::Tensor (const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Scalar &, const at::Scalar &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::addmv"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor"; + static at::Tensor call(const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta, const at::Scalar & alpha); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta, const at::Scalar & alpha); +}; + +struct TORCH_API addmv_ { + using schema = at::Tensor & (at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Scalar &, const at::Scalar &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::addmv_"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "addmv_(Tensor(a!) self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)"; + static at::Tensor & call(at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta, const at::Scalar & alpha); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta, const at::Scalar & alpha); +}; + +struct TORCH_API addmv_out { + using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Scalar &, const at::Scalar &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::addmv"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "addmv.out(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addr.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addr.h new file mode 100644 index 0000000000000000000000000000000000000000..078c8a49ad8a3b6708d7ef3d3306f603e2a556dc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addr.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor +inline at::Tensor addr(const at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addr::call(self, vec1, vec2, beta, alpha); +} + +// aten::addr.out(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & addr_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addr_out::call(self, vec1, vec2, beta, alpha, out); +} +// aten::addr.out(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & addr_outf(const at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::addr_out::call(self, vec1, vec2, beta, alpha, out); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addr_compositeexplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addr_compositeexplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..2ed955bd60b82af58d678998c239e93f4b3d8318 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addr_compositeexplicitautograd_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautograd { + +TORCH_API at::Tensor addr(const at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & addr_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & addr_outf(const at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); +TORCH_API at::Tensor & addr_(at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta=1, const at::Scalar & alpha=1); + +} // namespace compositeexplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addr_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addr_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..3b1cd5ceb003ec2e4d7487c04dc81c7b377aecac --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addr_cpu_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor addr(const at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & addr_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & addr_outf(const at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addr_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addr_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..e7ed4560edd879f8b65a1f9347326e57d140cd95 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addr_cuda_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor addr(const at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & addr_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & addr_outf(const at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addr_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addr_native.h new file mode 100644 index 0000000000000000000000000000000000000000..e6ab7cde2a435ccec9b1ea4eb11911cffc4acc42 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addr_native.h @@ -0,0 +1,25 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor math_addr(const at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & math_addr_out(const at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); +TORCH_API at::Tensor & addr_(at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor addr(const at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & addr_out(const at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/addr_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/addr_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..1807cd03a79c4bf82c6f78e9bea5a3e22d4dc189 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/addr_ops.h @@ -0,0 +1,51 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API addr { + using schema = at::Tensor (const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Scalar &, const at::Scalar &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::addr"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"; + static at::Tensor call(const at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta, const at::Scalar & alpha); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta, const at::Scalar & alpha); +}; + +struct TORCH_API addr_ { + using schema = at::Tensor & (at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Scalar &, const at::Scalar &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::addr_"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "addr_(Tensor(a!) self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)"; + static at::Tensor & call(at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta, const at::Scalar & alpha); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta, const at::Scalar & alpha); +}; + +struct TORCH_API addr_out { + using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Scalar &, const at::Scalar &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::addr"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "addr.out(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adjoint.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adjoint.h new file mode 100644 index 0000000000000000000000000000000000000000..a95ce3cbc0c9ce68a99192c81c705c8cecaabb69 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adjoint.h @@ -0,0 +1,31 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::adjoint(Tensor(a) self) -> Tensor(a) +inline at::Tensor adjoint(const at::Tensor & self) { + return at::_ops::adjoint::call(self); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adjoint_compositeimplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adjoint_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..02f79b5a3d82a69433926aa805fd6fe88f70d24c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adjoint_compositeimplicitautograd_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API at::Tensor adjoint(const at::Tensor & self); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adjoint_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adjoint_native.h new file mode 100644 index 0000000000000000000000000000000000000000..dd1594efafbf02af327a3ff5a9417aad1b7016cb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adjoint_native.h @@ -0,0 +1,21 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor adjoint(const at::Tensor & self); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/adjoint_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/adjoint_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..b989c1dbe571ca363373ff73df14848e8ca4eb6a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/adjoint_ops.h @@ -0,0 +1,29 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API adjoint { + using schema = at::Tensor (const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::adjoint"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "adjoint(Tensor(a) self) -> Tensor(a)"; + static at::Tensor call(const at::Tensor & self); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/affine_grid_generator.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/affine_grid_generator.h new file mode 100644 index 0000000000000000000000000000000000000000..f143e64474c83bbf14d2f255426d2fc661f99ccf --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/affine_grid_generator.h @@ -0,0 +1,92 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::affine_grid_generator(Tensor theta, SymInt[] size, bool align_corners) -> Tensor +inline at::Tensor affine_grid_generator(const at::Tensor & theta, at::IntArrayRef size, bool align_corners) { + return at::_ops::affine_grid_generator::call(theta, c10::fromIntArrayRefSlow(size), align_corners); +} +namespace symint { + template >> + at::Tensor affine_grid_generator(const at::Tensor & theta, at::IntArrayRef size, bool align_corners) { + return at::_ops::affine_grid_generator::call(theta, c10::fromIntArrayRefSlow(size), align_corners); + } +} + +// aten::affine_grid_generator(Tensor theta, SymInt[] size, bool align_corners) -> Tensor +inline at::Tensor affine_grid_generator_symint(const at::Tensor & theta, c10::SymIntArrayRef size, bool align_corners) { + return at::_ops::affine_grid_generator::call(theta, size, align_corners); +} +namespace symint { + template >> + at::Tensor affine_grid_generator(const at::Tensor & theta, c10::SymIntArrayRef size, bool align_corners) { + return at::_ops::affine_grid_generator::call(theta, size, align_corners); + } +} + +// aten::affine_grid_generator.out(Tensor theta, SymInt[] size, bool align_corners, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & affine_grid_generator_out(at::Tensor & out, const at::Tensor & theta, at::IntArrayRef size, bool align_corners) { + return at::_ops::affine_grid_generator_out::call(theta, c10::fromIntArrayRefSlow(size), align_corners, out); +} +namespace symint { + template >> + at::Tensor & affine_grid_generator_out(at::Tensor & out, const at::Tensor & theta, at::IntArrayRef size, bool align_corners) { + return at::_ops::affine_grid_generator_out::call(theta, c10::fromIntArrayRefSlow(size), align_corners, out); + } +} + +// aten::affine_grid_generator.out(Tensor theta, SymInt[] size, bool align_corners, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & affine_grid_generator_outf(const at::Tensor & theta, at::IntArrayRef size, bool align_corners, at::Tensor & out) { + return at::_ops::affine_grid_generator_out::call(theta, c10::fromIntArrayRefSlow(size), align_corners, out); +} +namespace symint { + template >> + at::Tensor & affine_grid_generator_outf(const at::Tensor & theta, at::IntArrayRef size, bool align_corners, at::Tensor & out) { + return at::_ops::affine_grid_generator_out::call(theta, c10::fromIntArrayRefSlow(size), align_corners, out); + } +} + +// aten::affine_grid_generator.out(Tensor theta, SymInt[] size, bool align_corners, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & affine_grid_generator_symint_out(at::Tensor & out, const at::Tensor & theta, c10::SymIntArrayRef size, bool align_corners) { + return at::_ops::affine_grid_generator_out::call(theta, size, align_corners, out); +} +namespace symint { + template >> + at::Tensor & affine_grid_generator_out(at::Tensor & out, const at::Tensor & theta, c10::SymIntArrayRef size, bool align_corners) { + return at::_ops::affine_grid_generator_out::call(theta, size, align_corners, out); + } +} + +// aten::affine_grid_generator.out(Tensor theta, SymInt[] size, bool align_corners, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & affine_grid_generator_symint_outf(const at::Tensor & theta, c10::SymIntArrayRef size, bool align_corners, at::Tensor & out) { + return at::_ops::affine_grid_generator_out::call(theta, size, align_corners, out); +} +namespace symint { + template >> + at::Tensor & affine_grid_generator_outf(const at::Tensor & theta, c10::SymIntArrayRef size, bool align_corners, at::Tensor & out) { + return at::_ops::affine_grid_generator_out::call(theta, size, align_corners, out); + } +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/affine_grid_generator_backward.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/affine_grid_generator_backward.h new file mode 100644 index 0000000000000000000000000000000000000000..62bd3925f42e8e48d561c3c1fb423f6e4b3de29b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/affine_grid_generator_backward.h @@ -0,0 +1,48 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::affine_grid_generator_backward(Tensor grad, SymInt[] size, bool align_corners) -> Tensor +inline at::Tensor affine_grid_generator_backward(const at::Tensor & grad, at::IntArrayRef size, bool align_corners) { + return at::_ops::affine_grid_generator_backward::call(grad, c10::fromIntArrayRefSlow(size), align_corners); +} +namespace symint { + template >> + at::Tensor affine_grid_generator_backward(const at::Tensor & grad, at::IntArrayRef size, bool align_corners) { + return at::_ops::affine_grid_generator_backward::call(grad, c10::fromIntArrayRefSlow(size), align_corners); + } +} + +// aten::affine_grid_generator_backward(Tensor grad, SymInt[] size, bool align_corners) -> Tensor +inline at::Tensor affine_grid_generator_backward_symint(const at::Tensor & grad, c10::SymIntArrayRef size, bool align_corners) { + return at::_ops::affine_grid_generator_backward::call(grad, size, align_corners); +} +namespace symint { + template >> + at::Tensor affine_grid_generator_backward(const at::Tensor & grad, c10::SymIntArrayRef size, bool align_corners) { + return at::_ops::affine_grid_generator_backward::call(grad, size, align_corners); + } +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/affine_grid_generator_backward_compositeimplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/affine_grid_generator_backward_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..4f21d3acb2d52fc997fd4dcd0f6fa4a6e2378426 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/affine_grid_generator_backward_compositeimplicitautograd_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API at::Tensor affine_grid_generator_backward(const at::Tensor & grad, at::IntArrayRef size, bool align_corners); +TORCH_API at::Tensor affine_grid_generator_backward_symint(const at::Tensor & grad, c10::SymIntArrayRef size, bool align_corners); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/affine_grid_generator_backward_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/affine_grid_generator_backward_native.h new file mode 100644 index 0000000000000000000000000000000000000000..55d6ae36732d3166544e0f5f891f06c3ca6cff69 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/affine_grid_generator_backward_native.h @@ -0,0 +1,21 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor affine_grid_generator_backward(const at::Tensor & grad, at::IntArrayRef size, bool align_corners); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/affine_grid_generator_backward_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/affine_grid_generator_backward_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..0c6d31f2cc284061defb0805f9339b862dc3c33d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/affine_grid_generator_backward_ops.h @@ -0,0 +1,29 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API affine_grid_generator_backward { + using schema = at::Tensor (const at::Tensor &, c10::SymIntArrayRef, bool); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::affine_grid_generator_backward"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "affine_grid_generator_backward(Tensor grad, SymInt[] size, bool align_corners) -> Tensor"; + static at::Tensor call(const at::Tensor & grad, c10::SymIntArrayRef size, bool align_corners); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, c10::SymIntArrayRef size, bool align_corners); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/affine_grid_generator_compositeexplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/affine_grid_generator_compositeexplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..a3687e9ae4908587c86306023c4400b5438fab36 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/affine_grid_generator_compositeexplicitautograd_dispatch.h @@ -0,0 +1,28 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautograd { + +TORCH_API at::Tensor affine_grid_generator(const at::Tensor & theta, at::IntArrayRef size, bool align_corners); +TORCH_API at::Tensor affine_grid_generator_symint(const at::Tensor & theta, c10::SymIntArrayRef size, bool align_corners); +TORCH_API at::Tensor & affine_grid_generator_out(at::Tensor & out, const at::Tensor & theta, at::IntArrayRef size, bool align_corners); +TORCH_API at::Tensor & affine_grid_generator_outf(const at::Tensor & theta, at::IntArrayRef size, bool align_corners, at::Tensor & out); +TORCH_API at::Tensor & affine_grid_generator_symint_out(at::Tensor & out, const at::Tensor & theta, c10::SymIntArrayRef size, bool align_corners); +TORCH_API at::Tensor & affine_grid_generator_symint_outf(const at::Tensor & theta, c10::SymIntArrayRef size, bool align_corners, at::Tensor & out); + +} // namespace compositeexplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/affine_grid_generator_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/affine_grid_generator_native.h new file mode 100644 index 0000000000000000000000000000000000000000..7f82c0633d1a7c403b51f25177de9878fd55c075 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/affine_grid_generator_native.h @@ -0,0 +1,22 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor affine_grid_generator(const at::Tensor & theta, at::IntArrayRef size, bool align_corners); +TORCH_API at::Tensor & affine_grid_generator_out_symint(const at::Tensor & theta, c10::SymIntArrayRef size, bool align_corners, at::Tensor & out); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/affine_grid_generator_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/affine_grid_generator_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..ad81bc78822da964b800fc27bd99901ff2d7e61f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/affine_grid_generator_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API affine_grid_generator { + using schema = at::Tensor (const at::Tensor &, c10::SymIntArrayRef, bool); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::affine_grid_generator"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "affine_grid_generator(Tensor theta, SymInt[] size, bool align_corners) -> Tensor"; + static at::Tensor call(const at::Tensor & theta, c10::SymIntArrayRef size, bool align_corners); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & theta, c10::SymIntArrayRef size, bool align_corners); +}; + +struct TORCH_API affine_grid_generator_out { + using schema = at::Tensor & (const at::Tensor &, c10::SymIntArrayRef, bool, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::affine_grid_generator"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "affine_grid_generator.out(Tensor theta, SymInt[] size, bool align_corners, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & theta, c10::SymIntArrayRef size, bool align_corners, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & theta, c10::SymIntArrayRef size, bool align_corners, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/alias.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/alias.h new file mode 100644 index 0000000000000000000000000000000000000000..65096db0eae48f2ed47927dfff552b1689af4b49 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/alias.h @@ -0,0 +1,31 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::alias(Tensor(a) self) -> Tensor(a) +inline at::Tensor alias(const at::Tensor & self) { + return at::_ops::alias::call(self); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/alias_compositeexplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/alias_compositeexplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..2b225912bcf354a059c5061ba1aaa24a4bd887b8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/alias_compositeexplicitautograd_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautograd { + +TORCH_API at::Tensor alias(const at::Tensor & self); + +} // namespace compositeexplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/alias_copy.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/alias_copy.h new file mode 100644 index 0000000000000000000000000000000000000000..3527859b9490cdd46f58beb057d5217530790be8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/alias_copy.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::alias_copy(Tensor self) -> Tensor +inline at::Tensor alias_copy(const at::Tensor & self) { + return at::_ops::alias_copy::call(self); +} + +// aten::alias_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & alias_copy_out(at::Tensor & out, const at::Tensor & self) { + return at::_ops::alias_copy_out::call(self, out); +} +// aten::alias_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & alias_copy_outf(const at::Tensor & self, at::Tensor & out) { + return at::_ops::alias_copy_out::call(self, out); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/alias_copy_compositeexplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/alias_copy_compositeexplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..17caef36a3ac4d89a0b6fa527133cb0c0fb7ce4a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/alias_copy_compositeexplicitautograd_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautograd { + +TORCH_API at::Tensor & alias_copy_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & alias_copy_outf(const at::Tensor & self, at::Tensor & out); + +} // namespace compositeexplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/alias_copy_compositeexplicitautogradnonfunctional_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/alias_copy_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..e6ac24a4936c56669f02f56f59cd111541354ad6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/alias_copy_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API at::Tensor alias_copy(const at::Tensor & self); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/alias_copy_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/alias_copy_native.h new file mode 100644 index 0000000000000000000000000000000000000000..bceb7e6ca9884f69b42c8c287b52019cd0f26cdb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/alias_copy_native.h @@ -0,0 +1,22 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor & alias_copy_out(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor alias_copy(const at::Tensor & self); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/alias_copy_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/alias_copy_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..acb31c73347e8708fe38de3aa7dc03021e63174b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/alias_copy_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API alias_copy { + using schema = at::Tensor (const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::alias_copy"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "alias_copy(Tensor self) -> Tensor"; + static at::Tensor call(const at::Tensor & self); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self); +}; + +struct TORCH_API alias_copy_out { + using schema = at::Tensor & (const at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::alias_copy"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "alias_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/alias_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/alias_native.h new file mode 100644 index 0000000000000000000000000000000000000000..4db0f4e31285b334fe6d72809320edd390d7fa8a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/alias_native.h @@ -0,0 +1,22 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor alias(const at::Tensor & self); +TORCH_API at::Tensor alias_nested(const at::Tensor & self); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/alias_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/alias_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..cd3c112556ff3f362d0ed890c56706df646bb015 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/alias_ops.h @@ -0,0 +1,29 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API alias { + using schema = at::Tensor (const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::alias"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "alias(Tensor(a) self) -> Tensor(a)"; + static at::Tensor call(const at::Tensor & self); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/align_as.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/align_as.h new file mode 100644 index 0000000000000000000000000000000000000000..be131ee828fb8ffe9d5f0976d28de567c7e144ab --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/align_as.h @@ -0,0 +1,27 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/align_as_compositeimplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/align_as_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..8accbb0058e9e9e1f1e6dec230fe75c889dea324 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/align_as_compositeimplicitautograd_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API at::Tensor align_as(const at::Tensor & self, const at::Tensor & other); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/align_as_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/align_as_native.h new file mode 100644 index 0000000000000000000000000000000000000000..09d513b736f96a98fe96625c0df1c74412be31cc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/align_as_native.h @@ -0,0 +1,21 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor align_as(const at::Tensor & self, const at::Tensor & other); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/align_as_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/align_as_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..0eacefe9faf9b8a30bbcc20868a5434c04783972 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/align_as_ops.h @@ -0,0 +1,29 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API align_as { + using schema = at::Tensor (const at::Tensor &, const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::align_as"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "align_as(Tensor self, Tensor other) -> Tensor"; + static at::Tensor call(const at::Tensor & self, const at::Tensor & other); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/align_tensors.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/align_tensors.h new file mode 100644 index 0000000000000000000000000000000000000000..2f780a2483635daf2304c03fda7d09959b54cab7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/align_tensors.h @@ -0,0 +1,31 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::align_tensors(Tensor[] tensors) -> Tensor[] +inline ::std::vector align_tensors(at::TensorList tensors) { + return at::_ops::align_tensors::call(tensors); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/align_tensors_compositeimplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/align_tensors_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..8cba221dc1a30a82e111fbb6527cb6d97a121e5f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/align_tensors_compositeimplicitautograd_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API ::std::vector align_tensors(at::TensorList tensors); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/align_tensors_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/align_tensors_native.h new file mode 100644 index 0000000000000000000000000000000000000000..420fb4ae12839849d96606662d741823caedf32a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/align_tensors_native.h @@ -0,0 +1,21 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API ::std::vector align_tensors(at::TensorList tensors); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/align_tensors_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/align_tensors_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..425daba29872579e6c812e5e3799618bcbeadf3f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/align_tensors_ops.h @@ -0,0 +1,29 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API align_tensors { + using schema = ::std::vector (at::TensorList); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::align_tensors"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "align_tensors(Tensor[] tensors) -> Tensor[]"; + static ::std::vector call(at::TensorList tensors); + static ::std::vector redispatch(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/align_to.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/align_to.h new file mode 100644 index 0000000000000000000000000000000000000000..9f161eb3f554b53c24442e6dff2ce888b3ae7fef --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/align_to.h @@ -0,0 +1,27 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/align_to_compositeimplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/align_to_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..03dc229690ea8ef81e6f5ee109a3c792e8f1277e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/align_to_compositeimplicitautograd_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API at::Tensor align_to(const at::Tensor & self, at::DimnameList names); +TORCH_API at::Tensor align_to(const at::Tensor & self, at::DimnameList order, int64_t ellipsis_idx); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/align_to_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/align_to_native.h new file mode 100644 index 0000000000000000000000000000000000000000..8da9def4e281a3c93129c3f837b084caf278c6c0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/align_to_native.h @@ -0,0 +1,22 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor align_to(const at::Tensor & self, at::DimnameList names); +TORCH_API at::Tensor align_to(const at::Tensor & self, at::DimnameList order, int64_t ellipsis_idx); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/align_to_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/align_to_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..dda98ec2fc78e57cb991cc84c9f2636f3516be05 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/align_to_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API align_to { + using schema = at::Tensor (const at::Tensor &, at::DimnameList); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::align_to"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "align_to(Tensor(a) self, Dimname[] names) -> Tensor(a)"; + static at::Tensor call(const at::Tensor & self, at::DimnameList names); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList names); +}; + +struct TORCH_API align_to_ellipsis_idx { + using schema = at::Tensor (const at::Tensor &, at::DimnameList, int64_t); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::align_to"; + static constexpr const char* overload_name = "ellipsis_idx"; + static constexpr const char* schema_str = "align_to.ellipsis_idx(Tensor(a) self, Dimname[] order, int ellipsis_idx) -> Tensor(a)"; + static at::Tensor call(const at::Tensor & self, at::DimnameList order, int64_t ellipsis_idx); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList order, int64_t ellipsis_idx); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/all.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/all.h new file mode 100644 index 0000000000000000000000000000000000000000..0573625a871b2cd00c08a87ef66989084b30a138 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/all.h @@ -0,0 +1,82 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor +inline at::Tensor all(const at::Tensor & self, int64_t dim, bool keepdim=false) { + return at::_ops::all_dim::call(self, dim, keepdim); +} + +// aten::all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor +inline at::Tensor all(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false) { + return at::_ops::all_dims::call(self, dim, keepdim); +} + +// aten::all.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & all_out(at::Tensor & out, const at::Tensor & self, int64_t dim, bool keepdim=false) { + return at::_ops::all_out::call(self, dim, keepdim, out); +} +// aten::all.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & all_outf(const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & out) { + return at::_ops::all_out::call(self, dim, keepdim, out); +} + +// aten::all.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & all_out(at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false) { + return at::_ops::all_dims_out::call(self, dim, keepdim, out); +} +// aten::all.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & all_outf(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, at::Tensor & out) { + return at::_ops::all_dims_out::call(self, dim, keepdim, out); +} + +// aten::all.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor +inline at::Tensor all(const at::Tensor & self, at::Dimname dim, bool keepdim=false) { + return at::_ops::all_dimname::call(self, dim, keepdim); +} + +// aten::all.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & all_out(at::Tensor & out, const at::Tensor & self, at::Dimname dim, bool keepdim=false) { + return at::_ops::all_dimname_out::call(self, dim, keepdim, out); +} +// aten::all.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & all_outf(const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & out) { + return at::_ops::all_dimname_out::call(self, dim, keepdim, out); +} + +// aten::all(Tensor self) -> Tensor +inline at::Tensor all(const at::Tensor & self) { + return at::_ops::all::call(self); +} + +// aten::all.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & all_out(at::Tensor & out, const at::Tensor & self) { + return at::_ops::all_all_out::call(self, out); +} +// aten::all.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & all_outf(const at::Tensor & self, at::Tensor & out) { + return at::_ops::all_all_out::call(self, out); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/all_compositeexplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/all_compositeexplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..a954e01fdd0fec2f9b3b6e65ac74237c8354f7e4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/all_compositeexplicitautograd_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautograd { + +TORCH_API at::Tensor all(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false); +TORCH_API at::Tensor & all_out(at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false); +TORCH_API at::Tensor & all_outf(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, at::Tensor & out); + +} // namespace compositeexplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/all_compositeexplicitautogradnonfunctional_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/all_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..188e562468b9826f40e87928bb8dc0fb51f5676a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/all_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API at::Tensor all(const at::Tensor & self, int64_t dim, bool keepdim=false); +TORCH_API at::Tensor all(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false); +TORCH_API at::Tensor all(const at::Tensor & self); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/all_compositeimplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/all_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..20700c656c4d904987c38d961574cbcf1329f896 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/all_compositeimplicitautograd_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API at::Tensor all(const at::Tensor & self, at::Dimname dim, bool keepdim=false); +TORCH_API at::Tensor & all_out(at::Tensor & out, const at::Tensor & self, at::Dimname dim, bool keepdim=false); +TORCH_API at::Tensor & all_outf(const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & out); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/all_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/all_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..799700e8c9827e3d7e9cb996a2fd678e8c671d51 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/all_cpu_dispatch.h @@ -0,0 +1,31 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor all(const at::Tensor & self, int64_t dim, bool keepdim=false); +TORCH_API at::Tensor & all_out(at::Tensor & out, const at::Tensor & self, int64_t dim, bool keepdim=false); +TORCH_API at::Tensor & all_outf(const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & out); +TORCH_API at::Tensor all(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false); +TORCH_API at::Tensor & all_out(at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false); +TORCH_API at::Tensor & all_outf(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, at::Tensor & out); +TORCH_API at::Tensor all(const at::Tensor & self); +TORCH_API at::Tensor & all_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & all_outf(const at::Tensor & self, at::Tensor & out); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/all_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/all_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..c5a7db38a5cfa4c5b67449611d7c8b5c4785a223 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/all_cuda_dispatch.h @@ -0,0 +1,31 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor all(const at::Tensor & self, int64_t dim, bool keepdim=false); +TORCH_API at::Tensor & all_out(at::Tensor & out, const at::Tensor & self, int64_t dim, bool keepdim=false); +TORCH_API at::Tensor & all_outf(const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & out); +TORCH_API at::Tensor all(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false); +TORCH_API at::Tensor & all_out(at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false); +TORCH_API at::Tensor & all_outf(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, at::Tensor & out); +TORCH_API at::Tensor all(const at::Tensor & self); +TORCH_API at::Tensor & all_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & all_outf(const at::Tensor & self, at::Tensor & out); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/all_meta.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/all_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..b90771ca9b5526e1b270b601280c73d4afc59ff0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/all_meta.h @@ -0,0 +1,37 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +struct TORCH_API structured_all_dim : public at::impl::MetaBase { + + + void meta(const at::Tensor & self, int64_t dim, bool keepdim); +}; +struct TORCH_API structured_all_dims : public at::impl::MetaBase { + + + void meta(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim); +}; +struct TORCH_API structured_all : public at::impl::MetaBase { + + + void meta(const at::Tensor & self); +}; + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/all_meta_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/all_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..d00f609522a803270d4762fd42b03de2596ae1cd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/all_meta_dispatch.h @@ -0,0 +1,31 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor all(const at::Tensor & self, int64_t dim, bool keepdim=false); +TORCH_API at::Tensor & all_out(at::Tensor & out, const at::Tensor & self, int64_t dim, bool keepdim=false); +TORCH_API at::Tensor & all_outf(const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & out); +TORCH_API at::Tensor all(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false); +TORCH_API at::Tensor & all_out(at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false); +TORCH_API at::Tensor & all_outf(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, at::Tensor & out); +TORCH_API at::Tensor all(const at::Tensor & self); +TORCH_API at::Tensor & all_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & all_outf(const at::Tensor & self, at::Tensor & out); + +} // namespace meta +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/all_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/all_native.h new file mode 100644 index 0000000000000000000000000000000000000000..ded795da0415782bef98296074a0f6173f12b20f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/all_native.h @@ -0,0 +1,34 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +struct TORCH_API structured_all_out : public at::meta::structured_all_dim { +void impl(const at::Tensor & self, int64_t dim, bool keepdim, const at::Tensor & out); +}; +TORCH_API at::Tensor NestedTensor_all(const at::Tensor & self, int64_t dim, bool keepdim=false); +TORCH_API at::Tensor all_dims_default(const at::Tensor & self, at::OptionalIntArrayRef dim=::std::nullopt, bool keepdim=false); +TORCH_API at::Tensor & all_dims_out_default(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, at::Tensor & out); +struct TORCH_API structured_all_dims_out : public at::meta::structured_all_dims { +void impl(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, const at::Tensor & out); +}; +TORCH_API at::Tensor all(const at::Tensor & self, at::Dimname dim, bool keepdim=false); +TORCH_API at::Tensor & all_out(const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & out); +struct TORCH_API structured_all_all_out : public at::meta::structured_all { +void impl(const at::Tensor & self, const at::Tensor & out); +}; +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/all_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/all_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..2edb06debbf05844d44ecb500ffad129abdd977b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/all_ops.h @@ -0,0 +1,106 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API all_dim { + using schema = at::Tensor (const at::Tensor &, int64_t, bool); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::all"; + static constexpr const char* overload_name = "dim"; + static constexpr const char* schema_str = "all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor"; + static at::Tensor call(const at::Tensor & self, int64_t dim, bool keepdim); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim); +}; + +struct TORCH_API all_dims { + using schema = at::Tensor (const at::Tensor &, at::OptionalIntArrayRef, bool); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::all"; + static constexpr const char* overload_name = "dims"; + static constexpr const char* schema_str = "all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor"; + static at::Tensor call(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim); +}; + +struct TORCH_API all_out { + using schema = at::Tensor & (const at::Tensor &, int64_t, bool, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::all"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "all.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & out); +}; + +struct TORCH_API all_dims_out { + using schema = at::Tensor & (const at::Tensor &, at::OptionalIntArrayRef, bool, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::all"; + static constexpr const char* overload_name = "dims_out"; + static constexpr const char* schema_str = "all.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, at::Tensor & out); +}; + +struct TORCH_API all_dimname { + using schema = at::Tensor (const at::Tensor &, at::Dimname, bool); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::all"; + static constexpr const char* overload_name = "dimname"; + static constexpr const char* schema_str = "all.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor"; + static at::Tensor call(const at::Tensor & self, at::Dimname dim, bool keepdim); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim); +}; + +struct TORCH_API all_dimname_out { + using schema = at::Tensor & (const at::Tensor &, at::Dimname, bool, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::all"; + static constexpr const char* overload_name = "dimname_out"; + static constexpr const char* schema_str = "all.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & out); +}; + +struct TORCH_API all { + using schema = at::Tensor (const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::all"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "all(Tensor self) -> Tensor"; + static at::Tensor call(const at::Tensor & self); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self); +}; + +struct TORCH_API all_all_out { + using schema = at::Tensor & (const at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::all"; + static constexpr const char* overload_name = "all_out"; + static constexpr const char* schema_str = "all.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/allclose.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/allclose.h new file mode 100644 index 0000000000000000000000000000000000000000..87ba5e602782fb12faff82c1c3cc7445c783ea95 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/allclose.h @@ -0,0 +1,31 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::allclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> bool +inline bool allclose(const at::Tensor & self, const at::Tensor & other, double rtol=1e-05, double atol=1e-08, bool equal_nan=false) { + return at::_ops::allclose::call(self, other, rtol, atol, equal_nan); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/allclose_compositeexplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/allclose_compositeexplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..266c4eaf6cbf92ecf30c204a2f24a2bf3a4a069d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/allclose_compositeexplicitautograd_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautograd { + +TORCH_API bool allclose(const at::Tensor & self, const at::Tensor & other, double rtol=1e-05, double atol=1e-08, bool equal_nan=false); + +} // namespace compositeexplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/allclose_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/allclose_native.h new file mode 100644 index 0000000000000000000000000000000000000000..c9d1d5558ee5bcecc02b6cb5049a22f8b086fd07 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/allclose_native.h @@ -0,0 +1,21 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API bool allclose(const at::Tensor & self, const at::Tensor & other, double rtol=1e-05, double atol=1e-08, bool equal_nan=false); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/allclose_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/allclose_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..b04828f245a9fbb1ae7c92d73c33df091deed880 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/allclose_ops.h @@ -0,0 +1,29 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API allclose { + using schema = bool (const at::Tensor &, const at::Tensor &, double, double, bool); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::allclose"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "allclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> bool"; + static bool call(const at::Tensor & self, const at::Tensor & other, double rtol, double atol, bool equal_nan); + static bool redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, double rtol, double atol, bool equal_nan); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/alpha_dropout.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/alpha_dropout.h new file mode 100644 index 0000000000000000000000000000000000000000..7ee4977e4b1a56b31b83a84714a3ef46e11140e4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/alpha_dropout.h @@ -0,0 +1,36 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::alpha_dropout(Tensor input, float p, bool train) -> Tensor +inline at::Tensor alpha_dropout(const at::Tensor & input, double p, bool train) { + return at::_ops::alpha_dropout::call(input, p, train); +} + +// aten::alpha_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!) +inline at::Tensor & alpha_dropout_(at::Tensor & self, double p, bool train) { + return at::_ops::alpha_dropout_::call(self, p, train); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/alpha_dropout_compositeimplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/alpha_dropout_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..5bfac5ce20edd3f24149822c5705e299c320718b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/alpha_dropout_compositeimplicitautograd_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API at::Tensor alpha_dropout(const at::Tensor & input, double p, bool train); +TORCH_API at::Tensor & alpha_dropout_(at::Tensor & self, double p, bool train); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/alpha_dropout_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/alpha_dropout_native.h new file mode 100644 index 0000000000000000000000000000000000000000..725ea23ed4b09dce3a3b34bb1edd8d6678252d01 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/alpha_dropout_native.h @@ -0,0 +1,22 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor alpha_dropout(const at::Tensor & input, double p, bool train); +TORCH_API at::Tensor & alpha_dropout_(at::Tensor & self, double p, bool train); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/alpha_dropout_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/alpha_dropout_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..8a9aa17495aa42dcd4d880744943c5a8779ceb35 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/alpha_dropout_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API alpha_dropout { + using schema = at::Tensor (const at::Tensor &, double, bool); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::alpha_dropout"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "alpha_dropout(Tensor input, float p, bool train) -> Tensor"; + static at::Tensor call(const at::Tensor & input, double p, bool train); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, double p, bool train); +}; + +struct TORCH_API alpha_dropout_ { + using schema = at::Tensor & (at::Tensor &, double, bool); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::alpha_dropout_"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "alpha_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)"; + static at::Tensor & call(at::Tensor & self, double p, bool train); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double p, bool train); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/amax.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/amax.h new file mode 100644 index 0000000000000000000000000000000000000000..e21288e7081b29e7e926387f3501e03868ce4312 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/amax.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor +inline at::Tensor amax(const at::Tensor & self, at::IntArrayRef dim={}, bool keepdim=false) { + return at::_ops::amax::call(self, dim, keepdim); +} + +// aten::amax.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & amax_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim={}, bool keepdim=false) { + return at::_ops::amax_out::call(self, dim, keepdim, out); +} +// aten::amax.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & amax_outf(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out) { + return at::_ops::amax_out::call(self, dim, keepdim, out); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/amax_compositeexplicitautogradnonfunctional_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/amax_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..0bc1bf6673c4682e4b7db3d2898bc7f832bc6950 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/amax_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API at::Tensor amax(const at::Tensor & self, at::IntArrayRef dim={}, bool keepdim=false); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/amax_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/amax_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..956671ceae77f09c9fd84dc1b25593e3f54b43ab --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/amax_cpu_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor amax(const at::Tensor & self, at::IntArrayRef dim={}, bool keepdim=false); +TORCH_API at::Tensor & amax_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim={}, bool keepdim=false); +TORCH_API at::Tensor & amax_outf(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/amax_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/amax_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..1091904d66c58465ae2cb0565d5b33d501a85b4e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/amax_cuda_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor amax(const at::Tensor & self, at::IntArrayRef dim={}, bool keepdim=false); +TORCH_API at::Tensor & amax_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim={}, bool keepdim=false); +TORCH_API at::Tensor & amax_outf(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/amax_meta.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/amax_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..bdbc9f3e1e9da2ac97680542647ecf4b1df74f68 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/amax_meta.h @@ -0,0 +1,27 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +struct TORCH_API structured_amax : public at::impl::MetaBase { + + + void meta(const at::Tensor & self, at::IntArrayRef dim, bool keepdim); +}; + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/amax_meta_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/amax_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..78591ca0076cfa7bdb4a12c96094323ed61783f2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/amax_meta_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor amax(const at::Tensor & self, at::IntArrayRef dim={}, bool keepdim=false); +TORCH_API at::Tensor & amax_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim={}, bool keepdim=false); +TORCH_API at::Tensor & amax_outf(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out); + +} // namespace meta +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/amax_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/amax_native.h new file mode 100644 index 0000000000000000000000000000000000000000..2e58cbc8fa92bd7e8058976109cf93c1028e4075 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/amax_native.h @@ -0,0 +1,23 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +struct TORCH_API structured_amax_out : public at::meta::structured_amax { +void impl(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, const at::Tensor & out); +}; +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/amax_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/amax_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..3fca269a7462a403702683be6dee3f09de3f1127 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/amax_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API amax { + using schema = at::Tensor (const at::Tensor &, at::IntArrayRef, bool); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::amax"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor"; + static at::Tensor call(const at::Tensor & self, at::IntArrayRef dim, bool keepdim); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim); +}; + +struct TORCH_API amax_out { + using schema = at::Tensor & (const at::Tensor &, at::IntArrayRef, bool, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::amax"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "amax.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/amin.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/amin.h new file mode 100644 index 0000000000000000000000000000000000000000..b40703db5d60706467e9805a6a4ac2bb4c4c6c7a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/amin.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor +inline at::Tensor amin(const at::Tensor & self, at::IntArrayRef dim={}, bool keepdim=false) { + return at::_ops::amin::call(self, dim, keepdim); +} + +// aten::amin.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & amin_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim={}, bool keepdim=false) { + return at::_ops::amin_out::call(self, dim, keepdim, out); +} +// aten::amin.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & amin_outf(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out) { + return at::_ops::amin_out::call(self, dim, keepdim, out); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/amin_compositeexplicitautogradnonfunctional_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/amin_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..7afef4dffde98019e919468a98bded8779d130d1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/amin_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API at::Tensor amin(const at::Tensor & self, at::IntArrayRef dim={}, bool keepdim=false); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/amin_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/amin_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..1d436b64ae162bed4b33aee8e333b142a58ca713 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/amin_cpu_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor amin(const at::Tensor & self, at::IntArrayRef dim={}, bool keepdim=false); +TORCH_API at::Tensor & amin_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim={}, bool keepdim=false); +TORCH_API at::Tensor & amin_outf(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/amin_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/amin_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..8fca63e45ad8b99295c7c2b0d6f9440acc7f5d71 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/amin_cuda_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor amin(const at::Tensor & self, at::IntArrayRef dim={}, bool keepdim=false); +TORCH_API at::Tensor & amin_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim={}, bool keepdim=false); +TORCH_API at::Tensor & amin_outf(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/amin_meta.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/amin_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..798cdc45aa7b9ee5be8cb595c17ed1738f29fe70 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/amin_meta.h @@ -0,0 +1,27 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +struct TORCH_API structured_amin : public at::impl::MetaBase { + + + void meta(const at::Tensor & self, at::IntArrayRef dim, bool keepdim); +}; + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/amin_meta_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/amin_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..d63d6d2795b0e056262e155a9cd10a4d46c6164d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/amin_meta_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor amin(const at::Tensor & self, at::IntArrayRef dim={}, bool keepdim=false); +TORCH_API at::Tensor & amin_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim={}, bool keepdim=false); +TORCH_API at::Tensor & amin_outf(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out); + +} // namespace meta +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/amin_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/amin_native.h new file mode 100644 index 0000000000000000000000000000000000000000..36ee2bf62b78a0b732a69b9d187bea0ae5eef1df --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/amin_native.h @@ -0,0 +1,23 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +struct TORCH_API structured_amin_out : public at::meta::structured_amin { +void impl(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, const at::Tensor & out); +}; +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/amin_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/amin_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..2e5844ed0d602bb6dea74ec5e6b32468819d5166 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/amin_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API amin { + using schema = at::Tensor (const at::Tensor &, at::IntArrayRef, bool); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::amin"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor"; + static at::Tensor call(const at::Tensor & self, at::IntArrayRef dim, bool keepdim); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim); +}; + +struct TORCH_API amin_out { + using schema = at::Tensor & (const at::Tensor &, at::IntArrayRef, bool, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::amin"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "amin.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/aminmax.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/aminmax.h new file mode 100644 index 0000000000000000000000000000000000000000..d45cb45ec46029fe3ba995bd42f40003e733d525 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/aminmax.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::aminmax(Tensor self, *, int? dim=None, bool keepdim=False) -> (Tensor min, Tensor max) +inline ::std::tuple aminmax(const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false) { + return at::_ops::aminmax::call(self, dim, keepdim); +} + +// aten::aminmax.out(Tensor self, *, int? dim=None, bool keepdim=False, Tensor(a!) min, Tensor(b!) max) -> (Tensor(a!) min, Tensor(b!) max) +inline ::std::tuple aminmax_out(at::Tensor & min, at::Tensor & max, const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false) { + return at::_ops::aminmax_out::call(self, dim, keepdim, min, max); +} +// aten::aminmax.out(Tensor self, *, int? dim=None, bool keepdim=False, Tensor(a!) min, Tensor(b!) max) -> (Tensor(a!) min, Tensor(b!) max) +inline ::std::tuple aminmax_outf(const at::Tensor & self, ::std::optional dim, bool keepdim, at::Tensor & min, at::Tensor & max) { + return at::_ops::aminmax_out::call(self, dim, keepdim, min, max); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/aminmax_compositeexplicitautogradnonfunctional_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/aminmax_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..8fd965f2a09a4f96ab37790f7a4571b72201e22d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/aminmax_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API ::std::tuple aminmax(const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/aminmax_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/aminmax_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..0eeb90bca0d1447d4183001b9a97d980b622d2fd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/aminmax_cpu_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API ::std::tuple aminmax(const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false); +TORCH_API ::std::tuple aminmax_out(at::Tensor & min, at::Tensor & max, const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false); +TORCH_API ::std::tuple aminmax_outf(const at::Tensor & self, ::std::optional dim, bool keepdim, at::Tensor & min, at::Tensor & max); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/aminmax_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/aminmax_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..15aab1f780ef432d0ad499988641d55a0489e66e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/aminmax_cuda_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API ::std::tuple aminmax(const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false); +TORCH_API ::std::tuple aminmax_out(at::Tensor & min, at::Tensor & max, const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false); +TORCH_API ::std::tuple aminmax_outf(const at::Tensor & self, ::std::optional dim, bool keepdim, at::Tensor & min, at::Tensor & max); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/aminmax_meta.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/aminmax_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..edcb7742aeffe1d5c283fae8f639401a4fc3db67 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/aminmax_meta.h @@ -0,0 +1,27 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +struct TORCH_API structured_aminmax : public at::impl::MetaBase { + + + void meta(const at::Tensor & self, ::std::optional dim, bool keepdim); +}; + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/aminmax_meta_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/aminmax_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..1b410252de49213153ac18313214021e929fd064 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/aminmax_meta_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API ::std::tuple aminmax(const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false); +TORCH_API ::std::tuple aminmax_out(at::Tensor & min, at::Tensor & max, const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false); +TORCH_API ::std::tuple aminmax_outf(const at::Tensor & self, ::std::optional dim, bool keepdim, at::Tensor & min, at::Tensor & max); + +} // namespace meta +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/aminmax_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/aminmax_native.h new file mode 100644 index 0000000000000000000000000000000000000000..0d2fdc9b44d56ade40442c94b10142ad5d8e8e6a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/aminmax_native.h @@ -0,0 +1,23 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +struct TORCH_API structured_aminmax_out : public at::meta::structured_aminmax { +void impl(const at::Tensor & self, ::std::optional dim, bool keepdim, const at::Tensor & min, const at::Tensor & max); +}; +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/aminmax_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/aminmax_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..ded09002be03259aee2757ef92470b056c7139ca --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/aminmax_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API aminmax { + using schema = ::std::tuple (const at::Tensor &, ::std::optional, bool); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::aminmax"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "aminmax(Tensor self, *, int? dim=None, bool keepdim=False) -> (Tensor min, Tensor max)"; + static ::std::tuple call(const at::Tensor & self, ::std::optional dim, bool keepdim); + static ::std::tuple redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dim, bool keepdim); +}; + +struct TORCH_API aminmax_out { + using schema = ::std::tuple (const at::Tensor &, ::std::optional, bool, at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::aminmax"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "aminmax.out(Tensor self, *, int? dim=None, bool keepdim=False, Tensor(a!) min, Tensor(b!) max) -> (Tensor(a!) min, Tensor(b!) max)"; + static ::std::tuple call(const at::Tensor & self, ::std::optional dim, bool keepdim, at::Tensor & min, at::Tensor & max); + static ::std::tuple redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dim, bool keepdim, at::Tensor & min, at::Tensor & max); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/and.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/and.h new file mode 100644 index 0000000000000000000000000000000000000000..679e7b540dd8b89429ee0769dc5a0de32fdf0308 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/and.h @@ -0,0 +1,36 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::__and__.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor __and__(const at::Tensor & self, const at::Scalar & other) { + return at::_ops::__and___Scalar::call(self, other); +} + +// aten::__and__.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor __and__(const at::Tensor & self, const at::Tensor & other) { + return at::_ops::__and___Tensor::call(self, other); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/and_compositeimplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/and_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..1e04786a31a34f2639f8d52fa04b35f5b912af1f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/and_compositeimplicitautograd_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API at::Tensor __and__(const at::Tensor & self, const at::Scalar & other); +TORCH_API at::Tensor & __iand__(at::Tensor & self, const at::Scalar & other); +TORCH_API at::Tensor __and__(const at::Tensor & self, const at::Tensor & other); +TORCH_API at::Tensor & __iand__(at::Tensor & self, const at::Tensor & other); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/and_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/and_native.h new file mode 100644 index 0000000000000000000000000000000000000000..e4405f86d64c25d0792324f5c8174793d48a673c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/and_native.h @@ -0,0 +1,24 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor __and__(const at::Tensor & self, const at::Scalar & other); +TORCH_API at::Tensor & __iand__(at::Tensor & self, const at::Scalar & other); +TORCH_API at::Tensor __and__(const at::Tensor & self, const at::Tensor & other); +TORCH_API at::Tensor & __iand__(at::Tensor & self, const at::Tensor & other); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/and_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/and_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..91184e81a73dfc843eaa7d5ee4454978e66bcd90 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/and_ops.h @@ -0,0 +1,62 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API __and___Scalar { + using schema = at::Tensor (const at::Tensor &, const at::Scalar &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::__and__"; + static constexpr const char* overload_name = "Scalar"; + static constexpr const char* schema_str = "__and__.Scalar(Tensor self, Scalar other) -> Tensor"; + static at::Tensor call(const at::Tensor & self, const at::Scalar & other); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other); +}; + +struct TORCH_API __and___Tensor { + using schema = at::Tensor (const at::Tensor &, const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::__and__"; + static constexpr const char* overload_name = "Tensor"; + static constexpr const char* schema_str = "__and__.Tensor(Tensor self, Tensor other) -> Tensor"; + static at::Tensor call(const at::Tensor & self, const at::Tensor & other); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other); +}; + +struct TORCH_API __iand___Scalar { + using schema = at::Tensor & (at::Tensor &, const at::Scalar &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::__iand__"; + static constexpr const char* overload_name = "Scalar"; + static constexpr const char* schema_str = "__iand__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"; + static at::Tensor & call(at::Tensor & self, const at::Scalar & other); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other); +}; + +struct TORCH_API __iand___Tensor { + using schema = at::Tensor & (at::Tensor &, const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::__iand__"; + static constexpr const char* overload_name = "Tensor"; + static constexpr const char* schema_str = "__iand__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"; + static at::Tensor & call(at::Tensor & self, const at::Tensor & other); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/angle.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/angle.h new file mode 100644 index 0000000000000000000000000000000000000000..1b4309f89fce1f2b385f80538d138c41c16f79de --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/angle.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::angle(Tensor self) -> Tensor +inline at::Tensor angle(const at::Tensor & self) { + return at::_ops::angle::call(self); +} + +// aten::angle.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & angle_out(at::Tensor & out, const at::Tensor & self) { + return at::_ops::angle_out::call(self, out); +} +// aten::angle.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & angle_outf(const at::Tensor & self, at::Tensor & out) { + return at::_ops::angle_out::call(self, out); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/angle_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/angle_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..d9d0733d232f8f5133f81ba308d3b8b5cf4b541e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/angle_cpu_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor angle(const at::Tensor & self); +TORCH_API at::Tensor & angle_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & angle_outf(const at::Tensor & self, at::Tensor & out); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/angle_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/angle_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..d68a4fe12b95568df35de2d3d2189752f1126825 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/angle_cuda_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor angle(const at::Tensor & self); +TORCH_API at::Tensor & angle_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & angle_outf(const at::Tensor & self, at::Tensor & out); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/angle_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/angle_native.h new file mode 100644 index 0000000000000000000000000000000000000000..97de30dad8305009b17d67a045fdeeff8b699e53 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/angle_native.h @@ -0,0 +1,24 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor angle(const at::Tensor & self); +TORCH_API at::Tensor & angle_out(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor angle_sparse_csr(const at::Tensor & self); +TORCH_API at::Tensor & angle_sparse_csr_out(const at::Tensor & self, at::Tensor & out); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/angle_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/angle_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..a899d4bce0a928cb149f43e652a66e01e729d228 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/angle_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API angle { + using schema = at::Tensor (const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::angle"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "angle(Tensor self) -> Tensor"; + static at::Tensor call(const at::Tensor & self); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self); +}; + +struct TORCH_API angle_out { + using schema = at::Tensor & (const at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::angle"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "angle.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/any.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/any.h new file mode 100644 index 0000000000000000000000000000000000000000..b235172795f1a2ef118889cb32aee9325376f454 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/any.h @@ -0,0 +1,82 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor +inline at::Tensor any(const at::Tensor & self, int64_t dim, bool keepdim=false) { + return at::_ops::any_dim::call(self, dim, keepdim); +} + +// aten::any.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor +inline at::Tensor any(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false) { + return at::_ops::any_dims::call(self, dim, keepdim); +} + +// aten::any.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & any_out(at::Tensor & out, const at::Tensor & self, int64_t dim, bool keepdim=false) { + return at::_ops::any_out::call(self, dim, keepdim, out); +} +// aten::any.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & any_outf(const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & out) { + return at::_ops::any_out::call(self, dim, keepdim, out); +} + +// aten::any.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & any_out(at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false) { + return at::_ops::any_dims_out::call(self, dim, keepdim, out); +} +// aten::any.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & any_outf(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, at::Tensor & out) { + return at::_ops::any_dims_out::call(self, dim, keepdim, out); +} + +// aten::any.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor +inline at::Tensor any(const at::Tensor & self, at::Dimname dim, bool keepdim=false) { + return at::_ops::any_dimname::call(self, dim, keepdim); +} + +// aten::any.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & any_out(at::Tensor & out, const at::Tensor & self, at::Dimname dim, bool keepdim=false) { + return at::_ops::any_dimname_out::call(self, dim, keepdim, out); +} +// aten::any.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & any_outf(const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & out) { + return at::_ops::any_dimname_out::call(self, dim, keepdim, out); +} + +// aten::any(Tensor self) -> Tensor +inline at::Tensor any(const at::Tensor & self) { + return at::_ops::any::call(self); +} + +// aten::any.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & any_out(at::Tensor & out, const at::Tensor & self) { + return at::_ops::any_all_out::call(self, out); +} +// aten::any.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & any_outf(const at::Tensor & self, at::Tensor & out) { + return at::_ops::any_all_out::call(self, out); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/any_compositeexplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/any_compositeexplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..a40350edea217a7d18007a8125120e39e946ea6f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/any_compositeexplicitautograd_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautograd { + +TORCH_API at::Tensor any(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false); +TORCH_API at::Tensor & any_out(at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false); +TORCH_API at::Tensor & any_outf(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, at::Tensor & out); + +} // namespace compositeexplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/any_compositeexplicitautogradnonfunctional_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/any_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..17bfea006774e09e4d5eb38e5143b6c612c99719 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/any_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API at::Tensor any(const at::Tensor & self, int64_t dim, bool keepdim=false); +TORCH_API at::Tensor any(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false); +TORCH_API at::Tensor any(const at::Tensor & self); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/any_compositeimplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/any_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..709ae8246826b53501f7578ca623237a3925a188 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/any_compositeimplicitautograd_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API at::Tensor any(const at::Tensor & self, at::Dimname dim, bool keepdim=false); +TORCH_API at::Tensor & any_out(at::Tensor & out, const at::Tensor & self, at::Dimname dim, bool keepdim=false); +TORCH_API at::Tensor & any_outf(const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & out); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/any_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/any_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..e17e71679f57317d0abd1f3961d7bb51c84cd42c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/any_cpu_dispatch.h @@ -0,0 +1,31 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor any(const at::Tensor & self, int64_t dim, bool keepdim=false); +TORCH_API at::Tensor & any_out(at::Tensor & out, const at::Tensor & self, int64_t dim, bool keepdim=false); +TORCH_API at::Tensor & any_outf(const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & out); +TORCH_API at::Tensor any(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false); +TORCH_API at::Tensor & any_out(at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false); +TORCH_API at::Tensor & any_outf(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, at::Tensor & out); +TORCH_API at::Tensor any(const at::Tensor & self); +TORCH_API at::Tensor & any_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & any_outf(const at::Tensor & self, at::Tensor & out); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/any_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/any_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..2fd370fc8f442e4833a410d46a85bc873ca29afe --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/any_cuda_dispatch.h @@ -0,0 +1,31 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor any(const at::Tensor & self, int64_t dim, bool keepdim=false); +TORCH_API at::Tensor & any_out(at::Tensor & out, const at::Tensor & self, int64_t dim, bool keepdim=false); +TORCH_API at::Tensor & any_outf(const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & out); +TORCH_API at::Tensor any(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false); +TORCH_API at::Tensor & any_out(at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false); +TORCH_API at::Tensor & any_outf(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, at::Tensor & out); +TORCH_API at::Tensor any(const at::Tensor & self); +TORCH_API at::Tensor & any_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & any_outf(const at::Tensor & self, at::Tensor & out); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/any_meta.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/any_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..e393fd428555ab2bda214e9fbe8572b4d5526b5e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/any_meta.h @@ -0,0 +1,37 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +struct TORCH_API structured_any_dim : public at::impl::MetaBase { + + + void meta(const at::Tensor & self, int64_t dim, bool keepdim); +}; +struct TORCH_API structured_any_dims : public at::impl::MetaBase { + + + void meta(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim); +}; +struct TORCH_API structured_any : public at::impl::MetaBase { + + + void meta(const at::Tensor & self); +}; + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/any_meta_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/any_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..d70c44c73dba511f1f1316dc60d0683e46980cde --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/any_meta_dispatch.h @@ -0,0 +1,31 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor any(const at::Tensor & self, int64_t dim, bool keepdim=false); +TORCH_API at::Tensor & any_out(at::Tensor & out, const at::Tensor & self, int64_t dim, bool keepdim=false); +TORCH_API at::Tensor & any_outf(const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & out); +TORCH_API at::Tensor any(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false); +TORCH_API at::Tensor & any_out(at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false); +TORCH_API at::Tensor & any_outf(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, at::Tensor & out); +TORCH_API at::Tensor any(const at::Tensor & self); +TORCH_API at::Tensor & any_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & any_outf(const at::Tensor & self, at::Tensor & out); + +} // namespace meta +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/any_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/any_native.h new file mode 100644 index 0000000000000000000000000000000000000000..aa6eab4fd85d4fe4d96d347f5ec27f344d3baf05 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/any_native.h @@ -0,0 +1,34 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +struct TORCH_API structured_any_out : public at::meta::structured_any_dim { +void impl(const at::Tensor & self, int64_t dim, bool keepdim, const at::Tensor & out); +}; +TORCH_API at::Tensor any_dims_default(const at::Tensor & self, at::OptionalIntArrayRef dim=::std::nullopt, bool keepdim=false); +TORCH_API at::Tensor & any_dims_out_default(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, at::Tensor & out); +struct TORCH_API structured_any_dims_out : public at::meta::structured_any_dims { +void impl(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, const at::Tensor & out); +}; +TORCH_API at::Tensor any(const at::Tensor & self, at::Dimname dim, bool keepdim=false); +TORCH_API at::Tensor & any_out(const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & out); +struct TORCH_API structured_any_all_out : public at::meta::structured_any { +void impl(const at::Tensor & self, const at::Tensor & out); +}; +TORCH_API at::Tensor any_sparse(const at::Tensor & self); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/any_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/any_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..aa1e970f3c4c43fc4e24f00ac98fdb4cd1468024 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/any_ops.h @@ -0,0 +1,106 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API any_dim { + using schema = at::Tensor (const at::Tensor &, int64_t, bool); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::any"; + static constexpr const char* overload_name = "dim"; + static constexpr const char* schema_str = "any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor"; + static at::Tensor call(const at::Tensor & self, int64_t dim, bool keepdim); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim); +}; + +struct TORCH_API any_dims { + using schema = at::Tensor (const at::Tensor &, at::OptionalIntArrayRef, bool); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::any"; + static constexpr const char* overload_name = "dims"; + static constexpr const char* schema_str = "any.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor"; + static at::Tensor call(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim); +}; + +struct TORCH_API any_out { + using schema = at::Tensor & (const at::Tensor &, int64_t, bool, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::any"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "any.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & out); +}; + +struct TORCH_API any_dims_out { + using schema = at::Tensor & (const at::Tensor &, at::OptionalIntArrayRef, bool, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::any"; + static constexpr const char* overload_name = "dims_out"; + static constexpr const char* schema_str = "any.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, at::Tensor & out); +}; + +struct TORCH_API any_dimname { + using schema = at::Tensor (const at::Tensor &, at::Dimname, bool); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::any"; + static constexpr const char* overload_name = "dimname"; + static constexpr const char* schema_str = "any.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor"; + static at::Tensor call(const at::Tensor & self, at::Dimname dim, bool keepdim); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim); +}; + +struct TORCH_API any_dimname_out { + using schema = at::Tensor & (const at::Tensor &, at::Dimname, bool, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::any"; + static constexpr const char* overload_name = "dimname_out"; + static constexpr const char* schema_str = "any.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & out); +}; + +struct TORCH_API any { + using schema = at::Tensor (const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::any"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "any(Tensor self) -> Tensor"; + static at::Tensor call(const at::Tensor & self); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self); +}; + +struct TORCH_API any_all_out { + using schema = at::Tensor & (const at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::any"; + static constexpr const char* overload_name = "all_out"; + static constexpr const char* schema_str = "any.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/arange.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/arange.h new file mode 100644 index 0000000000000000000000000000000000000000..5d7b0d2efbe604e88aa066811e1beb5614ea2dd5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/arange.h @@ -0,0 +1,71 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor arange(const at::Scalar & end, at::TensorOptions options={}) { + return at::_ops::arange::call(end, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); +} +// aten::arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor arange(const at::Scalar & end, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::arange::call(end, dtype, layout, device, pin_memory); +} + +// aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor arange(const at::Scalar & start, const at::Scalar & end, at::TensorOptions options={}) { + return at::_ops::arange_start::call(start, end, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); +} +// aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor arange(const at::Scalar & start, const at::Scalar & end, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::arange_start::call(start, end, dtype, layout, device, pin_memory); +} + +// aten::arange.start_step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor arange(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::TensorOptions options={}) { + return at::_ops::arange_start_step::call(start, end, step, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); +} +// aten::arange.start_step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor arange(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::arange_start_step::call(start, end, step, dtype, layout, device, pin_memory); +} + +// aten::arange.out(Scalar end, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & arange_out(at::Tensor & out, const at::Scalar & end) { + return at::_ops::arange_out::call(end, out); +} +// aten::arange.out(Scalar end, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & arange_outf(const at::Scalar & end, at::Tensor & out) { + return at::_ops::arange_out::call(end, out); +} + +// aten::arange.start_out(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & arange_out(at::Tensor & out, const at::Scalar & start, const at::Scalar & end, const at::Scalar & step) { + return at::_ops::arange_start_out::call(start, end, step, out); +} +// aten::arange.start_out(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & arange_outf(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out) { + return at::_ops::arange_start_out::call(start, end, step, out); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/arange_compositeexplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/arange_compositeexplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..ec63a7166299bb93c7f1fd95784db613accdbd43 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/arange_compositeexplicitautograd_dispatch.h @@ -0,0 +1,30 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautograd { + +TORCH_API at::Tensor arange(const at::Scalar & end, at::TensorOptions options={}); +TORCH_API at::Tensor arange(const at::Scalar & end, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); +TORCH_API at::Tensor & arange_out(at::Tensor & out, const at::Scalar & end); +TORCH_API at::Tensor & arange_outf(const at::Scalar & end, at::Tensor & out); +TORCH_API at::Tensor arange(const at::Scalar & start, const at::Scalar & end, at::TensorOptions options={}); +TORCH_API at::Tensor arange(const at::Scalar & start, const at::Scalar & end, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); +TORCH_API at::Tensor arange(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::TensorOptions options={}); +TORCH_API at::Tensor arange(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); + +} // namespace compositeexplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/arange_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/arange_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..56bbe7eed31510dd10b7ca253626c9eecace922a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/arange_cpu_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor & arange_out(at::Tensor & out, const at::Scalar & start, const at::Scalar & end, const at::Scalar & step); +TORCH_API at::Tensor & arange_outf(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/arange_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/arange_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..5b4f2cabcb7ed594b72d9f21e1cff8174475c7c7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/arange_cuda_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor & arange_out(at::Tensor & out, const at::Scalar & start, const at::Scalar & end, const at::Scalar & step); +TORCH_API at::Tensor & arange_outf(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/arange_meta_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/arange_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..9c273b36d2300497549eeb2219bc0e36ee89c76c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/arange_meta_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor & arange_out(at::Tensor & out, const at::Scalar & start, const at::Scalar & end, const at::Scalar & step); +TORCH_API at::Tensor & arange_outf(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out); + +} // namespace meta +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/arange_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/arange_native.h new file mode 100644 index 0000000000000000000000000000000000000000..95d25cc06edf070c6bbc1602116fbc0c7e7402dc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/arange_native.h @@ -0,0 +1,26 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor arange(const at::Scalar & end, ::std::optional dtype={}, ::std::optional layout={}, ::std::optional device={}, ::std::optional pin_memory={}); +TORCH_API at::Tensor & arange_out(const at::Scalar & end, at::Tensor & out); +TORCH_API at::Tensor arange(const at::Scalar & start, const at::Scalar & end, ::std::optional dtype={}, ::std::optional layout={}, ::std::optional device={}, ::std::optional pin_memory={}); +TORCH_API at::Tensor arange(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step=1, ::std::optional dtype={}, ::std::optional layout={}, ::std::optional device={}, ::std::optional pin_memory={}); +TORCH_API at::Tensor & arange_out(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out); +TORCH_API at::Tensor & arange_cuda_out(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/arange_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/arange_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..fb7aecc36e59e8be3b6bc5036f7f57d2682b3b74 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/arange_ops.h @@ -0,0 +1,73 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API arange { + using schema = at::Tensor (const at::Scalar &, ::std::optional, ::std::optional, ::std::optional, ::std::optional); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::arange"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"; + static at::Tensor call(const at::Scalar & end, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Scalar & end, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); +}; + +struct TORCH_API arange_start { + using schema = at::Tensor (const at::Scalar &, const at::Scalar &, ::std::optional, ::std::optional, ::std::optional, ::std::optional); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::arange"; + static constexpr const char* overload_name = "start"; + static constexpr const char* schema_str = "arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"; + static at::Tensor call(const at::Scalar & start, const at::Scalar & end, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); +}; + +struct TORCH_API arange_start_step { + using schema = at::Tensor (const at::Scalar &, const at::Scalar &, const at::Scalar &, ::std::optional, ::std::optional, ::std::optional, ::std::optional); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::arange"; + static constexpr const char* overload_name = "start_step"; + static constexpr const char* schema_str = "arange.start_step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"; + static at::Tensor call(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); +}; + +struct TORCH_API arange_out { + using schema = at::Tensor & (const at::Scalar &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::arange"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "arange.out(Scalar end, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Scalar & end, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Scalar & end, at::Tensor & out); +}; + +struct TORCH_API arange_start_out { + using schema = at::Tensor & (const at::Scalar &, const at::Scalar &, const at::Scalar &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::arange"; + static constexpr const char* overload_name = "start_out"; + static constexpr const char* schema_str = "arange.start_out(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/arccos.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/arccos.h new file mode 100644 index 0000000000000000000000000000000000000000..5c485072a3f9cc958f42fc3cba5c11346a334a95 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/arccos.h @@ -0,0 +1,45 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::arccos(Tensor self) -> Tensor +inline at::Tensor arccos(const at::Tensor & self) { + return at::_ops::arccos::call(self); +} + +// aten::arccos_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & arccos_(at::Tensor & self) { + return at::_ops::arccos_::call(self); +} + +// aten::arccos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & arccos_out(at::Tensor & out, const at::Tensor & self) { + return at::_ops::arccos_out::call(self, out); +} +// aten::arccos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & arccos_outf(const at::Tensor & self, at::Tensor & out) { + return at::_ops::arccos_out::call(self, out); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/arccos_compositeimplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/arccos_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..cd088bc30661b1ab76b6858d139d523c5809fade --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/arccos_compositeimplicitautograd_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API at::Tensor arccos(const at::Tensor & self); +TORCH_API at::Tensor & arccos_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & arccos_outf(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & arccos_(at::Tensor & self); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/arccos_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/arccos_native.h new file mode 100644 index 0000000000000000000000000000000000000000..41dca8181fb5ae19205f54b5868da978fbd1b701 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/arccos_native.h @@ -0,0 +1,23 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor arccos(const at::Tensor & self); +TORCH_API at::Tensor & arccos_out(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & arccos_(at::Tensor & self); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/arccos_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/arccos_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..59a2f4d951a10773644ecfc21db080ef2b446c56 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/arccos_ops.h @@ -0,0 +1,51 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API arccos { + using schema = at::Tensor (const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::arccos"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "arccos(Tensor self) -> Tensor"; + static at::Tensor call(const at::Tensor & self); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self); +}; + +struct TORCH_API arccos_ { + using schema = at::Tensor & (at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::arccos_"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "arccos_(Tensor(a!) self) -> Tensor(a!)"; + static at::Tensor & call(at::Tensor & self); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self); +}; + +struct TORCH_API arccos_out { + using schema = at::Tensor & (const at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::arccos"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "arccos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/arccosh.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/arccosh.h new file mode 100644 index 0000000000000000000000000000000000000000..89e3a6c7e3128c687581e2d4a712147b5e0ceac0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/arccosh.h @@ -0,0 +1,45 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::arccosh(Tensor self) -> Tensor +inline at::Tensor arccosh(const at::Tensor & self) { + return at::_ops::arccosh::call(self); +} + +// aten::arccosh_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & arccosh_(at::Tensor & self) { + return at::_ops::arccosh_::call(self); +} + +// aten::arccosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & arccosh_out(at::Tensor & out, const at::Tensor & self) { + return at::_ops::arccosh_out::call(self, out); +} +// aten::arccosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & arccosh_outf(const at::Tensor & self, at::Tensor & out) { + return at::_ops::arccosh_out::call(self, out); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/arccosh_compositeimplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/arccosh_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..d78afde78c203811b5cf8a7773d0c070b93173aa --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/arccosh_compositeimplicitautograd_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API at::Tensor arccosh(const at::Tensor & self); +TORCH_API at::Tensor & arccosh_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & arccosh_outf(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & arccosh_(at::Tensor & self); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/arccosh_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/arccosh_native.h new file mode 100644 index 0000000000000000000000000000000000000000..6119ae7e06effa1af5667303231b04b0ef3c2216 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/arccosh_native.h @@ -0,0 +1,23 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor arccosh(const at::Tensor & self); +TORCH_API at::Tensor & arccosh_out(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & arccosh_(at::Tensor & self); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/arccosh_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/arccosh_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..70a9ff1a2711912d6a5bf92b270f14bbfd8d81d9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/arccosh_ops.h @@ -0,0 +1,51 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API arccosh { + using schema = at::Tensor (const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::arccosh"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "arccosh(Tensor self) -> Tensor"; + static at::Tensor call(const at::Tensor & self); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self); +}; + +struct TORCH_API arccosh_ { + using schema = at::Tensor & (at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::arccosh_"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "arccosh_(Tensor(a!) self) -> Tensor(a!)"; + static at::Tensor & call(at::Tensor & self); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self); +}; + +struct TORCH_API arccosh_out { + using schema = at::Tensor & (const at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::arccosh"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "arccosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/arcsin.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/arcsin.h new file mode 100644 index 0000000000000000000000000000000000000000..a07b224c14645eb5df74a6c7122dfea46ffe28d2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/arcsin.h @@ -0,0 +1,45 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::arcsin(Tensor self) -> Tensor +inline at::Tensor arcsin(const at::Tensor & self) { + return at::_ops::arcsin::call(self); +} + +// aten::arcsin_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & arcsin_(at::Tensor & self) { + return at::_ops::arcsin_::call(self); +} + +// aten::arcsin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & arcsin_out(at::Tensor & out, const at::Tensor & self) { + return at::_ops::arcsin_out::call(self, out); +} +// aten::arcsin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & arcsin_outf(const at::Tensor & self, at::Tensor & out) { + return at::_ops::arcsin_out::call(self, out); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/arcsin_compositeimplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/arcsin_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..b3abe85ec70ec1a572858f904688928489d95e0c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/arcsin_compositeimplicitautograd_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API at::Tensor arcsin(const at::Tensor & self); +TORCH_API at::Tensor & arcsin_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & arcsin_outf(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & arcsin_(at::Tensor & self); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/arcsin_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/arcsin_native.h new file mode 100644 index 0000000000000000000000000000000000000000..35d5d9ff617d5b2ba3e118d3caba242da0ee02d2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/arcsin_native.h @@ -0,0 +1,23 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor arcsin(const at::Tensor & self); +TORCH_API at::Tensor & arcsin_out(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & arcsin_(at::Tensor & self); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/arcsin_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/arcsin_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..ab74cb32821e6a71185cfe4d3238fc9bba35a48c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/arcsin_ops.h @@ -0,0 +1,51 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API arcsin { + using schema = at::Tensor (const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::arcsin"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "arcsin(Tensor self) -> Tensor"; + static at::Tensor call(const at::Tensor & self); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self); +}; + +struct TORCH_API arcsin_ { + using schema = at::Tensor & (at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::arcsin_"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "arcsin_(Tensor(a!) self) -> Tensor(a!)"; + static at::Tensor & call(at::Tensor & self); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self); +}; + +struct TORCH_API arcsin_out { + using schema = at::Tensor & (const at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::arcsin"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "arcsin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/arcsinh.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/arcsinh.h new file mode 100644 index 0000000000000000000000000000000000000000..f0e7ec6e01775cff456065297534e860e2874959 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/arcsinh.h @@ -0,0 +1,45 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::arcsinh(Tensor self) -> Tensor +inline at::Tensor arcsinh(const at::Tensor & self) { + return at::_ops::arcsinh::call(self); +} + +// aten::arcsinh_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & arcsinh_(at::Tensor & self) { + return at::_ops::arcsinh_::call(self); +} + +// aten::arcsinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & arcsinh_out(at::Tensor & out, const at::Tensor & self) { + return at::_ops::arcsinh_out::call(self, out); +} +// aten::arcsinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & arcsinh_outf(const at::Tensor & self, at::Tensor & out) { + return at::_ops::arcsinh_out::call(self, out); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/arcsinh_compositeimplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/arcsinh_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..9b0b35e630a7e8f00333c0af8a5ebf4e1e4f314f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/arcsinh_compositeimplicitautograd_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API at::Tensor arcsinh(const at::Tensor & self); +TORCH_API at::Tensor & arcsinh_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & arcsinh_outf(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & arcsinh_(at::Tensor & self); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/arcsinh_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/arcsinh_native.h new file mode 100644 index 0000000000000000000000000000000000000000..55fcbc4436b683db88ff96507e6288aa8e5692b0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/arcsinh_native.h @@ -0,0 +1,23 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor arcsinh(const at::Tensor & self); +TORCH_API at::Tensor & arcsinh_out(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & arcsinh_(at::Tensor & self); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/arcsinh_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/arcsinh_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..bdb433bf1a310c2712c47bb343750f35a02adfd1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/arcsinh_ops.h @@ -0,0 +1,51 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API arcsinh { + using schema = at::Tensor (const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::arcsinh"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "arcsinh(Tensor self) -> Tensor"; + static at::Tensor call(const at::Tensor & self); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self); +}; + +struct TORCH_API arcsinh_ { + using schema = at::Tensor & (at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::arcsinh_"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "arcsinh_(Tensor(a!) self) -> Tensor(a!)"; + static at::Tensor & call(at::Tensor & self); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self); +}; + +struct TORCH_API arcsinh_out { + using schema = at::Tensor & (const at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::arcsinh"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "arcsinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/arctan.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/arctan.h new file mode 100644 index 0000000000000000000000000000000000000000..8f326cca8351975f89eb489eaf56ec0ce68a8619 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/arctan.h @@ -0,0 +1,45 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::arctan(Tensor self) -> Tensor +inline at::Tensor arctan(const at::Tensor & self) { + return at::_ops::arctan::call(self); +} + +// aten::arctan_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & arctan_(at::Tensor & self) { + return at::_ops::arctan_::call(self); +} + +// aten::arctan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & arctan_out(at::Tensor & out, const at::Tensor & self) { + return at::_ops::arctan_out::call(self, out); +} +// aten::arctan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & arctan_outf(const at::Tensor & self, at::Tensor & out) { + return at::_ops::arctan_out::call(self, out); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/arctan2.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/arctan2.h new file mode 100644 index 0000000000000000000000000000000000000000..b8e98fc0b9defc9b5eb1434e03d88e08779af2ff --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/arctan2.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::arctan2(Tensor self, Tensor other) -> Tensor +inline at::Tensor arctan2(const at::Tensor & self, const at::Tensor & other) { + return at::_ops::arctan2::call(self, other); +} + +// aten::arctan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & arctan2_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::arctan2_out::call(self, other, out); +} +// aten::arctan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & arctan2_outf(const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::arctan2_out::call(self, other, out); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/arctan2_compositeimplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/arctan2_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..3b297ef1b317a7347203e52dc971ebc9c2904299 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/arctan2_compositeimplicitautograd_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API at::Tensor arctan2(const at::Tensor & self, const at::Tensor & other); +TORCH_API at::Tensor & arctan2_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & other); +TORCH_API at::Tensor & arctan2_outf(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); +TORCH_API at::Tensor & arctan2_(at::Tensor & self, const at::Tensor & other); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/arctan2_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/arctan2_native.h new file mode 100644 index 0000000000000000000000000000000000000000..4a8b0ded187485f70a0dc6c55cd49c5177948944 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/arctan2_native.h @@ -0,0 +1,23 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor arctan2(const at::Tensor & self, const at::Tensor & other); +TORCH_API at::Tensor & arctan2_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); +TORCH_API at::Tensor & arctan2_(at::Tensor & self, const at::Tensor & other); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/arctan2_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/arctan2_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..bf8ad440f649c59f85c453a50c571c4b63fd17ca --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/arctan2_ops.h @@ -0,0 +1,51 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API arctan2 { + using schema = at::Tensor (const at::Tensor &, const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::arctan2"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "arctan2(Tensor self, Tensor other) -> Tensor"; + static at::Tensor call(const at::Tensor & self, const at::Tensor & other); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other); +}; + +struct TORCH_API arctan2_out { + using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::arctan2"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "arctan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out); +}; + +struct TORCH_API arctan2_ { + using schema = at::Tensor & (at::Tensor &, const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::arctan2_"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "arctan2_(Tensor(a!) self, Tensor other) -> Tensor(a!)"; + static at::Tensor & call(at::Tensor & self, const at::Tensor & other); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/arctan_compositeimplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/arctan_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..22071e3a29dd054ea00f03f0100970431f9e88ce --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/arctan_compositeimplicitautograd_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API at::Tensor arctan(const at::Tensor & self); +TORCH_API at::Tensor & arctan_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & arctan_outf(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & arctan_(at::Tensor & self); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/arctan_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/arctan_native.h new file mode 100644 index 0000000000000000000000000000000000000000..02b4acf656fe3ffff2f7eba212a785c757a5cb1e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/arctan_native.h @@ -0,0 +1,23 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor arctan(const at::Tensor & self); +TORCH_API at::Tensor & arctan_out(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & arctan_(at::Tensor & self); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/arctan_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/arctan_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..e86780a31838fb18b6ace22a7b6a106a376fcb42 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/arctan_ops.h @@ -0,0 +1,51 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API arctan { + using schema = at::Tensor (const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::arctan"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "arctan(Tensor self) -> Tensor"; + static at::Tensor call(const at::Tensor & self); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self); +}; + +struct TORCH_API arctan_ { + using schema = at::Tensor & (at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::arctan_"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "arctan_(Tensor(a!) self) -> Tensor(a!)"; + static at::Tensor & call(at::Tensor & self); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self); +}; + +struct TORCH_API arctan_out { + using schema = at::Tensor & (const at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::arctan"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "arctan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/arctanh.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/arctanh.h new file mode 100644 index 0000000000000000000000000000000000000000..b5ec9e4daeb0bfc9e003b05d20e9ae9954fa8c45 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/arctanh.h @@ -0,0 +1,45 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::arctanh(Tensor self) -> Tensor +inline at::Tensor arctanh(const at::Tensor & self) { + return at::_ops::arctanh::call(self); +} + +// aten::arctanh_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & arctanh_(at::Tensor & self) { + return at::_ops::arctanh_::call(self); +} + +// aten::arctanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & arctanh_out(at::Tensor & out, const at::Tensor & self) { + return at::_ops::arctanh_out::call(self, out); +} +// aten::arctanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & arctanh_outf(const at::Tensor & self, at::Tensor & out) { + return at::_ops::arctanh_out::call(self, out); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/arctanh_compositeimplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/arctanh_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..ca63c7602157e0da925995584184fda0e7b5918c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/arctanh_compositeimplicitautograd_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API at::Tensor arctanh(const at::Tensor & self); +TORCH_API at::Tensor & arctanh_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & arctanh_outf(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & arctanh_(at::Tensor & self); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/arctanh_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/arctanh_native.h new file mode 100644 index 0000000000000000000000000000000000000000..63de5a377fab2f218460dc9197d28016dcbc7e41 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/arctanh_native.h @@ -0,0 +1,23 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor arctanh(const at::Tensor & self); +TORCH_API at::Tensor & arctanh_out(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & arctanh_(at::Tensor & self); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/arctanh_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/arctanh_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..56d52b9eb2338a542f8adfdcdea7b9c6f5d692e5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/arctanh_ops.h @@ -0,0 +1,51 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API arctanh { + using schema = at::Tensor (const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::arctanh"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "arctanh(Tensor self) -> Tensor"; + static at::Tensor call(const at::Tensor & self); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self); +}; + +struct TORCH_API arctanh_ { + using schema = at::Tensor & (at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::arctanh_"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "arctanh_(Tensor(a!) self) -> Tensor(a!)"; + static at::Tensor & call(at::Tensor & self); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self); +}; + +struct TORCH_API arctanh_out { + using schema = at::Tensor & (const at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::arctanh"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "arctanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/argmax.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/argmax.h new file mode 100644 index 0000000000000000000000000000000000000000..72f568df37e85f5f7c62323573f160a5d177a84a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/argmax.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor +inline at::Tensor argmax(const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false) { + return at::_ops::argmax::call(self, dim, keepdim); +} + +// aten::argmax.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & argmax_out(at::Tensor & out, const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false) { + return at::_ops::argmax_out::call(self, dim, keepdim, out); +} +// aten::argmax.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & argmax_outf(const at::Tensor & self, ::std::optional dim, bool keepdim, at::Tensor & out) { + return at::_ops::argmax_out::call(self, dim, keepdim, out); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/argmax_compositeexplicitautogradnonfunctional_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/argmax_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..79995d4b1885fdb25a80473be81ec4cd992059f0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/argmax_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API at::Tensor argmax(const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/argmax_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/argmax_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..8955d0ac6cda366862ea83164292b8b0642e3181 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/argmax_cpu_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor argmax(const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false); +TORCH_API at::Tensor & argmax_out(at::Tensor & out, const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false); +TORCH_API at::Tensor & argmax_outf(const at::Tensor & self, ::std::optional dim, bool keepdim, at::Tensor & out); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/argmax_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/argmax_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..7dd663b59e8e8eeeaf92524978cdd2a6c58ac95a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/argmax_cuda_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor argmax(const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false); +TORCH_API at::Tensor & argmax_out(at::Tensor & out, const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false); +TORCH_API at::Tensor & argmax_outf(const at::Tensor & self, ::std::optional dim, bool keepdim, at::Tensor & out); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/argmax_meta.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/argmax_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..7ebe17718a3465c49c645856600636b5944eec47 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/argmax_meta.h @@ -0,0 +1,27 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +struct TORCH_API structured_argmax : public at::impl::MetaBase { + + + void meta(const at::Tensor & self, ::std::optional dim, bool keepdim); +}; + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/argmax_meta_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/argmax_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..892aa20730cb05432d47764aab92abc44710844e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/argmax_meta_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor argmax(const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false); +TORCH_API at::Tensor & argmax_out(at::Tensor & out, const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false); +TORCH_API at::Tensor & argmax_outf(const at::Tensor & self, ::std::optional dim, bool keepdim, at::Tensor & out); + +} // namespace meta +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/argmax_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/argmax_native.h new file mode 100644 index 0000000000000000000000000000000000000000..279a4a83ed95f7649ad5fbfa439092bd78253717 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/argmax_native.h @@ -0,0 +1,23 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +struct TORCH_API structured_argmax_out : public at::meta::structured_argmax { +void impl(const at::Tensor & self, ::std::optional dim, bool keepdim, const at::Tensor & out); +}; +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/argmax_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/argmax_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..d6c04589ff97a1c3fd13ccde7daeac9b9f2d27c9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/argmax_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API argmax { + using schema = at::Tensor (const at::Tensor &, ::std::optional, bool); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::argmax"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor"; + static at::Tensor call(const at::Tensor & self, ::std::optional dim, bool keepdim); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dim, bool keepdim); +}; + +struct TORCH_API argmax_out { + using schema = at::Tensor & (const at::Tensor &, ::std::optional, bool, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::argmax"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "argmax.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, ::std::optional dim, bool keepdim, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dim, bool keepdim, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/argmin.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/argmin.h new file mode 100644 index 0000000000000000000000000000000000000000..ba795867053f925a41a1cae51cc4b980a16fb5dc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/argmin.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor +inline at::Tensor argmin(const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false) { + return at::_ops::argmin::call(self, dim, keepdim); +} + +// aten::argmin.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & argmin_out(at::Tensor & out, const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false) { + return at::_ops::argmin_out::call(self, dim, keepdim, out); +} +// aten::argmin.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & argmin_outf(const at::Tensor & self, ::std::optional dim, bool keepdim, at::Tensor & out) { + return at::_ops::argmin_out::call(self, dim, keepdim, out); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/argmin_compositeexplicitautogradnonfunctional_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/argmin_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..72afb788f3aa6b9d863a18e6b289bc265ed56794 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/argmin_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API at::Tensor argmin(const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/argmin_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/argmin_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..d73e3a1427a8c5d75178fbbe8dc120585577f54d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/argmin_cpu_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor argmin(const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false); +TORCH_API at::Tensor & argmin_out(at::Tensor & out, const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false); +TORCH_API at::Tensor & argmin_outf(const at::Tensor & self, ::std::optional dim, bool keepdim, at::Tensor & out); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/argmin_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/argmin_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..f02c7f19962a7ebdb01bff62b372edc1ebbb528b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/argmin_cuda_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor argmin(const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false); +TORCH_API at::Tensor & argmin_out(at::Tensor & out, const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false); +TORCH_API at::Tensor & argmin_outf(const at::Tensor & self, ::std::optional dim, bool keepdim, at::Tensor & out); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/argmin_meta.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/argmin_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..2e59f8588e306e7a3fdb264b2c725683e968f3d9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/argmin_meta.h @@ -0,0 +1,27 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +struct TORCH_API structured_argmin : public at::impl::MetaBase { + + + void meta(const at::Tensor & self, ::std::optional dim, bool keepdim); +}; + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/argmin_meta_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/argmin_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..0b54741eb88c7cf801c33a5ddbf9bb3ebc5a9f5d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/argmin_meta_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor argmin(const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false); +TORCH_API at::Tensor & argmin_out(at::Tensor & out, const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false); +TORCH_API at::Tensor & argmin_outf(const at::Tensor & self, ::std::optional dim, bool keepdim, at::Tensor & out); + +} // namespace meta +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/argmin_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/argmin_native.h new file mode 100644 index 0000000000000000000000000000000000000000..01c5273b33f2a653fd3470a51fc5363a4503f54c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/argmin_native.h @@ -0,0 +1,23 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +struct TORCH_API structured_argmin_out : public at::meta::structured_argmin { +void impl(const at::Tensor & self, ::std::optional dim, bool keepdim, const at::Tensor & out); +}; +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/argmin_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/argmin_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..124a1457ac2d7cc1126df762c5d85bdf56c9af02 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/argmin_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API argmin { + using schema = at::Tensor (const at::Tensor &, ::std::optional, bool); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::argmin"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor"; + static at::Tensor call(const at::Tensor & self, ::std::optional dim, bool keepdim); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dim, bool keepdim); +}; + +struct TORCH_API argmin_out { + using schema = at::Tensor & (const at::Tensor &, ::std::optional, bool, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::argmin"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "argmin.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, ::std::optional dim, bool keepdim, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dim, bool keepdim, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/argsort.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/argsort.h new file mode 100644 index 0000000000000000000000000000000000000000..4a7453ab3e7c8317a13b11f34b8d1d19f82cc08b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/argsort.h @@ -0,0 +1,50 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::argsort(Tensor self, int dim=-1, bool descending=False) -> Tensor +inline at::Tensor argsort(const at::Tensor & self, int64_t dim=-1, bool descending=false) { + return at::_ops::argsort::call(self, dim, descending); +} + +// aten::argsort.stable(Tensor self, *, bool stable, int dim=-1, bool descending=False) -> Tensor +inline at::Tensor argsort(const at::Tensor & self, bool stable, int64_t dim=-1, bool descending=false) { + return at::_ops::argsort_stable::call(self, stable, dim, descending); +} + +// aten::argsort.stable_out(Tensor self, *, bool stable, int dim=-1, bool descending=False, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & argsort_out(at::Tensor & out, const at::Tensor & self, bool stable, int64_t dim=-1, bool descending=false) { + return at::_ops::argsort_stable_out::call(self, stable, dim, descending, out); +} +// aten::argsort.stable_out(Tensor self, *, bool stable, int dim=-1, bool descending=False, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & argsort_outf(const at::Tensor & self, bool stable, int64_t dim, bool descending, at::Tensor & out) { + return at::_ops::argsort_stable_out::call(self, stable, dim, descending, out); +} + +// aten::argsort.dimname(Tensor self, Dimname dim, bool descending=False) -> Tensor +inline at::Tensor argsort(const at::Tensor & self, at::Dimname dim, bool descending=false) { + return at::_ops::argsort_dimname::call(self, dim, descending); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/argsort_compositeimplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/argsort_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..1ade142554dcd175c0e0e42521be6deeeb2b1618 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/argsort_compositeimplicitautograd_dispatch.h @@ -0,0 +1,27 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API at::Tensor argsort(const at::Tensor & self, int64_t dim=-1, bool descending=false); +TORCH_API at::Tensor argsort(const at::Tensor & self, bool stable, int64_t dim=-1, bool descending=false); +TORCH_API at::Tensor & argsort_out(at::Tensor & out, const at::Tensor & self, bool stable, int64_t dim=-1, bool descending=false); +TORCH_API at::Tensor & argsort_outf(const at::Tensor & self, bool stable, int64_t dim, bool descending, at::Tensor & out); +TORCH_API at::Tensor argsort(const at::Tensor & self, at::Dimname dim, bool descending=false); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/argsort_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/argsort_native.h new file mode 100644 index 0000000000000000000000000000000000000000..5ec778aa21e02cb18aefdf1c73596534108ab593 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/argsort_native.h @@ -0,0 +1,24 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor argsort(const at::Tensor & self, int64_t dim=-1, bool descending=false); +TORCH_API at::Tensor argsort(const at::Tensor & self, bool stable, int64_t dim=-1, bool descending=false); +TORCH_API at::Tensor & argsort_out(const at::Tensor & self, bool stable, int64_t dim, bool descending, at::Tensor & out); +TORCH_API at::Tensor argsort(const at::Tensor & self, at::Dimname dim, bool descending=false); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/argsort_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/argsort_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..6d7b4253245f2da50190ffe33fcbe7bafa8529ff --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/argsort_ops.h @@ -0,0 +1,62 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API argsort { + using schema = at::Tensor (const at::Tensor &, int64_t, bool); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::argsort"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "argsort(Tensor self, int dim=-1, bool descending=False) -> Tensor"; + static at::Tensor call(const at::Tensor & self, int64_t dim, bool descending); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool descending); +}; + +struct TORCH_API argsort_stable { + using schema = at::Tensor (const at::Tensor &, bool, int64_t, bool); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::argsort"; + static constexpr const char* overload_name = "stable"; + static constexpr const char* schema_str = "argsort.stable(Tensor self, *, bool stable, int dim=-1, bool descending=False) -> Tensor"; + static at::Tensor call(const at::Tensor & self, bool stable, int64_t dim, bool descending); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool stable, int64_t dim, bool descending); +}; + +struct TORCH_API argsort_stable_out { + using schema = at::Tensor & (const at::Tensor &, bool, int64_t, bool, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::argsort"; + static constexpr const char* overload_name = "stable_out"; + static constexpr const char* schema_str = "argsort.stable_out(Tensor self, *, bool stable, int dim=-1, bool descending=False, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, bool stable, int64_t dim, bool descending, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool stable, int64_t dim, bool descending, at::Tensor & out); +}; + +struct TORCH_API argsort_dimname { + using schema = at::Tensor (const at::Tensor &, at::Dimname, bool); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::argsort"; + static constexpr const char* overload_name = "dimname"; + static constexpr const char* schema_str = "argsort.dimname(Tensor self, Dimname dim, bool descending=False) -> Tensor"; + static at::Tensor call(const at::Tensor & self, at::Dimname dim, bool descending); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool descending); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/argwhere.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/argwhere.h new file mode 100644 index 0000000000000000000000000000000000000000..5683ff7aa384646b105032380a081e57a0f95b86 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/argwhere.h @@ -0,0 +1,31 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::argwhere(Tensor self) -> Tensor +inline at::Tensor argwhere(const at::Tensor & self) { + return at::_ops::argwhere::call(self); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/argwhere_compositeimplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/argwhere_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..7d14760470f97fc6a8dad5997e053b90ac9b28c7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/argwhere_compositeimplicitautograd_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API at::Tensor argwhere(const at::Tensor & self); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/argwhere_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/argwhere_native.h new file mode 100644 index 0000000000000000000000000000000000000000..8b6f9d7febc9a08ef00e41d26fa01ff283c5de89 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/argwhere_native.h @@ -0,0 +1,21 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor argwhere(const at::Tensor & self); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/argwhere_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/argwhere_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..ed86aa6a405814f6d13a00e23431489e5489bf02 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/argwhere_ops.h @@ -0,0 +1,29 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API argwhere { + using schema = at::Tensor (const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::argwhere"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "argwhere(Tensor self) -> Tensor"; + static at::Tensor call(const at::Tensor & self); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided.h new file mode 100644 index 0000000000000000000000000000000000000000..0019fbf6d6e79ea47783b04be6aee184ec2ae197 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided.h @@ -0,0 +1,70 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a) +inline at::Tensor as_strided(const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided::call(self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt); +} +namespace symint { + template >> + at::Tensor as_strided(const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided::call(self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt); + } +} + +// aten::as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a) +inline at::Tensor as_strided_symint(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided::call(self, size, stride, storage_offset); +} +namespace symint { + template >> + at::Tensor as_strided(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided::call(self, size, stride, storage_offset); + } +} + +// aten::as_strided_(Tensor(a!) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a!) +inline const at::Tensor & as_strided_(const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_::call(self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt); +} +namespace symint { + template >> + const at::Tensor & as_strided_(const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_::call(self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt); + } +} + +// aten::as_strided_(Tensor(a!) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a!) +inline const at::Tensor & as_strided__symint(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_::call(self, size, stride, storage_offset); +} +namespace symint { + template >> + const at::Tensor & as_strided_(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_::call(self, size, stride, storage_offset); + } +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_compositeexplicitautogradnonfunctional_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..831bb05a7bc86bdc7920b70e437b592fbc12594d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API const at::Tensor & as_strided_(const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt); +TORCH_API const at::Tensor & as_strided__symint(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_copy.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_copy.h new file mode 100644 index 0000000000000000000000000000000000000000..c96d5658ae1b9e63a38039cddd2c1c84ef57faff --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_copy.h @@ -0,0 +1,92 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::as_strided_copy(Tensor self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor +inline at::Tensor as_strided_copy(const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_copy::call(self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt); +} +namespace symint { + template >> + at::Tensor as_strided_copy(const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_copy::call(self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt); + } +} + +// aten::as_strided_copy(Tensor self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor +inline at::Tensor as_strided_copy_symint(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_copy::call(self, size, stride, storage_offset); +} +namespace symint { + template >> + at::Tensor as_strided_copy(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_copy::call(self, size, stride, storage_offset); + } +} + +// aten::as_strided_copy.out(Tensor self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & as_strided_copy_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_copy_out::call(self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt, out); +} +namespace symint { + template >> + at::Tensor & as_strided_copy_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_copy_out::call(self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt, out); + } +} + +// aten::as_strided_copy.out(Tensor self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & as_strided_copy_outf(const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset, at::Tensor & out) { + return at::_ops::as_strided_copy_out::call(self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt, out); +} +namespace symint { + template >> + at::Tensor & as_strided_copy_outf(const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset, at::Tensor & out) { + return at::_ops::as_strided_copy_out::call(self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt, out); + } +} + +// aten::as_strided_copy.out(Tensor self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & as_strided_copy_symint_out(at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_copy_out::call(self, size, stride, storage_offset, out); +} +namespace symint { + template >> + at::Tensor & as_strided_copy_out(at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_copy_out::call(self, size, stride, storage_offset, out); + } +} + +// aten::as_strided_copy.out(Tensor self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & as_strided_copy_symint_outf(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset, at::Tensor & out) { + return at::_ops::as_strided_copy_out::call(self, size, stride, storage_offset, out); +} +namespace symint { + template >> + at::Tensor & as_strided_copy_outf(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset, at::Tensor & out) { + return at::_ops::as_strided_copy_out::call(self, size, stride, storage_offset, out); + } +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_copy_compositeexplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_copy_compositeexplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..6934717cc155ac1b62c0e8bc361e3ee5fa4f831a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_copy_compositeexplicitautograd_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautograd { + +TORCH_API at::Tensor & as_strided_copy_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt); +TORCH_API at::Tensor & as_strided_copy_outf(const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset, at::Tensor & out); +TORCH_API at::Tensor & as_strided_copy_symint_out(at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt); +TORCH_API at::Tensor & as_strided_copy_symint_outf(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset, at::Tensor & out); + +} // namespace compositeexplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_copy_compositeexplicitautogradnonfunctional_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_copy_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..7fca551de2e4859dd7d230d8178398b9d3e94fe2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_copy_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API at::Tensor as_strided_copy(const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt); +TORCH_API at::Tensor as_strided_copy_symint(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_copy_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_copy_native.h new file mode 100644 index 0000000000000000000000000000000000000000..e3be3508e7e5b5e47fc45ca41369e69cfd0a0830 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_copy_native.h @@ -0,0 +1,22 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor & as_strided_copy_out_symint(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset, at::Tensor & out); +TORCH_API at::Tensor as_strided_copy_symint(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_copy_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_copy_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..2cecc6d419e50e2eb4a8ec4d05ea1b6864268fb3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_copy_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API as_strided_copy { + using schema = at::Tensor (const at::Tensor &, c10::SymIntArrayRef, c10::SymIntArrayRef, ::std::optional); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::as_strided_copy"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "as_strided_copy(Tensor self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor"; + static at::Tensor call(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset); +}; + +struct TORCH_API as_strided_copy_out { + using schema = at::Tensor & (const at::Tensor &, c10::SymIntArrayRef, c10::SymIntArrayRef, ::std::optional, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::as_strided_copy"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "as_strided_copy.out(Tensor self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..faba29da865e0a4c196a4f3ecd5bd627162c47d6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_cpu_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor as_strided(const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt); +TORCH_API at::Tensor as_strided_symint(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..569e94db0e9737747e44bb4f6ef3208bda3fc6be --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_cuda_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor as_strided(const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt); +TORCH_API at::Tensor as_strided_symint(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_meta_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..ae67918120bbb93091e5a3dfdb0db3395abfb2dd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_meta_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor as_strided(const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt); +TORCH_API at::Tensor as_strided_symint(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt); + +} // namespace meta +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_native.h new file mode 100644 index 0000000000000000000000000000000000000000..5415fee6a143afd4a19abdbea319c38d5d753d5c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_native.h @@ -0,0 +1,24 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor as_strided_tensorimpl(const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt); +TORCH_API at::Tensor as_strided_tensorimpl_meta_symint(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt); +TORCH_API at::Tensor as_strided_qtensorimpl(const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt); +TORCH_API const at::Tensor & as_strided__symint(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..a9d8055e33846df05ebccb4c10f8ffb12e92da3e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API as_strided { + using schema = at::Tensor (const at::Tensor &, c10::SymIntArrayRef, c10::SymIntArrayRef, ::std::optional); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::as_strided"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a)"; + static at::Tensor call(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset); +}; + +struct TORCH_API as_strided_ { + using schema = const at::Tensor & (const at::Tensor &, c10::SymIntArrayRef, c10::SymIntArrayRef, ::std::optional); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::as_strided_"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "as_strided_(Tensor(a!) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a!)"; + static const at::Tensor & call(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset); + static const at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_scatter.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_scatter.h new file mode 100644 index 0000000000000000000000000000000000000000..b6618697b065ece0ad793ae8e522a708f3395a52 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_scatter.h @@ -0,0 +1,92 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor +inline at::Tensor as_strided_scatter(const at::Tensor & self, const at::Tensor & src, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_scatter::call(self, src, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt); +} +namespace symint { + template >> + at::Tensor as_strided_scatter(const at::Tensor & self, const at::Tensor & src, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_scatter::call(self, src, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt); + } +} + +// aten::as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor +inline at::Tensor as_strided_scatter_symint(const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_scatter::call(self, src, size, stride, storage_offset); +} +namespace symint { + template >> + at::Tensor as_strided_scatter(const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_scatter::call(self, src, size, stride, storage_offset); + } +} + +// aten::as_strided_scatter.out(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & as_strided_scatter_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & src, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_scatter_out::call(self, src, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt, out); +} +namespace symint { + template >> + at::Tensor & as_strided_scatter_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & src, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_scatter_out::call(self, src, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt, out); + } +} + +// aten::as_strided_scatter.out(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & as_strided_scatter_outf(const at::Tensor & self, const at::Tensor & src, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset, at::Tensor & out) { + return at::_ops::as_strided_scatter_out::call(self, src, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt, out); +} +namespace symint { + template >> + at::Tensor & as_strided_scatter_outf(const at::Tensor & self, const at::Tensor & src, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset, at::Tensor & out) { + return at::_ops::as_strided_scatter_out::call(self, src, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt, out); + } +} + +// aten::as_strided_scatter.out(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & as_strided_scatter_symint_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_scatter_out::call(self, src, size, stride, storage_offset, out); +} +namespace symint { + template >> + at::Tensor & as_strided_scatter_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_scatter_out::call(self, src, size, stride, storage_offset, out); + } +} + +// aten::as_strided_scatter.out(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & as_strided_scatter_symint_outf(const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset, at::Tensor & out) { + return at::_ops::as_strided_scatter_out::call(self, src, size, stride, storage_offset, out); +} +namespace symint { + template >> + at::Tensor & as_strided_scatter_outf(const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset, at::Tensor & out) { + return at::_ops::as_strided_scatter_out::call(self, src, size, stride, storage_offset, out); + } +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_scatter_compositeexplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_scatter_compositeexplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..4c78a24565f17a6606aae0af6d4903629dd5b284 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_scatter_compositeexplicitautograd_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautograd { + +TORCH_API at::Tensor & as_strided_scatter_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & src, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt); +TORCH_API at::Tensor & as_strided_scatter_outf(const at::Tensor & self, const at::Tensor & src, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset, at::Tensor & out); +TORCH_API at::Tensor & as_strided_scatter_symint_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt); +TORCH_API at::Tensor & as_strided_scatter_symint_outf(const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset, at::Tensor & out); + +} // namespace compositeexplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_scatter_compositeexplicitautogradnonfunctional_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_scatter_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..547fbbdf356193c3148f515d2a16c021ef0fb597 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_scatter_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API at::Tensor as_strided_scatter(const at::Tensor & self, const at::Tensor & src, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt); +TORCH_API at::Tensor as_strided_scatter_symint(const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_scatter_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_scatter_native.h new file mode 100644 index 0000000000000000000000000000000000000000..9a96bd016e2408d6988b0f066029280b3bf0c872 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_scatter_native.h @@ -0,0 +1,22 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor & as_strided_scatter_out_symint(const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset, at::Tensor & out); +TORCH_API at::Tensor as_strided_scatter_symint(const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_scatter_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_scatter_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..dd4e8022dbe6275008748df21ff1a0b12b4ddc36 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/as_strided_scatter_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API as_strided_scatter { + using schema = at::Tensor (const at::Tensor &, const at::Tensor &, c10::SymIntArrayRef, c10::SymIntArrayRef, ::std::optional); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::as_strided_scatter"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor"; + static at::Tensor call(const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset); +}; + +struct TORCH_API as_strided_scatter_out { + using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, c10::SymIntArrayRef, c10::SymIntArrayRef, ::std::optional, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::as_strided_scatter"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "as_strided_scatter.out(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/asin.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/asin.h new file mode 100644 index 0000000000000000000000000000000000000000..48afa3e11d73dee45fb5101fd63191dbeccd3e43 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/asin.h @@ -0,0 +1,45 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::asin(Tensor self) -> Tensor +inline at::Tensor asin(const at::Tensor & self) { + return at::_ops::asin::call(self); +} + +// aten::asin_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & asin_(at::Tensor & self) { + return at::_ops::asin_::call(self); +} + +// aten::asin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & asin_out(at::Tensor & out, const at::Tensor & self) { + return at::_ops::asin_out::call(self, out); +} +// aten::asin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & asin_outf(const at::Tensor & self, at::Tensor & out) { + return at::_ops::asin_out::call(self, out); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/asin_compositeexplicitautogradnonfunctional_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/asin_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..96b4a954f464512874cd96c0e6cb23607843b5f4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/asin_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API at::Tensor asin(const at::Tensor & self); +TORCH_API at::Tensor & asin_(at::Tensor & self); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/asin_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/asin_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..31b0feb89154a5edbfffd040f35f7b2c03bb29f6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/asin_cpu_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor asin(const at::Tensor & self); +TORCH_API at::Tensor & asin_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & asin_outf(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & asin_(at::Tensor & self); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/asin_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/asin_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..efa183cdf24aaf1116ee83056871ddf1978af4e9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/asin_cuda_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor asin(const at::Tensor & self); +TORCH_API at::Tensor & asin_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & asin_outf(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & asin_(at::Tensor & self); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/asin_meta.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/asin_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..c810d3bde66c373926c42b9c5f92842d14110e4b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/asin_meta.h @@ -0,0 +1,27 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +struct TORCH_API structured_asin : public TensorIteratorBase { + + + void meta(const at::Tensor & self); +}; + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/asin_meta_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/asin_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..0fc6c4bdd86ce4a0c6de3fb0ffb6a0b906ffb12d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/asin_meta_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor asin(const at::Tensor & self); +TORCH_API at::Tensor & asin_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & asin_outf(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & asin_(at::Tensor & self); + +} // namespace meta +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/asin_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/asin_native.h new file mode 100644 index 0000000000000000000000000000000000000000..2c6b2dabca2a1ac18faf8830acc25bc716124ad4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/asin_native.h @@ -0,0 +1,29 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +struct TORCH_API structured_asin_out : public at::meta::structured_asin { +void impl(const at::Tensor & self, const at::Tensor & out); +}; +TORCH_API at::Tensor asin_sparse(const at::Tensor & self); +TORCH_API at::Tensor & asin_sparse_out(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & asin_sparse_(at::Tensor & self); +TORCH_API at::Tensor asin_sparse_csr(const at::Tensor & self); +TORCH_API at::Tensor & asin_sparse_csr_out(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & asin_sparse_csr_(at::Tensor & self); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/asin_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/asin_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..193c775c3612685634ac7d765e3a313c92bd0ee4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/asin_ops.h @@ -0,0 +1,51 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API asin { + using schema = at::Tensor (const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::asin"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "asin(Tensor self) -> Tensor"; + static at::Tensor call(const at::Tensor & self); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self); +}; + +struct TORCH_API asin_ { + using schema = at::Tensor & (at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::asin_"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "asin_(Tensor(a!) self) -> Tensor(a!)"; + static at::Tensor & call(at::Tensor & self); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self); +}; + +struct TORCH_API asin_out { + using schema = at::Tensor & (const at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::asin"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "asin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/asinh.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/asinh.h new file mode 100644 index 0000000000000000000000000000000000000000..2827a34d0901fbc36e7b949d26a3b010bf16ba75 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/asinh.h @@ -0,0 +1,45 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::asinh(Tensor self) -> Tensor +inline at::Tensor asinh(const at::Tensor & self) { + return at::_ops::asinh::call(self); +} + +// aten::asinh_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & asinh_(at::Tensor & self) { + return at::_ops::asinh_::call(self); +} + +// aten::asinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & asinh_out(at::Tensor & out, const at::Tensor & self) { + return at::_ops::asinh_out::call(self, out); +} +// aten::asinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & asinh_outf(const at::Tensor & self, at::Tensor & out) { + return at::_ops::asinh_out::call(self, out); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/asinh_compositeexplicitautogradnonfunctional_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/asinh_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..971b6ff219d558f2398aed7c5e56644269d2f6ca --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/asinh_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API at::Tensor asinh(const at::Tensor & self); +TORCH_API at::Tensor & asinh_(at::Tensor & self); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/asinh_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/asinh_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..ecebabf63dbccca44c0b71a66f91600306da78d2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/asinh_cpu_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor asinh(const at::Tensor & self); +TORCH_API at::Tensor & asinh_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & asinh_outf(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & asinh_(at::Tensor & self); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/asinh_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/asinh_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..ce3169f5fe743cb2a7a077232a84411be3f17977 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/asinh_cuda_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor asinh(const at::Tensor & self); +TORCH_API at::Tensor & asinh_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & asinh_outf(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & asinh_(at::Tensor & self); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/asinh_meta.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/asinh_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..17434e8c414c12e1627af23c1d930e8c49d506c8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/asinh_meta.h @@ -0,0 +1,27 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +struct TORCH_API structured_asinh : public TensorIteratorBase { + + + void meta(const at::Tensor & self); +}; + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/asinh_meta_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/asinh_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..7f075051091e25a479b62414a550890cf0a46a32 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/asinh_meta_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor asinh(const at::Tensor & self); +TORCH_API at::Tensor & asinh_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & asinh_outf(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & asinh_(at::Tensor & self); + +} // namespace meta +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/asinh_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/asinh_native.h new file mode 100644 index 0000000000000000000000000000000000000000..67b72f6ee440a6fd75d2528096b2c4036b7bf66c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/asinh_native.h @@ -0,0 +1,29 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +struct TORCH_API structured_asinh_out : public at::meta::structured_asinh { +void impl(const at::Tensor & self, const at::Tensor & out); +}; +TORCH_API at::Tensor asinh_sparse(const at::Tensor & self); +TORCH_API at::Tensor & asinh_sparse_out(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & asinh_sparse_(at::Tensor & self); +TORCH_API at::Tensor asinh_sparse_csr(const at::Tensor & self); +TORCH_API at::Tensor & asinh_sparse_csr_out(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & asinh_sparse_csr_(at::Tensor & self); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/asinh_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/asinh_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..cdf4ab37b40cd4f7e78e6b352537aca88197b1d1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/asinh_ops.h @@ -0,0 +1,51 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API asinh { + using schema = at::Tensor (const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::asinh"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "asinh(Tensor self) -> Tensor"; + static at::Tensor call(const at::Tensor & self); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self); +}; + +struct TORCH_API asinh_ { + using schema = at::Tensor & (at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::asinh_"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "asinh_(Tensor(a!) self) -> Tensor(a!)"; + static at::Tensor & call(at::Tensor & self); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self); +}; + +struct TORCH_API asinh_out { + using schema = at::Tensor & (const at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::asinh"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "asinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/atan.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/atan.h new file mode 100644 index 0000000000000000000000000000000000000000..8705ee1ee441176eb39d36b2520701ede1dc86ce --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/atan.h @@ -0,0 +1,45 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::atan(Tensor self) -> Tensor +inline at::Tensor atan(const at::Tensor & self) { + return at::_ops::atan::call(self); +} + +// aten::atan_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & atan_(at::Tensor & self) { + return at::_ops::atan_::call(self); +} + +// aten::atan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & atan_out(at::Tensor & out, const at::Tensor & self) { + return at::_ops::atan_out::call(self, out); +} +// aten::atan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & atan_outf(const at::Tensor & self, at::Tensor & out) { + return at::_ops::atan_out::call(self, out); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/atan2.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/atan2.h new file mode 100644 index 0000000000000000000000000000000000000000..07651f9e9a8165e9ff01187d3a227a755f0dd9ea --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/atan2.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & atan2_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::atan2_out::call(self, other, out); +} +// aten::atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & atan2_outf(const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::atan2_out::call(self, other, out); +} + +// aten::atan2(Tensor self, Tensor other) -> Tensor +inline at::Tensor atan2(const at::Tensor & self, const at::Tensor & other) { + return at::_ops::atan2::call(self, other); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/atan2_compositeexplicitautogradnonfunctional_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/atan2_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..9a93327d661107f53b97fea76dfeb00b6ccb429e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/atan2_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API at::Tensor atan2(const at::Tensor & self, const at::Tensor & other); +TORCH_API at::Tensor & atan2_(at::Tensor & self, const at::Tensor & other); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/atan2_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/atan2_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..e94374e75b13255efcff24337705b51b85d5ba27 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/atan2_cpu_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor atan2(const at::Tensor & self, const at::Tensor & other); +TORCH_API at::Tensor & atan2_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & other); +TORCH_API at::Tensor & atan2_outf(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); +TORCH_API at::Tensor & atan2_(at::Tensor & self, const at::Tensor & other); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/atan2_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/atan2_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..0b63d8101ced79c86f2d0bce03c8a484dd3a1d2e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/atan2_cuda_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor atan2(const at::Tensor & self, const at::Tensor & other); +TORCH_API at::Tensor & atan2_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & other); +TORCH_API at::Tensor & atan2_outf(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); +TORCH_API at::Tensor & atan2_(at::Tensor & self, const at::Tensor & other); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/atan2_meta.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/atan2_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..c536b8579692a5b4d631a4c2008484df3bbfb33b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/atan2_meta.h @@ -0,0 +1,27 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +struct TORCH_API structured_atan2 : public TensorIteratorBase { + + + void meta(const at::Tensor & self, const at::Tensor & other); +}; + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/atan2_meta_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/atan2_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..31f2cd5d81fcde8d4241d5db21753c86b750d1f9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/atan2_meta_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor atan2(const at::Tensor & self, const at::Tensor & other); +TORCH_API at::Tensor & atan2_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & other); +TORCH_API at::Tensor & atan2_outf(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); +TORCH_API at::Tensor & atan2_(at::Tensor & self, const at::Tensor & other); + +} // namespace meta +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/atan2_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/atan2_native.h new file mode 100644 index 0000000000000000000000000000000000000000..9bbe7dc5bf5433f57af0f45588d6a9a4a4696909 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/atan2_native.h @@ -0,0 +1,23 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +struct TORCH_API structured_atan2_out : public at::meta::structured_atan2 { +void impl(const at::Tensor & self, const at::Tensor & other, const at::Tensor & out); +}; +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/atan2_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/atan2_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..33636cd278c5d2e3b41e76282e13ec38449d3aad --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/atan2_ops.h @@ -0,0 +1,51 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API atan2_out { + using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::atan2"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out); +}; + +struct TORCH_API atan2_ { + using schema = at::Tensor & (at::Tensor &, const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::atan2_"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "atan2_(Tensor(a!) self, Tensor other) -> Tensor(a!)"; + static at::Tensor & call(at::Tensor & self, const at::Tensor & other); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other); +}; + +struct TORCH_API atan2 { + using schema = at::Tensor (const at::Tensor &, const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::atan2"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "atan2(Tensor self, Tensor other) -> Tensor"; + static at::Tensor call(const at::Tensor & self, const at::Tensor & other); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/atan_compositeexplicitautogradnonfunctional_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/atan_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..e838afd41bdba19cb412f54f1f06db1c31f76b4b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/atan_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API at::Tensor atan(const at::Tensor & self); +TORCH_API at::Tensor & atan_(at::Tensor & self); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/atan_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/atan_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..94917dc8359d41fe41475c735f366c3fc4568650 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/atan_cpu_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor atan(const at::Tensor & self); +TORCH_API at::Tensor & atan_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & atan_outf(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & atan_(at::Tensor & self); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/atan_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/atan_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..fc00f57aca8733065eb7147768bdb8105838b53b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/atan_cuda_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor atan(const at::Tensor & self); +TORCH_API at::Tensor & atan_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & atan_outf(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & atan_(at::Tensor & self); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/atan_meta.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/atan_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..c5328fa0e18820d0b8bc91c1d31dee86ffe5dcee --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/atan_meta.h @@ -0,0 +1,27 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +struct TORCH_API structured_atan : public TensorIteratorBase { + + + void meta(const at::Tensor & self); +}; + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/atan_meta_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/atan_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..5d36f3abb5a7b7dfe3d23184d89fdcd51165bd26 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/atan_meta_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor atan(const at::Tensor & self); +TORCH_API at::Tensor & atan_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & atan_outf(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & atan_(at::Tensor & self); + +} // namespace meta +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/atan_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/atan_native.h new file mode 100644 index 0000000000000000000000000000000000000000..a1050f6247f0a8a1d419285fa50b455f0306e379 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/atan_native.h @@ -0,0 +1,29 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +struct TORCH_API structured_atan_out : public at::meta::structured_atan { +void impl(const at::Tensor & self, const at::Tensor & out); +}; +TORCH_API at::Tensor atan_sparse(const at::Tensor & self); +TORCH_API at::Tensor & atan_sparse_out(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & atan_sparse_(at::Tensor & self); +TORCH_API at::Tensor atan_sparse_csr(const at::Tensor & self); +TORCH_API at::Tensor & atan_sparse_csr_out(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & atan_sparse_csr_(at::Tensor & self); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/atan_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/atan_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..8bcebde76b6152c5615498d2fa8ab0dc586756ef --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/atan_ops.h @@ -0,0 +1,51 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API atan { + using schema = at::Tensor (const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::atan"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "atan(Tensor self) -> Tensor"; + static at::Tensor call(const at::Tensor & self); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self); +}; + +struct TORCH_API atan_ { + using schema = at::Tensor & (at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::atan_"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "atan_(Tensor(a!) self) -> Tensor(a!)"; + static at::Tensor & call(at::Tensor & self); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self); +}; + +struct TORCH_API atan_out { + using schema = at::Tensor & (const at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::atan"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "atan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/atanh.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/atanh.h new file mode 100644 index 0000000000000000000000000000000000000000..b53dc5556e59d6073b8b8020915b73b8eb9759fb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/atanh.h @@ -0,0 +1,45 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::atanh(Tensor self) -> Tensor +inline at::Tensor atanh(const at::Tensor & self) { + return at::_ops::atanh::call(self); +} + +// aten::atanh_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & atanh_(at::Tensor & self) { + return at::_ops::atanh_::call(self); +} + +// aten::atanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & atanh_out(at::Tensor & out, const at::Tensor & self) { + return at::_ops::atanh_out::call(self, out); +} +// aten::atanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & atanh_outf(const at::Tensor & self, at::Tensor & out) { + return at::_ops::atanh_out::call(self, out); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/atanh_compositeexplicitautogradnonfunctional_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/atanh_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..6b0949ef9c1f9f3ab034acccb238b43a06181fae --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/atanh_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API at::Tensor atanh(const at::Tensor & self); +TORCH_API at::Tensor & atanh_(at::Tensor & self); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/atanh_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/atanh_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..4150822fac074cba8610284f68468f157561807d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/atanh_cpu_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor atanh(const at::Tensor & self); +TORCH_API at::Tensor & atanh_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & atanh_outf(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & atanh_(at::Tensor & self); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/atanh_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/atanh_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..464bea2f974fe80590040eff0bb6e68f14f84347 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/atanh_cuda_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor atanh(const at::Tensor & self); +TORCH_API at::Tensor & atanh_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & atanh_outf(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & atanh_(at::Tensor & self); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/atanh_meta.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/atanh_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..1170ace6668849903b7435440f6eb79fbae7ebd2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/atanh_meta.h @@ -0,0 +1,27 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +struct TORCH_API structured_atanh : public TensorIteratorBase { + + + void meta(const at::Tensor & self); +}; + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/atanh_meta_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/atanh_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..d5e8994d40b6a7f33679c382b379f5afadb9c147 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/atanh_meta_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor atanh(const at::Tensor & self); +TORCH_API at::Tensor & atanh_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & atanh_outf(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & atanh_(at::Tensor & self); + +} // namespace meta +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/atanh_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/atanh_native.h new file mode 100644 index 0000000000000000000000000000000000000000..8456c4879d36b4984f3bbc73f29b84ef4ff3e0bb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/atanh_native.h @@ -0,0 +1,29 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +struct TORCH_API structured_atanh_out : public at::meta::structured_atanh { +void impl(const at::Tensor & self, const at::Tensor & out); +}; +TORCH_API at::Tensor atanh_sparse(const at::Tensor & self); +TORCH_API at::Tensor & atanh_sparse_out(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & atanh_sparse_(at::Tensor & self); +TORCH_API at::Tensor atanh_sparse_csr(const at::Tensor & self); +TORCH_API at::Tensor & atanh_sparse_csr_out(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & atanh_sparse_csr_(at::Tensor & self); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/atanh_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/atanh_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..de4c76380b1e37b367e93df4147cc8f6cb1aa42d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/atanh_ops.h @@ -0,0 +1,51 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API atanh { + using schema = at::Tensor (const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::atanh"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "atanh(Tensor self) -> Tensor"; + static at::Tensor call(const at::Tensor & self); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self); +}; + +struct TORCH_API atanh_ { + using schema = at::Tensor & (at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::atanh_"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "atanh_(Tensor(a!) self) -> Tensor(a!)"; + static at::Tensor & call(at::Tensor & self); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self); +}; + +struct TORCH_API atanh_out { + using schema = at::Tensor & (const at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::atanh"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "atanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/atleast_1d.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/atleast_1d.h new file mode 100644 index 0000000000000000000000000000000000000000..9300cf0361cac47e1569b32a78f40e3f89176b9d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/atleast_1d.h @@ -0,0 +1,36 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::atleast_1d(Tensor self) -> Tensor +inline at::Tensor atleast_1d(const at::Tensor & self) { + return at::_ops::atleast_1d::call(self); +} + +// aten::atleast_1d.Sequence(Tensor[] tensors) -> Tensor[] +inline ::std::vector atleast_1d(at::TensorList tensors) { + return at::_ops::atleast_1d_Sequence::call(tensors); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/atleast_1d_compositeimplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/atleast_1d_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..300bc7ca3e78a3227087aa2a8aff679317293939 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/atleast_1d_compositeimplicitautograd_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API at::Tensor atleast_1d(const at::Tensor & self); +TORCH_API ::std::vector atleast_1d(at::TensorList tensors); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/atleast_1d_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/atleast_1d_native.h new file mode 100644 index 0000000000000000000000000000000000000000..ff85466533257171c8eb1a5160619918fc269c77 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/atleast_1d_native.h @@ -0,0 +1,22 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor atleast_1d(const at::Tensor & self); +TORCH_API ::std::vector atleast_1d(at::TensorList tensors); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/atleast_1d_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/atleast_1d_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..e35ecce0de2585468ebd659b729357eb58c44c99 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/atleast_1d_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API atleast_1d { + using schema = at::Tensor (const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::atleast_1d"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "atleast_1d(Tensor self) -> Tensor"; + static at::Tensor call(const at::Tensor & self); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self); +}; + +struct TORCH_API atleast_1d_Sequence { + using schema = ::std::vector (at::TensorList); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::atleast_1d"; + static constexpr const char* overload_name = "Sequence"; + static constexpr const char* schema_str = "atleast_1d.Sequence(Tensor[] tensors) -> Tensor[]"; + static ::std::vector call(at::TensorList tensors); + static ::std::vector redispatch(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/atleast_2d.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/atleast_2d.h new file mode 100644 index 0000000000000000000000000000000000000000..733f97469b924e159538caa8b30a04fca491bb7d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/atleast_2d.h @@ -0,0 +1,36 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::atleast_2d(Tensor self) -> Tensor +inline at::Tensor atleast_2d(const at::Tensor & self) { + return at::_ops::atleast_2d::call(self); +} + +// aten::atleast_2d.Sequence(Tensor[] tensors) -> Tensor[] +inline ::std::vector atleast_2d(at::TensorList tensors) { + return at::_ops::atleast_2d_Sequence::call(tensors); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/atleast_2d_compositeimplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/atleast_2d_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..9a4cf68fe44a7d96f368af6c15d18203e06c58fc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/atleast_2d_compositeimplicitautograd_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API at::Tensor atleast_2d(const at::Tensor & self); +TORCH_API ::std::vector atleast_2d(at::TensorList tensors); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/atleast_2d_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/atleast_2d_native.h new file mode 100644 index 0000000000000000000000000000000000000000..961332966aae57b1040cbc831c3b0a4360f7054f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/atleast_2d_native.h @@ -0,0 +1,22 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor atleast_2d(const at::Tensor & self); +TORCH_API ::std::vector atleast_2d(at::TensorList tensors); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/atleast_2d_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/atleast_2d_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..ed809d8fa338f85f8d4c0d9f723aba15eae592ec --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/atleast_2d_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API atleast_2d { + using schema = at::Tensor (const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::atleast_2d"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "atleast_2d(Tensor self) -> Tensor"; + static at::Tensor call(const at::Tensor & self); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self); +}; + +struct TORCH_API atleast_2d_Sequence { + using schema = ::std::vector (at::TensorList); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::atleast_2d"; + static constexpr const char* overload_name = "Sequence"; + static constexpr const char* schema_str = "atleast_2d.Sequence(Tensor[] tensors) -> Tensor[]"; + static ::std::vector call(at::TensorList tensors); + static ::std::vector redispatch(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/atleast_3d.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/atleast_3d.h new file mode 100644 index 0000000000000000000000000000000000000000..ce41f0329b92dd9495041071c1c3bccb99832aff --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/atleast_3d.h @@ -0,0 +1,36 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::atleast_3d(Tensor self) -> Tensor +inline at::Tensor atleast_3d(const at::Tensor & self) { + return at::_ops::atleast_3d::call(self); +} + +// aten::atleast_3d.Sequence(Tensor[] tensors) -> Tensor[] +inline ::std::vector atleast_3d(at::TensorList tensors) { + return at::_ops::atleast_3d_Sequence::call(tensors); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/atleast_3d_compositeimplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/atleast_3d_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..db87c8f02a6144e4235040debe8101ede5343518 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/atleast_3d_compositeimplicitautograd_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API at::Tensor atleast_3d(const at::Tensor & self); +TORCH_API ::std::vector atleast_3d(at::TensorList tensors); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/atleast_3d_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/atleast_3d_native.h new file mode 100644 index 0000000000000000000000000000000000000000..b17c6d3fb28b7869c44f5a18836a26e6f714f1c3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/atleast_3d_native.h @@ -0,0 +1,22 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor atleast_3d(const at::Tensor & self); +TORCH_API ::std::vector atleast_3d(at::TensorList tensors); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/atleast_3d_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/atleast_3d_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..bf735549c1b8ec29eafd8007c29e4775600e027d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/atleast_3d_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API atleast_3d { + using schema = at::Tensor (const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::atleast_3d"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "atleast_3d(Tensor self) -> Tensor"; + static at::Tensor call(const at::Tensor & self); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self); +}; + +struct TORCH_API atleast_3d_Sequence { + using schema = ::std::vector (at::TensorList); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::atleast_3d"; + static constexpr const char* overload_name = "Sequence"; + static constexpr const char* schema_str = "atleast_3d.Sequence(Tensor[] tensors) -> Tensor[]"; + static ::std::vector call(at::TensorList tensors); + static ::std::vector redispatch(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool1d.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool1d.h new file mode 100644 index 0000000000000000000000000000000000000000..89fd312b65e7dce250f70745272e8d58b68bd5e7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool1d.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::avg_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, bool ceil_mode=False, bool count_include_pad=True) -> Tensor +inline at::Tensor avg_pool1d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true) { + return at::_ops::avg_pool1d::call(self, kernel_size, stride, padding, ceil_mode, count_include_pad); +} + +// aten::avg_pool1d.out(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, bool ceil_mode=False, bool count_include_pad=True, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & avg_pool1d_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true) { + return at::_ops::avg_pool1d_out::call(self, kernel_size, stride, padding, ceil_mode, count_include_pad, out); +} +// aten::avg_pool1d.out(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, bool ceil_mode=False, bool count_include_pad=True, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & avg_pool1d_outf(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, at::Tensor & out) { + return at::_ops::avg_pool1d_out::call(self, kernel_size, stride, padding, ceil_mode, count_include_pad, out); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool1d_compositeexplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool1d_compositeexplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..8480aac0f0d4f6c4903c1b5557874ad7809b0ea5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool1d_compositeexplicitautograd_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautograd { + +TORCH_API at::Tensor & avg_pool1d_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true); +TORCH_API at::Tensor & avg_pool1d_outf(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, at::Tensor & out); + +} // namespace compositeexplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool1d_compositeimplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool1d_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..3e59782279a6cae0aba7c39e19b1b848c73b3a3b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool1d_compositeimplicitautograd_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API at::Tensor avg_pool1d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool1d_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool1d_native.h new file mode 100644 index 0000000000000000000000000000000000000000..336dcc9a4c35e38203a8d8a5c053f50da908bc3d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool1d_native.h @@ -0,0 +1,22 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor avg_pool1d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true); +TORCH_API at::Tensor & avg_pool1d_out(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, at::Tensor & out); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool1d_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool1d_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..d5bb9fd4b0615f83a775eb112662c6c8770b16f4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool1d_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API avg_pool1d { + using schema = at::Tensor (const at::Tensor &, at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, bool, bool); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::avg_pool1d"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "avg_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, bool ceil_mode=False, bool count_include_pad=True) -> Tensor"; + static at::Tensor call(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad); +}; + +struct TORCH_API avg_pool1d_out { + using schema = at::Tensor & (const at::Tensor &, at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, bool, bool, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::avg_pool1d"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "avg_pool1d.out(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, bool ceil_mode=False, bool count_include_pad=True, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d.h new file mode 100644 index 0000000000000000000000000000000000000000..45c3ca03c2d9bf91ddc91089bd424b3986d0f4c7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::avg_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & avg_pool2d_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true, ::std::optional divisor_override=::std::nullopt) { + return at::_ops::avg_pool2d_out::call(self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, out); +} +// aten::avg_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & avg_pool2d_outf(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & out) { + return at::_ops::avg_pool2d_out::call(self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, out); +} + +// aten::avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor +inline at::Tensor avg_pool2d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true, ::std::optional divisor_override=::std::nullopt) { + return at::_ops::avg_pool2d::call(self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_backward.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_backward.h new file mode 100644 index 0000000000000000000000000000000000000000..edf059e9ded32400f384c02b214823e85ced7c14 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_backward.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::avg_pool2d_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override, *, Tensor(a!) grad_input) -> Tensor(a!) +inline at::Tensor & avg_pool2d_backward_out(at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override) { + return at::_ops::avg_pool2d_backward_grad_input::call(grad_output, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, grad_input); +} +// aten::avg_pool2d_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override, *, Tensor(a!) grad_input) -> Tensor(a!) +inline at::Tensor & avg_pool2d_backward_outf(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & grad_input) { + return at::_ops::avg_pool2d_backward_grad_input::call(grad_output, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, grad_input); +} + +// aten::avg_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor +inline at::Tensor avg_pool2d_backward(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override) { + return at::_ops::avg_pool2d_backward::call(grad_output, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_backward_compositeexplicitautogradnonfunctional_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_backward_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..ffbe8359b4116d9b5a47509bd23bb67ebb3bf637 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_backward_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API at::Tensor avg_pool2d_backward(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_backward_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_backward_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..3f7ed1f32a773957b2bfa94610d5d714ab56f160 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_backward_cpu_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor avg_pool2d_backward(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override); +TORCH_API at::Tensor & avg_pool2d_backward_out(at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override); +TORCH_API at::Tensor & avg_pool2d_backward_outf(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & grad_input); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_backward_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_backward_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..64c32acfeceebf0e87d7de8989b3c46ef99749d6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_backward_cuda_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor avg_pool2d_backward(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override); +TORCH_API at::Tensor & avg_pool2d_backward_out(at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override); +TORCH_API at::Tensor & avg_pool2d_backward_outf(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & grad_input); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_backward_meta.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_backward_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..8a88d49465ff8f076a6a656962450b128eb78214 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_backward_meta.h @@ -0,0 +1,27 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +struct TORCH_API structured_avg_pool2d_backward : public at::impl::MetaBase { + + + void meta(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override); +}; + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_backward_meta_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_backward_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..23d66ee62e866388a78e69f1b88929c497297354 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_backward_meta_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor avg_pool2d_backward(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override); +TORCH_API at::Tensor & avg_pool2d_backward_out(at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override); +TORCH_API at::Tensor & avg_pool2d_backward_outf(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & grad_input); + +} // namespace meta +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_backward_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_backward_native.h new file mode 100644 index 0000000000000000000000000000000000000000..9f7618b2d5aa2ce30e2005372b9157c7da23540b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_backward_native.h @@ -0,0 +1,28 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +struct TORCH_API structured_avg_pool2d_backward_out_cpu : public at::meta::structured_avg_pool2d_backward { +void impl(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, const at::Tensor & grad_input); +}; +struct TORCH_API structured_avg_pool2d_backward_out_cuda : public at::meta::structured_avg_pool2d_backward { +void impl(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, const at::Tensor & grad_input); +}; +TORCH_API at::Tensor mkldnn_avg_pool2d_backward(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override); +TORCH_API at::Tensor & mkldnn_avg_pool2d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & grad_input); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_backward_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_backward_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..14f0245bbaea20033094f6196a6af3c885849bae --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_backward_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API avg_pool2d_backward_grad_input { + using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, bool, bool, ::std::optional, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::avg_pool2d_backward"; + static constexpr const char* overload_name = "grad_input"; + static constexpr const char* schema_str = "avg_pool2d_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override, *, Tensor(a!) grad_input) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & grad_input); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & grad_input); +}; + +struct TORCH_API avg_pool2d_backward { + using schema = at::Tensor (const at::Tensor &, const at::Tensor &, at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, bool, bool, ::std::optional); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::avg_pool2d_backward"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "avg_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor"; + static at::Tensor call(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_compositeexplicitautogradnonfunctional_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..86eab46c0383b8eaa52c31a8b06762b71310cace --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API at::Tensor avg_pool2d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true, ::std::optional divisor_override=::std::nullopt); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..1b90d41fd258fdd81298ab6dd1874568f3fbfb92 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_cpu_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor avg_pool2d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true, ::std::optional divisor_override=::std::nullopt); +TORCH_API at::Tensor & avg_pool2d_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true, ::std::optional divisor_override=::std::nullopt); +TORCH_API at::Tensor & avg_pool2d_outf(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & out); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..700d767788a0ec92ab93433f1a451910e9ae72cd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_cuda_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor avg_pool2d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true, ::std::optional divisor_override=::std::nullopt); +TORCH_API at::Tensor & avg_pool2d_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true, ::std::optional divisor_override=::std::nullopt); +TORCH_API at::Tensor & avg_pool2d_outf(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & out); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_meta.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..1d2706ea84348f2726eae68277043865b9f3c4f7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_meta.h @@ -0,0 +1,114 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +struct TORCH_API structured_avg_pool2d : public at::impl::MetaBase { + + template + struct TORCH_API precompute_out { + + precompute_out set_kH(int64_t value) { + static_assert(KH == false, "kH already set"); + precompute_out ret; +ret.kH = value; +ret.kW = this->kW; +ret.dH = this->dH; +ret.dW = this->dW; +ret.padH = this->padH; +ret.padW = this->padW; +return ret; + } + + + precompute_out set_kW(int64_t value) { + static_assert(KW == false, "kW already set"); + precompute_out ret; +ret.kH = this->kH; +ret.kW = value; +ret.dH = this->dH; +ret.dW = this->dW; +ret.padH = this->padH; +ret.padW = this->padW; +return ret; + } + + + precompute_out set_dH(int64_t value) { + static_assert(DH == false, "dH already set"); + precompute_out ret; +ret.kH = this->kH; +ret.kW = this->kW; +ret.dH = value; +ret.dW = this->dW; +ret.padH = this->padH; +ret.padW = this->padW; +return ret; + } + + + precompute_out set_dW(int64_t value) { + static_assert(DW == false, "dW already set"); + precompute_out ret; +ret.kH = this->kH; +ret.kW = this->kW; +ret.dH = this->dH; +ret.dW = value; +ret.padH = this->padH; +ret.padW = this->padW; +return ret; + } + + + precompute_out set_padH(int64_t value) { + static_assert(PADH == false, "padH already set"); + precompute_out ret; +ret.kH = this->kH; +ret.kW = this->kW; +ret.dH = this->dH; +ret.dW = this->dW; +ret.padH = value; +ret.padW = this->padW; +return ret; + } + + + precompute_out set_padW(int64_t value) { + static_assert(PADW == false, "padW already set"); + precompute_out ret; +ret.kH = this->kH; +ret.kW = this->kW; +ret.dH = this->dH; +ret.dW = this->dW; +ret.padH = this->padH; +ret.padW = value; +return ret; + } + + int64_t kH; +int64_t kW; +int64_t dH; +int64_t dW; +int64_t padH; +int64_t padW; + }; + using meta_return_ty = precompute_out ; + meta_return_ty meta(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override); +}; + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_meta_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..56aaee95755b5ece49a1b6f4ef8aef7266ef1c17 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_meta_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor avg_pool2d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true, ::std::optional divisor_override=::std::nullopt); +TORCH_API at::Tensor & avg_pool2d_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true, ::std::optional divisor_override=::std::nullopt); +TORCH_API at::Tensor & avg_pool2d_outf(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & out); + +} // namespace meta +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_native.h new file mode 100644 index 0000000000000000000000000000000000000000..58557ca0502dc2b81a7fc3a8869aa6c33981fd2a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_native.h @@ -0,0 +1,29 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +struct TORCH_API structured_avg_pool2d_out_cpu : public at::meta::structured_avg_pool2d { +void impl(const at::Tensor & self, int64_t kH, int64_t kW, int64_t dH, int64_t dW, int64_t padH, int64_t padW, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, const at::Tensor & out); +}; +struct TORCH_API structured_avg_pool2d_out_cuda : public at::meta::structured_avg_pool2d { +void impl(const at::Tensor & self, int64_t kH, int64_t kW, int64_t dH, int64_t dW, int64_t padH, int64_t padW, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, const at::Tensor & out); +}; +TORCH_API at::Tensor mkldnn_avg_pool2d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true, ::std::optional divisor_override=::std::nullopt); +TORCH_API at::Tensor & mkldnn_avg_pool2d_out(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & out); +TORCH_API at::Tensor avg_pool2d_quantized_cpu(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true, ::std::optional divisor_override=::std::nullopt); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..95a7929f2cbf0d2beb6dd2d603991c0f8d2d050c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool2d_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API avg_pool2d_out { + using schema = at::Tensor & (const at::Tensor &, at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, bool, bool, ::std::optional, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::avg_pool2d"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "avg_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & out); +}; + +struct TORCH_API avg_pool2d { + using schema = at::Tensor (const at::Tensor &, at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, bool, bool, ::std::optional); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::avg_pool2d"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor"; + static at::Tensor call(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d.h new file mode 100644 index 0000000000000000000000000000000000000000..006a45dd4a819215fc231db860d2d1d9afe67b2a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::avg_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & avg_pool3d_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true, ::std::optional divisor_override=::std::nullopt) { + return at::_ops::avg_pool3d_out::call(self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, out); +} +// aten::avg_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & avg_pool3d_outf(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & out) { + return at::_ops::avg_pool3d_out::call(self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, out); +} + +// aten::avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor +inline at::Tensor avg_pool3d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true, ::std::optional divisor_override=::std::nullopt) { + return at::_ops::avg_pool3d::call(self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_backward.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_backward.h new file mode 100644 index 0000000000000000000000000000000000000000..ff949a34b157018a887b41aeaa673723bf042168 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_backward.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::avg_pool3d_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override, *, Tensor(a!) grad_input) -> Tensor(a!) +inline at::Tensor & avg_pool3d_backward_out(at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override) { + return at::_ops::avg_pool3d_backward_grad_input::call(grad_output, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, grad_input); +} +// aten::avg_pool3d_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override, *, Tensor(a!) grad_input) -> Tensor(a!) +inline at::Tensor & avg_pool3d_backward_outf(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & grad_input) { + return at::_ops::avg_pool3d_backward_grad_input::call(grad_output, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, grad_input); +} + +// aten::avg_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor +inline at::Tensor avg_pool3d_backward(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override) { + return at::_ops::avg_pool3d_backward::call(grad_output, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_backward_compositeexplicitautogradnonfunctional_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_backward_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..c86f8ccc003774cf51090dea58fab0180c46459e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_backward_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API at::Tensor avg_pool3d_backward(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_backward_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_backward_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..86920829dfd84a2f7d13fc1389d7e122e72c88f1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_backward_cpu_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor avg_pool3d_backward(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override); +TORCH_API at::Tensor & avg_pool3d_backward_out(at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override); +TORCH_API at::Tensor & avg_pool3d_backward_outf(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & grad_input); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_backward_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_backward_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..ac77a9753053f3144d23f528674848b2fb3bc72a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_backward_cuda_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor avg_pool3d_backward(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override); +TORCH_API at::Tensor & avg_pool3d_backward_out(at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override); +TORCH_API at::Tensor & avg_pool3d_backward_outf(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & grad_input); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_backward_meta.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_backward_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..be20a70749a2f704a34a0776446ab855b2091441 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_backward_meta.h @@ -0,0 +1,27 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +struct TORCH_API structured_avg_pool3d_backward : public at::impl::MetaBase { + + + void meta(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override); +}; + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_backward_meta_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_backward_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..c09a58d74a16bd73ebf94293521d7498fb7b6027 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_backward_meta_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor avg_pool3d_backward(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override); +TORCH_API at::Tensor & avg_pool3d_backward_out(at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override); +TORCH_API at::Tensor & avg_pool3d_backward_outf(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & grad_input); + +} // namespace meta +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_backward_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_backward_native.h new file mode 100644 index 0000000000000000000000000000000000000000..f926b54c9f38bca2b6bc66557c5dc616634950c7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_backward_native.h @@ -0,0 +1,28 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +struct TORCH_API structured_avg_pool3d_backward_out_cpu : public at::meta::structured_avg_pool3d_backward { +void impl(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, const at::Tensor & grad_input); +}; +struct TORCH_API structured_avg_pool3d_backward_out_cuda : public at::meta::structured_avg_pool3d_backward { +void impl(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, const at::Tensor & grad_input); +}; +TORCH_API at::Tensor mkldnn_avg_pool3d_backward(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override); +TORCH_API at::Tensor & mkldnn_avg_pool3d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & grad_input); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_backward_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_backward_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..11ccd633fa4c777908118b6996a0b8f66dc1baaa --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_backward_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API avg_pool3d_backward_grad_input { + using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, bool, bool, ::std::optional, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::avg_pool3d_backward"; + static constexpr const char* overload_name = "grad_input"; + static constexpr const char* schema_str = "avg_pool3d_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override, *, Tensor(a!) grad_input) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & grad_input); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & grad_input); +}; + +struct TORCH_API avg_pool3d_backward { + using schema = at::Tensor (const at::Tensor &, const at::Tensor &, at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, bool, bool, ::std::optional); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::avg_pool3d_backward"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "avg_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor"; + static at::Tensor call(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_compositeexplicitautogradnonfunctional_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..5f467a6bb9bb8fbdd4b32b7dd82462c540137257 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API at::Tensor avg_pool3d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true, ::std::optional divisor_override=::std::nullopt); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..d3869b790d02d9d91d568fb75a0ae681ff6cfb3c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_cpu_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor avg_pool3d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true, ::std::optional divisor_override=::std::nullopt); +TORCH_API at::Tensor & avg_pool3d_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true, ::std::optional divisor_override=::std::nullopt); +TORCH_API at::Tensor & avg_pool3d_outf(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & out); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..f8f6aa7aed5041fb84802ca407f7628290c0e162 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_cuda_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor avg_pool3d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true, ::std::optional divisor_override=::std::nullopt); +TORCH_API at::Tensor & avg_pool3d_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true, ::std::optional divisor_override=::std::nullopt); +TORCH_API at::Tensor & avg_pool3d_outf(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & out); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_meta.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..3ad817dcb1ae6dc72b7170edc4a73c4bea895cdd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_meta.h @@ -0,0 +1,27 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +struct TORCH_API structured_avg_pool3d : public at::impl::MetaBase { + + + void meta(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override); +}; + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_meta_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..f200e5b5e5c6e1d1dd6f9d9b2d29ec736c9c8a14 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_meta_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor avg_pool3d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true, ::std::optional divisor_override=::std::nullopt); +TORCH_API at::Tensor & avg_pool3d_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true, ::std::optional divisor_override=::std::nullopt); +TORCH_API at::Tensor & avg_pool3d_outf(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & out); + +} // namespace meta +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_native.h new file mode 100644 index 0000000000000000000000000000000000000000..6ed3ec26e17f9e0273fa18666c26980e35ae4cbd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_native.h @@ -0,0 +1,29 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +struct TORCH_API structured_avg_pool3d_out_cpu : public at::meta::structured_avg_pool3d { +void impl(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, const at::Tensor & out); +}; +struct TORCH_API structured_avg_pool3d_out_cuda : public at::meta::structured_avg_pool3d { +void impl(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, const at::Tensor & out); +}; +TORCH_API at::Tensor mkldnn_avg_pool3d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true, ::std::optional divisor_override=::std::nullopt); +TORCH_API at::Tensor & mkldnn_avg_pool3d_out(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & out); +TORCH_API at::Tensor avg_pool3d_quantized_cpu(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true, ::std::optional divisor_override=::std::nullopt); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..d953af84ca89d6edaa2410761a55ff673116e57b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/avg_pool3d_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API avg_pool3d_out { + using schema = at::Tensor & (const at::Tensor &, at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, bool, bool, ::std::optional, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::avg_pool3d"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "avg_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & out); +}; + +struct TORCH_API avg_pool3d { + using schema = at::Tensor (const at::Tensor &, at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, bool, bool, ::std::optional); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::avg_pool3d"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor"; + static at::Tensor call(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/baddbmm.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/baddbmm.h new file mode 100644 index 0000000000000000000000000000000000000000..603fdf773986316fb868e24ccde72b3245020bcc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/baddbmm.h @@ -0,0 +1,54 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor +inline at::Tensor baddbmm(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::baddbmm::call(self, batch1, batch2, beta, alpha); +} + +// aten::baddbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & baddbmm_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::baddbmm_out::call(self, batch1, batch2, beta, alpha, out); +} +// aten::baddbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & baddbmm_outf(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::baddbmm_out::call(self, batch1, batch2, beta, alpha, out); +} + +// aten::baddbmm.dtype(Tensor self, Tensor batch1, Tensor batch2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1) -> Tensor +inline at::Tensor baddbmm(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, at::ScalarType out_dtype, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::baddbmm_dtype::call(self, batch1, batch2, out_dtype, beta, alpha); +} + +// aten::baddbmm.dtype_out(Tensor self, Tensor batch1, Tensor batch2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & baddbmm_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, at::ScalarType out_dtype, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::baddbmm_dtype_out::call(self, batch1, batch2, out_dtype, beta, alpha, out); +} +// aten::baddbmm.dtype_out(Tensor self, Tensor batch1, Tensor batch2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & baddbmm_outf(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, at::ScalarType out_dtype, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::baddbmm_dtype_out::call(self, batch1, batch2, out_dtype, beta, alpha, out); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/baddbmm_compositeexplicitautogradnonfunctional_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/baddbmm_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..6f58ac039e32af37c97cb1c77e8328ff602b8c58 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/baddbmm_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API at::Tensor baddbmm(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & baddbmm_(at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/baddbmm_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/baddbmm_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..d00b98433ee71cbd61f4a71eb66545623efbf6f6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/baddbmm_cpu_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor baddbmm(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & baddbmm_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & baddbmm_outf(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); +TORCH_API at::Tensor & baddbmm_(at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/baddbmm_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/baddbmm_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..4ab31d21b3ee1e340bd2de2e97f80ae7937306bd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/baddbmm_cuda_dispatch.h @@ -0,0 +1,29 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor baddbmm(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & baddbmm_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & baddbmm_outf(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); +TORCH_API at::Tensor & baddbmm_(at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor baddbmm(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, at::ScalarType out_dtype, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & baddbmm_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, at::ScalarType out_dtype, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & baddbmm_outf(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, at::ScalarType out_dtype, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/baddbmm_meta.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/baddbmm_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..75cf4ae06e418d1ffef3337b899bb3726c598f06 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/baddbmm_meta.h @@ -0,0 +1,27 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +struct TORCH_API structured_baddbmm : public at::impl::MetaBase { + + + void meta(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha); +}; + +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/baddbmm_meta_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/baddbmm_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..29126cebc9226d0d6665c78ec9cfd8ba1a53458f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/baddbmm_meta_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor baddbmm(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & baddbmm_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & baddbmm_outf(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); +TORCH_API at::Tensor & baddbmm_(at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1); + +} // namespace meta +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/baddbmm_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/baddbmm_native.h new file mode 100644 index 0000000000000000000000000000000000000000..a35875cb35a7376525f3e9682b937b9bd7243731 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/baddbmm_native.h @@ -0,0 +1,29 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +struct TORCH_API structured_baddbmm_out_cpu : public at::meta::structured_baddbmm { +void impl(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha, const at::Tensor & out); +}; +struct TORCH_API structured_baddbmm_out_cuda : public at::meta::structured_baddbmm { +void impl(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha, const at::Tensor & out); +}; +TORCH_API at::Tensor & baddbmm_out_sparse_csr_cuda(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); +TORCH_API at::Tensor _baddbmm_dtype_cuda(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, at::ScalarType out_dtype, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & _baddbmm_out_dtype_cuda(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, at::ScalarType out_dtype, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/baddbmm_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/baddbmm_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..746595b1626441cd468bd4a6fda080498bbe7028 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/baddbmm_ops.h @@ -0,0 +1,73 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API baddbmm { + using schema = at::Tensor (const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Scalar &, const at::Scalar &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::baddbmm"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"; + static at::Tensor call(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha); +}; + +struct TORCH_API baddbmm_ { + using schema = at::Tensor & (at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Scalar &, const at::Scalar &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::baddbmm_"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "baddbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)"; + static at::Tensor & call(at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha); +}; + +struct TORCH_API baddbmm_out { + using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Scalar &, const at::Scalar &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::baddbmm"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "baddbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); +}; + +struct TORCH_API baddbmm_dtype { + using schema = at::Tensor (const at::Tensor &, const at::Tensor &, const at::Tensor &, at::ScalarType, const at::Scalar &, const at::Scalar &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::baddbmm"; + static constexpr const char* overload_name = "dtype"; + static constexpr const char* schema_str = "baddbmm.dtype(Tensor self, Tensor batch1, Tensor batch2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1) -> Tensor"; + static at::Tensor call(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, at::ScalarType out_dtype, const at::Scalar & beta, const at::Scalar & alpha); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, at::ScalarType out_dtype, const at::Scalar & beta, const at::Scalar & alpha); +}; + +struct TORCH_API baddbmm_dtype_out { + using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, const at::Tensor &, at::ScalarType, const at::Scalar &, const at::Scalar &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::baddbmm"; + static constexpr const char* overload_name = "dtype_out"; + static constexpr const char* schema_str = "baddbmm.dtype_out(Tensor self, Tensor batch1, Tensor batch2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, at::ScalarType out_dtype, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, at::ScalarType out_dtype, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/bartlett_window.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/bartlett_window.h new file mode 100644 index 0000000000000000000000000000000000000000..e58e0ef33aa724f60cf719299ab66c364eb0f33c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/bartlett_window.h @@ -0,0 +1,62 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::bartlett_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor bartlett_window(int64_t window_length, at::TensorOptions options={}) { + return at::_ops::bartlett_window::call(window_length, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); +} +// aten::bartlett_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor bartlett_window(int64_t window_length, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::bartlett_window::call(window_length, dtype, layout, device, pin_memory); +} + +// aten::bartlett_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor bartlett_window(int64_t window_length, bool periodic, at::TensorOptions options={}) { + return at::_ops::bartlett_window_periodic::call(window_length, periodic, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); +} +// aten::bartlett_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor bartlett_window(int64_t window_length, bool periodic, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::bartlett_window_periodic::call(window_length, periodic, dtype, layout, device, pin_memory); +} + +// aten::bartlett_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & bartlett_window_out(at::Tensor & out, int64_t window_length) { + return at::_ops::bartlett_window_out::call(window_length, out); +} +// aten::bartlett_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & bartlett_window_outf(int64_t window_length, at::Tensor & out) { + return at::_ops::bartlett_window_out::call(window_length, out); +} + +// aten::bartlett_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & bartlett_window_out(at::Tensor & out, int64_t window_length, bool periodic) { + return at::_ops::bartlett_window_periodic_out::call(window_length, periodic, out); +} +// aten::bartlett_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & bartlett_window_outf(int64_t window_length, bool periodic, at::Tensor & out) { + return at::_ops::bartlett_window_periodic_out::call(window_length, periodic, out); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/bartlett_window_compositeexplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/bartlett_window_compositeexplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..cf452e69c0b43e3009a8ba5a9e24bbf4d4cabe2e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/bartlett_window_compositeexplicitautograd_dispatch.h @@ -0,0 +1,30 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautograd { + +TORCH_API at::Tensor bartlett_window(int64_t window_length, at::TensorOptions options={}); +TORCH_API at::Tensor bartlett_window(int64_t window_length, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); +TORCH_API at::Tensor & bartlett_window_out(at::Tensor & out, int64_t window_length); +TORCH_API at::Tensor & bartlett_window_outf(int64_t window_length, at::Tensor & out); +TORCH_API at::Tensor bartlett_window(int64_t window_length, bool periodic, at::TensorOptions options={}); +TORCH_API at::Tensor bartlett_window(int64_t window_length, bool periodic, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); +TORCH_API at::Tensor & bartlett_window_out(at::Tensor & out, int64_t window_length, bool periodic); +TORCH_API at::Tensor & bartlett_window_outf(int64_t window_length, bool periodic, at::Tensor & out); + +} // namespace compositeexplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/bartlett_window_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/bartlett_window_native.h new file mode 100644 index 0000000000000000000000000000000000000000..52e626ef3a2d33e438ee64f34448d28bf455833a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/bartlett_window_native.h @@ -0,0 +1,24 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor bartlett_window(int64_t window_length, ::std::optional dtype={}, ::std::optional layout={}, ::std::optional device={}, ::std::optional pin_memory={}); +TORCH_API at::Tensor & bartlett_window_out(int64_t window_length, at::Tensor & out); +TORCH_API at::Tensor bartlett_window(int64_t window_length, bool periodic, ::std::optional dtype={}, ::std::optional layout={}, ::std::optional device={}, ::std::optional pin_memory={}); +TORCH_API at::Tensor & bartlett_window_periodic_out(int64_t window_length, bool periodic, at::Tensor & out); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/bartlett_window_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/bartlett_window_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..50dc39dd41d0e8cd8f094ff4ee17f140bd216c49 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/bartlett_window_ops.h @@ -0,0 +1,62 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API bartlett_window { + using schema = at::Tensor (int64_t, ::std::optional, ::std::optional, ::std::optional, ::std::optional); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::bartlett_window"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "bartlett_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"; + static at::Tensor call(int64_t window_length, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, int64_t window_length, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); +}; + +struct TORCH_API bartlett_window_periodic { + using schema = at::Tensor (int64_t, bool, ::std::optional, ::std::optional, ::std::optional, ::std::optional); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::bartlett_window"; + static constexpr const char* overload_name = "periodic"; + static constexpr const char* schema_str = "bartlett_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"; + static at::Tensor call(int64_t window_length, bool periodic, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); +}; + +struct TORCH_API bartlett_window_out { + using schema = at::Tensor & (int64_t, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::bartlett_window"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "bartlett_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(int64_t window_length, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, int64_t window_length, at::Tensor & out); +}; + +struct TORCH_API bartlett_window_periodic_out { + using schema = at::Tensor & (int64_t, bool, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::bartlett_window"; + static constexpr const char* overload_name = "periodic_out"; + static constexpr const char* schema_str = "bartlett_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(int64_t window_length, bool periodic, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm.h new file mode 100644 index 0000000000000000000000000000000000000000..07c7f708e670beddae88717bcbb78c4b8e743a86 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm.h @@ -0,0 +1,31 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor +inline at::Tensor batch_norm(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double momentum, double eps, bool cudnn_enabled) { + return at::_ops::batch_norm::call(input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward.h new file mode 100644 index 0000000000000000000000000000000000000000..0f91732d503766941b30f7ec6facfa4a8fcb8a95 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward.h @@ -0,0 +1,31 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, bool update, float eps, bool[3] output_mask, Tensor reserve) -> (Tensor, Tensor, Tensor) +inline ::std::tuple batch_norm_backward(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, bool update, double eps, ::std::array output_mask, const at::Tensor & reserve) { + return at::_ops::batch_norm_backward::call(grad_out, input, weight, running_mean, running_var, save_mean, save_var, update, eps, output_mask, reserve); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..cf4ecceab7bfe79d6ad555553807ae7561bc222b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_cpu_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API ::std::tuple batch_norm_backward(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, bool update, double eps, ::std::array output_mask, const at::Tensor & reserve); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..6e9d7a008e0fe497062bd88ab2845e59c5792c09 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_cuda_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API ::std::tuple batch_norm_backward(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, bool update, double eps, ::std::array output_mask, const at::Tensor & reserve); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_elemt.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_elemt.h new file mode 100644 index 0000000000000000000000000000000000000000..7ad8d8c870edcfa8273a462e4e74efde43d43d35 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_elemt.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::batch_norm_backward_elemt(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, Tensor sum_dy, Tensor sum_dy_xmu, Tensor count) -> Tensor +inline at::Tensor batch_norm_backward_elemt(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, const at::Tensor & sum_dy, const at::Tensor & sum_dy_xmu, const at::Tensor & count) { + return at::_ops::batch_norm_backward_elemt::call(grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count); +} + +// aten::batch_norm_backward_elemt.out(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, Tensor sum_dy, Tensor sum_dy_xmu, Tensor count, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & batch_norm_backward_elemt_out(at::Tensor & out, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, const at::Tensor & sum_dy, const at::Tensor & sum_dy_xmu, const at::Tensor & count) { + return at::_ops::batch_norm_backward_elemt_out::call(grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count, out); +} +// aten::batch_norm_backward_elemt.out(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, Tensor sum_dy, Tensor sum_dy_xmu, Tensor count, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & batch_norm_backward_elemt_outf(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, const at::Tensor & sum_dy, const at::Tensor & sum_dy_xmu, const at::Tensor & count, at::Tensor & out) { + return at::_ops::batch_norm_backward_elemt_out::call(grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count, out); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_elemt_compositeexplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_elemt_compositeexplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..4613e2f20f1fe7127f05e5dd6fe9cb7fb2e714b4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_elemt_compositeexplicitautograd_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautograd { + +TORCH_API at::Tensor & batch_norm_backward_elemt_out(at::Tensor & out, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, const at::Tensor & sum_dy, const at::Tensor & sum_dy_xmu, const at::Tensor & count); +TORCH_API at::Tensor & batch_norm_backward_elemt_outf(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, const at::Tensor & sum_dy, const at::Tensor & sum_dy_xmu, const at::Tensor & count, at::Tensor & out); + +} // namespace compositeexplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_elemt_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_elemt_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..f9d57fb8bacc940227e714ce70fb688ad3fa1010 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_elemt_cuda_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor batch_norm_backward_elemt(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, const at::Tensor & sum_dy, const at::Tensor & sum_dy_xmu, const at::Tensor & count); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_elemt_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_elemt_native.h new file mode 100644 index 0000000000000000000000000000000000000000..35c35145b76ef1a8ba1d558c94f73caf56276f1d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_elemt_native.h @@ -0,0 +1,22 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor & batch_norm_backward_elemt_out(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, const at::Tensor & sum_dy, const at::Tensor & sum_dy_xmu, const at::Tensor & count, at::Tensor & out); +TORCH_API at::Tensor batch_norm_backward_elemt_cuda(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, const at::Tensor & sum_dy, const at::Tensor & sum_dy_xmu, const at::Tensor & count); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_elemt_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_elemt_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..a0d9664eec7b6a5d4189208e9df29d625e39e128 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_elemt_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API batch_norm_backward_elemt { + using schema = at::Tensor (const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Tensor &, const ::std::optional &, const at::Tensor &, const at::Tensor &, const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::batch_norm_backward_elemt"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "batch_norm_backward_elemt(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, Tensor sum_dy, Tensor sum_dy_xmu, Tensor count) -> Tensor"; + static at::Tensor call(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, const at::Tensor & sum_dy, const at::Tensor & sum_dy_xmu, const at::Tensor & count); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, const at::Tensor & sum_dy, const at::Tensor & sum_dy_xmu, const at::Tensor & count); +}; + +struct TORCH_API batch_norm_backward_elemt_out { + using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Tensor &, const ::std::optional &, const at::Tensor &, const at::Tensor &, const at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::batch_norm_backward_elemt"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "batch_norm_backward_elemt.out(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, Tensor sum_dy, Tensor sum_dy_xmu, Tensor count, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, const at::Tensor & sum_dy, const at::Tensor & sum_dy_xmu, const at::Tensor & count, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, const at::Tensor & sum_dy, const at::Tensor & sum_dy_xmu, const at::Tensor & count, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_native.h new file mode 100644 index 0000000000000000000000000000000000000000..18db308a651e84e5140c8128d5f9ef649f71e775 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_native.h @@ -0,0 +1,23 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API ::std::tuple _new_batch_norm_backward_cpu(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, bool update, double eps, ::std::array output_mask, const at::Tensor & reserve); +TORCH_API ::std::tuple _new_batch_norm_backward_cuda(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, bool update, double eps, ::std::array output_mask, const at::Tensor & reserve); +TORCH_API ::std::tuple _new_batch_norm_backward_mkldnn(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, bool update, double eps, ::std::array output_mask, const at::Tensor & reserve); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..a76cb40e64c4d02cec09c60ed6395bb6b2387ebb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_ops.h @@ -0,0 +1,29 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API batch_norm_backward { + using schema = ::std::tuple (const at::Tensor &, const at::Tensor &, const at::Tensor &, const ::std::optional &, const ::std::optional &, const ::std::optional &, const ::std::optional &, bool, double, ::std::array, const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::batch_norm_backward"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, bool update, float eps, bool[3] output_mask, Tensor reserve) -> (Tensor, Tensor, Tensor)"; + static ::std::tuple call(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, bool update, double eps, ::std::array output_mask, const at::Tensor & reserve); + static ::std::tuple redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, bool update, double eps, ::std::array output_mask, const at::Tensor & reserve); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_reduce.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_reduce.h new file mode 100644 index 0000000000000000000000000000000000000000..7593a51409ec55ff68cd4ca87f4c6216c1289352 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_reduce.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::batch_norm_backward_reduce(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, bool input_g, bool weight_g, bool bias_g) -> (Tensor, Tensor, Tensor, Tensor) +inline ::std::tuple batch_norm_backward_reduce(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, bool input_g, bool weight_g, bool bias_g) { + return at::_ops::batch_norm_backward_reduce::call(grad_out, input, mean, invstd, weight, input_g, weight_g, bias_g); +} + +// aten::batch_norm_backward_reduce.out(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, bool input_g, bool weight_g, bool bias_g, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!)) +inline ::std::tuple batch_norm_backward_reduce_out(at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, bool input_g, bool weight_g, bool bias_g) { + return at::_ops::batch_norm_backward_reduce_out::call(grad_out, input, mean, invstd, weight, input_g, weight_g, bias_g, out0, out1, out2, out3); +} +// aten::batch_norm_backward_reduce.out(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, bool input_g, bool weight_g, bool bias_g, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!)) +inline ::std::tuple batch_norm_backward_reduce_outf(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, bool input_g, bool weight_g, bool bias_g, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3) { + return at::_ops::batch_norm_backward_reduce_out::call(grad_out, input, mean, invstd, weight, input_g, weight_g, bias_g, out0, out1, out2, out3); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_reduce_compositeexplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_reduce_compositeexplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..8248fb37ad2eaa5beeafde326c7945b6178835df --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_reduce_compositeexplicitautograd_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautograd { + +TORCH_API ::std::tuple batch_norm_backward_reduce_out(at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, bool input_g, bool weight_g, bool bias_g); +TORCH_API ::std::tuple batch_norm_backward_reduce_outf(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, bool input_g, bool weight_g, bool bias_g, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3); + +} // namespace compositeexplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_reduce_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_reduce_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..c6f025f4791179741d87f1acbe66dc5b12a32ba9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_reduce_cuda_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API ::std::tuple batch_norm_backward_reduce(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, bool input_g, bool weight_g, bool bias_g); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_reduce_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_reduce_native.h new file mode 100644 index 0000000000000000000000000000000000000000..e78fa9329e0c2ade0ff804262086f2539b62f7c4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_reduce_native.h @@ -0,0 +1,22 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API ::std::tuple batch_norm_backward_reduce_out(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, bool input_g, bool weight_g, bool bias_g, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3); +TORCH_API ::std::tuple batch_norm_backward_reduce_cuda(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, bool input_g, bool weight_g, bool bias_g); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_reduce_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_reduce_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..f94c4eb9d5ad8eca121ffa0cbfbece15f31d1c97 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_backward_reduce_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API batch_norm_backward_reduce { + using schema = ::std::tuple (const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Tensor &, const ::std::optional &, bool, bool, bool); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::batch_norm_backward_reduce"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "batch_norm_backward_reduce(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, bool input_g, bool weight_g, bool bias_g) -> (Tensor, Tensor, Tensor, Tensor)"; + static ::std::tuple call(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, bool input_g, bool weight_g, bool bias_g); + static ::std::tuple redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, bool input_g, bool weight_g, bool bias_g); +}; + +struct TORCH_API batch_norm_backward_reduce_out { + using schema = ::std::tuple (const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Tensor &, const ::std::optional &, bool, bool, bool, at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::batch_norm_backward_reduce"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "batch_norm_backward_reduce.out(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, bool input_g, bool weight_g, bool bias_g, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))"; + static ::std::tuple call(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, bool input_g, bool weight_g, bool bias_g, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3); + static ::std::tuple redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, bool input_g, bool weight_g, bool bias_g, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_compositeimplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..e59141e506b32521e83915f3ac0c51f696d354cc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_compositeimplicitautograd_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API at::Tensor batch_norm(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double momentum, double eps, bool cudnn_enabled); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_elemt.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_elemt.h new file mode 100644 index 0000000000000000000000000000000000000000..95a9c39885f6529730085bfd1cff81911fbe6809 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_elemt.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::batch_norm_elemt(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor invstd, float eps) -> Tensor +inline at::Tensor batch_norm_elemt(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & mean, const at::Tensor & invstd, double eps) { + return at::_ops::batch_norm_elemt::call(input, weight, bias, mean, invstd, eps); +} + +// aten::batch_norm_elemt.out(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor invstd, float eps, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & batch_norm_elemt_out(at::Tensor & out, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & mean, const at::Tensor & invstd, double eps) { + return at::_ops::batch_norm_elemt_out::call(input, weight, bias, mean, invstd, eps, out); +} +// aten::batch_norm_elemt.out(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor invstd, float eps, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & batch_norm_elemt_outf(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & mean, const at::Tensor & invstd, double eps, at::Tensor & out) { + return at::_ops::batch_norm_elemt_out::call(input, weight, bias, mean, invstd, eps, out); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_elemt_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_elemt_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..e0179a1acc8f1fc2fd290613fc5071f84fee2aa8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_elemt_cuda_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor batch_norm_elemt(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & mean, const at::Tensor & invstd, double eps); +TORCH_API at::Tensor & batch_norm_elemt_out(at::Tensor & out, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & mean, const at::Tensor & invstd, double eps); +TORCH_API at::Tensor & batch_norm_elemt_outf(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & mean, const at::Tensor & invstd, double eps, at::Tensor & out); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_elemt_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_elemt_native.h new file mode 100644 index 0000000000000000000000000000000000000000..c46e853b8047dcfd299f48e886df602c9b04fa26 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_elemt_native.h @@ -0,0 +1,22 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor batch_norm_elemt_cuda(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & mean, const at::Tensor & invstd, double eps); +TORCH_API at::Tensor & batch_norm_elemt_cuda_out(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & mean, const at::Tensor & invstd, double eps, at::Tensor & out); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_elemt_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_elemt_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..e90cfc5d408b413cbb08614390e6214de5d086a0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_elemt_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API batch_norm_elemt { + using schema = at::Tensor (const at::Tensor &, const ::std::optional &, const ::std::optional &, const at::Tensor &, const at::Tensor &, double); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::batch_norm_elemt"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "batch_norm_elemt(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor invstd, float eps) -> Tensor"; + static at::Tensor call(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & mean, const at::Tensor & invstd, double eps); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & mean, const at::Tensor & invstd, double eps); +}; + +struct TORCH_API batch_norm_elemt_out { + using schema = at::Tensor & (const at::Tensor &, const ::std::optional &, const ::std::optional &, const at::Tensor &, const at::Tensor &, double, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::batch_norm_elemt"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "batch_norm_elemt.out(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor invstd, float eps, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & mean, const at::Tensor & invstd, double eps, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & mean, const at::Tensor & invstd, double eps, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_gather_stats.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_gather_stats.h new file mode 100644 index 0000000000000000000000000000000000000000..54baea214a795d537b4fd660545af861e764ff51 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_gather_stats.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::batch_norm_gather_stats(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, int count) -> (Tensor, Tensor) +inline ::std::tuple batch_norm_gather_stats(const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, int64_t count) { + return at::_ops::batch_norm_gather_stats::call(input, mean, invstd, running_mean, running_var, momentum, eps, count); +} + +// aten::batch_norm_gather_stats.out(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, int count, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) +inline ::std::tuple batch_norm_gather_stats_out(at::Tensor & out0, at::Tensor & out1, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, int64_t count) { + return at::_ops::batch_norm_gather_stats_out::call(input, mean, invstd, running_mean, running_var, momentum, eps, count, out0, out1); +} +// aten::batch_norm_gather_stats.out(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, int count, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) +inline ::std::tuple batch_norm_gather_stats_outf(const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, int64_t count, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::batch_norm_gather_stats_out::call(input, mean, invstd, running_mean, running_var, momentum, eps, count, out0, out1); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_gather_stats_compositeexplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_gather_stats_compositeexplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..49d28db3270e069768b62c5cd18dd10cb4bb5e06 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_gather_stats_compositeexplicitautograd_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautograd { + +TORCH_API ::std::tuple batch_norm_gather_stats_out(at::Tensor & out0, at::Tensor & out1, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, int64_t count); +TORCH_API ::std::tuple batch_norm_gather_stats_outf(const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, int64_t count, at::Tensor & out0, at::Tensor & out1); + +} // namespace compositeexplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_gather_stats_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_gather_stats_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..c6696c8767f71461f58a2a50cdc800e0aa7a5c1e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_gather_stats_cuda_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API ::std::tuple batch_norm_gather_stats(const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, int64_t count); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_gather_stats_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_gather_stats_native.h new file mode 100644 index 0000000000000000000000000000000000000000..ce34d2a7f2277b622f6aa67245c6da9e35f1f9db --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_gather_stats_native.h @@ -0,0 +1,22 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API ::std::tuple batch_norm_gather_stats_out(const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, int64_t count, at::Tensor & out0, at::Tensor & out1); +TORCH_API ::std::tuple batch_norm_gather_stats_cuda(const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, int64_t count); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_gather_stats_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_gather_stats_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..c175b98a90c2426abe34d28a683b40f832f8a188 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_gather_stats_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API batch_norm_gather_stats { + using schema = ::std::tuple (const at::Tensor &, const at::Tensor &, const at::Tensor &, const ::std::optional &, const ::std::optional &, double, double, int64_t); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::batch_norm_gather_stats"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "batch_norm_gather_stats(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, int count) -> (Tensor, Tensor)"; + static ::std::tuple call(const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, int64_t count); + static ::std::tuple redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, int64_t count); +}; + +struct TORCH_API batch_norm_gather_stats_out { + using schema = ::std::tuple (const at::Tensor &, const at::Tensor &, const at::Tensor &, const ::std::optional &, const ::std::optional &, double, double, int64_t, at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::batch_norm_gather_stats"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "batch_norm_gather_stats.out(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, int count, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))"; + static ::std::tuple call(const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, int64_t count, at::Tensor & out0, at::Tensor & out1); + static ::std::tuple redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, int64_t count, at::Tensor & out0, at::Tensor & out1); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_gather_stats_with_counts.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_gather_stats_with_counts.h new file mode 100644 index 0000000000000000000000000000000000000000..1cf4fc50858cd8a27a619fcd14dfee157f9dd5b0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_gather_stats_with_counts.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::batch_norm_gather_stats_with_counts(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, Tensor counts) -> (Tensor, Tensor) +inline ::std::tuple batch_norm_gather_stats_with_counts(const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, const at::Tensor & counts) { + return at::_ops::batch_norm_gather_stats_with_counts::call(input, mean, invstd, running_mean, running_var, momentum, eps, counts); +} + +// aten::batch_norm_gather_stats_with_counts.out(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, Tensor counts, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) +inline ::std::tuple batch_norm_gather_stats_with_counts_out(at::Tensor & out0, at::Tensor & out1, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, const at::Tensor & counts) { + return at::_ops::batch_norm_gather_stats_with_counts_out::call(input, mean, invstd, running_mean, running_var, momentum, eps, counts, out0, out1); +} +// aten::batch_norm_gather_stats_with_counts.out(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, Tensor counts, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) +inline ::std::tuple batch_norm_gather_stats_with_counts_outf(const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, const at::Tensor & counts, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::batch_norm_gather_stats_with_counts_out::call(input, mean, invstd, running_mean, running_var, momentum, eps, counts, out0, out1); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_gather_stats_with_counts_compositeexplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_gather_stats_with_counts_compositeexplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..7b26be53a15c86cf9e1f5cb8c41e4ec81b8d6c08 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_gather_stats_with_counts_compositeexplicitautograd_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautograd { + +TORCH_API ::std::tuple batch_norm_gather_stats_with_counts_out(at::Tensor & out0, at::Tensor & out1, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, const at::Tensor & counts); +TORCH_API ::std::tuple batch_norm_gather_stats_with_counts_outf(const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, const at::Tensor & counts, at::Tensor & out0, at::Tensor & out1); + +} // namespace compositeexplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_gather_stats_with_counts_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_gather_stats_with_counts_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..f4c4be72d4347ddd72ecdb03acccb2f45a06e8c3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_gather_stats_with_counts_cuda_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API ::std::tuple batch_norm_gather_stats_with_counts(const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, const at::Tensor & counts); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_gather_stats_with_counts_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_gather_stats_with_counts_native.h new file mode 100644 index 0000000000000000000000000000000000000000..84f418a58b71e4fbb28183cf1c20aa9a60d6888d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_gather_stats_with_counts_native.h @@ -0,0 +1,22 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API ::std::tuple batch_norm_gather_stats_with_counts_out(const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, const at::Tensor & counts, at::Tensor & out0, at::Tensor & out1); +TORCH_API ::std::tuple batch_norm_gather_stats_with_counts_cuda(const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, const at::Tensor & counts); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_gather_stats_with_counts_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_gather_stats_with_counts_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..eeac0e2b2877c8a6669a298fa899082d927b90a3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_gather_stats_with_counts_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API batch_norm_gather_stats_with_counts { + using schema = ::std::tuple (const at::Tensor &, const at::Tensor &, const at::Tensor &, const ::std::optional &, const ::std::optional &, double, double, const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::batch_norm_gather_stats_with_counts"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "batch_norm_gather_stats_with_counts(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, Tensor counts) -> (Tensor, Tensor)"; + static ::std::tuple call(const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, const at::Tensor & counts); + static ::std::tuple redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, const at::Tensor & counts); +}; + +struct TORCH_API batch_norm_gather_stats_with_counts_out { + using schema = ::std::tuple (const at::Tensor &, const at::Tensor &, const at::Tensor &, const ::std::optional &, const ::std::optional &, double, double, const at::Tensor &, at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::batch_norm_gather_stats_with_counts"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "batch_norm_gather_stats_with_counts.out(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, Tensor counts, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))"; + static ::std::tuple call(const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, const at::Tensor & counts, at::Tensor & out0, at::Tensor & out1); + static ::std::tuple redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, const at::Tensor & counts, at::Tensor & out0, at::Tensor & out1); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_native.h new file mode 100644 index 0000000000000000000000000000000000000000..9e97d90f33d1c8dea90b088a2667c679df64e6d1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_native.h @@ -0,0 +1,21 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor batch_norm(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double momentum, double eps, bool cudnn_enabled); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..73d00a2f55a45ae03d164d06dc3922970db4edd5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_ops.h @@ -0,0 +1,29 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API batch_norm { + using schema = at::Tensor (const at::Tensor &, const ::std::optional &, const ::std::optional &, const ::std::optional &, const ::std::optional &, bool, double, double, bool); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::batch_norm"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor"; + static at::Tensor call(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double momentum, double eps, bool cudnn_enabled); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double momentum, double eps, bool cudnn_enabled); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_stats.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_stats.h new file mode 100644 index 0000000000000000000000000000000000000000..003ec12256830822d40e852389cdc3f6935e0d00 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_stats.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::batch_norm_stats(Tensor input, float eps) -> (Tensor, Tensor) +inline ::std::tuple batch_norm_stats(const at::Tensor & input, double eps) { + return at::_ops::batch_norm_stats::call(input, eps); +} + +// aten::batch_norm_stats.out(Tensor input, float eps, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) +inline ::std::tuple batch_norm_stats_out(at::Tensor & out0, at::Tensor & out1, const at::Tensor & input, double eps) { + return at::_ops::batch_norm_stats_out::call(input, eps, out0, out1); +} +// aten::batch_norm_stats.out(Tensor input, float eps, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) +inline ::std::tuple batch_norm_stats_outf(const at::Tensor & input, double eps, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::batch_norm_stats_out::call(input, eps, out0, out1); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_stats_compositeexplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_stats_compositeexplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..e9df01829a56509f6c5fcadb9e57d34a358ab549 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_stats_compositeexplicitautograd_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautograd { + +TORCH_API ::std::tuple batch_norm_stats_out(at::Tensor & out0, at::Tensor & out1, const at::Tensor & input, double eps); +TORCH_API ::std::tuple batch_norm_stats_outf(const at::Tensor & input, double eps, at::Tensor & out0, at::Tensor & out1); + +} // namespace compositeexplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_stats_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_stats_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..8fbc46e94f58e7005ead876ed4851408910663e2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_stats_cuda_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API ::std::tuple batch_norm_stats(const at::Tensor & input, double eps); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_stats_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_stats_native.h new file mode 100644 index 0000000000000000000000000000000000000000..e0affe9d7587677f0e297fcdbcb339c1dfe72e39 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_stats_native.h @@ -0,0 +1,22 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API ::std::tuple batch_norm_stats_out(const at::Tensor & input, double eps, at::Tensor & out0, at::Tensor & out1); +TORCH_API ::std::tuple batch_norm_stats_cuda(const at::Tensor & input, double eps); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_stats_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_stats_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..9a777a8136d43eb02aaa25a47cd74941c0953839 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_stats_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API batch_norm_stats { + using schema = ::std::tuple (const at::Tensor &, double); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::batch_norm_stats"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "batch_norm_stats(Tensor input, float eps) -> (Tensor, Tensor)"; + static ::std::tuple call(const at::Tensor & input, double eps); + static ::std::tuple redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, double eps); +}; + +struct TORCH_API batch_norm_stats_out { + using schema = ::std::tuple (const at::Tensor &, double, at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::batch_norm_stats"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "batch_norm_stats.out(Tensor input, float eps, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))"; + static ::std::tuple call(const at::Tensor & input, double eps, at::Tensor & out0, at::Tensor & out1); + static ::std::tuple redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, double eps, at::Tensor & out0, at::Tensor & out1); +}; + +}} // namespace at::_ops diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_update_stats.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_update_stats.h new file mode 100644 index 0000000000000000000000000000000000000000..6f8df3ecff5a880580af4162e548c84dbd89ff0f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_update_stats.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::batch_norm_update_stats(Tensor input, Tensor? running_mean, Tensor? running_var, float momentum) -> (Tensor, Tensor) +inline ::std::tuple batch_norm_update_stats(const at::Tensor & input, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum) { + return at::_ops::batch_norm_update_stats::call(input, running_mean, running_var, momentum); +} + +// aten::batch_norm_update_stats.out(Tensor input, Tensor? running_mean, Tensor? running_var, float momentum, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) +inline ::std::tuple batch_norm_update_stats_out(at::Tensor & out0, at::Tensor & out1, const at::Tensor & input, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum) { + return at::_ops::batch_norm_update_stats_out::call(input, running_mean, running_var, momentum, out0, out1); +} +// aten::batch_norm_update_stats.out(Tensor input, Tensor? running_mean, Tensor? running_var, float momentum, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) +inline ::std::tuple batch_norm_update_stats_outf(const at::Tensor & input, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::batch_norm_update_stats_out::call(input, running_mean, running_var, momentum, out0, out1); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_update_stats_compositeexplicitautograd_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_update_stats_compositeexplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..923e8efe0ce45ea5f822bebbd40c7ebbe6dc9f9c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_update_stats_compositeexplicitautograd_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautograd { + +TORCH_API ::std::tuple batch_norm_update_stats_out(at::Tensor & out0, at::Tensor & out1, const at::Tensor & input, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum); +TORCH_API ::std::tuple batch_norm_update_stats_outf(const at::Tensor & input, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, at::Tensor & out0, at::Tensor & out1); + +} // namespace compositeexplicitautograd +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_update_stats_cpu_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_update_stats_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..507a3b9af391bff8b5150bb24243e45504590f1c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_update_stats_cpu_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API ::std::tuple batch_norm_update_stats(const at::Tensor & input, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum); + +} // namespace cpu +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_update_stats_cuda_dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_update_stats_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..e398dc551a1b181b59a10945e1bde30cd034c8f8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_update_stats_cuda_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API ::std::tuple batch_norm_update_stats(const at::Tensor & input, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum); + +} // namespace cuda +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_update_stats_native.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_update_stats_native.h new file mode 100644 index 0000000000000000000000000000000000000000..1dc21be559e195c9cee4bec981aa9967b23cb07f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_update_stats_native.h @@ -0,0 +1,23 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API ::std::tuple batch_norm_update_stats_out(const at::Tensor & input, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, at::Tensor & out0, at::Tensor & out1); +TORCH_API ::std::tuple batch_norm_update_stats_cpu(const at::Tensor & input, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum); +TORCH_API ::std::tuple batch_norm_update_stats_cuda(const at::Tensor & input, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum); +} // namespace native +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_update_stats_ops.h b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_update_stats_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..a0d99b2312e48b8eddde5614aa287b8b977a1268 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ops/batch_norm_update_stats_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API batch_norm_update_stats { + using schema = ::std::tuple (const at::Tensor &, const ::std::optional &, const ::std::optional &, double); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::batch_norm_update_stats"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "batch_norm_update_stats(Tensor input, Tensor? running_mean, Tensor? running_var, float momentum) -> (Tensor, Tensor)"; + static ::std::tuple call(const at::Tensor & input, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum); + static ::std::tuple redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum); +}; + +struct TORCH_API batch_norm_update_stats_out { + using schema = ::std::tuple (const at::Tensor &, const ::std::optional &, const ::std::optional &, double, at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::batch_norm_update_stats"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "batch_norm_update_stats.out(Tensor input, Tensor? running_mean, Tensor? running_var, float momentum, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))"; + static ::std::tuple call(const at::Tensor & input, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, at::Tensor & out0, at::Tensor & out1); + static ::std::tuple redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, at::Tensor & out0, at::Tensor & out1); +}; + +}} // namespace at::_ops