Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +7 -0
- phivenv/Lib/site-packages/pip/_vendor/pkg_resources/__pycache__/__init__.cpython-39.pyc +3 -0
- phivenv/Lib/site-packages/pip/_vendor/rich/__pycache__/_emoji_codes.cpython-39.pyc +3 -0
- phivenv/Lib/site-packages/pkg_resources/__pycache__/__init__.cpython-39.pyc +3 -0
- phivenv/Lib/site-packages/pkg_resources/_vendor/__pycache__/pyparsing.cpython-39.pyc +3 -0
- phivenv/Lib/site-packages/regex/__pycache__/_regex_core.cpython-39.pyc +3 -0
- phivenv/Lib/site-packages/regex/__pycache__/test_regex.cpython-39.pyc +3 -0
- phivenv/Lib/site-packages/regex/_regex.cp39-win_amd64.pyd +3 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/Formatting.h +25 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/Generator.h +191 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/GeneratorForPrivateuseone.h +39 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/IListRef.h +631 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/IListRef_inl.h +203 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/LegacyTypeDispatch.h +111 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/List.h +491 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/List_inl.h +353 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/MT19937RNGEngine.h +194 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/NamedTensor.h +143 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/NestedIntSymNodeImpl.h +187 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/PhiloxRNGEngine.h +240 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/PythonFallbackKernel.h +35 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/PythonOpRegistrationTrampoline.h +22 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/QuantizerBase.h +84 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/Range.h +25 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/Reduction.h +14 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/Scalar.h +1 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/ScalarType.h +1 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/Tensor.h +98 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/TensorAccessor.h +275 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/TensorBase.h +1056 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/TensorBody.h +0 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/TorchDispatchUtils.h +17 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/TransformationHelper.h +175 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/UndefinedTensorImpl.h +1 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/UnsafeFromTH.h +21 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/VariableHooksInterface.h +83 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/Variadic.h +92 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/Vitals.h +94 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/boxing/BoxedKernel.h +213 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/boxing/BoxedKernel_impl.h +106 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/boxing/KernelFunction.h +283 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/boxing/KernelFunction_impl.h +320 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/boxing/OperatorKernel.h +27 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h +38 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h +41 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/boxing/impl/boxing.h +410 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h +785 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/boxing/impl/test_helpers.h +140 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/dispatch/CppSignature.h +67 -0
- phivenv/Lib/site-packages/torch/include/ATen/core/dispatch/DispatchKeyExtractor.h +279 -0
.gitattributes
CHANGED
|
@@ -54,3 +54,10 @@ phivenv/Lib/site-packages/pip/_vendor/distlib/t64.exe filter=lfs diff=lfs merge=
|
|
| 54 |
phivenv/Lib/site-packages/pip/_vendor/distlib/w64-arm.exe filter=lfs diff=lfs merge=lfs -text
|
| 55 |
phivenv/Lib/site-packages/pip/_vendor/idna/__pycache__/uts46data.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
|
| 56 |
phivenv/Lib/site-packages/pip/_vendor/pyparsing/__pycache__/core.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
phivenv/Lib/site-packages/pip/_vendor/distlib/w64-arm.exe filter=lfs diff=lfs merge=lfs -text
|
| 55 |
phivenv/Lib/site-packages/pip/_vendor/idna/__pycache__/uts46data.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
|
| 56 |
phivenv/Lib/site-packages/pip/_vendor/pyparsing/__pycache__/core.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
phivenv/Lib/site-packages/pip/_vendor/pkg_resources/__pycache__/__init__.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
|
| 58 |
+
phivenv/Lib/site-packages/pip/_vendor/rich/__pycache__/_emoji_codes.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
|
| 59 |
+
phivenv/Lib/site-packages/pkg_resources/_vendor/__pycache__/pyparsing.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
|
| 60 |
+
phivenv/Lib/site-packages/pkg_resources/__pycache__/__init__.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
|
| 61 |
+
phivenv/Lib/site-packages/regex/_regex.cp39-win_amd64.pyd filter=lfs diff=lfs merge=lfs -text
|
| 62 |
+
phivenv/Lib/site-packages/regex/__pycache__/test_regex.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
|
| 63 |
+
phivenv/Lib/site-packages/regex/__pycache__/_regex_core.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
|
phivenv/Lib/site-packages/pip/_vendor/pkg_resources/__pycache__/__init__.cpython-39.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a9ae9e3e39533b703fa0fb49576d02a073be55a6fbe3f9d9a38cbeb9ed03e116
|
| 3 |
+
size 100308
|
phivenv/Lib/site-packages/pip/_vendor/rich/__pycache__/_emoji_codes.cpython-39.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:88b150085f0eb6dcd1c70d632b16ccf923b66e1700800a4756d06b3726b91fcf
|
| 3 |
+
size 132673
|
phivenv/Lib/site-packages/pkg_resources/__pycache__/__init__.cpython-39.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:92b5449a62f76826fcde2e62b85c16f953a7ccfff8847bc4854a098b5a954dae
|
| 3 |
+
size 100411
|
phivenv/Lib/site-packages/pkg_resources/_vendor/__pycache__/pyparsing.cpython-39.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3f26434c485b7881d3ef563e57c88a171319a39cfcc3bf348cbe5bfd0d2a9887
|
| 3 |
+
size 201319
|
phivenv/Lib/site-packages/regex/__pycache__/_regex_core.cpython-39.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5727abd2cd4972398036f183a2e811e78ffa31946bf89c453917a171a61c12aa
|
| 3 |
+
size 114484
|
phivenv/Lib/site-packages/regex/__pycache__/test_regex.cpython-39.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:aa833453940a5409176fe65a5ba338e66d9d875a4a905f92c064b1ade0faba66
|
| 3 |
+
size 140105
|
phivenv/Lib/site-packages/regex/_regex.cp39-win_amd64.pyd
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:72ee579e80fb57b5b52f1a5a44b4dcbf85567e43442ad80f9da51f21e2f9977f
|
| 3 |
+
size 723968
|
phivenv/Lib/site-packages/torch/include/ATen/core/Formatting.h
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ostream>
|
| 4 |
+
#include <string>
|
| 5 |
+
|
| 6 |
+
#include <c10/core/Scalar.h>
|
| 7 |
+
#include <ATen/core/Tensor.h>
|
| 8 |
+
|
| 9 |
+
namespace c10 {
|
| 10 |
+
TORCH_API std::ostream& operator<<(std::ostream& out, Backend b);
|
| 11 |
+
TORCH_API std::ostream& operator<<(std::ostream & out, const Scalar& s);
|
| 12 |
+
TORCH_API std::string toString(const Scalar& s);
|
| 13 |
+
}
|
| 14 |
+
namespace at {
|
| 15 |
+
|
| 16 |
+
TORCH_API std::ostream& operator<<(std::ostream& out, const DeprecatedTypeProperties& t);
|
| 17 |
+
TORCH_API std::ostream& print(
|
| 18 |
+
std::ostream& stream,
|
| 19 |
+
const Tensor& tensor,
|
| 20 |
+
int64_t linesize);
|
| 21 |
+
inline std::ostream& operator<<(std::ostream & out, const Tensor & t) {
|
| 22 |
+
return print(out,t,80);
|
| 23 |
+
}
|
| 24 |
+
TORCH_API void print(const Tensor & t, int64_t linesize=80);
|
| 25 |
+
}
|
phivenv/Lib/site-packages/torch/include/ATen/core/Generator.h
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cstdint>
|
| 4 |
+
#include <deque>
|
| 5 |
+
#include <mutex>
|
| 6 |
+
#include <utility>
|
| 7 |
+
|
| 8 |
+
#include <c10/util/Exception.h>
|
| 9 |
+
#include <c10/util/intrusive_ptr.h>
|
| 10 |
+
#include <c10/core/Device.h>
|
| 11 |
+
#include <c10/core/DispatchKeySet.h>
|
| 12 |
+
|
| 13 |
+
// For the record I don't think this is a correct pimpl idiom.
|
| 14 |
+
// Including Impl header in interface header defeats the purpose
|
| 15 |
+
// because you can't change Impl private members without forcing
|
| 16 |
+
// everything that included the interface to rebuild.
|
| 17 |
+
// Impl should be forward-declared in the interface header instead.
|
| 18 |
+
#include <c10/core/GeneratorImpl.h>
|
| 19 |
+
|
| 20 |
+
/**
|
| 21 |
+
* Note [Generator]
|
| 22 |
+
* ~~~~~~~~~~~~~~~~
|
| 23 |
+
* A Pseudo Random Number Generator (PRNG) is an engine that uses an algorithm to
|
| 24 |
+
* generate a seemingly random sequence of numbers, that may be later be used in creating
|
| 25 |
+
* a random distribution. Such an engine almost always maintains a state and requires a
|
| 26 |
+
* seed to start off the creation of random numbers. Often times, users have
|
| 27 |
+
* found it beneficial to be able to explicitly create, retain, and destroy
|
| 28 |
+
* PRNG states and also be able to have control over the seed value.
|
| 29 |
+
*
|
| 30 |
+
* A Generator in ATen gives users the ability to read, write and modify a PRNG engine.
|
| 31 |
+
* For instance, it does so by letting users seed a PRNG engine, fork the state of the
|
| 32 |
+
* engine, etc.
|
| 33 |
+
*
|
| 34 |
+
* By default, there is one generator per device, and a device's generator is
|
| 35 |
+
* lazily created. A user can use the torch.Generator() api to create their own generator.
|
| 36 |
+
*/
|
| 37 |
+
|
| 38 |
+
/**
|
| 39 |
+
* Note [Acquire lock when using random generators]
|
| 40 |
+
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 41 |
+
* Generator and its derived classes are NOT thread-safe. Please note that most of the
|
| 42 |
+
* places where we have inserted locking for generators are historically based, and we
|
| 43 |
+
* haven't actually checked that everything is truly thread safe (and it probably isn't).
|
| 44 |
+
* Please use the public mutex_ when using any methods from these classes, except for the
|
| 45 |
+
* read-only methods. You can learn about the usage by looking into the unittests
|
| 46 |
+
* (aten/src/ATen/cpu_generator_test.cpp) and other places where we have used lock_guard.
|
| 47 |
+
*
|
| 48 |
+
* TODO: Look into changing the threading semantics of Generators in ATen (e.g., making
|
| 49 |
+
* them non-thread safe and instead making the generator state splittable, to accommodate
|
| 50 |
+
* forks into other threads).
|
| 51 |
+
*/
|
| 52 |
+
|
| 53 |
+
namespace at {
|
| 54 |
+
|
| 55 |
+
class Tensor;
|
| 56 |
+
|
| 57 |
+
struct TORCH_API Generator {
|
| 58 |
+
Generator() = default;
|
| 59 |
+
|
| 60 |
+
explicit Generator(c10::intrusive_ptr<c10::GeneratorImpl> gen_impl)
|
| 61 |
+
: impl_(std::move(gen_impl)) {
|
| 62 |
+
if (impl_.get() == nullptr) {
|
| 63 |
+
throw std::runtime_error("GeneratorImpl with nullptr is not supported");
|
| 64 |
+
}
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
bool operator==(const Generator& rhs) const {
|
| 68 |
+
return this->impl_ == rhs.impl_;
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
bool operator!=(const Generator& rhs) const {
|
| 72 |
+
return !((*this) == rhs);
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
bool defined() const {
|
| 76 |
+
return static_cast<bool>(impl_);
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
c10::GeneratorImpl* unsafeGetGeneratorImpl() const {
|
| 80 |
+
return impl_.get();
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
c10::GeneratorImpl* unsafeReleaseGeneratorImpl() {
|
| 84 |
+
return impl_.release();
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
const c10::intrusive_ptr<c10::GeneratorImpl>& getIntrusivePtr() const {
|
| 88 |
+
return impl_;
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
void set_current_seed(uint64_t seed) { impl_->set_current_seed(seed); }
|
| 92 |
+
// Sets the offset of Generator state to the desired offset. This is currently
|
| 93 |
+
// supported for only Philox based Generators, i.e., CUDA and MPS.
|
| 94 |
+
void set_offset(uint64_t offset) { impl_->set_offset(offset); }
|
| 95 |
+
|
| 96 |
+
// Returns the offset of Generator state. This is currently supported for only
|
| 97 |
+
// Philox based Generators, i.e., CUDA and MPS.
|
| 98 |
+
uint64_t get_offset() const { return impl_->get_offset(); }
|
| 99 |
+
|
| 100 |
+
uint64_t current_seed() const { return impl_->current_seed(); }
|
| 101 |
+
|
| 102 |
+
uint64_t seed() { return impl_->seed(); }
|
| 103 |
+
|
| 104 |
+
// Implementation not inlined to prevent cycle reference between
|
| 105 |
+
// `ATen/core/Generator.h` and `ATen/core/Tensor.h`
|
| 106 |
+
void set_state(const at::Tensor& new_state);
|
| 107 |
+
|
| 108 |
+
at::Tensor get_state() const;
|
| 109 |
+
|
| 110 |
+
void graphsafe_set_state(const Generator& new_state);
|
| 111 |
+
|
| 112 |
+
Generator graphsafe_get_state() const;
|
| 113 |
+
|
| 114 |
+
std::mutex& mutex() {
|
| 115 |
+
return impl_->mutex_;
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
DispatchKeySet key_set() const {
|
| 119 |
+
return impl_->key_set();
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
Device device() const { return impl_->device(); }
|
| 123 |
+
|
| 124 |
+
inline void set_pyobj(PyObject* pyobj) const noexcept {
|
| 125 |
+
impl_->set_pyobj(pyobj);
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
inline PyObject* pyobj() const noexcept {
|
| 129 |
+
return impl_->pyobj();
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
template<typename T>
|
| 133 |
+
T* get() const { return static_cast<T*>(impl_.get()); }
|
| 134 |
+
|
| 135 |
+
Generator clone() const {
|
| 136 |
+
return Generator(impl_->clone());
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
private:
|
| 140 |
+
c10::intrusive_ptr<c10::GeneratorImpl> impl_;
|
| 141 |
+
};
|
| 142 |
+
|
| 143 |
+
template<class Impl, class... Args>
|
| 144 |
+
Generator make_generator(Args&&... args) {
|
| 145 |
+
return Generator(c10::make_intrusive<Impl>(std::forward<Args>(args)...));
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
/**
|
| 149 |
+
* Utility function to static cast input Generator* to
|
| 150 |
+
* the backend generator type (CPU/CUDAGeneratorImpl etc.)
|
| 151 |
+
*/
|
| 152 |
+
template <typename T>
|
| 153 |
+
inline T * check_generator(std::optional<Generator> gen) {
|
| 154 |
+
TORCH_CHECK(gen.has_value(), "Expected Generator but received nullopt");
|
| 155 |
+
TORCH_CHECK(gen->defined(), "Generator with undefined implementation is not allowed");
|
| 156 |
+
TORCH_CHECK(T::device_type() == gen->device().type(), "Expected a '", T::device_type(), "' device type for generator but found '", gen->device().type(), "'");
|
| 157 |
+
return gen->get<T>();
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
/**
|
| 161 |
+
* Utility function used in tensor implementations, which
|
| 162 |
+
* supplies the default generator to tensors, if an input generator
|
| 163 |
+
* is not supplied. The input Generator* is also static casted to
|
| 164 |
+
* the backend generator type (CPU/CUDAGeneratorImpl etc.)
|
| 165 |
+
*/
|
| 166 |
+
template <typename T>
|
| 167 |
+
inline T* get_generator_or_default(const std::optional<Generator>& gen, const Generator& default_gen) {
|
| 168 |
+
return gen.has_value() && gen->defined() ? check_generator<T>(gen) : check_generator<T>(default_gen);
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
namespace detail {
|
| 172 |
+
|
| 173 |
+
/**
|
| 174 |
+
* Helper function for checking the validity of new random generator
|
| 175 |
+
* state. Right now following conditions are checked:
|
| 176 |
+
*
|
| 177 |
+
* - The new state tensor must be a torch.ByteTensor
|
| 178 |
+
* - Data of the new state tensor must be contiguous
|
| 179 |
+
*/
|
| 180 |
+
inline void check_rng_state(const c10::TensorImpl& new_state) {
|
| 181 |
+
TORCH_CHECK_TYPE(
|
| 182 |
+
new_state.layout() == kStrided && new_state.device().type() == kCPU && new_state.dtype() == kByte,
|
| 183 |
+
"RNG state must be a torch.ByteTensor"
|
| 184 |
+
);
|
| 185 |
+
|
| 186 |
+
TORCH_CHECK(new_state.is_contiguous(), "RNG state must be contiguous");
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
} // namespace detail
|
| 190 |
+
|
| 191 |
+
} // namespace at
|
phivenv/Lib/site-packages/torch/include/ATen/core/GeneratorForPrivateuseone.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Generator.h>
|
| 4 |
+
#include <c10/util/intrusive_ptr.h>
|
| 5 |
+
|
| 6 |
+
namespace at {
|
| 7 |
+
|
| 8 |
+
using GeneratorFuncType = std::function<at::Generator(c10::DeviceIndex)>;
|
| 9 |
+
|
| 10 |
+
TORCH_API std::optional<GeneratorFuncType>& GetGeneratorPrivate();
|
| 11 |
+
|
| 12 |
+
class TORCH_API _GeneratorRegister {
|
| 13 |
+
public:
|
| 14 |
+
explicit _GeneratorRegister(const GeneratorFuncType& func);
|
| 15 |
+
};
|
| 16 |
+
|
| 17 |
+
TORCH_API at::Generator GetGeneratorForPrivateuse1(
|
| 18 |
+
c10::DeviceIndex device_index);
|
| 19 |
+
|
| 20 |
+
/**
|
| 21 |
+
* This is used to register Generator to PyTorch for `privateuse1` key.
|
| 22 |
+
*
|
| 23 |
+
* Usage: REGISTER_GENERATOR_PRIVATEUSE1(MakeGeneratorForPrivateuse1)
|
| 24 |
+
*
|
| 25 |
+
* class CustomGeneratorImpl : public c10::GeneratorImpl {
|
| 26 |
+
* CustomGeneratorImpl(DeviceIndex device_index = -1);
|
| 27 |
+
* explicit ~CustomGeneratorImpl() override = default;
|
| 28 |
+
* ...
|
| 29 |
+
* };
|
| 30 |
+
*
|
| 31 |
+
* at::Generator MakeGeneratorForPrivateuse1(c10::DeviceIndex id) {
|
| 32 |
+
* return at::make_generator<CustomGeneratorImpl>(id);
|
| 33 |
+
* }
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#define REGISTER_GENERATOR_PRIVATEUSE1(GeneratorPrivate) \
|
| 37 |
+
static auto temp##GeneratorPrivate = at::_GeneratorRegister(GeneratorPrivate);
|
| 38 |
+
|
| 39 |
+
} // namespace at
|
phivenv/Lib/site-packages/torch/include/ATen/core/IListRef.h
ADDED
|
@@ -0,0 +1,631 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/ivalue_to.h>
|
| 4 |
+
#include <c10/util/ArrayRef.h>
|
| 5 |
+
#include <c10/util/Exception.h>
|
| 6 |
+
|
| 7 |
+
#include <functional>
|
| 8 |
+
#include <initializer_list>
|
| 9 |
+
#include <iterator>
|
| 10 |
+
#include <type_traits>
|
| 11 |
+
|
| 12 |
+
/*
|
| 13 |
+
* [Note: IListRef]
|
| 14 |
+
* Wrapper around different API containers (e.g. boxed and unboxed).
|
| 15 |
+
*
|
| 16 |
+
* What is it?
|
| 17 |
+
* ===========
|
| 18 |
+
* It is a tagged union of both boxed and unboxed API containers.
|
| 19 |
+
* Working implementations:
|
| 20 |
+
*
|
| 21 |
+
* - `IListRef<at::Tensor>`
|
| 22 |
+
* - `IListRef<at::OptionalTensorRef>`
|
| 23 |
+
*
|
| 24 |
+
* Note that `IListRef` is a view type. Meaning that it won't own the
|
| 25 |
+
* tensors it holds. It's intended to be used only as argument parameters.
|
| 26 |
+
* Specifically, where these 2 worlds overlap.
|
| 27 |
+
*
|
| 28 |
+
* What is this for?
|
| 29 |
+
* =================
|
| 30 |
+
* Historically, PyTorch has maintained 2 different APIs: the unboxed
|
| 31 |
+
* (called from C++ API and Python eager mode) and boxed APIs (called
|
| 32 |
+
* from the TorchScript JIT, mobile interpreter, and boxed fallbacks).
|
| 33 |
+
*
|
| 34 |
+
* Calling unboxed kernels from the boxed "world" and vice-versa may
|
| 35 |
+
* result in non-negligible overhead. Lists are one of those types:
|
| 36 |
+
*
|
| 37 |
+
* - Boxed world: `c10::List`
|
| 38 |
+
* - Unboxed world: `c10::ArrayRef`
|
| 39 |
+
*
|
| 40 |
+
* In this context, `c10::IListRef` solves this problem by wrapping those
|
| 41 |
+
* 2 container types, so that we don't need to convert from one to
|
| 42 |
+
* the other.
|
| 43 |
+
*
|
| 44 |
+
* (see https://github.com/pytorch/pytorch/issues/66328)
|
| 45 |
+
*
|
| 46 |
+
* What does it do?
|
| 47 |
+
* ================
|
| 48 |
+
* This container wraps around the different tagged containers
|
| 49 |
+
* (currently, only boxed and unboxed), without incurring in extra
|
| 50 |
+
* overhead for converting from one to another. It does so while
|
| 51 |
+
* exposing usual container methods, which dispatch to corresponding
|
| 52 |
+
* implementations.
|
| 53 |
+
*
|
| 54 |
+
* While it works with different container types, it introduces
|
| 55 |
+
* overhead for repeatedly calling member functions (since those will
|
| 56 |
+
* get dispatched, again). Therefore, you should only use it to iterate
|
| 57 |
+
* through the list up to one time. If you need to do more complex things,
|
| 58 |
+
* call `materialize()` first.
|
| 59 |
+
*
|
| 60 |
+
* Adding support for a new Tag
|
| 61 |
+
* ============================
|
| 62 |
+
* Suppose we want to add a new tag: `Chest`. Here are the steps
|
| 63 |
+
* we would have to go through:
|
| 64 |
+
*
|
| 65 |
+
* 1. Add a line for it in the macro `TORCH_ILISTREF_FORALL_TAGS`.
|
| 66 |
+
*
|
| 67 |
+
* #define TORCH_ILISTREF_FORALL_TAGS(_, ...) \
|
| 68 |
+
* ...
|
| 69 |
+
* _(Chest, ##__VA_ARGS__)
|
| 70 |
+
*
|
| 71 |
+
* 2. Add type aliases, union members, and constructors.
|
| 72 |
+
*
|
| 73 |
+
* template <typename T>
|
| 74 |
+
* class IListRef {
|
| 75 |
+
* ...
|
| 76 |
+
* using chest_type =
|
| 77 |
+
* typename detail::IListRefTagImpl<T, IListRefTag::Chest>::list_type;
|
| 78 |
+
* ...
|
| 79 |
+
* IListRef(...) : tag_(IListRefTag::Chest) {
|
| 80 |
+
* ...
|
| 81 |
+
* }
|
| 82 |
+
* ...
|
| 83 |
+
* union Payload {
|
| 84 |
+
* ...
|
| 85 |
+
* chest_type chest;
|
| 86 |
+
* ...
|
| 87 |
+
* };
|
| 88 |
+
* ...
|
| 89 |
+
* };
|
| 90 |
+
*
|
| 91 |
+
* 3. Add a default implementation for it (in 'IListRef_inl.h'). It's
|
| 92 |
+
* preferable to make the default implementation work for `T = Tensor`
|
| 93 |
+
* (both `Unboxed` and `Boxed` do it).
|
| 94 |
+
*
|
| 95 |
+
* template <typename T, typename ListElemT>
|
| 96 |
+
* class IListRefTagImplBase<IListRefTag::Chest, T, ListElemT> {
|
| 97 |
+
* public:
|
| 98 |
+
* using elem_type = ListElemT;
|
| 99 |
+
* using list_type = ChestContainer<elem_type>;
|
| 100 |
+
*
|
| 101 |
+
* static const list_type& unwrap(const IListRef<T>& ilist) { ... }
|
| 102 |
+
*
|
| 103 |
+
* static typename list_type::const_iterator& unwrap(
|
| 104 |
+
* IListRefIterator<T>& it) { ... }
|
| 105 |
+
*
|
| 106 |
+
* static const typename list_type::const_iterator& unwrap(
|
| 107 |
+
* const IListRefIterator<T>& it) { ... }
|
| 108 |
+
*
|
| 109 |
+
* static IListRefConstRef<T> iterator_get(
|
| 110 |
+
* const typename list_type::const_iterator& it) { ... }
|
| 111 |
+
* }
|
| 112 |
+
*
|
| 113 |
+
* 4. Add an specialization for each of the already supported types.
|
| 114 |
+
* Finally, for consistency, add them to the tracking list.
|
| 115 |
+
* (see [Note: IListRefTagImpl Specializations])
|
| 116 |
+
*
|
| 117 |
+
* template <>
|
| 118 |
+
* class IListRefTagImpl<IListRefTag::Chest, at::Tensor>
|
| 119 |
+
* : public IListRefTagImplBase<IListRefTag::Chest, at::Tensor> {};
|
| 120 |
+
*
|
| 121 |
+
* Adding support for a new Type
|
| 122 |
+
* =============================
|
| 123 |
+
* Suppose we want to add support for a new type: `Matrix`.
|
| 124 |
+
* Here are the steps we would have to go through:
|
| 125 |
+
*
|
| 126 |
+
* 1. Add an specialization for each of the existing tags.
|
| 127 |
+
* For consistency, add them to the tracking list.
|
| 128 |
+
* (see [Note: IListRefTagImpl Specializations])
|
| 129 |
+
*
|
| 130 |
+
* template <>
|
| 131 |
+
* class IListRefTagImpl<IListRefTag::Unboxed, Matrix>
|
| 132 |
+
* : public IListRefTagImplBase<IListRefTag::Unboxed, Matrix> {};
|
| 133 |
+
*
|
| 134 |
+
* template <>
|
| 135 |
+
* class IListRefTagImpl<Matrix, IListRefTag::Boxed>
|
| 136 |
+
* : public IListRefTagImplBase<IListRefTag::Boxed, Matrix> {};
|
| 137 |
+
*
|
| 138 |
+
* Common Problems
|
| 139 |
+
* ===============
|
| 140 |
+
* 1. One of `IListRef(Iterator)` methods are failing to compile.
|
| 141 |
+
*
|
| 142 |
+
* That may be happening because the container type you added
|
| 143 |
+
* is not compatible with the code written for that method. If
|
| 144 |
+
* that's true, then you might have to transform that code into
|
| 145 |
+
* a static method call (see `List::operator[]` method).
|
| 146 |
+
*
|
| 147 |
+
* 2. Can't make `IListRefIterator<T>::operator*` return a const-reference.
|
| 148 |
+
*
|
| 149 |
+
* First, keep in mind that we assume that boxed containers will
|
| 150 |
+
* have to deal with `IValue` (e.g. `c10::List`). In this context,
|
| 151 |
+
* what may be happening is that `IValue` doesn't store internally
|
| 152 |
+
* your type `T`. Instead, it constructs a type new `T` everytime
|
| 153 |
+
* you try to get `T` for it (see `IListRef<at::OptinalTensorRef>`).
|
| 154 |
+
*/
|
| 155 |
+
|
| 156 |
+
namespace c10 {
|
| 157 |
+
template <typename T>
|
| 158 |
+
class IListRef;
|
| 159 |
+
|
| 160 |
+
/*
|
| 161 |
+
* Applies arbitrary macros to each `IListRefTag`.
|
| 162 |
+
*/
|
| 163 |
+
#define TORCH_ILISTREF_FORALL_TAGS(_, ...) \
|
| 164 |
+
_(Unboxed, ##__VA_ARGS__) \
|
| 165 |
+
_(Boxed, ##__VA_ARGS__) \
|
| 166 |
+
_(Materialized, ##__VA_ARGS__)
|
| 167 |
+
|
| 168 |
+
/*
|
| 169 |
+
* Defines a "switch-case" for `TAG`. Inside, it executes `BODY`,
|
| 170 |
+
* while bringing to scope:
|
| 171 |
+
*
|
| 172 |
+
* - `ImplT`: the implementation class for `TAG`
|
| 173 |
+
* - `this_`: the result of unwrapping `this`
|
| 174 |
+
*/
|
| 175 |
+
#define TORCH_ILISTREF_UNWRAP_CASE(TAG, BODY) \
|
| 176 |
+
case c10::IListRefTag::TAG: { \
|
| 177 |
+
using ImplT = c10::detail::IListRefTagImpl<IListRefTag::TAG, T>; \
|
| 178 |
+
auto& this_ = ImplT::unwrap(*this); \
|
| 179 |
+
BODY \
|
| 180 |
+
} break;
|
| 181 |
+
|
| 182 |
+
/*
|
| 183 |
+
* Dispatches the unwrap call, depending on `TAG`, followed by
|
| 184 |
+
* the execution of `BODY`. It aborts if `TAG` is not a `IListRefTag`.
|
| 185 |
+
*
|
| 186 |
+
* This macro is useful because it allows us to handle different
|
| 187 |
+
* types (that correspond to different tags) to be implemented
|
| 188 |
+
* only once. We can do it even when the implementation of the
|
| 189 |
+
* different tags aren't syntatically the same, by dispatching
|
| 190 |
+
* it to a function (e.g. `ImplT::<dispatch-function>(this_)`).
|
| 191 |
+
*/
|
| 192 |
+
#define TORCH_ILISTREF_UNWRAP(TAG, BODY) \
|
| 193 |
+
switch (TAG) { \
|
| 194 |
+
TORCH_ILISTREF_FORALL_TAGS(TORCH_ILISTREF_UNWRAP_CASE, BODY) \
|
| 195 |
+
break; \
|
| 196 |
+
default: \
|
| 197 |
+
TORCH_INTERNAL_ASSERT(false, "invalid IListRef tag."); \
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
enum class IListRefTag {
|
| 201 |
+
#define DEFINE_TAG(tag, ...) tag,
|
| 202 |
+
TORCH_ILISTREF_FORALL_TAGS(DEFINE_TAG)
|
| 203 |
+
#undef DEFINE_TAG
|
| 204 |
+
None
|
| 205 |
+
};
|
| 206 |
+
|
| 207 |
+
namespace detail {
|
| 208 |
+
/*
|
| 209 |
+
* Type alias that specifies whether we return a reference or a copy of `T`.
|
| 210 |
+
*
|
| 211 |
+
* What is this for?
|
| 212 |
+
* =================
|
| 213 |
+
* Since values in the boxed world are represented by an `IValue`, we also
|
| 214 |
+
* depend on whether it can be converted to a const-reference (`Tensor`) or
|
| 215 |
+
* has to create a new copy of `T` (`OptionalTensorRef`).
|
| 216 |
+
*/
|
| 217 |
+
template <typename T>
|
| 218 |
+
using IListRefConstRef = typename ivalue_to_const_ref_overload_return<T>::type;
|
| 219 |
+
|
| 220 |
+
/*
|
| 221 |
+
* Interface that implements key functions for each `IListRefTag` type.
|
| 222 |
+
*
|
| 223 |
+
* What is this for?
|
| 224 |
+
* =================
|
| 225 |
+
* Given an `IListRef(Iterator)<T>`, some methods have to be implemented
|
| 226 |
+
* differently for each `TAG`. Therefore, the methods inside this class
|
| 227 |
+
* are used as dispatch targets for the different `IListRefTag` values.
|
| 228 |
+
*
|
| 229 |
+
* You should create an specialization of this class for each possible
|
| 230 |
+
* combination of `IListRefTag` type (except `None`) and element types
|
| 231 |
+
* (e.g. `Tensor`).
|
| 232 |
+
*
|
| 233 |
+
* What does it do?
|
| 234 |
+
* ================
|
| 235 |
+
* 1. defines static methods to be used as dispatch targets by both
|
| 236 |
+
* `IListRef<T>` and `IListRefIterator<T>` (see the implementation of
|
| 237 |
+
* `IListRefTagImplBase`).
|
| 238 |
+
*
|
| 239 |
+
* 2. defines the `elem_type` and `list_type` aliases that will be
|
| 240 |
+
* used in the definition of `IListRef<T>`. In general, we should do
|
| 241 |
+
* so by inheriting from `IListRefTagImplBase<TAG, T, ListElemT>`.
|
| 242 |
+
*
|
| 243 |
+
* [Note: IListRefTagImpl Specialization]
|
| 244 |
+
* ======================================
|
| 245 |
+
* For `IListRef(Iterator)<at::Tensor>`:
|
| 246 |
+
* - <IListRefTag::Unboxed, at::Tensor>
|
| 247 |
+
* - <IListRefTag::Boxed, at::Tensor>
|
| 248 |
+
* - <IListRefTag::Materialized, at::Tensor>
|
| 249 |
+
*
|
| 250 |
+
* For `IListRef(Iterator)<at::OptionalTensorRef>`:
|
| 251 |
+
* - <IListRefTag::Unboxed, at::OptionalTensorRef>
|
| 252 |
+
* - <IListRefTag::Boxed, at::OptionalTensorRef>
|
| 253 |
+
* - <IListRefTag::Materialized, at::OptionalTensorRef>
|
| 254 |
+
*/
|
| 255 |
+
template <IListRefTag TAG, typename T>
|
| 256 |
+
class IListRefTagImpl {};
|
| 257 |
+
|
| 258 |
+
/*
|
| 259 |
+
* Base implementation of `IListRefTagImpl<TAG, T>` methods.
|
| 260 |
+
*
|
| 261 |
+
* What is this for?
|
| 262 |
+
* =================
|
| 263 |
+
* This should make adding specializations for new types easier. For
|
| 264 |
+
* example, one should be able to add a new type just by making its
|
| 265 |
+
* `IListRefTagImpl` specialization inherit from `IListRefTagImplBase`.
|
| 266 |
+
*
|
| 267 |
+
* You should create a partial specialization for this class only if
|
| 268 |
+
* you introduce a new `IListRefTag`. The idea being that there is one
|
| 269 |
+
* default implementation for each possible value of `IListRefTag`.
|
| 270 |
+
*
|
| 271 |
+
* What does it do?
|
| 272 |
+
* ================
|
| 273 |
+
* 1. defines `elem_type` as an alias to `ListElemT`.
|
| 274 |
+
*
|
| 275 |
+
* 1. defines `list_type` as an alias to the default container type
|
| 276 |
+
* that will hold a collection of `elem_type`. The idea being that
|
| 277 |
+
* all types tagged as `TAG` will have `list_type` as its container,
|
| 278 |
+
* with different `elem_type`.
|
| 279 |
+
*
|
| 280 |
+
* 3. defines the default implementation for each of the methods that
|
| 281 |
+
* are supposed to be defined on `IListRefTagImpl` specializations.
|
| 282 |
+
*
|
| 283 |
+
* 4. inheriting from `IListRefTagImplBase<TAG, T, ListElemT>` also means
|
| 284 |
+
* that the payload of the type `IListRef<T>` will be of type `list_type`
|
| 285 |
+
* when it is tagged as `TAG`.
|
| 286 |
+
*/
|
| 287 |
+
template <IListRefTag TAG, typename T, typename ListElemT = T>
|
| 288 |
+
class IListRefTagImplBase {};
|
| 289 |
+
|
| 290 |
+
/*
|
| 291 |
+
* Materialized container for `IListRef<T>`.
|
| 292 |
+
*
|
| 293 |
+
* What is this for?
|
| 294 |
+
* =================
|
| 295 |
+
* Container that groups `T` references together. This exchanges the
|
| 296 |
+
* overhead of every method call from `IListRef<T>` for a dynamic allocation.
|
| 297 |
+
*
|
| 298 |
+
* You should use this container instead of `IListRef<T>` if:
|
| 299 |
+
*
|
| 300 |
+
* - You are going to iterate the list more than once
|
| 301 |
+
* - You need to repeatedly access arbitrary elements (using `operator[]`)
|
| 302 |
+
* What does it do?
|
| 303 |
+
|
| 304 |
+
* ================
|
| 305 |
+
* Removes the reference (&) from the type, and wraps it into a
|
| 306 |
+
* `std::reference_wrapper`. If `IListRefConstRef<T>` is not a
|
| 307 |
+
* reference type, then it's left unchanged.
|
| 308 |
+
*/
|
| 309 |
+
template <typename T>
|
| 310 |
+
using _MaterializedIListRefElem = std::conditional_t<
|
| 311 |
+
std::is_reference_v<T>,
|
| 312 |
+
typename std::reference_wrapper<std::remove_reference_t<T>>,
|
| 313 |
+
T>;
|
| 314 |
+
|
| 315 |
+
template <typename T>
|
| 316 |
+
using MaterializedIListRefElem = _MaterializedIListRefElem<IListRefConstRef<T>>;
|
| 317 |
+
|
| 318 |
+
template <typename T>
|
| 319 |
+
using MaterializedIListRef = std::vector<MaterializedIListRefElem<T>>;
|
| 320 |
+
|
| 321 |
+
} // namespace detail
|
| 322 |
+
|
| 323 |
+
/*
|
| 324 |
+
* Iterator for `IListRef<T>`.
|
| 325 |
+
*
|
| 326 |
+
* What is it?
|
| 327 |
+
* ===========
|
| 328 |
+
* Currently, a `std::bidirectional_iterator` that wraps the iterator
|
| 329 |
+
* types defined for each of the `IListRefTag`.
|
| 330 |
+
*
|
| 331 |
+
* One should be able to use it, as if it were the unwrapped
|
| 332 |
+
* iterators themselves.
|
| 333 |
+
|
| 334 |
+
* What does it do?
|
| 335 |
+
* ================
|
| 336 |
+
* Similarly to `IListRef<T>`, this is a wrapper class. Specifically, it
|
| 337 |
+
* wraps each container's `const_iterator` type alias. So, for example,
|
| 338 |
+
* given that the container for `IListRefTag::Boxed` is `c10::List`, this
|
| 339 |
+
* iterator will wrap a `c10::List::const_iterator`.
|
| 340 |
+
*
|
| 341 |
+
* [Note: MSVC Iterator Debug]
|
| 342 |
+
* ===========================
|
| 343 |
+
* MSVC `vector<T>::iterator` implementation (used in the boxed variant)
|
| 344 |
+
* makes it so this union's destructor, copy-constructor (assignment), and
|
| 345 |
+
* move-constructor (assignment) are implicitly deleted.
|
| 346 |
+
*
|
| 347 |
+
* Therefore, we need to explicitly define them as needed. Follows a list
|
| 348 |
+
* of places where these are needed and their reason:
|
| 349 |
+
*
|
| 350 |
+
* - `Payload` destructor:
|
| 351 |
+
* it is deleted only if the macro `_ITERATOR_DEBUG_LEVEL` is set to 2.
|
| 352 |
+
*
|
| 353 |
+
* - `IListRefIterator` destructor:
|
| 354 |
+
* same as above. However, we need to explicitly call the variant
|
| 355 |
+
* destructor explicitly.
|
| 356 |
+
*
|
| 357 |
+
* - `IListRefIterator` copy-constructor:
|
| 358 |
+
* it is deleted only if the macro `_ITERATOR_DEBUG_LEVEL` is different
|
| 359 |
+
* than 0.
|
| 360 |
+
*/
|
| 361 |
+
template <typename T>
|
| 362 |
+
class IListRefIterator {
|
| 363 |
+
private:
|
| 364 |
+
#define DEFINE_FRIEND_CLASS(TAG, ...) \
|
| 365 |
+
friend class detail::IListRefTagImpl<IListRefTag::TAG, T>; \
|
| 366 |
+
friend class detail::IListRefTagImplBase< \
|
| 367 |
+
IListRefTag::TAG, \
|
| 368 |
+
T, \
|
| 369 |
+
typename detail::IListRefTagImpl<IListRefTag::TAG, T>::elem_type>;
|
| 370 |
+
TORCH_ILISTREF_FORALL_TAGS(DEFINE_FRIEND_CLASS)
|
| 371 |
+
#undef DEFINE_FRIEND_CLASS
|
| 372 |
+
|
| 373 |
+
public:
|
| 374 |
+
// C++17 friendly std::iterator implementation
|
| 375 |
+
using iterator_category = std::bidirectional_iterator_tag;
|
| 376 |
+
using value_type = T;
|
| 377 |
+
using difference_type = std::ptrdiff_t;
|
| 378 |
+
using pointer = T*;
|
| 379 |
+
using reference = T&;
|
| 380 |
+
|
| 381 |
+
using unboxed_iterator_type = typename detail::
|
| 382 |
+
IListRefTagImpl<IListRefTag::Unboxed, T>::list_type::const_iterator;
|
| 383 |
+
using boxed_iterator_type = typename detail::
|
| 384 |
+
IListRefTagImpl<IListRefTag::Boxed, T>::list_type::const_iterator;
|
| 385 |
+
using materialized_iterator_type =
|
| 386 |
+
typename detail::MaterializedIListRef<T>::const_iterator;
|
| 387 |
+
|
| 388 |
+
IListRefIterator() : tag_(IListRefTag::None) {}
|
| 389 |
+
|
| 390 |
+
#if defined(_MSC_VER) && _ITERATOR_DEBUG_LEVEL != 0
|
| 391 |
+
// See [Note: MSVC Iterator Debug]
|
| 392 |
+
IListRefIterator(const IListRefIterator& iterator)
|
| 393 |
+
: tag_(iterator.tag_) {
|
| 394 |
+
switch (tag_) {
|
| 395 |
+
case IListRefTag::Boxed:
|
| 396 |
+
payload_.boxed_iterator = iterator.payload_.boxed_iterator;
|
| 397 |
+
break;
|
| 398 |
+
case IListRefTag::Unboxed:
|
| 399 |
+
payload_.unboxed_iterator = iterator.payload_.unboxed_iterator;
|
| 400 |
+
break;
|
| 401 |
+
case IListRefTag::Materialized:
|
| 402 |
+
payload_.materialized_iterator = iterator.payload_.materialized_iterator;
|
| 403 |
+
break;
|
| 404 |
+
default:
|
| 405 |
+
TORCH_INTERNAL_ASSERT(false, "invalid IListRef tag.");
|
| 406 |
+
}
|
| 407 |
+
}
|
| 408 |
+
#endif
|
| 409 |
+
|
| 410 |
+
#if defined(_MSC_VER) && _ITERATOR_DEBUG_LEVEL == 2
|
| 411 |
+
// See [Note: MSVC Iterator Debug]
|
| 412 |
+
~IListRefIterator() noexcept(false) {
|
| 413 |
+
switch (tag_) {
|
| 414 |
+
case IListRefTag::Boxed:
|
| 415 |
+
payload_.boxed_iterator.~boxed_iterator_type();
|
| 416 |
+
break;
|
| 417 |
+
case IListRefTag::Unboxed:
|
| 418 |
+
payload_.unboxed_iterator.~unboxed_iterator_type();
|
| 419 |
+
break;
|
| 420 |
+
case IListRefTag::Materialized:
|
| 421 |
+
payload_.materialized_iterator.~materialized_iterator_type();
|
| 422 |
+
break;
|
| 423 |
+
default:
|
| 424 |
+
TORCH_INTERNAL_ASSERT(false, "invalid IListRef tag.");
|
| 425 |
+
}
|
| 426 |
+
}
|
| 427 |
+
#endif
|
| 428 |
+
|
| 429 |
+
IListRefIterator(boxed_iterator_type boxed) : tag_(IListRefTag::Boxed) {
|
| 430 |
+
payload_.boxed_iterator = boxed;
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
IListRefIterator(unboxed_iterator_type unboxed) : tag_(IListRefTag::Unboxed) {
|
| 434 |
+
payload_.unboxed_iterator = unboxed;
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
+
IListRefIterator(materialized_iterator_type materialized) : tag_(IListRefTag::Materialized) {
|
| 438 |
+
payload_.materialized_iterator = materialized;
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
detail::IListRefConstRef<T> operator*() const {
|
| 442 |
+
TORCH_ILISTREF_UNWRAP(tag_, { return ImplT::iterator_get(this_); });
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
IListRefIterator& operator++() {
|
| 446 |
+
TORCH_ILISTREF_UNWRAP(tag_, { ++this_; });
|
| 447 |
+
return *this;
|
| 448 |
+
}
|
| 449 |
+
|
| 450 |
+
IListRefIterator operator++(int) {
|
| 451 |
+
auto old = *this;
|
| 452 |
+
TORCH_ILISTREF_UNWRAP(tag_, { ++this_; });
|
| 453 |
+
return old;
|
| 454 |
+
}
|
| 455 |
+
|
| 456 |
+
IListRefIterator& operator--() {
|
| 457 |
+
TORCH_ILISTREF_UNWRAP(tag_, { --this_; });
|
| 458 |
+
return *this;
|
| 459 |
+
}
|
| 460 |
+
|
| 461 |
+
IListRefIterator operator--(int) {
|
| 462 |
+
auto old = *this;
|
| 463 |
+
TORCH_ILISTREF_UNWRAP(tag_, { --this_; });
|
| 464 |
+
return old;
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
bool operator==(const IListRefIterator& rhs) const {
|
| 468 |
+
if (tag_ != rhs.tag_) {
|
| 469 |
+
return false;
|
| 470 |
+
}
|
| 471 |
+
TORCH_ILISTREF_UNWRAP(tag_, {
|
| 472 |
+
auto& rhs_it = ImplT::unwrap(rhs);
|
| 473 |
+
return this_ == rhs_it;
|
| 474 |
+
});
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
bool operator!=(const IListRefIterator& rhs) const {
|
| 478 |
+
return !(*this == rhs);
|
| 479 |
+
}
|
| 480 |
+
|
| 481 |
+
private:
|
| 482 |
+
union Payload {
|
| 483 |
+
boxed_iterator_type boxed_iterator;
|
| 484 |
+
unboxed_iterator_type unboxed_iterator;
|
| 485 |
+
materialized_iterator_type materialized_iterator;
|
| 486 |
+
void* _init_ptr;
|
| 487 |
+
Payload() : _init_ptr(nullptr) {}
|
| 488 |
+
#if defined(_MSC_VER)
|
| 489 |
+
// See [Note: MSVC Iterator Debug]
|
| 490 |
+
~Payload() {}
|
| 491 |
+
#endif
|
| 492 |
+
};
|
| 493 |
+
|
| 494 |
+
Payload payload_;
|
| 495 |
+
IListRefTag tag_;
|
| 496 |
+
};
|
| 497 |
+
|
| 498 |
+
/*
|
| 499 |
+
* See [Note: IListRef]
|
| 500 |
+
*/
|
| 501 |
+
template <typename T>
|
| 502 |
+
class IListRef {
|
| 503 |
+
private:
|
| 504 |
+
#define DEFINE_FRIEND_CLASS(TAG, ...) \
|
| 505 |
+
friend class detail::IListRefTagImpl<IListRefTag::TAG, T>; \
|
| 506 |
+
friend class detail::IListRefTagImplBase< \
|
| 507 |
+
IListRefTag::TAG, \
|
| 508 |
+
T, \
|
| 509 |
+
typename detail::IListRefTagImpl<IListRefTag::TAG, T>::elem_type>;
|
| 510 |
+
TORCH_ILISTREF_FORALL_TAGS(DEFINE_FRIEND_CLASS)
|
| 511 |
+
#undef DEFINE_FRIEND_CLASS
|
| 512 |
+
|
| 513 |
+
public:
|
| 514 |
+
using unboxed_type =
|
| 515 |
+
typename detail::IListRefTagImpl<IListRefTag::Unboxed, T>::list_type;
|
| 516 |
+
using boxed_type =
|
| 517 |
+
typename detail::IListRefTagImpl<IListRefTag::Boxed, T>::list_type;
|
| 518 |
+
using materialized_type =
|
| 519 |
+
typename detail::MaterializedIListRef<T>;
|
| 520 |
+
|
| 521 |
+
using iterator = IListRefIterator<T>;
|
| 522 |
+
using const_iterator = IListRefIterator<T>;
|
| 523 |
+
using reverse_iterator = std::reverse_iterator<iterator>;
|
| 524 |
+
using value_type = typename iterator::value_type;
|
| 525 |
+
|
| 526 |
+
IListRef() : tag_(IListRefTag::None) {}
|
| 527 |
+
|
| 528 |
+
IListRef(const boxed_type& boxed) : tag_(IListRefTag::Boxed) {
|
| 529 |
+
payload_.boxed = &boxed;
|
| 530 |
+
}
|
| 531 |
+
|
| 532 |
+
IListRef(const unboxed_type& unboxed) : tag_(IListRefTag::Unboxed) {
|
| 533 |
+
payload_.unboxed = unboxed;
|
| 534 |
+
}
|
| 535 |
+
|
| 536 |
+
IListRef(const std::initializer_list<T>& list) : tag_(IListRefTag::Unboxed) {
|
| 537 |
+
payload_.unboxed = at::ArrayRef<T>(list);
|
| 538 |
+
}
|
| 539 |
+
|
| 540 |
+
template <
|
| 541 |
+
typename... UnboxedConstructorArgs,
|
| 542 |
+
typename = std::enable_if_t<
|
| 543 |
+
std::is_constructible_v<unboxed_type, UnboxedConstructorArgs...>>>
|
| 544 |
+
IListRef(UnboxedConstructorArgs&&... args) : tag_(IListRefTag::Unboxed) {
|
| 545 |
+
payload_.unboxed = unboxed_type(std::forward<UnboxedConstructorArgs>(args)...);
|
| 546 |
+
}
|
| 547 |
+
|
| 548 |
+
IListRef(const materialized_type& materialized) : tag_(IListRefTag::Materialized) {
|
| 549 |
+
payload_.materialized = &materialized;
|
| 550 |
+
}
|
| 551 |
+
|
| 552 |
+
size_t size() const {
|
| 553 |
+
TORCH_ILISTREF_UNWRAP(tag_, { return this_.size(); });
|
| 554 |
+
}
|
| 555 |
+
|
| 556 |
+
bool empty() const {
|
| 557 |
+
return size() == 0;
|
| 558 |
+
}
|
| 559 |
+
|
| 560 |
+
iterator begin() const {
|
| 561 |
+
TORCH_ILISTREF_UNWRAP(tag_, { return this_.begin(); });
|
| 562 |
+
}
|
| 563 |
+
|
| 564 |
+
iterator end() const {
|
| 565 |
+
TORCH_ILISTREF_UNWRAP(tag_, { return this_.end(); });
|
| 566 |
+
}
|
| 567 |
+
|
| 568 |
+
detail::IListRefConstRef<T> front() const {
|
| 569 |
+
TORCH_ILISTREF_UNWRAP(tag_, { return ImplT::front(this_); });
|
| 570 |
+
}
|
| 571 |
+
|
| 572 |
+
/*
|
| 573 |
+
* Materializes the `IListRef` into a `std::vector`.
|
| 574 |
+
*
|
| 575 |
+
* This should be used when one wishes to either:
|
| 576 |
+
*
|
| 577 |
+
* - iterate over the list more than once: each `IListRefIterator`
|
| 578 |
+
* member function call has to go through a switch, introducing
|
| 579 |
+
* non-negligible overhead
|
| 580 |
+
*
|
| 581 |
+
* - randomly access an arbitrary element using `operator[]`:
|
| 582 |
+
* same reason as above
|
| 583 |
+
*/
|
| 584 |
+
detail::MaterializedIListRef<T> materialize() const {
|
| 585 |
+
if (isMaterialized()) {
|
| 586 |
+
return toMaterialized();
|
| 587 |
+
}
|
| 588 |
+
|
| 589 |
+
detail::MaterializedIListRef<T> materialized;
|
| 590 |
+
materialized.reserve(size());
|
| 591 |
+
for (const auto& t : *this) {
|
| 592 |
+
materialized.emplace_back(t);
|
| 593 |
+
}
|
| 594 |
+
return materialized;
|
| 595 |
+
}
|
| 596 |
+
|
| 597 |
+
#define DEFINE_CHECK(TAG, ...) \
|
| 598 |
+
bool is##TAG() const { \
|
| 599 |
+
return tag_ == IListRefTag::TAG; \
|
| 600 |
+
}
|
| 601 |
+
TORCH_ILISTREF_FORALL_TAGS(DEFINE_CHECK)
|
| 602 |
+
#undef DEFINE_CHECK
|
| 603 |
+
|
| 604 |
+
bool isNone() const {
|
| 605 |
+
return tag_ == IListRefTag::None;
|
| 606 |
+
}
|
| 607 |
+
|
| 608 |
+
#define DEFINE_CASTING(TAG, ...) \
|
| 609 |
+
const typename detail::IListRefTagImpl<IListRefTag::TAG, T>::list_type& \
|
| 610 |
+
to##TAG() const { \
|
| 611 |
+
TORCH_INTERNAL_ASSERT(is##TAG()); \
|
| 612 |
+
return detail::IListRefTagImpl<IListRefTag::TAG, T>::unwrap(*this); \
|
| 613 |
+
}
|
| 614 |
+
TORCH_ILISTREF_FORALL_TAGS(DEFINE_CASTING)
|
| 615 |
+
#undef DEFINE_CASTING
|
| 616 |
+
|
| 617 |
+
private:
|
| 618 |
+
union Payload {
|
| 619 |
+
const boxed_type* boxed;
|
| 620 |
+
unboxed_type unboxed;
|
| 621 |
+
const materialized_type* materialized;
|
| 622 |
+
Payload() : boxed(nullptr) {}
|
| 623 |
+
};
|
| 624 |
+
|
| 625 |
+
Payload payload_;
|
| 626 |
+
IListRefTag tag_;
|
| 627 |
+
};
|
| 628 |
+
|
| 629 |
+
} // namespace c10
|
| 630 |
+
|
| 631 |
+
#include <ATen/core/IListRef_inl.h>
|
phivenv/Lib/site-packages/torch/include/ATen/core/IListRef_inl.h
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/List.h>
|
| 4 |
+
#include <ATen/core/Tensor.h>
|
| 5 |
+
|
| 6 |
+
namespace at {
|
| 7 |
+
class Tensor;
|
| 8 |
+
class OptionalTensorRef;
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
namespace c10::detail {
|
| 13 |
+
|
| 14 |
+
/*
|
| 15 |
+
* Specializations of `IListRefTagImplBase` that implement the default
|
| 16 |
+
* implementation for `IListRefTag::Unboxed`.
|
| 17 |
+
*/
|
| 18 |
+
template <typename T, typename ListElemT>
|
| 19 |
+
class IListRefTagImplBase<IListRefTag::Unboxed, T, ListElemT> {
|
| 20 |
+
public:
|
| 21 |
+
using elem_type = ListElemT;
|
| 22 |
+
using list_type = ArrayRef<elem_type>;
|
| 23 |
+
|
| 24 |
+
/*
|
| 25 |
+
* These `unwrap` static methods unwraps the inner containers out
|
| 26 |
+
* of `IListRef<T>` (and `IListRefIterator<T>`). They are required when
|
| 27 |
+
* the macro `TORCH_ILISTREF_UNWRAP` is called.
|
| 28 |
+
*/
|
| 29 |
+
static const list_type& unwrap(const IListRef<T>& ilist) {
|
| 30 |
+
return ilist.payload_.unboxed;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
static typename list_type::const_iterator& unwrap(IListRefIterator<T>& it) {
|
| 34 |
+
return it.payload_.unboxed_iterator;
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
static const typename list_type::const_iterator& unwrap(
|
| 38 |
+
const IListRefIterator<T>& it) {
|
| 39 |
+
return it.payload_.unboxed_iterator;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
/*
|
| 43 |
+
* We have these function (besides the `unwrap`s above) because the
|
| 44 |
+
* implementation for both `IListRef::operator[]` and `IListRefIterator::operator*`
|
| 45 |
+
* weren't syntatically equal for the existing tags at the time
|
| 46 |
+
* (`Unboxed` and `Boxed`).
|
| 47 |
+
*/
|
| 48 |
+
static IListRefConstRef<T> front(const list_type& lst) {
|
| 49 |
+
return lst.front();
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
static IListRefConstRef<T> iterator_get(
|
| 53 |
+
const typename list_type::const_iterator& it) {
|
| 54 |
+
return *it;
|
| 55 |
+
}
|
| 56 |
+
};
|
| 57 |
+
|
| 58 |
+
/*
|
| 59 |
+
* Specializations of `IListRefTagImplBase` that implement the default
|
| 60 |
+
* implementation for `IListRefTag::Boxed`.
|
| 61 |
+
*/
|
| 62 |
+
template <typename T, typename ListElemT>
|
| 63 |
+
class IListRefTagImplBase<IListRefTag::Boxed, T, ListElemT> {
|
| 64 |
+
public:
|
| 65 |
+
using elem_type = ListElemT;
|
| 66 |
+
using list_type = List<elem_type>;
|
| 67 |
+
|
| 68 |
+
static const list_type& unwrap(const IListRef<T>& ilist) {
|
| 69 |
+
return *ilist.payload_.boxed;
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
static typename list_type::const_iterator& unwrap(IListRefIterator<T>& it) {
|
| 73 |
+
return it.payload_.boxed_iterator;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
static const typename list_type::const_iterator& unwrap(
|
| 77 |
+
const IListRefIterator<T>& it) {
|
| 78 |
+
return it.payload_.boxed_iterator;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
static IListRefConstRef<T> front(const list_type& lst) {
|
| 82 |
+
return lst[0];
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
static IListRefConstRef<T> iterator_get(
|
| 86 |
+
const typename list_type::const_iterator& it) {
|
| 87 |
+
return (*it).get().toTensor();
|
| 88 |
+
}
|
| 89 |
+
};
|
| 90 |
+
|
| 91 |
+
/*
|
| 92 |
+
* Specializations of `IListRefTagImplBase` that implement the default
|
| 93 |
+
* implementation for `IListRefTag::Materialized`.
|
| 94 |
+
*/
|
| 95 |
+
template <typename T>
|
| 96 |
+
class IListRefTagImplBase<IListRefTag::Materialized, T, MaterializedIListRefElem<T>> {
|
| 97 |
+
public:
|
| 98 |
+
using elem_type = MaterializedIListRefElem<T>;
|
| 99 |
+
using list_type = MaterializedIListRef<T>;
|
| 100 |
+
|
| 101 |
+
static const list_type& unwrap(const IListRef<T>& ilist) {
|
| 102 |
+
return *ilist.payload_.materialized;
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
static typename list_type::const_iterator& unwrap(IListRefIterator<T>& it) {
|
| 106 |
+
return it.payload_.materialized_iterator;
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
static const typename list_type::const_iterator& unwrap(
|
| 110 |
+
const IListRefIterator<T>& it) {
|
| 111 |
+
return it.payload_.materialized_iterator;
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
static IListRefConstRef<T> front(const list_type& lst) {
|
| 115 |
+
return lst[0];
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
static IListRefConstRef<T> iterator_get(
|
| 119 |
+
const typename list_type::const_iterator& it) {
|
| 120 |
+
return *it;
|
| 121 |
+
}
|
| 122 |
+
};
|
| 123 |
+
|
| 124 |
+
/*
|
| 125 |
+
* [Note: ITensorListRef]
|
| 126 |
+
* Specializations necessary for `IListRef<at::Tensor>` type.
|
| 127 |
+
*
|
| 128 |
+
* Since the default implementations are usually done with supporting
|
| 129 |
+
* `Tensor` in mind, we only have to inherit from the base implementations.
|
| 130 |
+
*/
|
| 131 |
+
template <>
|
| 132 |
+
class IListRefTagImpl<IListRefTag::Unboxed, at::Tensor>
|
| 133 |
+
: public IListRefTagImplBase<IListRefTag::Unboxed, at::Tensor> {};
|
| 134 |
+
|
| 135 |
+
template <>
|
| 136 |
+
class IListRefTagImpl<IListRefTag::Boxed, at::Tensor>
|
| 137 |
+
: public IListRefTagImplBase<IListRefTag::Boxed, at::Tensor> {};
|
| 138 |
+
|
| 139 |
+
template <>
|
| 140 |
+
class IListRefTagImpl<IListRefTag::Materialized, at::Tensor>
|
| 141 |
+
: public IListRefTagImplBase<
|
| 142 |
+
IListRefTag::Materialized,
|
| 143 |
+
at::Tensor,
|
| 144 |
+
MaterializedIListRefElem<at::Tensor>> {};
|
| 145 |
+
|
| 146 |
+
/*
|
| 147 |
+
* [Note: IOptTensorListRef]
|
| 148 |
+
* Specializations necessary for `IListRef<at::OptionalTensorRef>` type.
|
| 149 |
+
*
|
| 150 |
+
* We can't get an `at::OptionalTensorRef` directly from an instance of
|
| 151 |
+
* `List<optional<Tensor>>` (the type that corresponds to the boxed world).
|
| 152 |
+
*
|
| 153 |
+
* So, the default implementation won't help us. Thus, we have to implement
|
| 154 |
+
* this method ourselves.
|
| 155 |
+
*/
|
| 156 |
+
template <>
|
| 157 |
+
class IListRefTagImpl<IListRefTag::Unboxed, at::OptionalTensorRef>
|
| 158 |
+
: public IListRefTagImplBase<IListRefTag::Unboxed, at::OptionalTensorRef> {};
|
| 159 |
+
|
| 160 |
+
template <>
|
| 161 |
+
class IListRefTagImpl<IListRefTag::Boxed, at::OptionalTensorRef>
|
| 162 |
+
: public IListRefTagImplBase<IListRefTag::Boxed, at::OptionalTensorRef, std::optional<at::Tensor>> {
|
| 163 |
+
|
| 164 |
+
public:
|
| 165 |
+
/*
|
| 166 |
+
* Given an instance of the types corresponding to the `Boxed` tag, we override
|
| 167 |
+
* the default implementation, so that we can return a `at::OptionalTensorRef`.
|
| 168 |
+
*/
|
| 169 |
+
static IListRefConstRef<at::OptionalTensorRef> iterator_get(
|
| 170 |
+
const typename list_type::const_iterator& it) {
|
| 171 |
+
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdangling-reference")
|
| 172 |
+
const auto& ivalue = (*it).get();
|
| 173 |
+
C10_DIAGNOSTIC_POP()
|
| 174 |
+
if (!ivalue.isNone()) {
|
| 175 |
+
const auto& tensor = ivalue.toTensor();
|
| 176 |
+
return (tensor.defined()) ? tensor : at::OptionalTensorRef{};
|
| 177 |
+
}
|
| 178 |
+
return {};
|
| 179 |
+
}
|
| 180 |
+
};
|
| 181 |
+
|
| 182 |
+
template <>
|
| 183 |
+
class IListRefTagImpl<IListRefTag::Materialized, at::OptionalTensorRef>
|
| 184 |
+
: public IListRefTagImplBase<
|
| 185 |
+
IListRefTag::Materialized,
|
| 186 |
+
at::OptionalTensorRef,
|
| 187 |
+
MaterializedIListRefElem<at::OptionalTensorRef>> {};
|
| 188 |
+
|
| 189 |
+
} // namespace c10::detail
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
namespace at {
|
| 193 |
+
|
| 194 |
+
// [Note: ITensorListRef]
|
| 195 |
+
using ITensorListRef = c10::IListRef<at::Tensor>;
|
| 196 |
+
using ITensorListRefIterator = c10::IListRefIterator<at::Tensor>;
|
| 197 |
+
using MaterializedITensorListRef = c10::detail::MaterializedIListRef<at::Tensor>;
|
| 198 |
+
// [Note: IOptTensorListRef]
|
| 199 |
+
using IOptTensorListRef = c10::IListRef<at::OptionalTensorRef>;
|
| 200 |
+
using IOptTensorListRefIterator = c10::IListRefIterator<at::OptionalTensorRef>;
|
| 201 |
+
using MaterializedIOptTensorListRef = c10::detail::MaterializedIListRef<at::OptionalTensorRef>;
|
| 202 |
+
|
| 203 |
+
} // namespace at
|
phivenv/Lib/site-packages/torch/include/ATen/core/LegacyTypeDispatch.h
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// The legacy mechanism for dispatching operators in ATen is a Type
|
| 4 |
+
// object, which is essentially a giant virtual dispatch table
|
| 5 |
+
// for every operation we support dynamically dispatching over.
|
| 6 |
+
//
|
| 7 |
+
// This has been deprecated in favor of ATenDispatch, and in the future,
|
| 8 |
+
// c10 dispatcher.
|
| 9 |
+
// TODO: Clean up what remains here
|
| 10 |
+
|
| 11 |
+
#include <c10/core/impl/LocalDispatchKeySet.h>
|
| 12 |
+
|
| 13 |
+
namespace at {
|
| 14 |
+
|
| 15 |
+
// A RAII, thread local (!) guard that will disable dispatch to variable
|
| 16 |
+
// handler.
|
| 17 |
+
//
|
| 18 |
+
// NOTE [ Treating Variables as non-Variables in type dispatch ]
|
| 19 |
+
//
|
| 20 |
+
// What exactly does AutoDispatchBelowAutograd do? The short answer is, it causes
|
| 21 |
+
// dispatches on ATen functions to go to the non-variable implementation,
|
| 22 |
+
// bypassing autograd handling (and also profiling and tracing).
|
| 23 |
+
//
|
| 24 |
+
// To understand why this guard exists, it's helpful to understand the history
|
| 25 |
+
// behind how Variable was implemented. Previously, Variables were implemented
|
| 26 |
+
// as a wrapper on Tensors; so the act of processing a Variable involved
|
| 27 |
+
// unwrapping the underlying Tensor, and then calling the underlying base
|
| 28 |
+
// operation on /that/ operation
|
| 29 |
+
//
|
| 30 |
+
// However, after the Variable/Tensor merge, there is no concept of unwrapping
|
| 31 |
+
// a tensor anymore. If you just call the operation on the same variable
|
| 32 |
+
// again inside your VariableType handler, you'll dispatch back to
|
| 33 |
+
// VariableType, which is not what we want.
|
| 34 |
+
//
|
| 35 |
+
// The solution to the above problem is to add `at::AutoDispatchBelowAutograd`, which
|
| 36 |
+
// when enabled will cause `legacyTensorType()` and `getType()` to always return
|
| 37 |
+
// non-Variable type, even if the tensor being called on is a variable.
|
| 38 |
+
|
| 39 |
+
/* Note [AutoDispatchBelowAutograd]
|
| 40 |
+
* AutoDispatchBelowAutograd is **INTERNAL ONLY** that it should be used
|
| 41 |
+
* for kernel implementations and customized C++ kernels.
|
| 42 |
+
* If you are looking for a guard to run workload in inference mode, please use
|
| 43 |
+
* c10::InferenceMode RAII which is user facing API.
|
| 44 |
+
* In the past AutoDispatchBelowAutograd(or its old version AutoNonVariableTypeMode)
|
| 45 |
+
* was used in the user code for inference-only workload, this was under risk of
|
| 46 |
+
* producing wrong results silently in some edge cases. For example:
|
| 47 |
+
* ```
|
| 48 |
+
* torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(true);
|
| 49 |
+
* torch::Tensor out = s * s;
|
| 50 |
+
* {
|
| 51 |
+
* at::AutoDispatchBelowAutograd guard;
|
| 52 |
+
* s.add_(1); // Skips version bump on `s`.
|
| 53 |
+
* }
|
| 54 |
+
* // WRONG GRADIENT! s.grad() are now computed using `s` value after the
|
| 55 |
+
* // inplace update.
|
| 56 |
+
* out.backward(torch::ones_like(out));
|
| 57 |
+
* ```
|
| 58 |
+
* Users should use `c10::InferenceMode` here so that it'll properly throw an
|
| 59 |
+
* error saying "one of the variables needed for gradient computation has be modified."
|
| 60 |
+
*/
|
| 61 |
+
struct TORCH_API AutoDispatchBelowAutograd {
|
| 62 |
+
AutoDispatchBelowAutograd() :
|
| 63 |
+
autograd_guard_(c10::autograd_dispatch_keyset) {
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
// disable all autograd dispatch keys
|
| 67 |
+
c10::impl::ExcludeDispatchKeyGuard autograd_guard_;
|
| 68 |
+
};
|
| 69 |
+
|
| 70 |
+
// TODO: AutoNonVariableTypeMode should be removed in release 1.10.
|
| 71 |
+
struct TORCH_API AutoNonVariableTypeMode {
|
| 72 |
+
AutoNonVariableTypeMode(bool enabled = true) :
|
| 73 |
+
autograd_guard_(c10::autograd_dispatch_keyset) {
|
| 74 |
+
TORCH_WARN_ONCE("AutoNonVariableTypeMode is deprecated and will be removed in 1.10 release. "
|
| 75 |
+
"For kernel implementations please use AutoDispatchBelowADInplaceOrView instead, "
|
| 76 |
+
"If you are looking for a user facing API to enable running your inference-only "
|
| 77 |
+
"workload, please use c10::InferenceMode. Using AutoDispatchBelowADInplaceOrView in user code "
|
| 78 |
+
"is under risk of producing silent wrong result in some edge cases. "
|
| 79 |
+
"See Note [AutoDispatchBelowAutograd] for more details.");
|
| 80 |
+
TORCH_INTERNAL_ASSERT(enabled);
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
// disable all autograd dispatch keys
|
| 84 |
+
c10::impl::ExcludeDispatchKeyGuard autograd_guard_;
|
| 85 |
+
};
|
| 86 |
+
|
| 87 |
+
struct TORCH_API AutoDispatchSkipFunctionalize {
|
| 88 |
+
AutoDispatchSkipFunctionalize() :
|
| 89 |
+
dispatch_key_guard_(c10::DispatchKeySet(c10::DispatchKey::Functionalize)) {
|
| 90 |
+
}
|
| 91 |
+
c10::impl::ExcludeDispatchKeyGuard dispatch_key_guard_;
|
| 92 |
+
};
|
| 93 |
+
|
| 94 |
+
/* Note [AutoDispatchBelowADInplaceOrView]
|
| 95 |
+
* AutoDispatchBelowADInplaceOrView is equivalent to AutoNonVariableTypeMode
|
| 96 |
+
* before we split inplace & view ops out of VariableType kernel.
|
| 97 |
+
* Note this guard is used in VariableType kernels for functional ops
|
| 98 |
+
* as well as ADInplaceOrView kernels for inplace/view ops to enforce the
|
| 99 |
+
* Invariant:
|
| 100 |
+
* Once you are in VariableType/ADInplaceOrView kernel for an op,
|
| 101 |
+
* you never go back to a kernel on same dispatch key until
|
| 102 |
+
* you finish the current op.
|
| 103 |
+
*/
|
| 104 |
+
struct TORCH_API AutoDispatchBelowADInplaceOrView {
|
| 105 |
+
AutoDispatchBelowADInplaceOrView() :
|
| 106 |
+
dispatch_key_guard_(c10::autograd_dispatch_keyset_with_ADInplaceOrView) {
|
| 107 |
+
}
|
| 108 |
+
// disable Autograd & ADInplaceOrView dispatch keys
|
| 109 |
+
c10::impl::ExcludeDispatchKeyGuard dispatch_key_guard_;
|
| 110 |
+
};
|
| 111 |
+
} // namespace at
|
phivenv/Lib/site-packages/torch/include/ATen/core/List.h
ADDED
|
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/ivalue_to.h>
|
| 4 |
+
#include <ATen/core/jit_type_base.h>
|
| 5 |
+
#include <c10/macros/Macros.h>
|
| 6 |
+
#include <c10/macros/Export.h>
|
| 7 |
+
#include <c10/util/TypeTraits.h>
|
| 8 |
+
#include <c10/util/TypeList.h>
|
| 9 |
+
#include <c10/util/intrusive_ptr.h>
|
| 10 |
+
#include <c10/util/ArrayRef.h>
|
| 11 |
+
#include <optional>
|
| 12 |
+
#include <vector>
|
| 13 |
+
|
| 14 |
+
namespace at {
|
| 15 |
+
class Tensor;
|
| 16 |
+
}
|
| 17 |
+
namespace c10 {
|
| 18 |
+
struct IValue;
|
| 19 |
+
template<class T> class List;
|
| 20 |
+
struct Type;
|
| 21 |
+
|
| 22 |
+
namespace detail {
|
| 23 |
+
|
| 24 |
+
struct ListImpl final : public c10::intrusive_ptr_target {
|
| 25 |
+
using list_type = std::vector<IValue>;
|
| 26 |
+
|
| 27 |
+
explicit TORCH_API ListImpl(list_type list_, TypePtr elementType_);
|
| 28 |
+
|
| 29 |
+
list_type list;
|
| 30 |
+
|
| 31 |
+
TypePtr elementType;
|
| 32 |
+
|
| 33 |
+
intrusive_ptr<ListImpl> copy() const {
|
| 34 |
+
return make_intrusive<ListImpl>(list, elementType);
|
| 35 |
+
}
|
| 36 |
+
friend TORCH_API bool operator==(const ListImpl& lhs, const ListImpl& rhs);
|
| 37 |
+
};
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
namespace impl {
|
| 41 |
+
|
| 42 |
+
template<class T, class Iterator> class ListIterator;
|
| 43 |
+
|
| 44 |
+
template<class T, class Iterator> class ListElementReference;
|
| 45 |
+
|
| 46 |
+
template<class T, class Iterator>
|
| 47 |
+
void swap(ListElementReference<T, Iterator>&& lhs, ListElementReference<T, Iterator>&& rhs) noexcept;
|
| 48 |
+
|
| 49 |
+
template<class T, class Iterator>
|
| 50 |
+
bool operator==(const ListElementReference<T, Iterator>& lhs, const T& rhs);
|
| 51 |
+
|
| 52 |
+
template<class T, class Iterator>
|
| 53 |
+
bool operator==(const T& lhs, const ListElementReference<T, Iterator>& rhs);
|
| 54 |
+
|
| 55 |
+
template<class T>
|
| 56 |
+
struct ListElementConstReferenceTraits {
|
| 57 |
+
// In the general case, we use IValue::to().
|
| 58 |
+
using const_reference = typename c10::detail::ivalue_to_const_ref_overload_return<T>::type;
|
| 59 |
+
};
|
| 60 |
+
|
| 61 |
+
// There is no to() overload for std::optional<std::string>.
|
| 62 |
+
template<>
|
| 63 |
+
struct ListElementConstReferenceTraits<std::optional<std::string>> {
|
| 64 |
+
using const_reference = std::optional<std::reference_wrapper<const std::string>>;
|
| 65 |
+
};
|
| 66 |
+
|
| 67 |
+
template<class T, class Iterator>
|
| 68 |
+
class ListElementReference final {
|
| 69 |
+
public:
|
| 70 |
+
operator std::conditional_t<
|
| 71 |
+
std::is_reference_v<typename c10::detail::
|
| 72 |
+
ivalue_to_const_ref_overload_return<T>::type>,
|
| 73 |
+
const T&,
|
| 74 |
+
T>() const;
|
| 75 |
+
|
| 76 |
+
ListElementReference& operator=(T&& new_value) &&;
|
| 77 |
+
|
| 78 |
+
ListElementReference& operator=(const T& new_value) &&;
|
| 79 |
+
|
| 80 |
+
// assigning another ref to this assigns the underlying value
|
| 81 |
+
ListElementReference& operator=(ListElementReference&& rhs) && noexcept;
|
| 82 |
+
|
| 83 |
+
const IValue& get() const& {
|
| 84 |
+
return *iterator_;
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
friend void swap<T, Iterator>(ListElementReference&& lhs, ListElementReference&& rhs) noexcept;
|
| 88 |
+
|
| 89 |
+
ListElementReference(const ListElementReference&) = delete;
|
| 90 |
+
ListElementReference& operator=(const ListElementReference&) = delete;
|
| 91 |
+
~ListElementReference() = default;
|
| 92 |
+
|
| 93 |
+
private:
|
| 94 |
+
ListElementReference(Iterator iter)
|
| 95 |
+
: iterator_(iter) {}
|
| 96 |
+
|
| 97 |
+
// allow moving, but only our friends (i.e. the List class) can move us
|
| 98 |
+
ListElementReference(ListElementReference&&) noexcept = default;
|
| 99 |
+
ListElementReference& operator=(ListElementReference&& rhs) & noexcept {
|
| 100 |
+
iterator_ = std::move(rhs.iterator_);
|
| 101 |
+
return *this;
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
friend class List<T>;
|
| 105 |
+
friend class ListIterator<T, Iterator>;
|
| 106 |
+
|
| 107 |
+
Iterator iterator_;
|
| 108 |
+
};
|
| 109 |
+
|
| 110 |
+
// this wraps vector::iterator to make sure user code can't rely
|
| 111 |
+
// on it being the type of the underlying vector.
|
| 112 |
+
template <class T, class Iterator>
|
| 113 |
+
class ListIterator final {
|
| 114 |
+
public:
|
| 115 |
+
// C++17 friendly std::iterator implementation
|
| 116 |
+
using iterator_category = std::random_access_iterator_tag;
|
| 117 |
+
using value_type = T;
|
| 118 |
+
using difference_type = std::ptrdiff_t;
|
| 119 |
+
using pointer = T*;
|
| 120 |
+
using reference = ListElementReference<T, Iterator>;
|
| 121 |
+
|
| 122 |
+
explicit ListIterator() = default;
|
| 123 |
+
~ListIterator() = default;
|
| 124 |
+
|
| 125 |
+
ListIterator(const ListIterator&) = default;
|
| 126 |
+
ListIterator(ListIterator&&) noexcept = default;
|
| 127 |
+
ListIterator& operator=(const ListIterator&) = default;
|
| 128 |
+
ListIterator& operator=(ListIterator&&) noexcept = default;
|
| 129 |
+
|
| 130 |
+
ListIterator& operator++() {
|
| 131 |
+
++iterator_;
|
| 132 |
+
return *this;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
ListIterator operator++(int) {
|
| 136 |
+
ListIterator copy(*this);
|
| 137 |
+
++*this;
|
| 138 |
+
return copy;
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
ListIterator& operator--() {
|
| 142 |
+
--iterator_;
|
| 143 |
+
return *this;
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
ListIterator operator--(int) {
|
| 147 |
+
ListIterator copy(*this);
|
| 148 |
+
--*this;
|
| 149 |
+
return copy;
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
ListIterator& operator+=(typename List<T>::size_type offset) {
|
| 153 |
+
iterator_ += offset;
|
| 154 |
+
return *this;
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
ListIterator& operator-=(typename List<T>::size_type offset) {
|
| 158 |
+
iterator_ -= offset;
|
| 159 |
+
return *this;
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
ListIterator operator+(typename List<T>::size_type offset) const {
|
| 163 |
+
return ListIterator{iterator_ + offset};
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
ListIterator operator-(typename List<T>::size_type offset) const {
|
| 167 |
+
return ListIterator{iterator_ - offset};
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
friend difference_type operator-(const ListIterator& lhs, const ListIterator& rhs) {
|
| 171 |
+
return lhs.iterator_ - rhs.iterator_;
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
ListElementReference<T, Iterator> operator*() const {
|
| 175 |
+
return {iterator_};
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
ListElementReference<T, Iterator> operator[](typename List<T>::size_type offset) const {
|
| 179 |
+
return {iterator_ + offset};
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
private:
|
| 183 |
+
explicit ListIterator(Iterator iterator): iterator_(std::move(iterator)) {}
|
| 184 |
+
|
| 185 |
+
Iterator iterator_;
|
| 186 |
+
|
| 187 |
+
friend bool operator==(const ListIterator& lhs, const ListIterator& rhs) {
|
| 188 |
+
return lhs.iterator_ == rhs.iterator_;
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
friend bool operator!=(const ListIterator& lhs, const ListIterator& rhs) {
|
| 192 |
+
return !(lhs == rhs);
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
friend bool operator<(const ListIterator& lhs, const ListIterator& rhs) {
|
| 196 |
+
return lhs.iterator_ < rhs.iterator_;
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
friend bool operator<=(const ListIterator& lhs, const ListIterator& rhs) {
|
| 200 |
+
return lhs.iterator_ <= rhs.iterator_;
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
friend bool operator>(const ListIterator& lhs, const ListIterator& rhs) {
|
| 204 |
+
return lhs.iterator_ > rhs.iterator_;
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
friend bool operator>=(const ListIterator& lhs, const ListIterator& rhs) {
|
| 208 |
+
return lhs.iterator_ >= rhs.iterator_;
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
friend class ListIterator<T, typename c10::detail::ListImpl::list_type::iterator>;
|
| 212 |
+
friend class List<T>;
|
| 213 |
+
};
|
| 214 |
+
|
| 215 |
+
template<class T> List<T> toTypedList(List<IValue> list);
|
| 216 |
+
template<class T> List<IValue> toList(List<T>&& list);
|
| 217 |
+
template<class T> List<IValue> toList(const List<T>& list);
|
| 218 |
+
const IValue* ptr_to_first_element(const List<IValue>& list);
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
/**
|
| 222 |
+
* An object of this class stores a list of values of type T.
|
| 223 |
+
*
|
| 224 |
+
* This is a pointer type. After a copy, both Lists
|
| 225 |
+
* will share the same storage:
|
| 226 |
+
*
|
| 227 |
+
* > List<int> a;
|
| 228 |
+
* > List<int> b = a;
|
| 229 |
+
* > b.push_back("three");
|
| 230 |
+
* > ASSERT("three" == a.get(0));
|
| 231 |
+
*
|
| 232 |
+
* We use this class in the PyTorch kernel API instead of
|
| 233 |
+
* std::vector<T>, because that allows us to do optimizations
|
| 234 |
+
* and switch out the underlying list implementation without
|
| 235 |
+
* breaking backwards compatibility for the kernel API.
|
| 236 |
+
*/
|
| 237 |
+
template<class T>
|
| 238 |
+
// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions)
|
| 239 |
+
class List final {
|
| 240 |
+
private:
|
| 241 |
+
// This is an intrusive_ptr because List is a pointer type.
|
| 242 |
+
// Invariant: This will never be a nullptr, there will always be a valid
|
| 243 |
+
// ListImpl.
|
| 244 |
+
c10::intrusive_ptr<c10::detail::ListImpl> impl_;
|
| 245 |
+
|
| 246 |
+
using internal_reference_type = impl::ListElementReference<T, typename c10::detail::ListImpl::list_type::iterator>;
|
| 247 |
+
using internal_const_reference_type = typename impl::ListElementConstReferenceTraits<T>::const_reference;
|
| 248 |
+
|
| 249 |
+
public:
|
| 250 |
+
using value_type = T;
|
| 251 |
+
using size_type = typename c10::detail::ListImpl::list_type::size_type;
|
| 252 |
+
using iterator = impl::ListIterator<T, typename c10::detail::ListImpl::list_type::iterator>;
|
| 253 |
+
using const_iterator = impl::ListIterator<T, typename c10::detail::ListImpl::list_type::iterator>;
|
| 254 |
+
using reverse_iterator = impl::ListIterator<T, typename c10::detail::ListImpl::list_type::reverse_iterator>;
|
| 255 |
+
|
| 256 |
+
/**
|
| 257 |
+
* Constructs an empty list.
|
| 258 |
+
*/
|
| 259 |
+
explicit List();
|
| 260 |
+
|
| 261 |
+
/**
|
| 262 |
+
* Constructs a list with some initial values.
|
| 263 |
+
* Example:
|
| 264 |
+
* List<int> a({2, 3, 4});
|
| 265 |
+
*/
|
| 266 |
+
List(std::initializer_list<T> initial_values);
|
| 267 |
+
explicit List(ArrayRef<T> initial_values);
|
| 268 |
+
|
| 269 |
+
/**
|
| 270 |
+
* Create a generic list with runtime type information.
|
| 271 |
+
* This only works for c10::impl::GenericList and is not part of the public API
|
| 272 |
+
* but only supposed to be used internally by PyTorch.
|
| 273 |
+
*/
|
| 274 |
+
explicit List(TypePtr elementType);
|
| 275 |
+
|
| 276 |
+
List(const List&) = default;
|
| 277 |
+
List& operator=(const List&) = default;
|
| 278 |
+
~List() = default;
|
| 279 |
+
|
| 280 |
+
/**
|
| 281 |
+
* Create a new List pointing to a deep copy of the same data.
|
| 282 |
+
* The List returned is a new list with separate storage.
|
| 283 |
+
* Changes in it are not reflected in the original list or vice versa.
|
| 284 |
+
*/
|
| 285 |
+
List copy() const;
|
| 286 |
+
|
| 287 |
+
/**
|
| 288 |
+
* Returns the element at specified location pos, with bounds checking.
|
| 289 |
+
* If pos is not within the range of the container, an exception of type std::out_of_range is thrown.
|
| 290 |
+
*/
|
| 291 |
+
internal_const_reference_type get(size_type pos) const;
|
| 292 |
+
|
| 293 |
+
/**
|
| 294 |
+
* Moves out the element at the specified location pos and returns it, with bounds checking.
|
| 295 |
+
* If pos is not within the range of the container, an exception of type std::out_of_range is thrown.
|
| 296 |
+
* The list contains an invalid element at position pos afterwards. Any operations
|
| 297 |
+
* on it before re-setting it are invalid.
|
| 298 |
+
*/
|
| 299 |
+
value_type extract(size_type pos) const;
|
| 300 |
+
|
| 301 |
+
/**
|
| 302 |
+
* Returns a reference to the element at specified location pos, with bounds checking.
|
| 303 |
+
* If pos is not within the range of the container, an exception of type std::out_of_range is thrown.
|
| 304 |
+
*
|
| 305 |
+
* You cannot store the reference, but you can read it and assign new values to it:
|
| 306 |
+
*
|
| 307 |
+
* List<int64_t> list = ...;
|
| 308 |
+
* list[2] = 5;
|
| 309 |
+
* int64_t v = list[1];
|
| 310 |
+
*/
|
| 311 |
+
internal_const_reference_type operator[](size_type pos) const;
|
| 312 |
+
|
| 313 |
+
internal_reference_type operator[](size_type pos);
|
| 314 |
+
|
| 315 |
+
/**
|
| 316 |
+
* Assigns a new value to the element at location pos.
|
| 317 |
+
*/
|
| 318 |
+
void set(size_type pos, const value_type& value) const;
|
| 319 |
+
|
| 320 |
+
/**
|
| 321 |
+
* Assigns a new value to the element at location pos.
|
| 322 |
+
*/
|
| 323 |
+
void set(size_type pos, value_type&& value) const;
|
| 324 |
+
|
| 325 |
+
/**
|
| 326 |
+
* Returns an iterator to the first element of the container.
|
| 327 |
+
* If the container is empty, the returned iterator will be equal to end().
|
| 328 |
+
*/
|
| 329 |
+
iterator begin() const;
|
| 330 |
+
|
| 331 |
+
/**
|
| 332 |
+
* Returns an iterator to the element following the last element of the container.
|
| 333 |
+
* This element acts as a placeholder; attempting to access it results in undefined behavior.
|
| 334 |
+
*/
|
| 335 |
+
iterator end() const;
|
| 336 |
+
|
| 337 |
+
/**
|
| 338 |
+
* Checks if the container has no elements.
|
| 339 |
+
*/
|
| 340 |
+
bool empty() const;
|
| 341 |
+
|
| 342 |
+
/**
|
| 343 |
+
* Returns the number of elements in the container
|
| 344 |
+
*/
|
| 345 |
+
size_type size() const;
|
| 346 |
+
|
| 347 |
+
/**
|
| 348 |
+
* Increase the capacity of the vector to a value that's greater or equal to new_cap.
|
| 349 |
+
*/
|
| 350 |
+
void reserve(size_type new_cap) const;
|
| 351 |
+
|
| 352 |
+
/**
|
| 353 |
+
* Erases all elements from the container. After this call, size() returns zero.
|
| 354 |
+
* Invalidates any references, pointers, or iterators referring to contained elements. Any past-the-end iterators are also invalidated.
|
| 355 |
+
*/
|
| 356 |
+
void clear() const;
|
| 357 |
+
|
| 358 |
+
/**
|
| 359 |
+
* Inserts value before pos.
|
| 360 |
+
* May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
|
| 361 |
+
*/
|
| 362 |
+
iterator insert(iterator pos, const T& value) const;
|
| 363 |
+
|
| 364 |
+
/**
|
| 365 |
+
* Inserts value before pos.
|
| 366 |
+
* May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
|
| 367 |
+
*/
|
| 368 |
+
iterator insert(iterator pos, T&& value) const;
|
| 369 |
+
|
| 370 |
+
/**
|
| 371 |
+
* Inserts a new element into the container directly before pos.
|
| 372 |
+
* The new element is constructed with the given arguments.
|
| 373 |
+
* May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
|
| 374 |
+
*/
|
| 375 |
+
template<class... Args>
|
| 376 |
+
iterator emplace(iterator pos, Args&&... value) const;
|
| 377 |
+
|
| 378 |
+
/**
|
| 379 |
+
* Appends the given element value to the end of the container.
|
| 380 |
+
* May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
|
| 381 |
+
*/
|
| 382 |
+
void push_back(const T& value) const;
|
| 383 |
+
|
| 384 |
+
/**
|
| 385 |
+
* Appends the given element value to the end of the container.
|
| 386 |
+
* May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
|
| 387 |
+
*/
|
| 388 |
+
void push_back(T&& value) const;
|
| 389 |
+
|
| 390 |
+
/**
|
| 391 |
+
* Appends the given list to the end of the container. Uses at most one memory allocation.
|
| 392 |
+
* May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
|
| 393 |
+
*/
|
| 394 |
+
void append(List<T> lst) const;
|
| 395 |
+
|
| 396 |
+
/**
|
| 397 |
+
* Appends the given element value to the end of the container.
|
| 398 |
+
* The new element is constructed with the given arguments.
|
| 399 |
+
* May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
|
| 400 |
+
*/
|
| 401 |
+
template<class... Args>
|
| 402 |
+
void emplace_back(Args&&... args) const;
|
| 403 |
+
|
| 404 |
+
/**
|
| 405 |
+
* Removes the element at pos.
|
| 406 |
+
* May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
|
| 407 |
+
*/
|
| 408 |
+
iterator erase(iterator pos) const;
|
| 409 |
+
|
| 410 |
+
/**
|
| 411 |
+
* Removes the elements in the range [first, last).
|
| 412 |
+
* May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
|
| 413 |
+
*/
|
| 414 |
+
iterator erase(iterator first, iterator last) const;
|
| 415 |
+
|
| 416 |
+
/**
|
| 417 |
+
* Removes the last element of the container.
|
| 418 |
+
* Calling pop_back on an empty container is undefined.
|
| 419 |
+
* May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
|
| 420 |
+
*/
|
| 421 |
+
void pop_back() const;
|
| 422 |
+
|
| 423 |
+
/**
|
| 424 |
+
* Resizes the container to contain count elements.
|
| 425 |
+
* If the current size is less than count, additional default-inserted elements are appended.
|
| 426 |
+
* May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
|
| 427 |
+
*/
|
| 428 |
+
void resize(size_type count) const;
|
| 429 |
+
|
| 430 |
+
/**
|
| 431 |
+
* Resizes the container to contain count elements.
|
| 432 |
+
* If the current size is less than count, additional copies of value are appended.
|
| 433 |
+
* May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
|
| 434 |
+
*/
|
| 435 |
+
void resize(size_type count, const T& value) const;
|
| 436 |
+
|
| 437 |
+
/**
|
| 438 |
+
* Value equality comparison. This function implements Python-like semantics for
|
| 439 |
+
* equality: two lists with the same identity (e.g. same pointer) trivially
|
| 440 |
+
* compare equal, otherwise each element is compared for equality.
|
| 441 |
+
*/
|
| 442 |
+
template <class T_>
|
| 443 |
+
friend bool operator==(const List<T_>& lhs, const List<T_>& rhs);
|
| 444 |
+
|
| 445 |
+
template <class T_>
|
| 446 |
+
friend bool operator!=(const List<T_>& lhs, const List<T_>& rhs);
|
| 447 |
+
|
| 448 |
+
/**
|
| 449 |
+
* Identity comparison. Returns true if and only if `rhs` represents the same
|
| 450 |
+
* List object as `this`.
|
| 451 |
+
*/
|
| 452 |
+
bool is(const List<T>& rhs) const;
|
| 453 |
+
|
| 454 |
+
std::vector<T> vec() const;
|
| 455 |
+
|
| 456 |
+
/**
|
| 457 |
+
* Returns the number of Lists currently pointing to this same list.
|
| 458 |
+
* If this is the only instance pointing to this list, returns 1.
|
| 459 |
+
*/
|
| 460 |
+
// TODO Test use_count
|
| 461 |
+
size_t use_count() const;
|
| 462 |
+
|
| 463 |
+
TypePtr elementType() const;
|
| 464 |
+
|
| 465 |
+
// See [unsafe set type] for why this exists.
|
| 466 |
+
void unsafeSetElementType(TypePtr t);
|
| 467 |
+
|
| 468 |
+
private:
|
| 469 |
+
explicit List(c10::intrusive_ptr<c10::detail::ListImpl>&& elements);
|
| 470 |
+
explicit List(const c10::intrusive_ptr<c10::detail::ListImpl>& elements);
|
| 471 |
+
friend struct IValue;
|
| 472 |
+
template<class T_> friend List<T_> impl::toTypedList(List<IValue>);
|
| 473 |
+
template<class T_> friend List<IValue> impl::toList(List<T_>&&);
|
| 474 |
+
template<class T_> friend List<IValue> impl::toList(const List<T_>&);
|
| 475 |
+
friend const IValue* impl::ptr_to_first_element(const List<IValue>& list);
|
| 476 |
+
};
|
| 477 |
+
|
| 478 |
+
namespace impl {
|
| 479 |
+
// GenericList is how IValue stores lists. It is, however, not part of the
|
| 480 |
+
// public API. Kernels should use Lists with concrete types instead
|
| 481 |
+
// (maybe except for some internal prim ops).
|
| 482 |
+
using GenericList = List<IValue>;
|
| 483 |
+
|
| 484 |
+
}
|
| 485 |
+
}
|
| 486 |
+
|
| 487 |
+
namespace torch {
|
| 488 |
+
template<class T> using List = c10::List<T>;
|
| 489 |
+
}
|
| 490 |
+
|
| 491 |
+
#include <ATen/core/List_inl.h> // IWYU pragma: keep
|
phivenv/Lib/site-packages/torch/include/ATen/core/List_inl.h
ADDED
|
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/jit_type_base.h>
|
| 4 |
+
#include <ATen/core/ivalue.h>
|
| 5 |
+
|
| 6 |
+
namespace c10 {
|
| 7 |
+
|
| 8 |
+
template<class T> decltype(auto) getTypePtr();
|
| 9 |
+
std::string toString(const Type& type);
|
| 10 |
+
|
| 11 |
+
template<class T>
|
| 12 |
+
List<T>::List(c10::intrusive_ptr<c10::detail::ListImpl>&& elements)
|
| 13 |
+
: impl_(std::move(elements)) {}
|
| 14 |
+
|
| 15 |
+
template<class T>
|
| 16 |
+
List<T>::List(const c10::intrusive_ptr<c10::detail::ListImpl>& elements)
|
| 17 |
+
: impl_(elements) {}
|
| 18 |
+
|
| 19 |
+
template<class T>
|
| 20 |
+
List<T>::List()
|
| 21 |
+
: List(make_intrusive<c10::detail::ListImpl>(
|
| 22 |
+
typename c10::detail::ListImpl::list_type(),
|
| 23 |
+
getTypePtr<T>())) {
|
| 24 |
+
static_assert(!std::is_same_v<T, IValue>, "This constructor is not valid for List<IValue>. Please use c10::impl::GenericList(elementType) instead.");
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
template<class T>
|
| 28 |
+
List<T>::List(ArrayRef<T> values)
|
| 29 |
+
: List(make_intrusive<c10::detail::ListImpl>(
|
| 30 |
+
typename c10::detail::ListImpl::list_type(),
|
| 31 |
+
getTypePtr<T>())) {
|
| 32 |
+
static_assert(!std::is_same_v<T, IValue>, "This constructor is not valid for List<IValue>. Please use c10::impl::GenericList(elementType).");
|
| 33 |
+
impl_->list.reserve(values.size());
|
| 34 |
+
for (const T& element : values) {
|
| 35 |
+
impl_->list.push_back(element);
|
| 36 |
+
}
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
template<class T>
|
| 40 |
+
List<T>::List(std::initializer_list<T> initial_values)
|
| 41 |
+
: List(ArrayRef<T>(initial_values)) {
|
| 42 |
+
static_assert(!std::is_same_v<T, IValue>, "This constructor is not valid for List<IValue>. Please use c10::impl::GenericList(elementType).");
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
template<class T>
|
| 46 |
+
List<T>::List(TypePtr elementType)
|
| 47 |
+
: List(make_intrusive<c10::detail::ListImpl>(
|
| 48 |
+
typename c10::detail::ListImpl::list_type(),
|
| 49 |
+
std::move(elementType))) {
|
| 50 |
+
static_assert(std::is_same_v<T, IValue> || std::is_same_v<T, c10::intrusive_ptr<ivalue::Future>>,
|
| 51 |
+
"This constructor is only valid for c10::impl::GenericList or List<Future>.");
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
namespace impl {
|
| 55 |
+
template<class T>
|
| 56 |
+
List<T> toTypedList(impl::GenericList list) {
|
| 57 |
+
// If there's other instances of the list (i.e. list.use_count() > 1), then we have to be invariant
|
| 58 |
+
// because upcasting would allow people to add types into the new list that would break the old list.
|
| 59 |
+
// However, if there aren't any other instances of this list (i.e. list.use_count() == 1), then we can
|
| 60 |
+
// allow upcasting. This can be a perf improvement since we can cast List<T> to List<optional<T>>
|
| 61 |
+
// without having to copy it. This is also used to provide backwards compatibility with some old models
|
| 62 |
+
// that serialized the index arguments to aten::index, aten::index_put, aten::index_put_ and aten::index_put_impl_
|
| 63 |
+
// as List<Tensor> before we changed that argument to be List<optional<Tensor>>. When deserializing, we
|
| 64 |
+
// have list.use_count() == 1 and can deserialize the List<Tensor> directly as List<optional<Tensor>>.
|
| 65 |
+
TORCH_CHECK(*list.impl_->elementType == *getTypePtr<T>()
|
| 66 |
+
|| (list.use_count() == 1 && list.impl_->elementType->isSubtypeOf(*getTypePtr<T>()))
|
| 67 |
+
, "Tried to cast a List<", toString(*list.impl_->elementType), "> to a List<", toString(*getTypePtr<T>()), ">. Types mismatch.");
|
| 68 |
+
return List<T>(std::move(list.impl_));
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
template<class T>
|
| 72 |
+
impl::GenericList toList(List<T>&& list) {
|
| 73 |
+
return GenericList(std::move(list.impl_));
|
| 74 |
+
}
|
| 75 |
+
template<class T>
|
| 76 |
+
impl::GenericList toList(const List<T>& list) {
|
| 77 |
+
return GenericList(list.impl_);
|
| 78 |
+
}
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
template<class T>
|
| 82 |
+
List<T> List<T>::copy() const {
|
| 83 |
+
return List<T>(impl_->copy());
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
namespace detail {
|
| 87 |
+
template<class T>
|
| 88 |
+
T list_element_to(T element) {
|
| 89 |
+
return element;
|
| 90 |
+
}
|
| 91 |
+
template<class T>
|
| 92 |
+
T list_element_to(const IValue& element) {
|
| 93 |
+
return element.template to<T>();
|
| 94 |
+
}
|
| 95 |
+
template<class T>
|
| 96 |
+
T list_element_to(IValue&& element) {
|
| 97 |
+
return std::move(element).template to<T>();
|
| 98 |
+
}
|
| 99 |
+
template<class T>
|
| 100 |
+
struct ListElementFrom {
|
| 101 |
+
static IValue from(const T& element) {
|
| 102 |
+
return element;
|
| 103 |
+
}
|
| 104 |
+
static IValue from(T&& element) {
|
| 105 |
+
return std::move(element);
|
| 106 |
+
}
|
| 107 |
+
};
|
| 108 |
+
template<>
|
| 109 |
+
struct ListElementFrom<IValue> {
|
| 110 |
+
static const IValue& from(const IValue& element) {
|
| 111 |
+
return element;
|
| 112 |
+
}
|
| 113 |
+
static IValue&& from(IValue&& element) {
|
| 114 |
+
return std::move(element);
|
| 115 |
+
}
|
| 116 |
+
};
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
namespace impl {
|
| 120 |
+
|
| 121 |
+
template <class T, class Iterator>
|
| 122 |
+
ListElementReference<T, Iterator>::operator std::conditional_t<
|
| 123 |
+
std::is_reference_v<typename c10::detail::ivalue_to_const_ref_overload_return<
|
| 124 |
+
T>::type>,
|
| 125 |
+
const T&,
|
| 126 |
+
T>() const {
|
| 127 |
+
return iterator_->template to<T>();
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
template<class T, class Iterator>
|
| 131 |
+
ListElementReference<T, Iterator>& ListElementReference<T, Iterator>::operator=(T&& new_value) && {
|
| 132 |
+
*iterator_ = c10::detail::ListElementFrom<T>::from(std::move(new_value));
|
| 133 |
+
return *this;
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
template<class T, class Iterator>
|
| 137 |
+
ListElementReference<T, Iterator>& ListElementReference<T, Iterator>::operator=(const T& new_value) && {
|
| 138 |
+
*iterator_ = c10::detail::ListElementFrom<T>::from(new_value);
|
| 139 |
+
return *this;
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
template<class T, class Iterator>
|
| 143 |
+
ListElementReference<T, Iterator>& ListElementReference<T, Iterator>::operator=(ListElementReference<T, Iterator>&& rhs) && noexcept {
|
| 144 |
+
*iterator_ = *rhs.iterator_;
|
| 145 |
+
return *this;
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
template<class T, class Iterator>
|
| 149 |
+
void swap(ListElementReference<T, Iterator>&& lhs, ListElementReference<T, Iterator>&& rhs) noexcept {
|
| 150 |
+
std::swap(*lhs.iterator_, *rhs.iterator_);
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
template<class T, class Iterator>
|
| 154 |
+
bool operator==(const ListElementReference<T, Iterator>& lhs, const T& rhs) {
|
| 155 |
+
const T& lhs_tmp = lhs;
|
| 156 |
+
return lhs_tmp == rhs;
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
template<class T, class Iterator>
|
| 160 |
+
inline bool operator==(const T& lhs, const ListElementReference<T, Iterator>& rhs) {
|
| 161 |
+
return rhs == lhs;
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
template<class T>
|
| 165 |
+
inline typename ListElementConstReferenceTraits<T>::const_reference
|
| 166 |
+
list_element_to_const_ref(const IValue& element) {
|
| 167 |
+
return element.template to<T>();
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
template<>
|
| 171 |
+
inline typename ListElementConstReferenceTraits<std::optional<std::string>>::const_reference
|
| 172 |
+
list_element_to_const_ref<std::optional<std::string>>(const IValue& element) {
|
| 173 |
+
return element.toOptionalStringRef();
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
} // namespace impl
|
| 177 |
+
|
| 178 |
+
template<class T>
|
| 179 |
+
void List<T>::set(size_type pos, const value_type& value) const {
|
| 180 |
+
impl_->list.at(pos) = c10::detail::ListElementFrom<T>::from(value);
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
template<class T>
|
| 184 |
+
void List<T>::set(size_type pos, value_type&& value) const {
|
| 185 |
+
impl_->list.at(pos) = c10::detail::ListElementFrom<T>::from(std::move(value));
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
template<class T>
|
| 189 |
+
typename List<T>::internal_const_reference_type List<T>::get(size_type pos) const {
|
| 190 |
+
return operator[](pos);
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
template<class T>
|
| 194 |
+
typename List<T>::internal_const_reference_type List<T>::operator[](size_type pos) const {
|
| 195 |
+
return c10::impl::list_element_to_const_ref<T>(impl_->list.at(pos));
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
template<class T>
|
| 199 |
+
typename List<T>::internal_reference_type List<T>::operator[](size_type pos) {
|
| 200 |
+
static_cast<void>(impl_->list.at(pos)); // Throw the exception if it is out of range.
|
| 201 |
+
return {impl_->list.begin() + static_cast<typename decltype(impl_->list)::difference_type>(pos)};
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
template<class T>
|
| 205 |
+
typename List<T>::value_type List<T>::extract(size_type pos) const {
|
| 206 |
+
auto& elem = impl_->list.at(pos);
|
| 207 |
+
auto result = c10::detail::list_element_to<T>(std::move(elem));
|
| 208 |
+
// Reset the list element to a T() instead of None to keep it correctly typed
|
| 209 |
+
elem = c10::detail::ListElementFrom<T>::from(T{});
|
| 210 |
+
return result;
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
template<class T>
|
| 214 |
+
typename List<T>::iterator List<T>::begin() const {
|
| 215 |
+
return iterator(impl_->list.begin());
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
template<class T>
|
| 219 |
+
typename List<T>::iterator List<T>::end() const {
|
| 220 |
+
return iterator(impl_->list.end());
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
template<class T>
|
| 224 |
+
bool List<T>::empty() const {
|
| 225 |
+
return impl_->list.empty();
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
template<class T>
|
| 229 |
+
typename List<T>::size_type List<T>::size() const {
|
| 230 |
+
return impl_->list.size();
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
template<class T>
|
| 234 |
+
void List<T>::reserve(size_type new_cap) const {
|
| 235 |
+
impl_->list.reserve(new_cap);
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
template<class T>
|
| 239 |
+
void List<T>::clear() const {
|
| 240 |
+
impl_->list.clear();
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
template<class T>
|
| 244 |
+
typename List<T>::iterator List<T>::insert(iterator pos, const T& value) const {
|
| 245 |
+
return iterator { impl_->list.insert(pos.iterator_, c10::detail::ListElementFrom<T>::from(value)) };
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
template<class T>
|
| 249 |
+
typename List<T>::iterator List<T>::insert(iterator pos, T&& value) const {
|
| 250 |
+
return iterator { impl_->list.insert(pos.iterator_, c10::detail::ListElementFrom<T>::from(std::move(value))) };
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
template<class T>
|
| 254 |
+
template<class... Args>
|
| 255 |
+
typename List<T>::iterator List<T>::emplace(iterator pos, Args&&... value) const {
|
| 256 |
+
// TODO Use list_element_from?
|
| 257 |
+
return iterator { impl_->list.emplace(pos.iterator_, std::forward<Args>(value)...) };
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
template<class T>
|
| 261 |
+
void List<T>::push_back(const T& value) const {
|
| 262 |
+
impl_->list.push_back(c10::detail::ListElementFrom<T>::from(value));
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
template<class T>
|
| 266 |
+
void List<T>::push_back(T&& value) const {
|
| 267 |
+
impl_->list.push_back(c10::detail::ListElementFrom<T>::from(std::move(value)));
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
template<class T>
|
| 271 |
+
void List<T>::append(List<T> b) const {
|
| 272 |
+
if (b.use_count() == 1) {
|
| 273 |
+
impl_->list.insert(impl_->list.end(), make_move_iterator(b.impl_->list.begin()), make_move_iterator(b.impl_->list.end()));
|
| 274 |
+
} else {
|
| 275 |
+
impl_->list.insert(impl_->list.end(), b.impl_->list.begin(), b.impl_->list.end());
|
| 276 |
+
}
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
template<class T>
|
| 280 |
+
template<class... Args>
|
| 281 |
+
void List<T>::emplace_back(Args&&... args) const {
|
| 282 |
+
// TODO Use list_element_from?
|
| 283 |
+
impl_->list.push_back(T(std::forward<Args>(args)...));
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
template<class T>
|
| 287 |
+
typename List<T>::iterator List<T>::erase(iterator pos) const {
|
| 288 |
+
return iterator { impl_->list.erase(pos.iterator_) };
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
template<class T>
|
| 292 |
+
typename List<T>::iterator List<T>::erase(iterator first, iterator last) const {
|
| 293 |
+
return iterator { impl_->list.erase(first.iterator_, last.iterator_) };
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
template<class T>
|
| 297 |
+
void List<T>::pop_back() const {
|
| 298 |
+
impl_->list.pop_back();
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
template<class T>
|
| 302 |
+
void List<T>::resize(size_type count) const {
|
| 303 |
+
impl_->list.resize(count, T{});
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
template<class T>
|
| 307 |
+
void List<T>::resize(size_type count, const T& value) const {
|
| 308 |
+
impl_->list.resize(count, value);
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
template<class T>
|
| 312 |
+
bool operator==(const List<T>& lhs, const List<T>& rhs) {
|
| 313 |
+
// Lists with the same identity trivially compare equal.
|
| 314 |
+
if (lhs.impl_ == rhs.impl_) {
|
| 315 |
+
return true;
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
// Otherwise, just compare values directly.
|
| 319 |
+
return *lhs.impl_ == *rhs.impl_;
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
template<class T>
|
| 323 |
+
bool operator!=(const List<T>& lhs, const List<T>& rhs) {
|
| 324 |
+
return !(lhs == rhs);
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
template<class T>
|
| 328 |
+
bool List<T>::is(const List<T>& rhs) const {
|
| 329 |
+
return this->impl_ == rhs.impl_;
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
template<class T>
|
| 333 |
+
std::vector<T> List<T>::vec() const {
|
| 334 |
+
std::vector<T> result(begin(), end());
|
| 335 |
+
return result;
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
template<class T>
|
| 339 |
+
size_t List<T>::use_count() const {
|
| 340 |
+
return impl_.use_count();
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
template <class T>
|
| 344 |
+
TypePtr List<T>::elementType() const {
|
| 345 |
+
return impl_->elementType;
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
template <class T>
|
| 349 |
+
void List<T>::unsafeSetElementType(TypePtr t) {
|
| 350 |
+
impl_->elementType = std::move(t);
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
}
|
phivenv/Lib/site-packages/torch/include/ATen/core/MT19937RNGEngine.h
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/util/irange.h>
|
| 4 |
+
|
| 5 |
+
// define constants like M_PI and C keywords for MSVC
|
| 6 |
+
#ifdef _MSC_VER
|
| 7 |
+
#ifndef _USE_MATH_DEFINES
|
| 8 |
+
#define _USE_MATH_DEFINES
|
| 9 |
+
#endif
|
| 10 |
+
#include <math.h>
|
| 11 |
+
#endif
|
| 12 |
+
|
| 13 |
+
#include <array>
|
| 14 |
+
#include <cmath>
|
| 15 |
+
#include <cstdint>
|
| 16 |
+
|
| 17 |
+
namespace at {
|
| 18 |
+
|
| 19 |
+
constexpr int MERSENNE_STATE_N = 624;
|
| 20 |
+
constexpr int MERSENNE_STATE_M = 397;
|
| 21 |
+
constexpr uint32_t MATRIX_A = 0x9908b0df;
|
| 22 |
+
constexpr uint32_t UMASK = 0x80000000;
|
| 23 |
+
constexpr uint32_t LMASK = 0x7fffffff;
|
| 24 |
+
|
| 25 |
+
/**
|
| 26 |
+
* Note [Mt19937 Engine implementation]
|
| 27 |
+
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 28 |
+
* Originally implemented in:
|
| 29 |
+
* http://www.math.sci.hiroshima-u.ac.jp/~m-mat/MT/MT2002/CODES/MTARCOK/mt19937ar-cok.c
|
| 30 |
+
* and modified with C++ constructs. Moreover the state array of the engine
|
| 31 |
+
* has been modified to hold 32 bit uints instead of 64 bits.
|
| 32 |
+
*
|
| 33 |
+
* Note that we reimplemented mt19937 instead of using std::mt19937 because,
|
| 34 |
+
* at::mt19937 turns out to be faster in the pytorch codebase. PyTorch builds with -O2
|
| 35 |
+
* by default and following are the benchmark numbers (benchmark code can be found at
|
| 36 |
+
* https://github.com/syed-ahmed/benchmark-rngs):
|
| 37 |
+
*
|
| 38 |
+
* with -O2
|
| 39 |
+
* Time to get 100000000 philox randoms with at::uniform_real_distribution = 0.462759s
|
| 40 |
+
* Time to get 100000000 at::mt19937 randoms with at::uniform_real_distribution = 0.39628s
|
| 41 |
+
* Time to get 100000000 std::mt19937 randoms with std::uniform_real_distribution = 0.352087s
|
| 42 |
+
* Time to get 100000000 std::mt19937 randoms with at::uniform_real_distribution = 0.419454s
|
| 43 |
+
*
|
| 44 |
+
* std::mt19937 is faster when used in conjunction with std::uniform_real_distribution,
|
| 45 |
+
* however we can't use std::uniform_real_distribution because of this bug:
|
| 46 |
+
* http://open-std.org/JTC1/SC22/WG21/docs/lwg-active.html#2524. Plus, even if we used
|
| 47 |
+
* std::uniform_real_distribution and filtered out the 1's, it is a different algorithm
|
| 48 |
+
* than what's in pytorch currently and that messes up the tests in tests_distributions.py.
|
| 49 |
+
* The other option, using std::mt19937 with at::uniform_real_distribution is a tad bit slower
|
| 50 |
+
* than at::mt19937 with at::uniform_real_distribution and hence, we went with the latter.
|
| 51 |
+
*
|
| 52 |
+
* Copyright notice:
|
| 53 |
+
* A C-program for MT19937, with initialization improved 2002/2/10.
|
| 54 |
+
* Coded by Takuji Nishimura and Makoto Matsumoto.
|
| 55 |
+
* This is a faster version by taking Shawn Cokus's optimization,
|
| 56 |
+
* Matthe Bellew's simplification, Isaku Wada's real version.
|
| 57 |
+
*
|
| 58 |
+
* Before using, initialize the state by using init_genrand(seed)
|
| 59 |
+
* or init_by_array(init_key, key_length).
|
| 60 |
+
*
|
| 61 |
+
* Copyright (C) 1997 - 2002, Makoto Matsumoto and Takuji Nishimura,
|
| 62 |
+
* All rights reserved.
|
| 63 |
+
*
|
| 64 |
+
* Redistribution and use in source and binary forms, with or without
|
| 65 |
+
* modification, are permitted provided that the following conditions
|
| 66 |
+
* are met:
|
| 67 |
+
*
|
| 68 |
+
* 1. Redistributions of source code must retain the above copyright
|
| 69 |
+
* notice, this list of conditions and the following disclaimer.
|
| 70 |
+
*
|
| 71 |
+
* 2. Redistributions in binary form must reproduce the above copyright
|
| 72 |
+
* notice, this list of conditions and the following disclaimer in the
|
| 73 |
+
* documentation and/or other materials provided with the distribution.
|
| 74 |
+
*
|
| 75 |
+
* 3. The names of its contributors may not be used to endorse or promote
|
| 76 |
+
* products derived from this software without specific prior written
|
| 77 |
+
* permission.
|
| 78 |
+
*
|
| 79 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
| 80 |
+
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
| 81 |
+
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
| 82 |
+
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
| 83 |
+
* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
| 84 |
+
* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
| 85 |
+
* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
| 86 |
+
* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
| 87 |
+
* LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
| 88 |
+
* NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
| 89 |
+
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 90 |
+
*
|
| 91 |
+
*
|
| 92 |
+
* Any feedback is very welcome.
|
| 93 |
+
* http://www.math.sci.hiroshima-u.ac.jp/~m-mat/MT/emt.html
|
| 94 |
+
* email: m-mat @ math.sci.hiroshima-u.ac.jp (remove space)
|
| 95 |
+
*/
|
| 96 |
+
|
| 97 |
+
/**
|
| 98 |
+
* mt19937_data_pod is used to get POD data in and out
|
| 99 |
+
* of mt19937_engine. Used in torch.get_rng_state and
|
| 100 |
+
* torch.set_rng_state functions.
|
| 101 |
+
*/
|
| 102 |
+
struct mt19937_data_pod {
|
| 103 |
+
uint64_t seed_;
|
| 104 |
+
int left_;
|
| 105 |
+
bool seeded_;
|
| 106 |
+
uint32_t next_;
|
| 107 |
+
std::array<uint32_t, MERSENNE_STATE_N> state_;
|
| 108 |
+
};
|
| 109 |
+
|
| 110 |
+
class mt19937_engine {
|
| 111 |
+
public:
|
| 112 |
+
|
| 113 |
+
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
| 114 |
+
inline explicit mt19937_engine(uint64_t seed = 5489) {
|
| 115 |
+
init_with_uint32(seed);
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
inline mt19937_data_pod data() const {
|
| 119 |
+
return data_;
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
inline void set_data(const mt19937_data_pod& data) {
|
| 123 |
+
data_ = data;
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
inline uint64_t seed() const {
|
| 127 |
+
return data_.seed_;
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
inline bool is_valid() {
|
| 131 |
+
if ((data_.seeded_ == true)
|
| 132 |
+
&& (data_.left_ > 0 && data_.left_ <= MERSENNE_STATE_N)
|
| 133 |
+
&& (data_.next_ <= MERSENNE_STATE_N)) {
|
| 134 |
+
return true;
|
| 135 |
+
}
|
| 136 |
+
return false;
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
inline uint32_t operator()() {
|
| 140 |
+
if (--(data_.left_) == 0) {
|
| 141 |
+
next_state();
|
| 142 |
+
}
|
| 143 |
+
uint32_t y = *(data_.state_.data() + data_.next_++);
|
| 144 |
+
y ^= (y >> 11);
|
| 145 |
+
y ^= (y << 7) & 0x9d2c5680;
|
| 146 |
+
y ^= (y << 15) & 0xefc60000;
|
| 147 |
+
y ^= (y >> 18);
|
| 148 |
+
|
| 149 |
+
return y;
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
private:
|
| 153 |
+
mt19937_data_pod data_;
|
| 154 |
+
|
| 155 |
+
inline void init_with_uint32(uint64_t seed) {
|
| 156 |
+
data_.seed_ = seed;
|
| 157 |
+
data_.seeded_ = true;
|
| 158 |
+
data_.state_[0] = seed & 0xffffffff;
|
| 159 |
+
for (const auto j : c10::irange(1, MERSENNE_STATE_N)) {
|
| 160 |
+
data_.state_[j] = (1812433253 * (data_.state_[j-1] ^ (data_.state_[j-1] >> 30)) + j);
|
| 161 |
+
}
|
| 162 |
+
data_.left_ = 1;
|
| 163 |
+
data_.next_ = 0;
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
inline uint32_t mix_bits(uint32_t u, uint32_t v) {
|
| 167 |
+
return (u & UMASK) | (v & LMASK);
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
inline uint32_t twist(uint32_t u, uint32_t v) {
|
| 171 |
+
return (mix_bits(u,v) >> 1) ^ (v & 1 ? MATRIX_A : 0);
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
inline void next_state() {
|
| 175 |
+
uint32_t* p = data_.state_.data();
|
| 176 |
+
data_.left_ = MERSENNE_STATE_N;
|
| 177 |
+
data_.next_ = 0;
|
| 178 |
+
|
| 179 |
+
for(int j = MERSENNE_STATE_N - MERSENNE_STATE_M + 1; --j; p++) {
|
| 180 |
+
*p = p[MERSENNE_STATE_M] ^ twist(p[0], p[1]);
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
for(int j = MERSENNE_STATE_M; --j; p++) {
|
| 184 |
+
*p = p[MERSENNE_STATE_M - MERSENNE_STATE_N] ^ twist(p[0], p[1]);
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
*p = p[MERSENNE_STATE_M - MERSENNE_STATE_N] ^ twist(p[0], data_.state_[0]);
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
};
|
| 191 |
+
|
| 192 |
+
typedef mt19937_engine mt19937;
|
| 193 |
+
|
| 194 |
+
} // namespace at
|
phivenv/Lib/site-packages/torch/include/ATen/core/NamedTensor.h
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Dimname.h>
|
| 4 |
+
#include <c10/core/TensorImpl.h>
|
| 5 |
+
|
| 6 |
+
namespace at {
|
| 7 |
+
|
| 8 |
+
class TensorBase;
|
| 9 |
+
|
| 10 |
+
// XXX: This file exists because TensorImpl is in c10, but Dimname is in ATen.
|
| 11 |
+
// Due to the c10/ATen library split, TensorImpl cannot depend on Dimname,
|
| 12 |
+
// so we have a couple of workarounds.
|
| 13 |
+
//
|
| 14 |
+
// In the long term, we'll move Dimname to c10 and everything in this file
|
| 15 |
+
// can be refactored out. The main blocker for that is that "c10::Symbol"
|
| 16 |
+
// actually exists outside of c10 and needs to be moved in.
|
| 17 |
+
|
| 18 |
+
// TensorImpl has a unique_ptr<NamedTensorMetaInterface> field.
|
| 19 |
+
// XXX: Ideally we would just put std::optional<vector<Dimname>> into TensorImpl.
|
| 20 |
+
//
|
| 21 |
+
// This class has an important invariant: there must be at least ONE
|
| 22 |
+
// non-wildcard
|
| 23 |
+
struct TORCH_API NamedTensorMeta final : public c10::NamedTensorMetaInterface {
|
| 24 |
+
// This enum is to remind people that the invariant on constructors is that
|
| 25 |
+
// the list of dimnames must have at least one non-wildcard
|
| 26 |
+
enum HAS_NON_WILDCARD {
|
| 27 |
+
HasNonWildcard
|
| 28 |
+
};
|
| 29 |
+
|
| 30 |
+
explicit NamedTensorMeta(HAS_NON_WILDCARD, DimnameList names)
|
| 31 |
+
: names_(names.vec()) {
|
| 32 |
+
check_invariants();
|
| 33 |
+
}
|
| 34 |
+
explicit NamedTensorMeta(HAS_NON_WILDCARD, std::vector<Dimname>&& names)
|
| 35 |
+
: names_(std::move(names)) {
|
| 36 |
+
check_invariants();
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
std::unique_ptr<c10::NamedTensorMetaInterface> clone() const override {
|
| 40 |
+
return std::make_unique<NamedTensorMeta>(HasNonWildcard, names_);
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
DimnameList names() const { return names_; }
|
| 44 |
+
|
| 45 |
+
// Used for an assertion in TensorImpl.h
|
| 46 |
+
int64_t slow_dim() const override {
|
| 47 |
+
return static_cast<int64_t>(names_.size());
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
void check_invariants() const {
|
| 51 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
| 52 |
+
std::any_of(names_.begin(), names_.end(), [](const Dimname& n) { return !n.isWildcard(); }));
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
void set_names(HAS_NON_WILDCARD, DimnameList new_names) {
|
| 56 |
+
TORCH_INTERNAL_ASSERT(new_names.size() == names_.size());
|
| 57 |
+
std::copy(new_names.begin(), new_names.end(), names_.begin());
|
| 58 |
+
check_invariants();
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
void set_names(HAS_NON_WILDCARD, std::vector<Dimname>&& new_names) {
|
| 62 |
+
TORCH_INTERNAL_ASSERT(new_names.size() == names_.size());
|
| 63 |
+
names_ = std::move(new_names);
|
| 64 |
+
check_invariants();
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
// INVARIANT: at least one Dimname is non-WILDCARD
|
| 68 |
+
std::vector<Dimname> names_;
|
| 69 |
+
};
|
| 70 |
+
|
| 71 |
+
// When NamesMode is disabled, then all operations ignore tensors' names fields.
|
| 72 |
+
// Concretely speaking, all tensors are treated as having nullopt names.
|
| 73 |
+
struct TORCH_API NamesMode {
|
| 74 |
+
static bool is_enabled();
|
| 75 |
+
static void set_enabled(bool enabled);
|
| 76 |
+
};
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
// A RAII, thread local (!) guard that enables or disables names upon
|
| 80 |
+
// construction, and sets it back to the original value upon destruction.
|
| 81 |
+
struct TORCH_API NoNamesGuard {
|
| 82 |
+
NoNamesGuard() : prev_mode(NamesMode::is_enabled()) {
|
| 83 |
+
NamesMode::set_enabled(false);
|
| 84 |
+
}
|
| 85 |
+
NoNamesGuard(const NoNamesGuard&) = delete;
|
| 86 |
+
NoNamesGuard(NoNamesGuard&&) = delete;
|
| 87 |
+
NoNamesGuard& operator=(const NoNamesGuard&) = delete;
|
| 88 |
+
NoNamesGuard& operator=(NoNamesGuard&&) = delete;
|
| 89 |
+
~NoNamesGuard() {
|
| 90 |
+
if (initialized) {
|
| 91 |
+
reset();
|
| 92 |
+
}
|
| 93 |
+
}
|
| 94 |
+
void reset() {
|
| 95 |
+
TORCH_INTERNAL_ASSERT(initialized);
|
| 96 |
+
NamesMode::set_enabled(prev_mode);
|
| 97 |
+
}
|
| 98 |
+
private:
|
| 99 |
+
bool prev_mode;
|
| 100 |
+
bool initialized{true};
|
| 101 |
+
};
|
| 102 |
+
|
| 103 |
+
void check_names_valid_for(const TensorBase& tensor, DimnameList names);
|
| 104 |
+
void check_names_valid_for(size_t tensor_dim, DimnameList names);
|
| 105 |
+
|
| 106 |
+
// Sets the names of `tensor` to be `names`.
|
| 107 |
+
TORCH_API const TensorBase& internal_set_names_inplace(const TensorBase& tensor, std::optional<DimnameList> names);
|
| 108 |
+
TORCH_API const TensorBase& internal_set_names_inplace(const TensorBase& tensor, std::vector<Dimname>&& names, bool validate_names);
|
| 109 |
+
|
| 110 |
+
constexpr size_t kMaxNamedTensorDim = 64;
|
| 111 |
+
|
| 112 |
+
DimnameList default_names(size_t len);
|
| 113 |
+
|
| 114 |
+
namespace impl {
|
| 115 |
+
|
| 116 |
+
// Some helper functions on TensorImpl. Useful for working with names in TH.
|
| 117 |
+
// XXX: Ideally these would exist as methods on TensorImpl
|
| 118 |
+
TORCH_API void internal_set_names_inplace(TensorImpl* impl, std::optional<DimnameList> names, bool validate_names);
|
| 119 |
+
TORCH_API void internal_set_names_inplace(TensorImpl* impl, std::vector<Dimname>&& names, bool validate_names);
|
| 120 |
+
|
| 121 |
+
void check_names_valid_for(TensorImpl* impl, DimnameList names);
|
| 122 |
+
|
| 123 |
+
// Returns true if the tensor's names exist and are not all 'None'.
|
| 124 |
+
// Returns false if the tensor's names don't exist (were not allocated),
|
| 125 |
+
// or if all names are 'None'.
|
| 126 |
+
// We treat not-allocated-names the same as allocated names that are all 'None'.
|
| 127 |
+
TORCH_API bool has_names(const TensorImpl* impl);
|
| 128 |
+
|
| 129 |
+
// Returns the names of the tensor's dimensions.
|
| 130 |
+
// Unnamed tensors are treated as having 'None' in all dimension; this method
|
| 131 |
+
// would return a DimnameList of all 'None's for an unnamed tensor.
|
| 132 |
+
TORCH_API DimnameList get_names(const TensorImpl* impl);
|
| 133 |
+
|
| 134 |
+
// This is more of an implementation detail; one should use impl::get_names /
|
| 135 |
+
// Tensor::names() whenever possible because it provides a cleaner API.
|
| 136 |
+
// Returns the names of the tensor if they have been allocated; returns nullopt
|
| 137 |
+
// instead if the haven't been. The names of a tensor are not allocated if a
|
| 138 |
+
// tensor is constructed with names=None.
|
| 139 |
+
TORCH_API std::optional<DimnameList> get_opt_names(const TensorImpl* impl);
|
| 140 |
+
|
| 141 |
+
} // namespace impl
|
| 142 |
+
|
| 143 |
+
} // namespace at
|
phivenv/Lib/site-packages/torch/include/ATen/core/NestedIntSymNodeImpl.h
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/ConstantSymNodeImpl.h>
|
| 4 |
+
#include <c10/core/SymNodeImpl.h>
|
| 5 |
+
#include <c10/macros/Export.h>
|
| 6 |
+
#include <c10/util/Exception.h>
|
| 7 |
+
#include <c10/util/intrusive_ptr.h>
|
| 8 |
+
#include <cstdint>
|
| 9 |
+
#include <optional>
|
| 10 |
+
#include <string>
|
| 11 |
+
|
| 12 |
+
namespace c10 {
|
| 13 |
+
|
| 14 |
+
// The motivating usecase for this is to represent the ragged size structure
|
| 15 |
+
// of a jagged tensor [B, [s_0, s_1, s_2], D] as a single integer j0. This
|
| 16 |
+
// allows us to simply return [B, j0, D] if someone queries for the size of our
|
| 17 |
+
// tensor.
|
| 18 |
+
//
|
| 19 |
+
// Morally we define comparison between two nested ints to return true if
|
| 20 |
+
// that comparison holds for all corresponding elements of the arrays they
|
| 21 |
+
// represent. Comparison between a nested int and a plain int is defined
|
| 22 |
+
// similarly.
|
| 23 |
+
//
|
| 24 |
+
// To simulate this desired behavior but also avoid the O(N) cost of checking,
|
| 25 |
+
// we associate each raggedness pattern with an integer "id" that can be used as
|
| 26 |
+
// a proxy to evaluate equality. We also constrain the range of values for this
|
| 27 |
+
// as to enable inequality checks.
|
| 28 |
+
//
|
| 29 |
+
// We also support a positive integer scalar "coeff" that is used for computing
|
| 30 |
+
// strides. For example given, a [B, j0, D] tensor, it can be strided in two
|
| 31 |
+
// different ways: [D * j0, D, 1] and [j0, 1, sum(j0)]. The coeff is used to
|
| 32 |
+
// differentiate the two cases.
|
| 33 |
+
//
|
| 34 |
+
// During tracing the strides of the outputs need to be a function of the size
|
| 35 |
+
// and strides of the inputs so it is important that NestedIntSymNode itself is
|
| 36 |
+
// able to express this.
|
| 37 |
+
class TORCH_API NestedIntSymNodeImpl : public SymNodeImpl {
|
| 38 |
+
public:
|
| 39 |
+
// CAUTION: you should probably not be constructing these directly; please
|
| 40 |
+
// the higher-level API in python instead (TODO: actually introduce that).
|
| 41 |
+
explicit NestedIntSymNodeImpl(int64_t val, int64_t coeff)
|
| 42 |
+
: val_(val), coeff_(coeff) {}
|
| 43 |
+
|
| 44 |
+
bool bool_() override {
|
| 45 |
+
return false;
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
bool is_int() override {
|
| 49 |
+
return true;
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
bool is_float() override {
|
| 53 |
+
return false;
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
bool is_bool() override {
|
| 57 |
+
return false;
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
bool is_nested_int() const override {
|
| 61 |
+
return true;
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
bool has_hint() override {
|
| 65 |
+
return true;
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
c10::SymNode wrap_int(int64_t num) override {
|
| 69 |
+
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<int64_t>>(num));
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
int64_t guard_int(const char* file, int64_t line) override {
|
| 73 |
+
TORCH_CHECK(false);
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
double guard_float(const char* file, int64_t line) override {
|
| 77 |
+
TORCH_CHECK(false, "not a float");
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
bool guard_bool(const char* file, int64_t line) override {
|
| 81 |
+
TORCH_CHECK(false, "not a bool");
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
int64_t int_() override {
|
| 85 |
+
TORCH_CHECK(false);
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
std::string str() override {
|
| 89 |
+
if (coeff_ == 1) {
|
| 90 |
+
return "j" + std::to_string(val_);
|
| 91 |
+
}
|
| 92 |
+
return std::to_string(coeff_) + "*j" + std::to_string(val_);
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
// NOTE [ Inequalities with nested int ]
|
| 96 |
+
//
|
| 97 |
+
// The semantics of nested int when it comes to relations is that it is
|
| 98 |
+
// treated as integer known to be within a certain range,
|
| 99 |
+
//
|
| 100 |
+
// j0 \in [2, int64_t::max]
|
| 101 |
+
//
|
| 102 |
+
// allowing us to answer queries like j0 >= 1 (True), and j0 == 0 (False).
|
| 103 |
+
// This is a useful default range for the raggedness pattern of a jagged
|
| 104 |
+
// tensor (1) since sizes are non-negative, and (2) we need to get past 0/1
|
| 105 |
+
// specialization checks.
|
| 106 |
+
//
|
| 107 |
+
// [ Indeterminate inequalities error out ]
|
| 108 |
+
//
|
| 109 |
+
// Given the semantic defined above, certain relations like j0 < 3 are thus
|
| 110 |
+
// indeterminable. In our impl today, evaluating such relations error
|
| 111 |
+
//
|
| 112 |
+
// It may seem convenient to just define indeterminate relations to return
|
| 113 |
+
// False, but the implementation we maintain in parallel using sympy does not
|
| 114 |
+
// allow this.
|
| 115 |
+
//
|
| 116 |
+
// Sympy only allows overriding of Ge. The other relations (Lt, Gt, Le) are,
|
| 117 |
+
// by consequence, all derived from Ge e.g., Lt(a, b) := !Ge(a, b). This
|
| 118 |
+
// would mean that means that if we define the indeterminate j0 >= 3 to be
|
| 119 |
+
// False, the also indeterminate j0 < 3 will be evaluated to be True!
|
| 120 |
+
//
|
| 121 |
+
// [ Coefficient are assumed positive ]
|
| 122 |
+
//
|
| 123 |
+
// For the purpose of computing inequalities, we consider the coefficient of
|
| 124 |
+
// the nested int to be a positive integer.
|
| 125 |
+
//
|
| 126 |
+
// Thus, no modifications are needed to the logic since
|
| 127 |
+
// j0 >= k implies coeff * j0 >= k
|
| 128 |
+
//
|
| 129 |
+
c10::SymNode eq(const c10::SymNode& other) override;
|
| 130 |
+
c10::SymNode ne(const c10::SymNode& other) override;
|
| 131 |
+
c10::SymNode ge(const c10::SymNode& other) override;
|
| 132 |
+
c10::SymNode gt(const c10::SymNode& other) override;
|
| 133 |
+
c10::SymNode lt(const c10::SymNode& other) override;
|
| 134 |
+
c10::SymNode le(const c10::SymNode& other) override;
|
| 135 |
+
c10::SymNode mul(const c10::SymNode& other) override;
|
| 136 |
+
|
| 137 |
+
std::optional<int64_t> nested_int() override {
|
| 138 |
+
return val_;
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
std::optional<int64_t> nested_int_coeff() override {
|
| 142 |
+
return coeff_;
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
bool is_symbolic() override {
|
| 146 |
+
return false;
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
c10::SymNode clone() override;
|
| 150 |
+
|
| 151 |
+
#define DEFINE_BINARY_NOT_SUPPORTED(name) \
|
| 152 |
+
c10::SymNode name(const c10::SymNode& other) override { \
|
| 153 |
+
TORCH_CHECK(false, #name " not supported by NestedIntSymNode"); \
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
DEFINE_BINARY_NOT_SUPPORTED(add)
|
| 157 |
+
DEFINE_BINARY_NOT_SUPPORTED(sub)
|
| 158 |
+
DEFINE_BINARY_NOT_SUPPORTED(truediv)
|
| 159 |
+
DEFINE_BINARY_NOT_SUPPORTED(pow)
|
| 160 |
+
DEFINE_BINARY_NOT_SUPPORTED(floordiv)
|
| 161 |
+
DEFINE_BINARY_NOT_SUPPORTED(mod)
|
| 162 |
+
DEFINE_BINARY_NOT_SUPPORTED(sym_min)
|
| 163 |
+
DEFINE_BINARY_NOT_SUPPORTED(sym_max)
|
| 164 |
+
DEFINE_BINARY_NOT_SUPPORTED(sym_and)
|
| 165 |
+
DEFINE_BINARY_NOT_SUPPORTED(sym_or)
|
| 166 |
+
|
| 167 |
+
#undef DEFINE_BINARY_NOT_SUPPORTED
|
| 168 |
+
|
| 169 |
+
#define DEFINE_NOT_SUPPORTED(name) \
|
| 170 |
+
c10::SymNode name() override { \
|
| 171 |
+
TORCH_CHECK(false, #name " is not supported by NestedIntSymNode"); \
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
DEFINE_NOT_SUPPORTED(sym_not)
|
| 175 |
+
DEFINE_NOT_SUPPORTED(ceil)
|
| 176 |
+
DEFINE_NOT_SUPPORTED(floor)
|
| 177 |
+
DEFINE_NOT_SUPPORTED(neg)
|
| 178 |
+
DEFINE_NOT_SUPPORTED(sym_float)
|
| 179 |
+
|
| 180 |
+
#undef DEFINE_NOT_SUPPORTED
|
| 181 |
+
|
| 182 |
+
private:
|
| 183 |
+
int64_t val_;
|
| 184 |
+
int64_t coeff_;
|
| 185 |
+
};
|
| 186 |
+
|
| 187 |
+
} // namespace c10
|
phivenv/Lib/site-packages/torch/include/ATen/core/PhiloxRNGEngine.h
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// define constants like M_PI and C keywords for MSVC
|
| 4 |
+
#ifdef _MSC_VER
|
| 5 |
+
#define _USE_MATH_DEFINES
|
| 6 |
+
#include <math.h>
|
| 7 |
+
#endif
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
#ifdef __CUDACC__
|
| 11 |
+
#include <cuda.h>
|
| 12 |
+
#endif
|
| 13 |
+
|
| 14 |
+
#include <array>
|
| 15 |
+
#include <c10/macros/Macros.h>
|
| 16 |
+
#include <cmath>
|
| 17 |
+
#include <cstdint>
|
| 18 |
+
|
| 19 |
+
namespace at {
|
| 20 |
+
|
| 21 |
+
// typedefs for holding vector data
|
| 22 |
+
namespace detail {
|
| 23 |
+
|
| 24 |
+
typedef std::array<uint32_t, 4> UINT4;
|
| 25 |
+
typedef std::array<uint32_t, 2> UINT2;
|
| 26 |
+
typedef std::array<double, 2> DOUBLE2;
|
| 27 |
+
typedef std::array<float, 2> FLOAT2;
|
| 28 |
+
|
| 29 |
+
} // namespace detail
|
| 30 |
+
|
| 31 |
+
/**
|
| 32 |
+
* Note [Philox Engine implementation]
|
| 33 |
+
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 34 |
+
* Originally implemented in PyTorch's fusion compiler
|
| 35 |
+
* Refer to: http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
|
| 36 |
+
* for details regarding the engine.
|
| 37 |
+
*
|
| 38 |
+
* Note that currently this implementation of the philox engine is not used
|
| 39 |
+
* anywhere except for tests in cpu_generator_test.cpp. However, this engine
|
| 40 |
+
* will replace curandStatePhilox4_32_10_t in the future.
|
| 41 |
+
*
|
| 42 |
+
* The philox engine takes a seed value, a subsequeunce
|
| 43 |
+
* for starting the generation and an offset for the subsequence.
|
| 44 |
+
* Think of this engine as an algorithm producing a huge array. We are
|
| 45 |
+
* parallelizing this array by partitioning the huge array and assigning
|
| 46 |
+
* a thread index to each partition. In other words, each seed value
|
| 47 |
+
* (there are 2^64 possible seed values) gives a sub array of size
|
| 48 |
+
* 2^128 (each element in that array is a 128 bit number). Reasoning
|
| 49 |
+
* behind the array being of size 2^128 is, there are 2^64 possible
|
| 50 |
+
* thread index value and there is an array of size 2^64 for each of
|
| 51 |
+
* those thread index. Hence 2^64 * 2^64 = 2^128 for each seed value.
|
| 52 |
+
*
|
| 53 |
+
* In short, this generator can produce 2^64 (seed values) * 2^128 (number
|
| 54 |
+
* of elements in an array given by a seed value) = 2^192 values.
|
| 55 |
+
*
|
| 56 |
+
* Arguments:
|
| 57 |
+
* seed: Seed values could be any number from 0 to 2^64-1.
|
| 58 |
+
* subsequence: Subsequence is just the cuda thread indexing with:
|
| 59 |
+
* - blockIdx.x * blockDim.x + threadIdx.x
|
| 60 |
+
* offset: The offset variable in PhiloxEngine decides how many 128-bit
|
| 61 |
+
* random numbers to skip (i.e. how many groups of 4, 32-bit numbers to skip)
|
| 62 |
+
* and hence really decides the total number of randoms that can be achieved
|
| 63 |
+
* for the given subsequence.
|
| 64 |
+
*/
|
| 65 |
+
|
| 66 |
+
class philox_engine {
|
| 67 |
+
public:
|
| 68 |
+
|
| 69 |
+
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
| 70 |
+
C10_HOST_DEVICE inline explicit philox_engine(uint64_t seed = 67280421310721,
|
| 71 |
+
uint64_t subsequence = 0,
|
| 72 |
+
uint64_t offset = 0) {
|
| 73 |
+
|
| 74 |
+
reset_state(seed, subsequence);
|
| 75 |
+
incr_n(offset);
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
C10_HOST_DEVICE inline void reset_state(uint64_t seed = 67280421310721,
|
| 79 |
+
uint64_t subsequence = 0) {
|
| 80 |
+
key_[0] = static_cast<uint32_t>(seed);
|
| 81 |
+
key_[1] = static_cast<uint32_t>(seed >> 32);
|
| 82 |
+
counter_ = detail::UINT4{};
|
| 83 |
+
counter_[2] = static_cast<uint32_t>(subsequence);
|
| 84 |
+
counter_[3] = static_cast<uint32_t>(subsequence >> 32);
|
| 85 |
+
STATE = 0;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
/**
|
| 89 |
+
* Set the offset field of Philox Generator to the desired offset.
|
| 90 |
+
*/
|
| 91 |
+
C10_HOST_DEVICE inline void set_offset(uint64_t offset) {
|
| 92 |
+
counter_[0] = static_cast<uint32_t>(offset);
|
| 93 |
+
counter_[1] = static_cast<uint32_t>(offset >> 32);
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
/**
|
| 97 |
+
* Gets the current offset of the Philox Generator.
|
| 98 |
+
*/
|
| 99 |
+
C10_HOST_DEVICE uint64_t get_offset() const {
|
| 100 |
+
uint64_t lo = static_cast<uint64_t>(counter_[0]);
|
| 101 |
+
uint64_t hi = static_cast<uint64_t>(counter_[1]) << 32;
|
| 102 |
+
return lo | hi;
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
/**
|
| 106 |
+
* Produces a unique 32-bit pseudo random number on every invocation. Bookeeps state to avoid waste.
|
| 107 |
+
*/
|
| 108 |
+
C10_HOST_DEVICE inline uint32_t operator()(int32_t n_rounds = 10) { // 10 here to preserve back-compat behavior
|
| 109 |
+
if(STATE == 0) {
|
| 110 |
+
detail::UINT4 counter = counter_;
|
| 111 |
+
detail::UINT2 key = key_;
|
| 112 |
+
output_ = rand(counter, key, n_rounds);
|
| 113 |
+
incr();
|
| 114 |
+
}
|
| 115 |
+
uint32_t ret = output_[static_cast<int>(STATE)];
|
| 116 |
+
STATE = (STATE + 1) & 3;
|
| 117 |
+
return ret;
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
inline float randn(uint32_t n_rounds) {
|
| 121 |
+
#ifdef __CUDA_ARCH__
|
| 122 |
+
AT_ASSERT(false, "Unsupported invocation of randn on CUDA");
|
| 123 |
+
#endif
|
| 124 |
+
if(STATE == 0) {
|
| 125 |
+
detail::UINT4 counter = counter_;
|
| 126 |
+
detail::UINT2 key = key_;
|
| 127 |
+
output_ = rand(counter, key, n_rounds);
|
| 128 |
+
incr();
|
| 129 |
+
}
|
| 130 |
+
// TODO(min-jean-cho) change to Polar method, a more efficient version of Box-Muller method
|
| 131 |
+
// TODO(voz) We use std:: below, and thus need a separate impl for CUDA.
|
| 132 |
+
float u1 = 1 - uint32_to_uniform_float(output_[0]); // uint32_to_uniform_float returns [0,1), we need (0,1] to avoid passing 0 to log.
|
| 133 |
+
float u2 = 1 - uint32_to_uniform_float(output_[1]);
|
| 134 |
+
return static_cast<float>(std::sqrt(-2.0 * std::log(u1)) * std::cos(2.0 * M_PI * u2));
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
/**
|
| 138 |
+
* Function that Skips N 128 bit numbers in a subsequence
|
| 139 |
+
*/
|
| 140 |
+
C10_HOST_DEVICE inline void incr_n(uint64_t n) {
|
| 141 |
+
uint32_t nlo = static_cast<uint32_t>(n);
|
| 142 |
+
uint32_t nhi = static_cast<uint32_t>(n >> 32);
|
| 143 |
+
counter_[0] += nlo;
|
| 144 |
+
// if overflow in x has occurred, carry over to nhi
|
| 145 |
+
if (counter_[0] < nlo) {
|
| 146 |
+
nhi++;
|
| 147 |
+
// if overflow in nhi has occurred during carry over,
|
| 148 |
+
// propagate that overflow to y and exit to increment z
|
| 149 |
+
// otherwise return
|
| 150 |
+
counter_[1] += nhi;
|
| 151 |
+
if(nhi != 0) {
|
| 152 |
+
if (nhi <= counter_[1]) {
|
| 153 |
+
return;
|
| 154 |
+
}
|
| 155 |
+
}
|
| 156 |
+
} else {
|
| 157 |
+
// if overflow in y has occurred during addition,
|
| 158 |
+
// exit to increment z
|
| 159 |
+
// otherwise return
|
| 160 |
+
counter_[1] += nhi;
|
| 161 |
+
if (nhi <= counter_[1]) {
|
| 162 |
+
return;
|
| 163 |
+
}
|
| 164 |
+
}
|
| 165 |
+
if (++counter_[2])
|
| 166 |
+
return;
|
| 167 |
+
++counter_[3];
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
/**
|
| 171 |
+
* Function that Skips one 128 bit number in a subsequence
|
| 172 |
+
*/
|
| 173 |
+
C10_HOST_DEVICE inline void incr() {
|
| 174 |
+
if (++counter_[0])
|
| 175 |
+
return;
|
| 176 |
+
if (++counter_[1])
|
| 177 |
+
return;
|
| 178 |
+
if (++counter_[2]) {
|
| 179 |
+
return;
|
| 180 |
+
}
|
| 181 |
+
++counter_[3];
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
private:
|
| 185 |
+
detail::UINT4 counter_;
|
| 186 |
+
detail::UINT4 output_;
|
| 187 |
+
detail::UINT2 key_;
|
| 188 |
+
uint32_t STATE;
|
| 189 |
+
|
| 190 |
+
C10_HOST_DEVICE inline uint32_t mulhilo32(uint32_t a, uint32_t b,
|
| 191 |
+
uint32_t *result_high) {
|
| 192 |
+
#ifdef __CUDA_ARCH__
|
| 193 |
+
*result_high = __umulhi(a, b);
|
| 194 |
+
return a*b;
|
| 195 |
+
#else
|
| 196 |
+
const uint64_t product = static_cast<uint64_t>(a) * b;
|
| 197 |
+
*result_high = static_cast<uint32_t>(product >> 32);
|
| 198 |
+
return static_cast<uint32_t>(product);
|
| 199 |
+
#endif
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
C10_HOST_DEVICE inline detail::UINT4 single_round(detail::UINT4 ctr, detail::UINT2 in_key) {
|
| 203 |
+
uint32_t hi0 = 0;
|
| 204 |
+
uint32_t hi1 = 0;
|
| 205 |
+
uint32_t lo0 = mulhilo32(kPhiloxSA, ctr[0], &hi0);
|
| 206 |
+
uint32_t lo1 = mulhilo32(kPhiloxSB, ctr[2], &hi1);
|
| 207 |
+
detail::UINT4 ret;
|
| 208 |
+
ret[0] = hi1 ^ ctr[1] ^ in_key[0];
|
| 209 |
+
ret[1] = lo1;
|
| 210 |
+
ret[2] = hi0 ^ ctr[3] ^ in_key[1];
|
| 211 |
+
ret[3] = lo0;
|
| 212 |
+
return ret;
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
C10_HOST_DEVICE constexpr float uint32_to_uniform_float(uint32_t value) {
|
| 216 |
+
// maximum value such that `MAX_INT * scale < 1.0` (with float rounding)
|
| 217 |
+
constexpr float scale = 4.6566127342e-10;
|
| 218 |
+
return static_cast<float>(value & 0x7FFFFFFF) * scale;
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
C10_HOST_DEVICE inline detail::UINT4 rand(detail::UINT4& counter, detail::UINT2& key, uint32_t n_rounds) {
|
| 224 |
+
for (uint32_t round = 0; round < (n_rounds - 1); round++) {
|
| 225 |
+
counter = single_round(counter, key);
|
| 226 |
+
key[0] += (kPhilox10A); key[1] += (kPhilox10B);
|
| 227 |
+
}
|
| 228 |
+
return single_round(counter, key);
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
static const uint32_t kPhilox10A = 0x9E3779B9;
|
| 233 |
+
static const uint32_t kPhilox10B = 0xBB67AE85;
|
| 234 |
+
static const uint32_t kPhiloxSA = 0xD2511F53;
|
| 235 |
+
static const uint32_t kPhiloxSB = 0xCD9E8D57;
|
| 236 |
+
};
|
| 237 |
+
|
| 238 |
+
typedef philox_engine Philox4_32;
|
| 239 |
+
|
| 240 |
+
} // namespace at
|
phivenv/Lib/site-packages/torch/include/ATen/core/PythonFallbackKernel.h
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/TorchDispatchUtils.h>
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
namespace at::impl {
|
| 6 |
+
|
| 7 |
+
struct TORCH_API RestorePythonTLSSnapshot {
|
| 8 |
+
RestorePythonTLSSnapshot();
|
| 9 |
+
RestorePythonTLSSnapshot(RestorePythonTLSSnapshot&& other) = delete;
|
| 10 |
+
RestorePythonTLSSnapshot(const RestorePythonTLSSnapshot&) = delete;
|
| 11 |
+
RestorePythonTLSSnapshot& operator=(const RestorePythonTLSSnapshot&) = delete;
|
| 12 |
+
RestorePythonTLSSnapshot& operator=(RestorePythonTLSSnapshot&&) = delete;
|
| 13 |
+
~RestorePythonTLSSnapshot();
|
| 14 |
+
|
| 15 |
+
private:
|
| 16 |
+
c10::impl::LocalDispatchKeySet saved_;
|
| 17 |
+
c10::impl::ForceDispatchKeyGuard guard_;
|
| 18 |
+
};
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
// RAII guard to make working with the above TLS safer.
|
| 22 |
+
struct TORCH_API MaybeSetTLSOnEntryGuard {
|
| 23 |
+
public:
|
| 24 |
+
MaybeSetTLSOnEntryGuard();
|
| 25 |
+
MaybeSetTLSOnEntryGuard(MaybeSetTLSOnEntryGuard&& other) = delete;
|
| 26 |
+
MaybeSetTLSOnEntryGuard(const MaybeSetTLSOnEntryGuard&) = delete;
|
| 27 |
+
MaybeSetTLSOnEntryGuard& operator=(const MaybeSetTLSOnEntryGuard&) = delete;
|
| 28 |
+
MaybeSetTLSOnEntryGuard& operator=(MaybeSetTLSOnEntryGuard&&) = delete;
|
| 29 |
+
~MaybeSetTLSOnEntryGuard();
|
| 30 |
+
|
| 31 |
+
private:
|
| 32 |
+
bool value_set_;
|
| 33 |
+
};
|
| 34 |
+
|
| 35 |
+
} // namespace at::impl
|
phivenv/Lib/site-packages/torch/include/ATen/core/PythonOpRegistrationTrampoline.h
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/dispatch/Dispatcher.h>
|
| 4 |
+
|
| 5 |
+
// TODO: this can probably live in c10
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
namespace at::impl {
|
| 9 |
+
|
| 10 |
+
class TORCH_API PythonOpRegistrationTrampoline final {
|
| 11 |
+
static std::atomic<c10::impl::PyInterpreter*> interpreter_;
|
| 12 |
+
|
| 13 |
+
public:
|
| 14 |
+
// Returns true if you successfully registered yourself (that means
|
| 15 |
+
// you are in the hot seat for doing the operator registrations!)
|
| 16 |
+
static bool registerInterpreter(c10::impl::PyInterpreter*);
|
| 17 |
+
|
| 18 |
+
// Returns nullptr if no interpreter has been registered yet.
|
| 19 |
+
static c10::impl::PyInterpreter* getInterpreter();
|
| 20 |
+
};
|
| 21 |
+
|
| 22 |
+
} // namespace at::impl
|
phivenv/Lib/site-packages/torch/include/ATen/core/QuantizerBase.h
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/ScalarType.h>
|
| 4 |
+
#include <c10/core/QScheme.h>
|
| 5 |
+
#include <c10/util/intrusive_ptr.h>
|
| 6 |
+
|
| 7 |
+
namespace at {
|
| 8 |
+
|
| 9 |
+
class Tensor;
|
| 10 |
+
struct QTensorImpl;
|
| 11 |
+
struct Quantizer;
|
| 12 |
+
using ConstQuantizerPtr = const c10::intrusive_ptr<Quantizer>&;
|
| 13 |
+
using QuantizerPtr = c10::intrusive_ptr<Quantizer>;
|
| 14 |
+
|
| 15 |
+
/**
|
| 16 |
+
* Quantizer is the class for storing all the information
|
| 17 |
+
* that's necessary to perform quantize and dequantize
|
| 18 |
+
* operation.
|
| 19 |
+
*
|
| 20 |
+
* We might have different types of quantization schemes and this is
|
| 21 |
+
* the base class for all quantizers.
|
| 22 |
+
*
|
| 23 |
+
* QTensorImpl will hold a pointer to Quantizer so that we can support
|
| 24 |
+
* different quantization schemes on Tensor.
|
| 25 |
+
*
|
| 26 |
+
* For example, the most common quantization scheme, Affine Quantization,
|
| 27 |
+
* requires scale and zero_point as parameters, we'll store scale and zero_point
|
| 28 |
+
* inside the instance and we can use it to quantize a float Tensor or
|
| 29 |
+
* dequantize a quantized Tensor.
|
| 30 |
+
*
|
| 31 |
+
* When you add new types of leaf Quantizer class, please also
|
| 32 |
+
* make sure to add a corresponding QScheme enum since
|
| 33 |
+
* they should have one to one mapping.
|
| 34 |
+
*
|
| 35 |
+
* Note about intrusive_ptr:
|
| 36 |
+
* Quantized Tensor holds an intrusive_ptr to Quantizer, and multiple Tensor can
|
| 37 |
+
* share the same Quantizer. Quantizer should be immutable.
|
| 38 |
+
*/
|
| 39 |
+
struct TORCH_API Quantizer : public c10::intrusive_ptr_target {
|
| 40 |
+
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
| 41 |
+
const ScalarType scalar_type_;
|
| 42 |
+
explicit Quantizer(ScalarType scalar_type) : scalar_type_(scalar_type) {}
|
| 43 |
+
~Quantizer() override = default;
|
| 44 |
+
|
| 45 |
+
// Copied from torch/csrc/jit/ir/scope.h
|
| 46 |
+
QuantizerPtr intrusive_from_this() {
|
| 47 |
+
c10::raw::intrusive_ptr::incref(this); // we are creating a new pointer
|
| 48 |
+
// from a raw `this` pointer
|
| 49 |
+
// so we need to bump the refcount
|
| 50 |
+
// to account for this ownership
|
| 51 |
+
return c10::intrusive_ptr<Quantizer>::reclaim(this);
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
/**
|
| 55 |
+
* Each concrete Quantizer type should have a unique QScheme type.
|
| 56 |
+
*/
|
| 57 |
+
virtual QScheme qscheme() const = 0;
|
| 58 |
+
|
| 59 |
+
ScalarType scalar_type() const {
|
| 60 |
+
return scalar_type_;
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
/**
|
| 64 |
+
* quantize a float Tensor into a quantized Tensor.
|
| 65 |
+
*/
|
| 66 |
+
virtual Tensor quantize(const Tensor& t) = 0;
|
| 67 |
+
|
| 68 |
+
/**
|
| 69 |
+
* dequantize a quantized Tensor into a float Tensor.
|
| 70 |
+
*/
|
| 71 |
+
virtual Tensor dequantize(const Tensor& t) = 0;
|
| 72 |
+
|
| 73 |
+
/**
|
| 74 |
+
* dequantize a quantized Tensor into a float Tensor, out= variant
|
| 75 |
+
*/
|
| 76 |
+
virtual Tensor& dequantize_out(Tensor& out, const Tensor& t) = 0;
|
| 77 |
+
|
| 78 |
+
/**
|
| 79 |
+
* Compare against `other` for equality.
|
| 80 |
+
*/
|
| 81 |
+
virtual bool equalTo(QuantizerPtr other) const = 0;
|
| 82 |
+
};
|
| 83 |
+
|
| 84 |
+
} // namespace at
|
phivenv/Lib/site-packages/torch/include/ATen/core/Range.h
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cstdint>
|
| 4 |
+
#include <iosfwd>
|
| 5 |
+
|
| 6 |
+
namespace at {
|
| 7 |
+
|
| 8 |
+
struct Range {
|
| 9 |
+
Range(int64_t begin, int64_t end)
|
| 10 |
+
: begin(begin)
|
| 11 |
+
, end(end) {}
|
| 12 |
+
|
| 13 |
+
int64_t size() const { return end - begin; }
|
| 14 |
+
|
| 15 |
+
Range operator/(int64_t divisor) {
|
| 16 |
+
return Range(begin / divisor, end / divisor);
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
int64_t begin;
|
| 20 |
+
int64_t end;
|
| 21 |
+
};
|
| 22 |
+
|
| 23 |
+
std::ostream& operator<<(std::ostream& out, const Range& range);
|
| 24 |
+
|
| 25 |
+
} // namespace at
|
phivenv/Lib/site-packages/torch/include/ATen/core/Reduction.h
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
namespace at::Reduction {
|
| 4 |
+
|
| 5 |
+
// NB: Keep this in sync with Reduction class in torch/nn/_reduction.py
|
| 6 |
+
// These constants control the reduction behavior of loss functions.
|
| 7 |
+
// Ideally, this would be a scoped enum, but jit doesn't support that
|
| 8 |
+
enum Reduction {
|
| 9 |
+
None, // Do not reduce
|
| 10 |
+
Mean, // (Possibly weighted) mean of losses
|
| 11 |
+
Sum, // Sum losses
|
| 12 |
+
END
|
| 13 |
+
};
|
| 14 |
+
} // namespace at::Reduction
|
phivenv/Lib/site-packages/torch/include/ATen/core/Scalar.h
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
#include <c10/core/Scalar.h>
|
phivenv/Lib/site-packages/torch/include/ATen/core/ScalarType.h
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
#include <c10/core/ScalarType.h>
|
phivenv/Lib/site-packages/torch/include/ATen/core/Tensor.h
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/TensorBody.h>
|
| 4 |
+
#include <c10/util/Exception.h>
|
| 5 |
+
|
| 6 |
+
namespace at {
|
| 7 |
+
// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions)
|
| 8 |
+
class TORCH_API OptionalTensorRef {
|
| 9 |
+
public:
|
| 10 |
+
OptionalTensorRef() = default;
|
| 11 |
+
|
| 12 |
+
~OptionalTensorRef() {
|
| 13 |
+
ref_.unsafeReleaseTensorImpl();
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
OptionalTensorRef(const TensorBase& src)
|
| 17 |
+
: ref_(Tensor::unsafe_borrow_t{}, src) {
|
| 18 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src.defined());
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
OptionalTensorRef(const OptionalTensorRef& rhs)
|
| 22 |
+
: ref_(Tensor::unsafe_borrow_t{}, rhs.ref_) {}
|
| 23 |
+
|
| 24 |
+
OptionalTensorRef(OptionalTensorRef&& rhs) = default;
|
| 25 |
+
OptionalTensorRef& operator=(OptionalTensorRef rhs) {
|
| 26 |
+
std::swap(ref_, rhs.ref_);
|
| 27 |
+
return *this;
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
bool has_value() const {
|
| 31 |
+
return ref_.defined();
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
const Tensor& getTensorRef() const & {
|
| 35 |
+
return ref_;
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
const Tensor& operator*() const & {
|
| 39 |
+
return ref_;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
const Tensor* operator->() const & {
|
| 43 |
+
return &ref_;
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
operator bool() const {
|
| 47 |
+
return ref_.defined();
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
private:
|
| 51 |
+
Tensor ref_;
|
| 52 |
+
};
|
| 53 |
+
|
| 54 |
+
// Use to convert a TensorBase (that may be undefined) to an at::Tensor
|
| 55 |
+
// without bumping refcount.
|
| 56 |
+
class TORCH_API TensorRef {
|
| 57 |
+
public:
|
| 58 |
+
~TensorRef() {
|
| 59 |
+
ref_.unsafeReleaseTensorImpl();
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
TensorRef(const TensorBase& src)
|
| 63 |
+
: ref_(Tensor::unsafe_borrow_t{}, src) {}
|
| 64 |
+
TensorRef(TensorRef&& other) = default;
|
| 65 |
+
TensorRef(const TensorRef&) = default;
|
| 66 |
+
TensorRef& operator=(const TensorRef&) = default;
|
| 67 |
+
TensorRef& operator=(TensorRef&&) = default;
|
| 68 |
+
|
| 69 |
+
const Tensor& operator*() const & {
|
| 70 |
+
return ref_;
|
| 71 |
+
}
|
| 72 |
+
private:
|
| 73 |
+
Tensor ref_;
|
| 74 |
+
};
|
| 75 |
+
|
| 76 |
+
template <typename T>
|
| 77 |
+
auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_void_t<T> {
|
| 78 |
+
// Return the grad argument in case of a hook with void return type to have an
|
| 79 |
+
// std::function with Tensor return type
|
| 80 |
+
static_assert(std::is_same_v<decltype(hook(Tensor())), void>,
|
| 81 |
+
"Expected hook to return void");
|
| 82 |
+
return _register_hook([fn=std::forward<T>(hook)](const TensorBase& grad_base) {
|
| 83 |
+
TensorRef grad(grad_base);
|
| 84 |
+
fn(*grad);
|
| 85 |
+
return Tensor();
|
| 86 |
+
});
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
template <typename T>
|
| 90 |
+
auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_var_t<T> {
|
| 91 |
+
return _register_hook([fn=std::forward<T>(hook)](const TensorBase& grad_base) {
|
| 92 |
+
TensorRef grad(grad_base);
|
| 93 |
+
Tensor ret = fn(*grad);
|
| 94 |
+
return TensorBase(std::move(ret));
|
| 95 |
+
});
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
} // namespace at
|
phivenv/Lib/site-packages/torch/include/ATen/core/TensorAccessor.h
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/macros/Macros.h>
|
| 4 |
+
#include <c10/util/ArrayRef.h>
|
| 5 |
+
#include <c10/util/Deprecated.h>
|
| 6 |
+
#include <c10/util/Exception.h>
|
| 7 |
+
#include <c10/util/irange.h>
|
| 8 |
+
#include <cstddef>
|
| 9 |
+
#include <cstdint>
|
| 10 |
+
#include <type_traits>
|
| 11 |
+
|
| 12 |
+
namespace at {
|
| 13 |
+
|
| 14 |
+
// The PtrTraits argument to the TensorAccessor/GenericPackedTensorAccessor
|
| 15 |
+
// is used to enable the __restrict__ keyword/modifier for the data
|
| 16 |
+
// passed to cuda.
|
| 17 |
+
template <typename T>
|
| 18 |
+
struct DefaultPtrTraits {
|
| 19 |
+
typedef T* PtrType;
|
| 20 |
+
};
|
| 21 |
+
|
| 22 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 23 |
+
template <typename T>
|
| 24 |
+
struct RestrictPtrTraits {
|
| 25 |
+
typedef T* __restrict__ PtrType;
|
| 26 |
+
};
|
| 27 |
+
#endif
|
| 28 |
+
|
| 29 |
+
// TensorAccessorBase and TensorAccessor are used for both CPU and CUDA tensors.
|
| 30 |
+
// For CUDA tensors it is used in device code (only). This means that we restrict ourselves
|
| 31 |
+
// to functions and types available there (e.g. IntArrayRef isn't).
|
| 32 |
+
|
| 33 |
+
// The PtrTraits argument is only relevant to cuda to support `__restrict__` pointers.
|
| 34 |
+
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
|
| 35 |
+
class TensorAccessorBase {
|
| 36 |
+
public:
|
| 37 |
+
typedef typename PtrTraits<T>::PtrType PtrType;
|
| 38 |
+
|
| 39 |
+
C10_HOST_DEVICE TensorAccessorBase(
|
| 40 |
+
PtrType data_,
|
| 41 |
+
const index_t* sizes_,
|
| 42 |
+
const index_t* strides_)
|
| 43 |
+
: data_(data_), sizes_(sizes_), strides_(strides_) {}
|
| 44 |
+
C10_HOST IntArrayRef sizes() const {
|
| 45 |
+
return IntArrayRef(sizes_,N);
|
| 46 |
+
}
|
| 47 |
+
C10_HOST IntArrayRef strides() const {
|
| 48 |
+
return IntArrayRef(strides_,N);
|
| 49 |
+
}
|
| 50 |
+
C10_HOST_DEVICE index_t stride(index_t i) const {
|
| 51 |
+
return strides_[i];
|
| 52 |
+
}
|
| 53 |
+
C10_HOST_DEVICE index_t size(index_t i) const {
|
| 54 |
+
return sizes_[i];
|
| 55 |
+
}
|
| 56 |
+
C10_HOST_DEVICE PtrType data() {
|
| 57 |
+
return data_;
|
| 58 |
+
}
|
| 59 |
+
C10_HOST_DEVICE const PtrType data() const {
|
| 60 |
+
return data_;
|
| 61 |
+
}
|
| 62 |
+
protected:
|
| 63 |
+
PtrType data_;
|
| 64 |
+
const index_t* sizes_;
|
| 65 |
+
const index_t* strides_;
|
| 66 |
+
};
|
| 67 |
+
|
| 68 |
+
// The `TensorAccessor` is typically instantiated for CPU `Tensor`s using
|
| 69 |
+
// `Tensor.accessor<T, N>()`.
|
| 70 |
+
// For CUDA `Tensor`s, `GenericPackedTensorAccessor` is used on the host and only
|
| 71 |
+
// indexing on the device uses `TensorAccessor`s.
|
| 72 |
+
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
|
| 73 |
+
class TensorAccessor : public TensorAccessorBase<T,N,PtrTraits,index_t> {
|
| 74 |
+
public:
|
| 75 |
+
typedef typename PtrTraits<T>::PtrType PtrType;
|
| 76 |
+
|
| 77 |
+
C10_HOST_DEVICE TensorAccessor(
|
| 78 |
+
PtrType data_,
|
| 79 |
+
const index_t* sizes_,
|
| 80 |
+
const index_t* strides_)
|
| 81 |
+
: TensorAccessorBase<T, N, PtrTraits, index_t>(data_,sizes_,strides_) {}
|
| 82 |
+
|
| 83 |
+
C10_HOST_DEVICE TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) {
|
| 84 |
+
return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1);
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
C10_HOST_DEVICE const TensorAccessor<T, N-1, PtrTraits, index_t> operator[](index_t i) const {
|
| 88 |
+
return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1);
|
| 89 |
+
}
|
| 90 |
+
};
|
| 91 |
+
|
| 92 |
+
template<typename T, template <typename U> class PtrTraits, typename index_t>
|
| 93 |
+
class TensorAccessor<T,1,PtrTraits,index_t> : public TensorAccessorBase<T,1,PtrTraits,index_t> {
|
| 94 |
+
public:
|
| 95 |
+
typedef typename PtrTraits<T>::PtrType PtrType;
|
| 96 |
+
|
| 97 |
+
C10_HOST_DEVICE TensorAccessor(
|
| 98 |
+
PtrType data_,
|
| 99 |
+
const index_t* sizes_,
|
| 100 |
+
const index_t* strides_)
|
| 101 |
+
: TensorAccessorBase<T, 1, PtrTraits, index_t>(data_,sizes_,strides_) {}
|
| 102 |
+
C10_HOST_DEVICE T & operator[](index_t i) {
|
| 103 |
+
// NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
|
| 104 |
+
return this->data_[this->strides_[0]*i];
|
| 105 |
+
}
|
| 106 |
+
C10_HOST_DEVICE const T & operator[](index_t i) const {
|
| 107 |
+
return this->data_[this->strides_[0]*i];
|
| 108 |
+
}
|
| 109 |
+
};
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
// GenericPackedTensorAccessorBase and GenericPackedTensorAccessor are used on for CUDA `Tensor`s on the host
|
| 113 |
+
// and as
|
| 114 |
+
// In contrast to `TensorAccessor`s, they copy the strides and sizes on instantiation (on the host)
|
| 115 |
+
// in order to transfer them on the device when calling kernels.
|
| 116 |
+
// On the device, indexing of multidimensional tensors gives to `TensorAccessor`s.
|
| 117 |
+
// Use RestrictPtrTraits as PtrTraits if you want the tensor's data pointer to be marked as __restrict__.
|
| 118 |
+
// Instantiation from data, sizes, strides is only needed on the host and std::copy isn't available
|
| 119 |
+
// on the device, so those functions are host only.
|
| 120 |
+
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
|
| 121 |
+
class GenericPackedTensorAccessorBase {
|
| 122 |
+
public:
|
| 123 |
+
typedef typename PtrTraits<T>::PtrType PtrType;
|
| 124 |
+
C10_HOST GenericPackedTensorAccessorBase(
|
| 125 |
+
PtrType data_,
|
| 126 |
+
const index_t* sizes_,
|
| 127 |
+
const index_t* strides_)
|
| 128 |
+
: data_(data_) {
|
| 129 |
+
std::copy(sizes_, sizes_ + N, std::begin(this->sizes_));
|
| 130 |
+
std::copy(strides_, strides_ + N, std::begin(this->strides_));
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
// if index_t is not int64_t, we want to have an int64_t constructor
|
| 134 |
+
template <typename source_index_t, class = std::enable_if_t<std::is_same_v<source_index_t, int64_t>>>
|
| 135 |
+
C10_HOST GenericPackedTensorAccessorBase(
|
| 136 |
+
PtrType data_,
|
| 137 |
+
const source_index_t* sizes_,
|
| 138 |
+
const source_index_t* strides_)
|
| 139 |
+
: data_(data_) {
|
| 140 |
+
for (const auto i : c10::irange(N)) {
|
| 141 |
+
this->sizes_[i] = sizes_[i];
|
| 142 |
+
this->strides_[i] = strides_[i];
|
| 143 |
+
}
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
C10_HOST_DEVICE index_t stride(index_t i) const {
|
| 147 |
+
return strides_[i];
|
| 148 |
+
}
|
| 149 |
+
C10_HOST_DEVICE index_t size(index_t i) const {
|
| 150 |
+
return sizes_[i];
|
| 151 |
+
}
|
| 152 |
+
C10_HOST_DEVICE PtrType data() {
|
| 153 |
+
return data_;
|
| 154 |
+
}
|
| 155 |
+
C10_HOST_DEVICE const PtrType data() const {
|
| 156 |
+
return data_;
|
| 157 |
+
}
|
| 158 |
+
protected:
|
| 159 |
+
PtrType data_;
|
| 160 |
+
// NOLINTNEXTLINE(*c-arrays*)
|
| 161 |
+
index_t sizes_[N];
|
| 162 |
+
// NOLINTNEXTLINE(*c-arrays*)
|
| 163 |
+
index_t strides_[N];
|
| 164 |
+
C10_HOST void bounds_check_(index_t i) const {
|
| 165 |
+
TORCH_CHECK_INDEX(
|
| 166 |
+
0 <= i && i < index_t{N},
|
| 167 |
+
"Index ",
|
| 168 |
+
i,
|
| 169 |
+
" is not within bounds of a tensor of dimension ",
|
| 170 |
+
N);
|
| 171 |
+
}
|
| 172 |
+
};
|
| 173 |
+
|
| 174 |
+
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
|
| 175 |
+
class GenericPackedTensorAccessor : public GenericPackedTensorAccessorBase<T,N,PtrTraits,index_t> {
|
| 176 |
+
public:
|
| 177 |
+
typedef typename PtrTraits<T>::PtrType PtrType;
|
| 178 |
+
|
| 179 |
+
C10_HOST GenericPackedTensorAccessor(
|
| 180 |
+
PtrType data_,
|
| 181 |
+
const index_t* sizes_,
|
| 182 |
+
const index_t* strides_)
|
| 183 |
+
: GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>(data_, sizes_, strides_) {}
|
| 184 |
+
|
| 185 |
+
// if index_t is not int64_t, we want to have an int64_t constructor
|
| 186 |
+
template <typename source_index_t, class = std::enable_if_t<std::is_same_v<source_index_t, int64_t>>>
|
| 187 |
+
C10_HOST GenericPackedTensorAccessor(
|
| 188 |
+
PtrType data_,
|
| 189 |
+
const source_index_t* sizes_,
|
| 190 |
+
const source_index_t* strides_)
|
| 191 |
+
: GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>(data_, sizes_, strides_) {}
|
| 192 |
+
|
| 193 |
+
C10_DEVICE TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) {
|
| 194 |
+
index_t* new_sizes = this->sizes_ + 1;
|
| 195 |
+
index_t* new_strides = this->strides_ + 1;
|
| 196 |
+
return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i, new_sizes, new_strides);
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
C10_DEVICE const TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) const {
|
| 200 |
+
const index_t* new_sizes = this->sizes_ + 1;
|
| 201 |
+
const index_t* new_strides = this->strides_ + 1;
|
| 202 |
+
return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i, new_sizes, new_strides);
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
/// Returns a PackedTensorAccessor of the same dimension after transposing the
|
| 206 |
+
/// two dimensions given. Does not actually move elements; transposition is
|
| 207 |
+
/// made by permuting the size/stride arrays. If the dimensions are not valid,
|
| 208 |
+
/// asserts.
|
| 209 |
+
C10_HOST GenericPackedTensorAccessor<T, N, PtrTraits, index_t> transpose(
|
| 210 |
+
index_t dim1,
|
| 211 |
+
index_t dim2) const {
|
| 212 |
+
this->bounds_check_(dim1);
|
| 213 |
+
this->bounds_check_(dim2);
|
| 214 |
+
GenericPackedTensorAccessor<T, N, PtrTraits, index_t> result(
|
| 215 |
+
this->data_, this->sizes_, this->strides_);
|
| 216 |
+
std::swap(result.strides_[dim1], result.strides_[dim2]);
|
| 217 |
+
std::swap(result.sizes_[dim1], result.sizes_[dim2]);
|
| 218 |
+
return result;
|
| 219 |
+
}
|
| 220 |
+
};
|
| 221 |
+
|
| 222 |
+
template<typename T, template <typename U> class PtrTraits, typename index_t>
|
| 223 |
+
class GenericPackedTensorAccessor<T,1,PtrTraits,index_t> : public GenericPackedTensorAccessorBase<T,1,PtrTraits,index_t> {
|
| 224 |
+
public:
|
| 225 |
+
typedef typename PtrTraits<T>::PtrType PtrType;
|
| 226 |
+
C10_HOST GenericPackedTensorAccessor(
|
| 227 |
+
PtrType data_,
|
| 228 |
+
const index_t* sizes_,
|
| 229 |
+
const index_t* strides_)
|
| 230 |
+
: GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>(data_, sizes_, strides_) {}
|
| 231 |
+
|
| 232 |
+
// if index_t is not int64_t, we want to have an int64_t constructor
|
| 233 |
+
template <typename source_index_t, class = std::enable_if_t<std::is_same_v<source_index_t, int64_t>>>
|
| 234 |
+
C10_HOST GenericPackedTensorAccessor(
|
| 235 |
+
PtrType data_,
|
| 236 |
+
const source_index_t* sizes_,
|
| 237 |
+
const source_index_t* strides_)
|
| 238 |
+
: GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>(data_, sizes_, strides_) {}
|
| 239 |
+
|
| 240 |
+
C10_DEVICE T & operator[](index_t i) {
|
| 241 |
+
return this->data_[this->strides_[0] * i];
|
| 242 |
+
}
|
| 243 |
+
C10_DEVICE const T& operator[](index_t i) const {
|
| 244 |
+
return this->data_[this->strides_[0]*i];
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
// Same as in the general N-dimensional case, but note that in the
|
| 248 |
+
// 1-dimensional case the returned PackedTensorAccessor will always be an
|
| 249 |
+
// identical copy of the original
|
| 250 |
+
C10_HOST GenericPackedTensorAccessor<T, 1, PtrTraits, index_t> transpose(
|
| 251 |
+
index_t dim1,
|
| 252 |
+
index_t dim2) const {
|
| 253 |
+
this->bounds_check_(dim1);
|
| 254 |
+
this->bounds_check_(dim2);
|
| 255 |
+
return GenericPackedTensorAccessor<T, 1, PtrTraits, index_t>(
|
| 256 |
+
this->data_, this->sizes_, this->strides_);
|
| 257 |
+
}
|
| 258 |
+
};
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
// Can't put this directly into the macro function args because of commas
|
| 262 |
+
#define AT_X GenericPackedTensorAccessor<T, N, PtrTraits, index_t>
|
| 263 |
+
|
| 264 |
+
// Old name for `GenericPackedTensorAccessor`
|
| 265 |
+
template <typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
|
| 266 |
+
C10_DEFINE_DEPRECATED_USING(PackedTensorAccessor, AT_X)
|
| 267 |
+
|
| 268 |
+
#undef AT_X
|
| 269 |
+
|
| 270 |
+
template <typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
|
| 271 |
+
using PackedTensorAccessor32 = GenericPackedTensorAccessor<T, N, PtrTraits, int32_t>;
|
| 272 |
+
|
| 273 |
+
template <typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
|
| 274 |
+
using PackedTensorAccessor64 = GenericPackedTensorAccessor<T, N, PtrTraits, int64_t>;
|
| 275 |
+
} // namespace at
|
phivenv/Lib/site-packages/torch/include/ATen/core/TensorBase.h
ADDED
|
@@ -0,0 +1,1056 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/Device.h>
|
| 4 |
+
#include <c10/core/Layout.h>
|
| 5 |
+
#include <c10/core/MemoryFormat.h>
|
| 6 |
+
#include <c10/core/ScalarType.h>
|
| 7 |
+
#include <c10/core/ScalarTypeToTypeMeta.h>
|
| 8 |
+
#include <c10/core/Storage.h>
|
| 9 |
+
#include <c10/core/SymIntArrayRef.h>
|
| 10 |
+
#include <c10/core/TensorImpl.h>
|
| 11 |
+
#include <c10/core/TensorOptions.h>
|
| 12 |
+
#include <c10/core/UndefinedTensorImpl.h>
|
| 13 |
+
#include <c10/core/WrapDimMinimal.h>
|
| 14 |
+
#include <c10/util/C++17.h>
|
| 15 |
+
#include <c10/util/Exception.h>
|
| 16 |
+
#include <c10/util/ExclusivelyOwned.h>
|
| 17 |
+
#include <c10/util/ExclusivelyOwnedTensorTraits.h>
|
| 18 |
+
#include <c10/util/MaybeOwned.h>
|
| 19 |
+
#include <optional>
|
| 20 |
+
#include <c10/util/intrusive_ptr.h>
|
| 21 |
+
|
| 22 |
+
#include <ATen/core/NamedTensor.h>
|
| 23 |
+
#include <ATen/core/QuantizerBase.h>
|
| 24 |
+
#include <ATen/core/TensorAccessor.h>
|
| 25 |
+
#include <ATen/StorageUtils.h>
|
| 26 |
+
|
| 27 |
+
namespace c10 {
|
| 28 |
+
class Scalar;
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
namespace torch::autograd {
|
| 32 |
+
|
| 33 |
+
struct Node;
|
| 34 |
+
|
| 35 |
+
} // namespace torch::autograd
|
| 36 |
+
|
| 37 |
+
namespace at {
|
| 38 |
+
|
| 39 |
+
class Tensor;
|
| 40 |
+
class TensorBase;
|
| 41 |
+
|
| 42 |
+
// Convert Tensor to TensorBase without any need to include Tensor.h
|
| 43 |
+
TORCH_API const TensorBase& get_tensor_base(const Tensor& t);
|
| 44 |
+
|
| 45 |
+
namespace impl {
|
| 46 |
+
inline bool variable_excluded_from_dispatch() {
|
| 47 |
+
#ifdef C10_MOBILE
|
| 48 |
+
// Please read the comment in `VariableFallbackKernel.cpp` about the background of this change.
|
| 49 |
+
return true;
|
| 50 |
+
#else
|
| 51 |
+
return c10::impl::tls_local_dispatch_key_set().excluded_.isSupersetOf(c10::autograd_dispatch_keyset);
|
| 52 |
+
#endif
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
// NOTE: [Tensor vs. TensorBase]
|
| 58 |
+
//
|
| 59 |
+
// Tensor, being the central data structure in PyTorch, gets used and
|
| 60 |
+
// its header included almost everywhere. Unfortunately this means
|
| 61 |
+
// every time an operator signature is updated or changed in
|
| 62 |
+
// native_functions.yaml, you (and every other PyTorch developer) need
|
| 63 |
+
// to recompile all of ATen and its dependencies.
|
| 64 |
+
//
|
| 65 |
+
// TensorBase aims to break up these header dependencies, and improve
|
| 66 |
+
// incremental build times for all PyTorch developers. TensorBase
|
| 67 |
+
// represents a reference counted handle to TensorImpl, exactly the
|
| 68 |
+
// same as Tensor. However, TensorBase doesn't have code generated
|
| 69 |
+
// methods in its API and thus no dependence on native_functions.yaml.
|
| 70 |
+
//
|
| 71 |
+
// Usage tips
|
| 72 |
+
// ----------
|
| 73 |
+
// - You can `#define TORCH_ASSERT_NO_OPERATORS` at the top of a .cpp
|
| 74 |
+
// or .cu file to ensure it has no header dependencies on
|
| 75 |
+
// native_functions.yaml (direct or indirect).
|
| 76 |
+
// - Tensor inherits from TensorBase, so functions taking
|
| 77 |
+
// `const TensorBase &` are callable with Tensor as well.
|
| 78 |
+
// - TensorBase can be converted to Tensor with `Tensor(tensor_base)`,
|
| 79 |
+
// but this requires a reference-count bump. OptionalTensorRef, on
|
| 80 |
+
// the other hand, can materialize a `const Tensor &` without
|
| 81 |
+
// touching the reference-count.
|
| 82 |
+
class TORCH_API TensorBase {
|
| 83 |
+
public:
|
| 84 |
+
struct unsafe_borrow_t { explicit unsafe_borrow_t() = default; };
|
| 85 |
+
|
| 86 |
+
protected:
|
| 87 |
+
// Create a Tensor with a +0 reference count. Special care must be
|
| 88 |
+
// taken to avoid decrementing this reference count at destruction
|
| 89 |
+
// time. Intended to support MaybeOwnedTraits<Tensor>.
|
| 90 |
+
explicit TensorBase(unsafe_borrow_t, const TensorBase& rhs)
|
| 91 |
+
: impl_(c10::intrusive_ptr<at::TensorImpl, UndefinedTensorImpl>(rhs.impl_.get(), c10::raw::DontIncreaseRefcount{})) {}
|
| 92 |
+
friend MaybeOwnedTraits<TensorBase>;
|
| 93 |
+
|
| 94 |
+
public:
|
| 95 |
+
TensorBase() = default;
|
| 96 |
+
// This constructor should not be used by end users and is an implementation
|
| 97 |
+
// detail invoked by autogenerated code.
|
| 98 |
+
explicit TensorBase(
|
| 99 |
+
c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl)
|
| 100 |
+
: impl_(std::move(tensor_impl)) {
|
| 101 |
+
if (impl_.get() == nullptr) {
|
| 102 |
+
throw std::runtime_error("TensorImpl with nullptr is not supported");
|
| 103 |
+
}
|
| 104 |
+
}
|
| 105 |
+
TensorBase(const TensorBase&) = default;
|
| 106 |
+
TensorBase(TensorBase&&) noexcept = default;
|
| 107 |
+
~TensorBase() noexcept = default;
|
| 108 |
+
|
| 109 |
+
public:
|
| 110 |
+
// Creates a new wrapper from TensorImpl. Intentionally a free method because
|
| 111 |
+
// it should be used with care. Checks necessary invariants
|
| 112 |
+
static TensorBase wrap_tensor_impl(
|
| 113 |
+
c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl) {
|
| 114 |
+
TensorBase r(std::move(tensor_impl));
|
| 115 |
+
r.enforce_invariants();
|
| 116 |
+
return r;
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
int64_t dim() const {
|
| 120 |
+
return impl_->dim();
|
| 121 |
+
}
|
| 122 |
+
int64_t storage_offset() const {
|
| 123 |
+
return impl_->storage_offset();
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
TensorBase contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const {
|
| 127 |
+
if (is_contiguous(memory_format)) {
|
| 128 |
+
return *this;
|
| 129 |
+
} else {
|
| 130 |
+
return __dispatch_contiguous(memory_format);
|
| 131 |
+
}
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
/// Should be used if *this can reasonably be expected to be contiguous and
|
| 135 |
+
/// performance is important.
|
| 136 |
+
/// Compared to contiguous, it saves a reference count
|
| 137 |
+
/// increment/decrement if *this is already contiguous, at the cost
|
| 138 |
+
/// in all cases of an extra pointer of stack usage, an extra branch
|
| 139 |
+
/// to access, and an extra branch at destruction time.
|
| 140 |
+
c10::MaybeOwned<TensorBase> expect_contiguous(
|
| 141 |
+
MemoryFormat memory_format=MemoryFormat::Contiguous) const &;
|
| 142 |
+
|
| 143 |
+
// Use .contiguous() instead. Trying to borrow from a prvalue
|
| 144 |
+
// will only lead to trouble and dangling references.
|
| 145 |
+
c10::MaybeOwned<TensorBase> expect_contiguous(
|
| 146 |
+
MemoryFormat memory_format=MemoryFormat::Contiguous) && = delete;
|
| 147 |
+
|
| 148 |
+
const TensorBase& fill_(const c10::Scalar& scalar) const;
|
| 149 |
+
const TensorBase& zero_() const;
|
| 150 |
+
|
| 151 |
+
TensorBase to(at::TensorOptions options={}, bool non_blocking=false, bool copy=false, std::optional<at::MemoryFormat> memory_format=std::nullopt) const;
|
| 152 |
+
|
| 153 |
+
bool is_complex() const {
|
| 154 |
+
return at::isComplexType(this->scalar_type());
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
bool is_floating_point() const {
|
| 158 |
+
return at::isFloatingType(this->scalar_type());
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
bool is_signed() const {
|
| 162 |
+
return at::isSignedType(this->scalar_type());
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
c10::SymInt sym_size(int64_t dim) const {
|
| 166 |
+
return impl_->sym_size(dim);
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
c10::SymInt sym_stride(int64_t dim) const {
|
| 170 |
+
const auto sizes = this->sym_strides();
|
| 171 |
+
const auto ndim = static_cast<int64_t>(sizes.size());
|
| 172 |
+
// false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping)
|
| 173 |
+
return sizes[c10::maybe_wrap_dim(dim, ndim, /*wrap_scalar=*/false)];
|
| 174 |
+
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
int64_t size(int64_t dim) const {
|
| 178 |
+
return impl_->size(dim);
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
int64_t stride(int64_t dim) const {
|
| 182 |
+
const auto strides = this->strides();
|
| 183 |
+
const auto ndim = static_cast<int64_t>(strides.size());
|
| 184 |
+
// false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping)
|
| 185 |
+
return strides[c10::maybe_wrap_dim(dim, ndim, /*wrap_scalar=*/false)];
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
TensorImpl * unsafeGetTensorImpl() const {
|
| 189 |
+
return impl_.get();
|
| 190 |
+
}
|
| 191 |
+
TensorImpl * unsafeReleaseTensorImpl() {
|
| 192 |
+
return impl_.release();
|
| 193 |
+
}
|
| 194 |
+
const c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>& getIntrusivePtr() const {
|
| 195 |
+
return impl_;
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> unsafeReleaseIntrusivePtr() {
|
| 199 |
+
return std::move(impl_);
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
bool defined() const {
|
| 203 |
+
return impl_;
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
void reset() {
|
| 207 |
+
impl_.reset();
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
#if defined (_MSC_VER)
|
| 211 |
+
TensorBase& operator=(const TensorBase& x) & {
|
| 212 |
+
impl_ = x.impl_;
|
| 213 |
+
return *this;
|
| 214 |
+
};
|
| 215 |
+
TensorBase& operator=(TensorBase&& x) & noexcept {
|
| 216 |
+
impl_ = std::move(x.impl_);
|
| 217 |
+
return *this;
|
| 218 |
+
}
|
| 219 |
+
#else
|
| 220 |
+
TensorBase& operator=(const TensorBase& x) & = default;
|
| 221 |
+
TensorBase& operator=(TensorBase&& x) & noexcept = default;
|
| 222 |
+
#endif
|
| 223 |
+
|
| 224 |
+
// Ban assignment to rvalues, since at::Tensor (weirdly) performs a deep copy here
|
| 225 |
+
TensorBase& operator=(const TensorBase&) && = delete;
|
| 226 |
+
TensorBase& operator=(TensorBase&&) && noexcept = delete;
|
| 227 |
+
|
| 228 |
+
bool is_same(const TensorBase& other) const noexcept {
|
| 229 |
+
return impl_ == other.impl_;
|
| 230 |
+
}
|
| 231 |
+
size_t use_count() const noexcept {
|
| 232 |
+
return impl_.use_count();
|
| 233 |
+
}
|
| 234 |
+
size_t weak_use_count() const noexcept {
|
| 235 |
+
return impl_.weak_use_count();
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
std::string toString() const;
|
| 239 |
+
|
| 240 |
+
IntArrayRef sizes() const {
|
| 241 |
+
return impl_->sizes();
|
| 242 |
+
}
|
| 243 |
+
c10::SymIntArrayRef sym_sizes() const {
|
| 244 |
+
return impl_->sym_sizes();
|
| 245 |
+
}
|
| 246 |
+
c10::SymIntArrayRef sym_strides() const {
|
| 247 |
+
return impl_->sym_strides();
|
| 248 |
+
}
|
| 249 |
+
IntArrayRef strides() const {
|
| 250 |
+
return impl_->strides();
|
| 251 |
+
}
|
| 252 |
+
// See impl::get_opt_names in ATen/NamedTensor.h for docs.
|
| 253 |
+
std::optional<DimnameList> opt_names() const {
|
| 254 |
+
return impl::get_opt_names(unsafeGetTensorImpl());
|
| 255 |
+
}
|
| 256 |
+
// See impl::get_names in ATen/NamedTensor.h for docs.
|
| 257 |
+
DimnameList names() const {
|
| 258 |
+
return impl::get_names(unsafeGetTensorImpl());
|
| 259 |
+
}
|
| 260 |
+
int64_t ndimension() const {
|
| 261 |
+
return dim();
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const {
|
| 265 |
+
return impl_->is_contiguous(memory_format);
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
bool is_non_overlapping_and_dense() const {
|
| 269 |
+
return impl_->is_non_overlapping_and_dense();
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
at::MemoryFormat suggest_memory_format(
|
| 273 |
+
bool channels_last_strides_exact_match = false) const {
|
| 274 |
+
// Setting channels_last_strides_exact_match to true forces function to
|
| 275 |
+
// check 0,1 - sized dimension strides.
|
| 276 |
+
if (layout() == at::kStrided) {
|
| 277 |
+
if (impl_->is_strides_like_channels_last()) {
|
| 278 |
+
if (!channels_last_strides_exact_match ||
|
| 279 |
+
get_channels_last_strides_2d(sizes()) == strides()) {
|
| 280 |
+
return at::MemoryFormat::ChannelsLast;
|
| 281 |
+
}
|
| 282 |
+
}
|
| 283 |
+
else if (impl_->is_strides_like_channels_last_3d()) {
|
| 284 |
+
if (!channels_last_strides_exact_match ||
|
| 285 |
+
get_channels_last_strides_3d(sizes()) == strides()) {
|
| 286 |
+
return at::MemoryFormat::ChannelsLast3d;
|
| 287 |
+
}
|
| 288 |
+
}
|
| 289 |
+
}
|
| 290 |
+
return at::MemoryFormat::Contiguous;
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
// Total bytes consumed by the "view" of elements of the array. Does not
|
| 294 |
+
// include size of metadata. The number reported here does not necessarily
|
| 295 |
+
// correspond to the true physical memory consumed by a tensor; instead,
|
| 296 |
+
// it reports the memory the tensor would take *if* it were contiguous.
|
| 297 |
+
// Defined to be numel() * itemsize()
|
| 298 |
+
size_t nbytes() const {
|
| 299 |
+
TORCH_CHECK(layout () != at::kSparse,
|
| 300 |
+
"nbytes is not defined for sparse tensors. If you want the size of the constituent " \
|
| 301 |
+
"tensors, add the nbytes of the indices and values. If you want the size of the " \
|
| 302 |
+
"equivalent dense tensor, multiply numel() by element_size()");
|
| 303 |
+
return impl_->numel() * impl_->itemsize();
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
c10::SymInt sym_nbytes() const {
|
| 307 |
+
TORCH_CHECK(layout () != at::kSparse,
|
| 308 |
+
"nbytes is not defined for sparse tensors. If you want the size of the constituent " \
|
| 309 |
+
"tensors, add the nbytes of the indices and values. If you want the size of the " \
|
| 310 |
+
"equivalent dense tensor, multiply numel() by element_size()");
|
| 311 |
+
return impl_->sym_numel() * impl_->itemsize();
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
int64_t numel() const {
|
| 315 |
+
return impl_->numel();
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
c10::SymInt sym_numel() const {
|
| 319 |
+
return impl_->sym_numel();
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
c10::SymInt sym_storage_offset() const {
|
| 323 |
+
return impl_->sym_storage_offset();
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
// Length of one array element in bytes. This is the traditional
|
| 327 |
+
// Numpy naming.
|
| 328 |
+
size_t itemsize() const {
|
| 329 |
+
return impl_->itemsize();
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
// Same as itemsize(). This is the PyTorch naming.
|
| 333 |
+
int64_t element_size() const {
|
| 334 |
+
return static_cast<int64_t>(impl_->itemsize());
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
DispatchKeySet key_set() const {
|
| 338 |
+
return impl_->key_set();
|
| 339 |
+
}
|
| 340 |
+
ScalarType scalar_type() const {
|
| 341 |
+
return typeMetaToScalarType(impl_->dtype());
|
| 342 |
+
}
|
| 343 |
+
bool has_storage() const {
|
| 344 |
+
return defined() && impl_->has_storage();
|
| 345 |
+
}
|
| 346 |
+
const Storage& storage() const {
|
| 347 |
+
return impl_->storage();
|
| 348 |
+
}
|
| 349 |
+
bool is_alias_of(const at::TensorBase& other) const{
|
| 350 |
+
return impl_->storage().is_alias_of(other.storage());
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
// Move the storage backend to shm based
|
| 354 |
+
// to enable memory sharing across processes.
|
| 355 |
+
//
|
| 356 |
+
// NB1: the ideal behavior of this API still requires further discussion
|
| 357 |
+
// but for now we are inclined to keep it consistent with existing THP behavior
|
| 358 |
+
// https://github.com/pytorch/pytorch/blob/4dca9bde0552afc67b5b74f4a0696fe6055709c4/torch/storage.py#L196-L212
|
| 359 |
+
// so we don't assert on anything here and rely on caller knowing
|
| 360 |
+
// what it's doing.
|
| 361 |
+
//
|
| 362 |
+
// NB2: this currently provides Linux fd based shm support only
|
| 363 |
+
// to simplify the storage lifetime management logic in ATen
|
| 364 |
+
// and similarly for now we are not adding support for file system based
|
| 365 |
+
// shm support like in THP due to additional GC manager support needed
|
| 366 |
+
// to prevent leaks.
|
| 367 |
+
// As such, calling this from non supported systems (e.g. Windows) would fail.
|
| 368 |
+
void share_memory_() {
|
| 369 |
+
at::share_memory_(*this);
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
+
inline bool _is_zerotensor() const {
|
| 373 |
+
return impl_->_is_zerotensor();
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
inline void _set_zero(bool zero) const {
|
| 377 |
+
impl_->_set_zero(zero);
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
inline bool is_conj() const {
|
| 381 |
+
return impl_->is_conj();
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
// sets the conjugate bit of a tensor.
|
| 385 |
+
// NOTE: Conjugate bit is supposed to be a read-only field. Only change this, if you are sure
|
| 386 |
+
// that's what you want. Changing this might lead to incorrect behavior since conjugation is
|
| 387 |
+
// a lazy operation and we rely on this bit to determine if a conjugation needs to be materialized.
|
| 388 |
+
inline void _set_conj(bool conjugate) const {
|
| 389 |
+
impl_->_set_conj(conjugate);
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
inline bool is_neg() const {
|
| 393 |
+
return impl_->is_neg();
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
// sets the negative bit of a tensor.
|
| 397 |
+
// NOTE: Negative bit is supposed to be a read-only field. Only change this, if you are sure
|
| 398 |
+
// that's what you want. Changing this might lead to incorrect behavior since we rely on this
|
| 399 |
+
// bit to determine if a negation needs to be materialized.
|
| 400 |
+
inline void _set_neg(bool negative) const {
|
| 401 |
+
impl_->_set_neg(negative);
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
/// Returns a `Tensor`'s layout.
|
| 405 |
+
Layout layout() const {
|
| 406 |
+
return impl_->layout();
|
| 407 |
+
}
|
| 408 |
+
|
| 409 |
+
/// Returns a `Tensor`'s dtype (`TypeMeta`).
|
| 410 |
+
caffe2::TypeMeta dtype() const {
|
| 411 |
+
return impl_->dtype();
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
/// Returns a `Tensor`'s device.
|
| 415 |
+
inline Device device() const {
|
| 416 |
+
return impl_->device();
|
| 417 |
+
}
|
| 418 |
+
|
| 419 |
+
/// Returns a `Tensor`'s device index.
|
| 420 |
+
DeviceIndex get_device() const {
|
| 421 |
+
// NB: this is not a native function to avoid dispatching overhead.
|
| 422 |
+
return impl_->get_device();
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
/// Returns if a `Tensor` has CPU backend.
|
| 426 |
+
bool is_cpu() const {
|
| 427 |
+
// NB: this is not a native function to avoid dispatching overhead.
|
| 428 |
+
return impl_->is_cpu();
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
/// Returns if a `Tensor` has CUDA backend.
|
| 432 |
+
bool is_cuda() const {
|
| 433 |
+
// NB: this is not a native function to avoid dispatching overhead.
|
| 434 |
+
return impl_->is_cuda();
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
+
/// Returns if a `Tensor` has IPU backend.
|
| 438 |
+
bool is_ipu() const {
|
| 439 |
+
// NB: this is not a native function to avoid dispatching overhead.
|
| 440 |
+
return impl_->is_ipu();
|
| 441 |
+
}
|
| 442 |
+
|
| 443 |
+
/// Returns if a `Tensor` has XPU backend.
|
| 444 |
+
bool is_xpu() const {
|
| 445 |
+
// NB: this is not a native function to avoid dispatching overhead.
|
| 446 |
+
return impl_->is_xpu();
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
/// Returns if a `Tensor` has XLA backend.
|
| 450 |
+
bool is_xla() const {
|
| 451 |
+
return impl_->is_xla();
|
| 452 |
+
}
|
| 453 |
+
|
| 454 |
+
/// Returns if a `Tensor` has MTIA backend.
|
| 455 |
+
bool is_mtia() const {
|
| 456 |
+
return impl_->is_mtia();
|
| 457 |
+
}
|
| 458 |
+
|
| 459 |
+
/// Returns if a `Tensor` has HPU backend.
|
| 460 |
+
bool is_hpu() const {
|
| 461 |
+
return impl_->is_hpu();
|
| 462 |
+
}
|
| 463 |
+
|
| 464 |
+
/// Returns if a `Tensor` has Lazy backend.
|
| 465 |
+
bool is_lazy() const {
|
| 466 |
+
return impl_->is_lazy();
|
| 467 |
+
}
|
| 468 |
+
|
| 469 |
+
/// Returns if a `Tensor` has HIP backend.
|
| 470 |
+
bool is_hip() const {
|
| 471 |
+
// NB: this is not a native function to avoid dispatching overhead.
|
| 472 |
+
return impl_->is_hip();
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
/// Returns if a `Tensor` has VE backend.
|
| 476 |
+
bool is_ve() const {
|
| 477 |
+
// NB: this is not a native function to avoid dispatching overhead.
|
| 478 |
+
return impl_->is_ve();
|
| 479 |
+
}
|
| 480 |
+
|
| 481 |
+
/// Returns if a `Tensor` has PrivateUse1 backend.
|
| 482 |
+
bool is_privateuseone() const {
|
| 483 |
+
// NB: this is not a native function to avoid dispatching overhead.
|
| 484 |
+
return impl_->is_privateuseone();
|
| 485 |
+
}
|
| 486 |
+
|
| 487 |
+
/// Returns if a `Tensor` has sparse backend.
|
| 488 |
+
bool is_sparse() const {
|
| 489 |
+
// NB: this is not a native function to avoid dispatching overhead.
|
| 490 |
+
return impl_->is_sparse();
|
| 491 |
+
}
|
| 492 |
+
|
| 493 |
+
/// Returns is a `Tensor` has a sparse CSR backend.
|
| 494 |
+
bool is_sparse_csr() const {
|
| 495 |
+
// NB: this is not a native function to avoid dispatching overhead.
|
| 496 |
+
return impl_->is_sparse_csr();
|
| 497 |
+
}
|
| 498 |
+
|
| 499 |
+
/// Returns if a `Tensor` is mkldnn tensor.
|
| 500 |
+
bool is_mkldnn() const {
|
| 501 |
+
// NB: this is not a native function to avoid dispatching overhead.
|
| 502 |
+
return impl_->is_mkldnn();
|
| 503 |
+
}
|
| 504 |
+
|
| 505 |
+
/// Returns if a `Tensor` is mps tensor.
|
| 506 |
+
bool is_mps() const {
|
| 507 |
+
// NB: this is not a native function to avoid dispatching overhead.
|
| 508 |
+
return impl_->is_mps();
|
| 509 |
+
}
|
| 510 |
+
|
| 511 |
+
/// Returns if a `Tensor` is maia tensor.
|
| 512 |
+
bool is_maia() const {
|
| 513 |
+
// NB: this is not a native function to avoid dispatching overhead.
|
| 514 |
+
return impl_->is_maia();
|
| 515 |
+
}
|
| 516 |
+
|
| 517 |
+
/// Returns if a `Tensor` is vulkan tensor.
|
| 518 |
+
bool is_vulkan() const {
|
| 519 |
+
// NB: this is not a native function to avoid dispatching overhead.
|
| 520 |
+
return impl_->is_vulkan();
|
| 521 |
+
}
|
| 522 |
+
|
| 523 |
+
/// Returns if a `Tensor` is metal tensor.
|
| 524 |
+
bool is_metal() const {
|
| 525 |
+
// NB: this is not a native function to avoid dispatching overhead.
|
| 526 |
+
return impl_->is_metal();
|
| 527 |
+
}
|
| 528 |
+
|
| 529 |
+
/// Returns if a `Tensor` has quantized backend.
|
| 530 |
+
bool is_quantized() const {
|
| 531 |
+
// NB: this is not a native function to avoid dispatching overhead.
|
| 532 |
+
return impl_->is_quantized();
|
| 533 |
+
}
|
| 534 |
+
|
| 535 |
+
/// Returns if a `Tensor` is a meta tensor. Meta tensors can
|
| 536 |
+
/// also have other designations.
|
| 537 |
+
bool is_meta() const {
|
| 538 |
+
return impl_->is_meta();
|
| 539 |
+
}
|
| 540 |
+
|
| 541 |
+
/// Returns if a `Tensor` is an inference tensor.
|
| 542 |
+
bool is_inference() const {
|
| 543 |
+
return impl_->is_inference();
|
| 544 |
+
}
|
| 545 |
+
|
| 546 |
+
// Returns if a `Tensor` is a NestedTensor.
|
| 547 |
+
bool is_nested() const {
|
| 548 |
+
return impl_->is_nested();
|
| 549 |
+
}
|
| 550 |
+
|
| 551 |
+
/// If a tensor is a quantized tensor, returns its quantizer
|
| 552 |
+
/// TODO: it's not in native_functions.yaml yet as it's not exposed to python
|
| 553 |
+
QuantizerPtr quantizer() const;
|
| 554 |
+
|
| 555 |
+
/// Returns if a `Tensor` has any dimension names
|
| 556 |
+
bool has_names() const {
|
| 557 |
+
// If a user is using unnamed tensors, then we can short-circuit right here.
|
| 558 |
+
// Otherwise, impl::has_names attempts to retrieve names.
|
| 559 |
+
if (!impl_->has_named_tensor_meta()) {
|
| 560 |
+
return false;
|
| 561 |
+
}
|
| 562 |
+
return impl::has_names(unsafeGetTensorImpl());
|
| 563 |
+
}
|
| 564 |
+
|
| 565 |
+
/// Returns a `Tensor`'s dimension names data structure
|
| 566 |
+
const NamedTensorMeta* get_named_tensor_meta() const {
|
| 567 |
+
return static_cast<NamedTensorMeta*>(impl_->named_tensor_meta());
|
| 568 |
+
}
|
| 569 |
+
|
| 570 |
+
NamedTensorMeta* get_named_tensor_meta() {
|
| 571 |
+
return static_cast<NamedTensorMeta*>(impl_->named_tensor_meta());
|
| 572 |
+
}
|
| 573 |
+
|
| 574 |
+
/// Returns the `TensorOptions` corresponding to this `Tensor`. Defined in
|
| 575 |
+
/// TensorOptions.h.
|
| 576 |
+
TensorOptions options() const {
|
| 577 |
+
return TensorOptions().dtype(dtype())
|
| 578 |
+
.device(device())
|
| 579 |
+
.layout(layout());
|
| 580 |
+
}
|
| 581 |
+
|
| 582 |
+
const void* const_data_ptr() const {
|
| 583 |
+
return this->unsafeGetTensorImpl()->data();
|
| 584 |
+
}
|
| 585 |
+
|
| 586 |
+
void* mutable_data_ptr() const {
|
| 587 |
+
return this->unsafeGetTensorImpl()->mutable_data();
|
| 588 |
+
}
|
| 589 |
+
|
| 590 |
+
// TODO(#97856) Make this return a const pointer. This currently
|
| 591 |
+
// returns a non-const pointer because of the large
|
| 592 |
+
// number of clients that we still want to audit before
|
| 593 |
+
// migrating to mutable_data_ptr().
|
| 594 |
+
void* data_ptr() const {
|
| 595 |
+
return mutable_data_ptr();
|
| 596 |
+
}
|
| 597 |
+
|
| 598 |
+
template <typename T, std::enable_if_t<!std::is_const_v<T>, int> = 0>
|
| 599 |
+
const T* const_data_ptr() const;
|
| 600 |
+
|
| 601 |
+
template <typename T, std::enable_if_t<std::is_const_v<T>, int> = 0>
|
| 602 |
+
const std::remove_const_t<T>* const_data_ptr() const;
|
| 603 |
+
|
| 604 |
+
template <typename T>
|
| 605 |
+
T* mutable_data_ptr() const;
|
| 606 |
+
|
| 607 |
+
// Legacy interface during the migration to indicate that a callsite
|
| 608 |
+
// has not been audited for mutability.
|
| 609 |
+
//
|
| 610 |
+
// Do not add new uses of this, use const_data_ptr() if possible,
|
| 611 |
+
// mutable_data_ptr() otherwise.
|
| 612 |
+
//
|
| 613 |
+
// TODO(#97856) Make this return a const pointer. This is currently
|
| 614 |
+
// const because of the vast number of clients that
|
| 615 |
+
// rely on this.
|
| 616 |
+
template <typename T>
|
| 617 |
+
T* data_ptr() const;
|
| 618 |
+
|
| 619 |
+
// Purposely not defined here to avoid inlining
|
| 620 |
+
void print() const;
|
| 621 |
+
|
| 622 |
+
// Return a `TensorAccessor` for CPU `Tensor`s. You have to specify scalar type and
|
| 623 |
+
// dimension.
|
| 624 |
+
template<typename T, size_t N>
|
| 625 |
+
TensorAccessor<T,N> accessor() const& {
|
| 626 |
+
static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr<T>()");
|
| 627 |
+
TORCH_CHECK(dim() == N, "TensorAccessor expected ", N, " dims but tensor has ", dim());
|
| 628 |
+
T* ptr = nullptr;
|
| 629 |
+
if constexpr (std::is_const_v<T>) {
|
| 630 |
+
ptr = const_data_ptr<T>();
|
| 631 |
+
} else {
|
| 632 |
+
ptr = mutable_data_ptr<T>();
|
| 633 |
+
}
|
| 634 |
+
return TensorAccessor<T,N>(ptr,sizes().data(),strides().data());
|
| 635 |
+
}
|
| 636 |
+
template<typename T, size_t N>
|
| 637 |
+
TensorAccessor<T,N> accessor() && = delete;
|
| 638 |
+
|
| 639 |
+
// Return a `GenericPackedTensorAccessor` for CUDA `Tensor`s. You have to specify scalar type and
|
| 640 |
+
// dimension. You can optionally specify RestrictPtrTraits as a template parameter to
|
| 641 |
+
// cast the data pointer to a __restrict__ pointer.
|
| 642 |
+
// In order to use this, your CUDA kernel has to take a corresponding GenericPackedTensorAccessor
|
| 643 |
+
// as an argument.
|
| 644 |
+
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
|
| 645 |
+
GenericPackedTensorAccessor<T,N,PtrTraits,index_t> generic_packed_accessor() const& {
|
| 646 |
+
static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr<T>()");
|
| 647 |
+
TORCH_CHECK(dim() == N, "TensorAccessor expected ", N, " dims but tensor has ", dim());
|
| 648 |
+
T* ptr = nullptr;
|
| 649 |
+
if constexpr (std::is_const_v<T>) {
|
| 650 |
+
ptr = const_data_ptr<T>();
|
| 651 |
+
} else {
|
| 652 |
+
ptr = mutable_data_ptr<T>();
|
| 653 |
+
}
|
| 654 |
+
return GenericPackedTensorAccessor<T,N,PtrTraits,index_t>(static_cast<typename PtrTraits<T>::PtrType>(ptr),sizes().data(),strides().data());
|
| 655 |
+
}
|
| 656 |
+
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
|
| 657 |
+
GenericPackedTensorAccessor<T,N> generic_packed_accessor() && = delete;
|
| 658 |
+
|
| 659 |
+
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
|
| 660 |
+
PackedTensorAccessor32<T,N,PtrTraits> packed_accessor32() const& {
|
| 661 |
+
TORCH_CHECK(
|
| 662 |
+
impl_->numel() <=
|
| 663 |
+
static_cast<int64_t>(std::numeric_limits<int32_t>::max()),
|
| 664 |
+
"numel needs to be smaller than int32_t max; otherwise, please use packed_accessor64");
|
| 665 |
+
return generic_packed_accessor<T,N,PtrTraits,int32_t>();
|
| 666 |
+
}
|
| 667 |
+
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
|
| 668 |
+
PackedTensorAccessor32<T,N,PtrTraits> packed_accessor32() && = delete;
|
| 669 |
+
|
| 670 |
+
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
|
| 671 |
+
PackedTensorAccessor64<T,N,PtrTraits> packed_accessor64() const& {
|
| 672 |
+
return generic_packed_accessor<T,N,PtrTraits,int64_t>();
|
| 673 |
+
}
|
| 674 |
+
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
|
| 675 |
+
PackedTensorAccessor64<T,N,PtrTraits> packed_accessor64() && = delete;
|
| 676 |
+
|
| 677 |
+
// ~~~~~ Autograd API ~~~~~
|
| 678 |
+
|
| 679 |
+
/// \fn bool is_leaf() const;
|
| 680 |
+
///
|
| 681 |
+
/// All Tensors that have `requires_grad()` which is ``false`` will be leaf Tensors by convention.
|
| 682 |
+
///
|
| 683 |
+
/// For Tensors that have `requires_grad()` which is ``true``, they will be leaf Tensors if they were
|
| 684 |
+
/// created by the user. This means that they are not the result of an operation and so
|
| 685 |
+
/// `grad_fn()` is `nullptr`.
|
| 686 |
+
///
|
| 687 |
+
/// Only leaf Tensors will have their `grad()` populated during a call to `backward()`.
|
| 688 |
+
/// To get `grad()` populated for non-leaf Tensors, you can use `retain_grad()`.
|
| 689 |
+
///
|
| 690 |
+
/// Example:
|
| 691 |
+
/// @code
|
| 692 |
+
/// auto a = torch::rand(10, torch::requires_grad());
|
| 693 |
+
/// std::cout << a.is_leaf() << std::endl; // prints `true`
|
| 694 |
+
///
|
| 695 |
+
/// auto b = torch::rand(10, torch::requires_grad()).to(torch::kCUDA);
|
| 696 |
+
/// std::cout << b.is_leaf() << std::endl; // prints `false`
|
| 697 |
+
/// // b was created by the operation that cast a cpu Tensor into a cuda Tensor
|
| 698 |
+
///
|
| 699 |
+
/// auto c = torch::rand(10, torch::requires_grad()) + 2;
|
| 700 |
+
/// std::cout << c.is_leaf() << std::endl; // prints `false`
|
| 701 |
+
/// // c was created by the addition operation
|
| 702 |
+
///
|
| 703 |
+
/// auto d = torch::rand(10).cuda();
|
| 704 |
+
/// std::cout << d.is_leaf() << std::endl; // prints `true`
|
| 705 |
+
/// // d does not require gradients and so has no operation creating it (that is tracked by the autograd engine)
|
| 706 |
+
///
|
| 707 |
+
/// auto e = torch::rand(10).cuda().requires_grad_();
|
| 708 |
+
/// std::cout << e.is_leaf() << std::endl; // prints `true`
|
| 709 |
+
/// // e requires gradients and has no operations creating it
|
| 710 |
+
///
|
| 711 |
+
/// auto f = torch::rand(10, torch::device(torch::kCUDA).requires_grad(true));
|
| 712 |
+
/// std::cout << f.is_leaf() << std::endl; // prints `true`
|
| 713 |
+
/// // f requires grad, has no operation creating it
|
| 714 |
+
/// @endcode
|
| 715 |
+
|
| 716 |
+
/// \fn void backward(const Tensor & gradient={}, std::optional<bool> retain_graph=std::nullopt, bool create_graph=false, std::optional<TensorList> inputs=std::nullopt) const;
|
| 717 |
+
///
|
| 718 |
+
/// Computes the gradient of current tensor with respect to graph leaves.
|
| 719 |
+
///
|
| 720 |
+
/// The graph is differentiated using the chain rule. If the tensor is
|
| 721 |
+
/// non-scalar (i.e. its data has more than one element) and requires
|
| 722 |
+
/// gradient, the function additionally requires specifying ``gradient``.
|
| 723 |
+
/// It should be a tensor of matching type and location, that contains
|
| 724 |
+
/// the gradient of the differentiated function w.r.t. this Tensor.
|
| 725 |
+
///
|
| 726 |
+
/// This function accumulates gradients in the leaves - you might need to
|
| 727 |
+
/// zero them before calling it.
|
| 728 |
+
///
|
| 729 |
+
/// \param gradient Gradient w.r.t. the
|
| 730 |
+
/// tensor. If it is a tensor, it will be automatically converted
|
| 731 |
+
/// to a Tensor that does not require grad unless ``create_graph`` is True.
|
| 732 |
+
/// None values can be specified for scalar Tensors or ones that
|
| 733 |
+
/// don't require grad. If a None value would be acceptable then
|
| 734 |
+
/// this argument is optional.
|
| 735 |
+
/// \param retain_graph If ``false``, the graph used to compute
|
| 736 |
+
/// the grads will be freed. Note that in nearly all cases setting
|
| 737 |
+
/// this option to True is not needed and often can be worked around
|
| 738 |
+
/// in a much more efficient way. Defaults to the value of
|
| 739 |
+
/// ``create_graph``.
|
| 740 |
+
/// \param create_graph If ``true``, graph of the derivative will
|
| 741 |
+
/// be constructed, allowing to compute higher order derivative
|
| 742 |
+
/// products. Defaults to ``false``.
|
| 743 |
+
/// \param inputs Inputs w.r.t. which the gradient will be accumulated into
|
| 744 |
+
/// ``at::Tensor::grad``. All other Tensors will be ignored. If not
|
| 745 |
+
/// provided, the gradient is accumulated into all the leaf Tensors
|
| 746 |
+
/// that were used to compute the current tensor.
|
| 747 |
+
/// When inputs are provided and a given input is not a leaf,
|
| 748 |
+
/// the current implementation will call its grad_fn (even though it is not strictly needed to get this gradients).
|
| 749 |
+
/// It is an implementation detail on which the user should not rely.
|
| 750 |
+
/// See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details.
|
| 751 |
+
|
| 752 |
+
/// \fn Tensor detach() const;
|
| 753 |
+
///
|
| 754 |
+
/// Returns a new Tensor, detached from the current graph.
|
| 755 |
+
/// The result will never require gradient.
|
| 756 |
+
|
| 757 |
+
/// \fn Tensor & detach_() const;
|
| 758 |
+
///
|
| 759 |
+
/// Detaches the Tensor from the graph that created it, making it a leaf.
|
| 760 |
+
/// Views cannot be detached in-place.
|
| 761 |
+
|
| 762 |
+
/// \fn void retain_grad() const;
|
| 763 |
+
///
|
| 764 |
+
/// Enables this Tensor to have their :attr:`grad` populated during
|
| 765 |
+
/// :func:`backward`. This is a no-op for leaf tensors.
|
| 766 |
+
|
| 767 |
+
/// \fn bool retains_grad() const;
|
| 768 |
+
///
|
| 769 |
+
/// Is ``true`` if this Tensor is non-leaf and its :attr:`grad` is enabled to be
|
| 770 |
+
/// populated during :func:`backward`, ``false`` otherwise.
|
| 771 |
+
|
| 772 |
+
const TensorBase& set_requires_grad(bool requires_grad) const {
|
| 773 |
+
impl_->set_requires_grad(requires_grad);
|
| 774 |
+
return *this;
|
| 775 |
+
}
|
| 776 |
+
bool requires_grad() const {
|
| 777 |
+
return impl_->requires_grad();
|
| 778 |
+
}
|
| 779 |
+
|
| 780 |
+
// The Forward AD API functions below are low level and are not to be used by end
|
| 781 |
+
// users who should use the API provided in torch/csrc/autograd.h
|
| 782 |
+
|
| 783 |
+
/// This function returns the forward gradient for this Tensor at the given level.
|
| 784 |
+
const Tensor& _fw_grad(uint64_t level) const {
|
| 785 |
+
return impl_->_fw_grad(level, *this);
|
| 786 |
+
}
|
| 787 |
+
|
| 788 |
+
/// This function can be used to set the value of the forward grad.
|
| 789 |
+
/// Note that the given new_grad might not be used directly if it has different
|
| 790 |
+
/// metadata (size/stride/storage offset) compared to this Tensor. In that case,
|
| 791 |
+
/// new_grad content will be copied into a new Tensor
|
| 792 |
+
void _set_fw_grad(const TensorBase& new_grad, uint64_t level, bool is_inplace_op) const {
|
| 793 |
+
impl_->_set_fw_grad(new_grad, *this, level, is_inplace_op);
|
| 794 |
+
}
|
| 795 |
+
|
| 796 |
+
/// NOTE: This is similar to the legacy `.data()` function on `Variable`, and is intended
|
| 797 |
+
/// to be used from functions that need to access the `Variable`'s equivalent `Tensor`
|
| 798 |
+
/// (i.e. `Tensor` that shares the same storage and tensor metadata with the `Variable`).
|
| 799 |
+
///
|
| 800 |
+
/// One notable difference with the legacy `.data()` function is that changes to the
|
| 801 |
+
/// returned `Tensor`'s tensor metadata (e.g. sizes / strides / storage / storage_offset)
|
| 802 |
+
/// will not update the original `Variable`, due to the fact that this function
|
| 803 |
+
/// shallow-copies the `Variable`'s underlying TensorImpl.
|
| 804 |
+
at::TensorBase tensor_data() const;
|
| 805 |
+
|
| 806 |
+
/// NOTE: `var.variable_data()` in C++ has the same semantics as `tensor.data`
|
| 807 |
+
/// in Python, which create a new `Variable` that shares the same storage and
|
| 808 |
+
/// tensor metadata with the original `Variable`, but with a completely new
|
| 809 |
+
/// autograd history.
|
| 810 |
+
///
|
| 811 |
+
/// NOTE: If we change the tensor metadata (e.g. sizes / strides /
|
| 812 |
+
/// storage / storage_offset) of a variable created from `var.variable_data()`, those
|
| 813 |
+
/// changes will not update the original variable `var`. In `.variable_data()`, we set
|
| 814 |
+
/// `allow_tensor_metadata_change_` to false to make such changes explicitly illegal,
|
| 815 |
+
/// in order to prevent users from changing metadata of `var.variable_data()`
|
| 816 |
+
/// and expecting the original variable `var` to also be updated.
|
| 817 |
+
at::TensorBase variable_data() const;
|
| 818 |
+
|
| 819 |
+
// Gradient Node and Edges
|
| 820 |
+
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 821 |
+
|
| 822 |
+
/// Gets the gradient function of the `Variable`. If this is a leaf variable,
|
| 823 |
+
/// the pointer returned will be null.
|
| 824 |
+
///
|
| 825 |
+
/// For View Variables:
|
| 826 |
+
/// Gets the up-to-date grad_fn. If the shared data or base was modified, we
|
| 827 |
+
/// re-create the grad_fn to express the up-to-date view relationship between
|
| 828 |
+
/// this and the base Variable.
|
| 829 |
+
const std::shared_ptr<torch::autograd::Node>& grad_fn() const;
|
| 830 |
+
|
| 831 |
+
// Hooks
|
| 832 |
+
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 833 |
+
|
| 834 |
+
template <typename T>
|
| 835 |
+
using hook_return_void_t = std::enable_if_t<std::is_void_v<typename std::invoke_result_t<T&, TensorBase>>, unsigned>;
|
| 836 |
+
template <typename T>
|
| 837 |
+
using hook_return_var_t = std::enable_if_t<std::is_same_v<typename std::invoke_result_t<T&, TensorBase>, TensorBase>, unsigned>;
|
| 838 |
+
|
| 839 |
+
/// Registers a backward hook.
|
| 840 |
+
///
|
| 841 |
+
/// The hook will be called every time a gradient with respect to the Tensor is computed.
|
| 842 |
+
/// The hook should have one of the following signature:
|
| 843 |
+
/// ```
|
| 844 |
+
/// hook(TensorBase grad) -> TensorBase
|
| 845 |
+
/// ```
|
| 846 |
+
/// ```
|
| 847 |
+
/// hook(TensorBase grad) -> void
|
| 848 |
+
/// ```
|
| 849 |
+
/// The hook should not modify its argument, but it can optionally return a new gradient
|
| 850 |
+
/// which will be used in place of `grad`.
|
| 851 |
+
///
|
| 852 |
+
/// This function returns the index of the hook in the list which can be used to remove hook.
|
| 853 |
+
///
|
| 854 |
+
/// Example:
|
| 855 |
+
/// @code
|
| 856 |
+
/// auto v = torch::tensor({0., 0., 0.}, torch::requires_grad());
|
| 857 |
+
/// auto h = v.register_hook([](torch::Tensor grad){ return grad * 2; }); // double the gradient
|
| 858 |
+
/// v.backward(torch::tensor({1., 2., 3.}));
|
| 859 |
+
/// // This prints:
|
| 860 |
+
/// // ```
|
| 861 |
+
/// // 2
|
| 862 |
+
/// // 4
|
| 863 |
+
/// // 6
|
| 864 |
+
/// // [ CPUFloatType{3} ]
|
| 865 |
+
/// // ```
|
| 866 |
+
/// std::cout << v.grad() << std::endl;
|
| 867 |
+
/// v.remove_hook(h); // removes the hook
|
| 868 |
+
/// @endcode
|
| 869 |
+
template <typename T>
|
| 870 |
+
hook_return_void_t<T> register_hook(T&& hook) const;
|
| 871 |
+
template <typename T>
|
| 872 |
+
hook_return_var_t<T> register_hook(T&& hook) const;
|
| 873 |
+
|
| 874 |
+
protected:
|
| 875 |
+
unsigned _register_hook(std::function<TensorBase(const TensorBase&)> hook) const;
|
| 876 |
+
|
| 877 |
+
public:
|
| 878 |
+
|
| 879 |
+
/// Remove hook at given position
|
| 880 |
+
void remove_hook(unsigned pos) const;
|
| 881 |
+
|
| 882 |
+
// Variable methods
|
| 883 |
+
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 884 |
+
|
| 885 |
+
bool is_leaf() const;
|
| 886 |
+
|
| 887 |
+
int64_t output_nr() const;
|
| 888 |
+
|
| 889 |
+
void set_data(const TensorBase & new_data) const;
|
| 890 |
+
|
| 891 |
+
TensorBase data() const;
|
| 892 |
+
|
| 893 |
+
int64_t _version() const;
|
| 894 |
+
|
| 895 |
+
void retain_grad() const;
|
| 896 |
+
|
| 897 |
+
bool retains_grad() const;
|
| 898 |
+
|
| 899 |
+
const TensorBase& requires_grad_(bool _requires_grad=true) const;
|
| 900 |
+
|
| 901 |
+
// View Variables
|
| 902 |
+
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 903 |
+
|
| 904 |
+
/// Returns true if this `Variable` is a view of another `Variable`.
|
| 905 |
+
bool is_view() const;
|
| 906 |
+
|
| 907 |
+
/// Returns the `Variable` that this `Variable` is a view of. If this
|
| 908 |
+
/// `Variable` is not a view, throw a `std::runtime_error`.
|
| 909 |
+
const TensorBase& _base() const;
|
| 910 |
+
|
| 911 |
+
// Miscellaneous
|
| 912 |
+
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 913 |
+
|
| 914 |
+
const std::string& name() const;
|
| 915 |
+
|
| 916 |
+
protected:
|
| 917 |
+
void enforce_invariants();
|
| 918 |
+
c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> impl_;
|
| 919 |
+
|
| 920 |
+
private:
|
| 921 |
+
TensorBase __dispatch_contiguous(c10::MemoryFormat) const;
|
| 922 |
+
};
|
| 923 |
+
|
| 924 |
+
inline DeviceIndex get_device(const TensorBase& self) {
|
| 925 |
+
return self.get_device();
|
| 926 |
+
}
|
| 927 |
+
|
| 928 |
+
template <typename T>
|
| 929 |
+
auto TensorBase::register_hook(T&& hook) const -> TensorBase::hook_return_void_t<T> {
|
| 930 |
+
// Return the grad argument in case of a hook with void return type to have an
|
| 931 |
+
// std::function with Tensor return type
|
| 932 |
+
static_assert(std::is_same_v<decltype(hook(TensorBase())), void>,
|
| 933 |
+
"Expected hook to return void");
|
| 934 |
+
return _register_hook([fn=std::forward<T>(hook)](const TensorBase& grad) {
|
| 935 |
+
fn(grad);
|
| 936 |
+
return TensorBase();
|
| 937 |
+
});
|
| 938 |
+
}
|
| 939 |
+
|
| 940 |
+
template <typename T>
|
| 941 |
+
auto TensorBase::register_hook(T&& hook) const -> TensorBase::hook_return_var_t<T> {
|
| 942 |
+
return _register_hook(std::forward<T>(hook));
|
| 943 |
+
}
|
| 944 |
+
|
| 945 |
+
namespace detail {
|
| 946 |
+
// Helper creator for Tensor class which doesn't requires the users to pass
|
| 947 |
+
// in an intrusive_ptr instead it just converts the argument passed to
|
| 948 |
+
// requested intrusive_ptr type.
|
| 949 |
+
template <typename T, typename... Args>
|
| 950 |
+
TensorBase make_tensor_base(Args&&... args) {
|
| 951 |
+
return TensorBase(c10::make_intrusive<T>(std::forward<Args>(args)...));
|
| 952 |
+
}
|
| 953 |
+
|
| 954 |
+
} // namespace detail
|
| 955 |
+
|
| 956 |
+
inline DispatchKey legacyExtractDispatchKey(const TensorBase& t) {
|
| 957 |
+
return legacyExtractDispatchKey(t.key_set());
|
| 958 |
+
}
|
| 959 |
+
|
| 960 |
+
} // namespace at
|
| 961 |
+
|
| 962 |
+
namespace c10 {
|
| 963 |
+
template <>
|
| 964 |
+
struct MaybeOwnedTraits<at::TensorBase> {
|
| 965 |
+
using owned_type = at::TensorBase;
|
| 966 |
+
using borrow_type = at::TensorBase;
|
| 967 |
+
|
| 968 |
+
static borrow_type createBorrow(const owned_type& from) {
|
| 969 |
+
// NOTE: this can be implemented without the special
|
| 970 |
+
// unsafe_borrow_t Tensor constructor as
|
| 971 |
+
//
|
| 972 |
+
// return borrow_type(c10::intrusive_ptr<at::TensorImpl, at::UndefinedTensorImpl>::reclaim(from.unsafeGetTensorImpl()));
|
| 973 |
+
//
|
| 974 |
+
// but that hurts inlining due to the nullptr check in the
|
| 975 |
+
// Tensor(c10::intrusive_ptr<...>) constructor. We already know
|
| 976 |
+
// that from.impl_ isn't null because from is a valid Tensor, so
|
| 977 |
+
// we needn't do the check again. (using __builtin_assume can
|
| 978 |
+
// avoid this, but wouldn't be portable to MSVC.)
|
| 979 |
+
return borrow_type(borrow_type::unsafe_borrow_t{}, from);
|
| 980 |
+
}
|
| 981 |
+
|
| 982 |
+
static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) {
|
| 983 |
+
lhs.unsafeReleaseTensorImpl();
|
| 984 |
+
// See above note: this can be implemented with public API
|
| 985 |
+
// similarly to createBorrow(), but that would hurt inlining.
|
| 986 |
+
lhs = borrow_type(borrow_type::unsafe_borrow_t{}, rhs);
|
| 987 |
+
}
|
| 988 |
+
|
| 989 |
+
static void destroyBorrow(borrow_type& toDestroy) {
|
| 990 |
+
toDestroy.unsafeReleaseTensorImpl(); // "leak" it, but it was already +0.
|
| 991 |
+
}
|
| 992 |
+
|
| 993 |
+
static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
|
| 994 |
+
return borrow;
|
| 995 |
+
}
|
| 996 |
+
|
| 997 |
+
static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
|
| 998 |
+
return &borrow;
|
| 999 |
+
}
|
| 1000 |
+
|
| 1001 |
+
static bool debugBorrowIsValid(const borrow_type& /*borrow*/) {
|
| 1002 |
+
return true;
|
| 1003 |
+
}
|
| 1004 |
+
};
|
| 1005 |
+
|
| 1006 |
+
template <>
|
| 1007 |
+
struct ExclusivelyOwnedTraits<at::TensorBase> : public c10::ExclusivelyOwnedTensorTraits<at::TensorBase> {};
|
| 1008 |
+
} // namespace c10
|
| 1009 |
+
|
| 1010 |
+
namespace at {
|
| 1011 |
+
|
| 1012 |
+
inline c10::MaybeOwned<TensorBase> borrow_from_optional_tensor(
|
| 1013 |
+
const std::optional<TensorBase>& opt) {
|
| 1014 |
+
return opt.has_value()
|
| 1015 |
+
? c10::MaybeOwned<TensorBase>::borrowed(*opt)
|
| 1016 |
+
: c10::MaybeOwned<TensorBase>::owned(std::in_place);
|
| 1017 |
+
}
|
| 1018 |
+
|
| 1019 |
+
inline c10::MaybeOwned<TensorBase> TensorBase::expect_contiguous(MemoryFormat memory_format) const & {
|
| 1020 |
+
if (is_contiguous(memory_format)) {
|
| 1021 |
+
return c10::MaybeOwned<TensorBase>::borrowed(*this);
|
| 1022 |
+
} else {
|
| 1023 |
+
return c10::MaybeOwned<TensorBase>::owned(__dispatch_contiguous(memory_format));
|
| 1024 |
+
}
|
| 1025 |
+
}
|
| 1026 |
+
|
| 1027 |
+
namespace symint {
|
| 1028 |
+
|
| 1029 |
+
template <typename T>
|
| 1030 |
+
using enable_if_symint = std::enable_if_t<std::is_same_v<T, c10::SymInt>>;
|
| 1031 |
+
template <typename T>
|
| 1032 |
+
using enable_if_int = std::enable_if_t<std::is_same_v<T, int64_t>>;
|
| 1033 |
+
|
| 1034 |
+
template <typename T, typename = enable_if_symint<T>>
|
| 1035 |
+
c10::SymIntArrayRef sizes(const TensorBase& t) { return t.sym_sizes(); }
|
| 1036 |
+
template <typename T, typename = enable_if_int<T>>
|
| 1037 |
+
IntArrayRef sizes(const TensorBase& t) { return t.sizes(); }
|
| 1038 |
+
|
| 1039 |
+
template <typename T, typename = enable_if_symint<T>>
|
| 1040 |
+
c10::SymInt size(const TensorBase& t, int64_t dim) { return t.sym_size(dim); }
|
| 1041 |
+
template <typename T, typename = enable_if_int<T>>
|
| 1042 |
+
int64_t size(const TensorBase& t, int64_t dim) { return t.size(dim); }
|
| 1043 |
+
|
| 1044 |
+
template <typename T, typename = enable_if_symint<T>>
|
| 1045 |
+
c10::SymIntArrayRef strides(const TensorBase& t) { return t.sym_strides(); }
|
| 1046 |
+
template <typename T, typename = enable_if_int<T>>
|
| 1047 |
+
IntArrayRef strides(const TensorBase& t) { return t.strides(); }
|
| 1048 |
+
|
| 1049 |
+
template <typename T, typename = enable_if_symint<T>>
|
| 1050 |
+
c10::SymInt numel(const TensorBase& t) { return t.sym_numel(); }
|
| 1051 |
+
template <typename T, typename = enable_if_int<T>>
|
| 1052 |
+
int64_t numel(const TensorBase& t) { return t.numel(); }
|
| 1053 |
+
|
| 1054 |
+
} // namespace symint
|
| 1055 |
+
|
| 1056 |
+
} // namespace at
|
phivenv/Lib/site-packages/torch/include/ATen/core/TensorBody.h
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
phivenv/Lib/site-packages/torch/include/ATen/core/TorchDispatchUtils.h
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/dispatch/Dispatcher.h>
|
| 4 |
+
#include <c10/core/impl/TorchDispatchModeTLS.h>
|
| 5 |
+
#include <c10/util/ArrayRef.h>
|
| 6 |
+
#include <torch/library.h>
|
| 7 |
+
#include <optional>
|
| 8 |
+
|
| 9 |
+
namespace at::impl {
|
| 10 |
+
|
| 11 |
+
TORCH_API bool tensor_has_dispatch(const at::Tensor& t);
|
| 12 |
+
TORCH_API bool tensorlist_has_dispatch(at::ITensorListRef li);
|
| 13 |
+
TORCH_API bool tensorlist_has_dispatch(
|
| 14 |
+
const c10::List<std::optional<at::Tensor>>& li);
|
| 15 |
+
using c10::impl::dispatch_mode_enabled;
|
| 16 |
+
|
| 17 |
+
} // namespace at::impl
|
phivenv/Lib/site-packages/torch/include/ATen/core/TransformationHelper.h
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/NumericUtils.h>
|
| 2 |
+
#include <c10/macros/Macros.h>
|
| 3 |
+
#include <c10/util/Half.h>
|
| 4 |
+
#include <c10/util/BFloat16.h>
|
| 5 |
+
#include <c10/util/MathConstants.h>
|
| 6 |
+
#include <cmath>
|
| 7 |
+
#include <cstdint>
|
| 8 |
+
#include <cassert>
|
| 9 |
+
#include <limits>
|
| 10 |
+
#include <type_traits>
|
| 11 |
+
|
| 12 |
+
namespace at {
|
| 13 |
+
|
| 14 |
+
// Using DistAccumType in accumulate types for distributions.
|
| 15 |
+
// Note: Ideally we'd be using ATen/AccumulateType.h but looks
|
| 16 |
+
// like the there is some inconsistency in how accumulate types
|
| 17 |
+
// are mapped currently, e.g. for the cpu side, float is mapped
|
| 18 |
+
// to double.
|
| 19 |
+
template <typename T>
|
| 20 |
+
struct DistAccumType { };
|
| 21 |
+
|
| 22 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 23 |
+
template <> struct DistAccumType<half> { using type = float; };
|
| 24 |
+
#endif
|
| 25 |
+
template <> struct DistAccumType<BFloat16> { using type = float; };
|
| 26 |
+
template <> struct DistAccumType<Half> { using type = float; };
|
| 27 |
+
template <> struct DistAccumType<float> { using type = float; };
|
| 28 |
+
template <> struct DistAccumType<double> { using type = double; };
|
| 29 |
+
|
| 30 |
+
template <typename T>
|
| 31 |
+
using dist_acctype = typename DistAccumType<T>::type;
|
| 32 |
+
|
| 33 |
+
namespace transformation {
|
| 34 |
+
|
| 35 |
+
/**
|
| 36 |
+
* A transformation function for `torch.Tensor.random_()`, when both `from` and `to` are specified.
|
| 37 |
+
* `range` is `to - from`
|
| 38 |
+
* `base` is `from`
|
| 39 |
+
*/
|
| 40 |
+
template <typename T, typename V>
|
| 41 |
+
C10_HOST_DEVICE inline T uniform_int_from_to(V val, uint64_t range, int64_t base) {
|
| 42 |
+
return static_cast<T>(static_cast<int64_t>((val % range) + base));
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
/**
|
| 46 |
+
* A transformation function for `torch.Tensor.random_()`, when `from=min_value(int64_t)` and to=None
|
| 47 |
+
*/
|
| 48 |
+
template <typename T, typename V>
|
| 49 |
+
C10_HOST_DEVICE inline T uniform_int_full_range(V val) {
|
| 50 |
+
return static_cast<T>(static_cast<int64_t>(val));
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
/**
|
| 54 |
+
* A transformation function for `torch.Tensor.random_()`, when used without specifying `from` and `to`.
|
| 55 |
+
* In order to prevent compiler warnings reported in GitHub issue 46391, T can't be float or double
|
| 56 |
+
* in this overloaded version
|
| 57 |
+
*/
|
| 58 |
+
template <typename T, typename V>
|
| 59 |
+
C10_HOST_DEVICE inline std::enable_if_t<!(std::is_floating_point_v<T>), T>uniform_int(V val) {
|
| 60 |
+
if constexpr (std::is_same_v<T, bool>) {
|
| 61 |
+
return static_cast<bool>(val & 1);
|
| 62 |
+
} else if constexpr (std::is_same_v<T, int64_t>) {
|
| 63 |
+
return static_cast<T>(val % (static_cast<uint64_t>(std::numeric_limits<T>::max()) + 1));
|
| 64 |
+
} else if constexpr (std::is_same_v<T, at::Half> || std::is_same_v<T, at::BFloat16>) {
|
| 65 |
+
return static_cast<T>(val % static_cast<uint64_t>((1ULL << std::numeric_limits<T>::digits) + 1));
|
| 66 |
+
} else if constexpr (std::is_integral_v<T>) {
|
| 67 |
+
return static_cast<T>(val % (static_cast<uint64_t>(std::numeric_limits<T>::max()) + 1));
|
| 68 |
+
} else {
|
| 69 |
+
assert(false);
|
| 70 |
+
return 0;
|
| 71 |
+
}
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
/**
|
| 75 |
+
* An overloaded transformation function for `torch.Tensor.random_()`, when used without specifying `from` and `to`,
|
| 76 |
+
* added to fix compiler warnings reported in GitHub issue 46391. T is either float or double in this version.
|
| 77 |
+
*/
|
| 78 |
+
template<typename T, typename V>
|
| 79 |
+
C10_HOST_DEVICE inline std::enable_if_t<std::is_floating_point_v<T>, T>uniform_int(V val) {
|
| 80 |
+
return static_cast<T>(val % static_cast<uint64_t>((1ULL << std::numeric_limits<T>::digits) + 1));
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
template <typename T, typename V>
|
| 84 |
+
C10_HOST_DEVICE inline dist_acctype<T> uniform_real(V val, T from, T to) {
|
| 85 |
+
constexpr auto MASK = static_cast<V>((static_cast<uint64_t>(1) << std::numeric_limits<T>::digits) - 1);
|
| 86 |
+
constexpr auto DIVISOR = static_cast<dist_acctype<T>>(1) / (static_cast<uint64_t>(1) << std::numeric_limits<T>::digits);
|
| 87 |
+
dist_acctype<T> x = (val & MASK) * DIVISOR;
|
| 88 |
+
return (x * (to - from) + from);
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
/**
|
| 92 |
+
* Transforms normally distributed `val` with mean 0.0 and standard deviation 1.0 to
|
| 93 |
+
* normally distributed with `mean` and standard deviation `std`.
|
| 94 |
+
*/
|
| 95 |
+
template <typename T>
|
| 96 |
+
C10_HOST_DEVICE inline T normal(T val, T mean, T std) {
|
| 97 |
+
return val * std + mean;
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
/**
|
| 101 |
+
* Transforms uniformly distributed `val` between 0.0 and 1.0 to
|
| 102 |
+
* Cauchy distribution with location parameter `median` and scale parameter `sigma`.
|
| 103 |
+
*/
|
| 104 |
+
template <typename T>
|
| 105 |
+
C10_HOST_DEVICE inline T cauchy(T val, T median, T sigma) {
|
| 106 |
+
// https://en.wikipedia.org/wiki/Cauchy_distribution#Cumulative_distribution_function
|
| 107 |
+
// __tanf overflows and returns `inf/-inf` when (val > 1 - eps) or (val < 0 + eps),
|
| 108 |
+
// thus we clip those values.
|
| 109 |
+
constexpr T eps = std::numeric_limits<T>::epsilon();
|
| 110 |
+
constexpr T one_minus_eps = 1 - eps;
|
| 111 |
+
constexpr T zero_plus_eps = 0 + eps;
|
| 112 |
+
val = (val > one_minus_eps ? one_minus_eps : val);
|
| 113 |
+
val = (val < zero_plus_eps ? zero_plus_eps : val);
|
| 114 |
+
return median + sigma * at::tan(c10::pi<T> * (val - static_cast<T>(0.5)));
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
template <>
|
| 118 |
+
C10_HOST_DEVICE inline double cauchy(double val, double median, double sigma) {
|
| 119 |
+
// https://en.wikipedia.org/wiki/Cauchy_distribution#Cumulative_distribution_function
|
| 120 |
+
return median + sigma * at::tan(c10::pi<double> * (val - static_cast<double>(0.5)));
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
/**
|
| 124 |
+
* Transforms uniformly distributed `val` between 0.0 and 1.0 to
|
| 125 |
+
* exponentially distributed with `lambda` parameter of the distribution.
|
| 126 |
+
*/
|
| 127 |
+
template <typename T>
|
| 128 |
+
C10_HOST_DEVICE inline T exponential(T val, T lambda) {
|
| 129 |
+
// https://en.wikipedia.org/wiki/Exponential_distribution#Generating_exponential_variates
|
| 130 |
+
// Different implementations for CUDA and CPU to preserve original logic
|
| 131 |
+
// TODO: must be investigated and unified!!!
|
| 132 |
+
// https://github.com/pytorch/pytorch/issues/38662
|
| 133 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 134 |
+
// BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/16706
|
| 135 |
+
// curand_uniform has (0,1] bounds. log(1) is 0 and exponential excludes 0.
|
| 136 |
+
// we need log to be not 0, and not underflow when converted to half
|
| 137 |
+
// fast __logf approximation can underflow, so set log to -epsilon/2 for 1 or close to 1 args
|
| 138 |
+
auto log = val >= static_cast<T>(1.) - std::numeric_limits<T>::epsilon() / 2
|
| 139 |
+
? -std::numeric_limits<T>::epsilon() / 2
|
| 140 |
+
: at::log(val);
|
| 141 |
+
return static_cast<T>(-1.0) / lambda * log;
|
| 142 |
+
#else
|
| 143 |
+
return static_cast<T>(-1.0) / lambda * at::log1p(-val);
|
| 144 |
+
#endif
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
/**
|
| 148 |
+
* Transforms uniformly distributed `val` between 0.0 and 1.0 to
|
| 149 |
+
* geometrically distributed with success probability `p`.
|
| 150 |
+
*/
|
| 151 |
+
template <typename T>
|
| 152 |
+
C10_HOST_DEVICE inline T geometric(T val, T p) {
|
| 153 |
+
// https://en.wikipedia.org/wiki/Geometric_distribution#Related_distributions
|
| 154 |
+
return static_cast<T>(::ceil(at::log(val) / at::log1p(-p)));
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
/**
|
| 158 |
+
* Transforms normally distributed `val` to log-normally distributed.
|
| 159 |
+
*/
|
| 160 |
+
template <typename T>
|
| 161 |
+
C10_HOST_DEVICE inline T log_normal(T val) {
|
| 162 |
+
// https://en.wikipedia.org/wiki/Log-normal_distribution#Mode,_median,_quantiles
|
| 163 |
+
return at::exp(val);
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
/**
|
| 167 |
+
* Transforms uniformly distributed `val` between 0.0 and 1.0 to
|
| 168 |
+
* bernoulli distributed with success probability `p`.
|
| 169 |
+
*/
|
| 170 |
+
template <typename T>
|
| 171 |
+
C10_HOST_DEVICE inline T bernoulli(T val, T p) {
|
| 172 |
+
return val < p;
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
}} // namespace at::transformation
|
phivenv/Lib/site-packages/torch/include/ATen/core/UndefinedTensorImpl.h
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
#include <c10/core/UndefinedTensorImpl.h>
|
phivenv/Lib/site-packages/torch/include/ATen/core/UnsafeFromTH.h
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/Tensor.h>
|
| 3 |
+
|
| 4 |
+
namespace at {
|
| 5 |
+
|
| 6 |
+
inline Tensor unsafeTensorFromTH(void * th_pointer, bool retain) {
|
| 7 |
+
auto tensor_impl = c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(static_cast<TensorImpl*>(th_pointer));
|
| 8 |
+
if (retain && tensor_impl.get() != UndefinedTensorImpl::singleton()) {
|
| 9 |
+
c10::raw::intrusive_ptr::incref(tensor_impl.get());
|
| 10 |
+
}
|
| 11 |
+
return Tensor(std::move(tensor_impl));
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
inline Storage unsafeStorageFromTH(void * th_pointer, bool retain) {
|
| 15 |
+
if (retain && th_pointer) {
|
| 16 |
+
c10::raw::intrusive_ptr::incref(static_cast<StorageImpl*>(th_pointer));
|
| 17 |
+
}
|
| 18 |
+
return Storage(c10::intrusive_ptr<StorageImpl>::reclaim(static_cast<StorageImpl*>(th_pointer)));
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
}
|
phivenv/Lib/site-packages/torch/include/ATen/core/VariableHooksInterface.h
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <c10/macros/Export.h>
|
| 5 |
+
|
| 6 |
+
// A little explanation about why this file exists at all. We have
|
| 7 |
+
// a few methods on Tensor class which require access to reified access to
|
| 8 |
+
// AutogradMeta. In open source, this isn't a big deal: we just access
|
| 9 |
+
// torch/csrc/autograd/variable.h from aten/src/ATen/core/Tensor.cpp and
|
| 10 |
+
// we can put the definitions inline. This is because everything gets balled
|
| 11 |
+
// into a single dynamic library in the end.
|
| 12 |
+
//
|
| 13 |
+
// However, inside our Facebook internal version of our build system, we
|
| 14 |
+
// have a split between aten and torch/csrc. So we cannot simply just
|
| 15 |
+
// cross this boundary. "Now wait," you might say, "Why don't we just
|
| 16 |
+
// merge the libraries inside Facebook". Well, the problem is that there
|
| 17 |
+
// are some downstream applications which are at binary size limit, and
|
| 18 |
+
// incorporating all of the extra code from libtorch would push them
|
| 19 |
+
// over (admarket/adreview/service:adreviewservice, see also
|
| 20 |
+
// https://github.com/pytorch/pytorch/pull/29299) So if you want to do that,
|
| 21 |
+
// we have to fix all of the services like this.
|
| 22 |
+
//
|
| 23 |
+
// I didn't want to block eliminating Tensor-Variable on this work, so I
|
| 24 |
+
// had to introduce another dynamic dispatch to get to the variable
|
| 25 |
+
// implementations (which live in torch/csrc/autograd/variable.cpp, FYI).
|
| 26 |
+
//
|
| 27 |
+
// I also considered using our existing dynamic dispatch mechanism, c10
|
| 28 |
+
// dispatcher, to do this. However, (1) some of the functions on Tensor
|
| 29 |
+
// have weird signatures that are not supported by autograd, and (2)
|
| 30 |
+
// see this bug https://github.com/pytorch/pytorch/issues/30102
|
| 31 |
+
|
| 32 |
+
namespace torch::autograd {
|
| 33 |
+
|
| 34 |
+
struct Node;
|
| 35 |
+
|
| 36 |
+
} // namespace torch::autograd
|
| 37 |
+
|
| 38 |
+
namespace at::impl {
|
| 39 |
+
|
| 40 |
+
struct TORCH_API VariableHooksInterface {
|
| 41 |
+
virtual ~VariableHooksInterface() = default;
|
| 42 |
+
virtual TensorBase tensor_data(const TensorBase&) const = 0;
|
| 43 |
+
virtual TensorBase variable_data(const TensorBase&) const = 0;
|
| 44 |
+
virtual const std::shared_ptr<torch::autograd::Node>& grad_fn(
|
| 45 |
+
const TensorBase&) const = 0;
|
| 46 |
+
virtual unsigned _register_hook(
|
| 47 |
+
const TensorBase&,
|
| 48 |
+
std::function<TensorBase(const TensorBase&)> hook) const = 0;
|
| 49 |
+
virtual void remove_hook(const TensorBase&, unsigned pos) const = 0;
|
| 50 |
+
virtual bool is_view(const TensorBase&) const = 0;
|
| 51 |
+
virtual const TensorBase& base(const TensorBase&) const = 0;
|
| 52 |
+
virtual const std::string& name(const TensorBase&) const = 0;
|
| 53 |
+
virtual bool is_leaf(const TensorBase&) const = 0;
|
| 54 |
+
virtual int64_t output_nr(const TensorBase&) const = 0;
|
| 55 |
+
virtual void set_data(const TensorBase&, const TensorBase&) const = 0;
|
| 56 |
+
virtual TensorBase data(const TensorBase&) const = 0;
|
| 57 |
+
virtual int64_t _version(const TensorBase&) const = 0;
|
| 58 |
+
virtual void retain_grad(const TensorBase&) const = 0;
|
| 59 |
+
virtual bool retains_grad(const TensorBase&) const = 0;
|
| 60 |
+
virtual void _backward(
|
| 61 |
+
const Tensor&,
|
| 62 |
+
TensorList,
|
| 63 |
+
const std::optional<Tensor>&,
|
| 64 |
+
std::optional<bool>,
|
| 65 |
+
bool) const = 0;
|
| 66 |
+
virtual void requires_grad_(const TensorBase&, bool) const = 0;
|
| 67 |
+
virtual void basic_autograd_not_implemented_fallback(
|
| 68 |
+
const c10::OperatorHandle& op,
|
| 69 |
+
c10::DispatchKeySet dispatch_keys,
|
| 70 |
+
torch::jit::Stack* stack) const = 0;
|
| 71 |
+
};
|
| 72 |
+
|
| 73 |
+
TORCH_API void SetVariableHooks(VariableHooksInterface* hooks);
|
| 74 |
+
TORCH_API VariableHooksInterface* GetVariableHooks();
|
| 75 |
+
TORCH_API bool HasVariableHooks();
|
| 76 |
+
|
| 77 |
+
struct TORCH_API VariableHooksRegisterer {
|
| 78 |
+
explicit VariableHooksRegisterer(VariableHooksInterface* hooks) {
|
| 79 |
+
SetVariableHooks(hooks);
|
| 80 |
+
}
|
| 81 |
+
};
|
| 82 |
+
|
| 83 |
+
} // namespace at::impl
|
phivenv/Lib/site-packages/torch/include/ATen/core/Variadic.h
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <utility>
|
| 4 |
+
|
| 5 |
+
#include <c10/util/ArrayRef.h>
|
| 6 |
+
#include <ATen/core/List.h>
|
| 7 |
+
|
| 8 |
+
namespace at {
|
| 9 |
+
|
| 10 |
+
// This class allows you to write variadic functions which
|
| 11 |
+
// call a (possibly overloaded) function on each argument,
|
| 12 |
+
// in order. This is most commonly used in autogenerated code,
|
| 13 |
+
// where it is convenient to have a function that can uniformly
|
| 14 |
+
// take arguments of different types. If your arguments
|
| 15 |
+
// are homogenous consider using a std::initializer_list instead.
|
| 16 |
+
//
|
| 17 |
+
// For examples of this in use, see torch/csrc/utils/variadic.h
|
| 18 |
+
template <typename F>
|
| 19 |
+
struct IterArgs {
|
| 20 |
+
template <typename... Args>
|
| 21 |
+
inline F& apply() {
|
| 22 |
+
return self();
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
// NB: Use perfect forwarding here, otherwise we'll make value
|
| 26 |
+
// copies of all arguments!
|
| 27 |
+
template <typename T, typename... Args>
|
| 28 |
+
inline F& apply(T&& arg, Args&&... args) {
|
| 29 |
+
self()(std::forward<T>(arg));
|
| 30 |
+
if (self().short_circuit()) {
|
| 31 |
+
return self();
|
| 32 |
+
} else {
|
| 33 |
+
return apply(std::forward<Args>(args)...);
|
| 34 |
+
}
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
// Here are some handy overloads which provide sensible
|
| 38 |
+
// defaults for container-like structures that one might
|
| 39 |
+
// be interested in recursing into. You can enable them
|
| 40 |
+
// by adding:
|
| 41 |
+
//
|
| 42 |
+
// using IterArgs<YourStructName>::operator()
|
| 43 |
+
//
|
| 44 |
+
// to your struct. These are not enabled by default because
|
| 45 |
+
// you may be able to process these structures more efficiently
|
| 46 |
+
// than handling them one-by-one.
|
| 47 |
+
|
| 48 |
+
template <typename T>
|
| 49 |
+
void operator()(c10::IListRef<T> args) {
|
| 50 |
+
for (const auto& arg : args) {
|
| 51 |
+
self()(arg);
|
| 52 |
+
if (self().short_circuit())
|
| 53 |
+
return;
|
| 54 |
+
}
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
template <typename T>
|
| 58 |
+
void operator()(at::ArrayRef<T> args) {
|
| 59 |
+
for (const auto& arg : args) {
|
| 60 |
+
self()(arg);
|
| 61 |
+
if (self().short_circuit())
|
| 62 |
+
return;
|
| 63 |
+
}
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
template <typename T>
|
| 67 |
+
void operator()(const torch::List<T>& args) {
|
| 68 |
+
for (const auto& arg : args) {
|
| 69 |
+
self()(arg);
|
| 70 |
+
if (self().short_circuit())
|
| 71 |
+
return;
|
| 72 |
+
}
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
// NB: we need to specify std::vector manually as C++ won't
|
| 76 |
+
// do an implicit conversion to make a template deduction go through.
|
| 77 |
+
template <typename T>
|
| 78 |
+
void operator()(const std::vector<T>& args) {
|
| 79 |
+
self()(at::ArrayRef<T>{args});
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
constexpr bool short_circuit() const {
|
| 83 |
+
return false;
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
private:
|
| 87 |
+
inline F& self() {
|
| 88 |
+
return *static_cast<F*>(this);
|
| 89 |
+
}
|
| 90 |
+
};
|
| 91 |
+
|
| 92 |
+
} // namespace torch
|
phivenv/Lib/site-packages/torch/include/ATen/core/Vitals.h
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ostream>
|
| 3 |
+
#include <sstream>
|
| 4 |
+
#include <unordered_map>
|
| 5 |
+
|
| 6 |
+
#include <c10/core/impl/LocalDispatchKeySet.h>
|
| 7 |
+
|
| 8 |
+
namespace at::vitals {
|
| 9 |
+
|
| 10 |
+
TORCH_API bool torchVitalEnabled();
|
| 11 |
+
|
| 12 |
+
struct TORCH_API TorchVitalAttr {
|
| 13 |
+
// always initialized to empty
|
| 14 |
+
std::string value;
|
| 15 |
+
template <typename T>
|
| 16 |
+
TorchVitalAttr& operator<<(const T& t) {
|
| 17 |
+
if (torchVitalEnabled()) {
|
| 18 |
+
std::stringstream ss;
|
| 19 |
+
ss << t;
|
| 20 |
+
value += ss.str();
|
| 21 |
+
}
|
| 22 |
+
return *this;
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
template <typename T>
|
| 26 |
+
void write(const T& t, bool force) {
|
| 27 |
+
if (force || torchVitalEnabled()) {
|
| 28 |
+
std::stringstream ss;
|
| 29 |
+
ss << t;
|
| 30 |
+
value = ss.str();
|
| 31 |
+
}
|
| 32 |
+
}
|
| 33 |
+
};
|
| 34 |
+
|
| 35 |
+
struct TORCH_API TorchVital {
|
| 36 |
+
std::string name;
|
| 37 |
+
std::unordered_map<std::string, TorchVitalAttr> attrs;
|
| 38 |
+
|
| 39 |
+
explicit TorchVital(std::string n) : name(std::move(n)) {}
|
| 40 |
+
TorchVital(const TorchVital&) = default;
|
| 41 |
+
TorchVital(TorchVital&&) = default;
|
| 42 |
+
TorchVital& operator=(const TorchVital&) = default;
|
| 43 |
+
TorchVital& operator=(TorchVital&&) = default;
|
| 44 |
+
TorchVital() = delete;
|
| 45 |
+
|
| 46 |
+
TorchVitalAttr& create(const std::string& attr);
|
| 47 |
+
TorchVitalAttr& create(const std::string& attr, bool force);
|
| 48 |
+
friend std::ostream& operator<<(std::ostream& os, const TorchVital& dt);
|
| 49 |
+
|
| 50 |
+
~TorchVital();
|
| 51 |
+
};
|
| 52 |
+
|
| 53 |
+
std::ostream& operator<<(std::ostream& os, TorchVital const& tv);
|
| 54 |
+
|
| 55 |
+
// A way to access vitals by string names instead of by global reference.
|
| 56 |
+
// This enables access to vitals from the PythonAPI.
|
| 57 |
+
class TORCH_API APIVitals {
|
| 58 |
+
public:
|
| 59 |
+
bool vitals_enabled;
|
| 60 |
+
|
| 61 |
+
// Set any vital sign that was added to the map.
|
| 62 |
+
bool setVital(
|
| 63 |
+
const std::string& vital_name,
|
| 64 |
+
const std::string& attr_name,
|
| 65 |
+
const std::string& value,
|
| 66 |
+
bool force = false);
|
| 67 |
+
std::string readVitals();
|
| 68 |
+
|
| 69 |
+
APIVitals();
|
| 70 |
+
|
| 71 |
+
// Ensure this stays a singleton
|
| 72 |
+
APIVitals(APIVitals const& other) = delete;
|
| 73 |
+
APIVitals(APIVitals&& other) = delete;
|
| 74 |
+
APIVitals& operator=(const APIVitals&) = delete;
|
| 75 |
+
APIVitals& operator=(APIVitals&&) = delete;
|
| 76 |
+
~APIVitals() = default;
|
| 77 |
+
|
| 78 |
+
private:
|
| 79 |
+
std::unordered_map<std::string, TorchVital> name_map_;
|
| 80 |
+
};
|
| 81 |
+
|
| 82 |
+
extern TORCH_API APIVitals VitalsAPI;
|
| 83 |
+
|
| 84 |
+
} // namespace at::vitals
|
| 85 |
+
|
| 86 |
+
#define TORCH_VITAL_DECLARE(name) \
|
| 87 |
+
TORCH_API at::vitals::TorchVital TorchVital_##name;
|
| 88 |
+
|
| 89 |
+
#define TORCH_VITAL_DEFINE(name) \
|
| 90 |
+
TORCH_API at::vitals::TorchVital TorchVital_##name(#name);
|
| 91 |
+
|
| 92 |
+
#define TORCH_VITAL_BASE(name) TorchVital_##name
|
| 93 |
+
|
| 94 |
+
#define TORCH_VITAL(name, attr) TORCH_VITAL_BASE(name).create(#attr)
|
phivenv/Lib/site-packages/torch/include/ATen/core/boxing/BoxedKernel.h
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/boxing/OperatorKernel.h>
|
| 4 |
+
#include <c10/core/DispatchKeySet.h>
|
| 5 |
+
#include <c10/util/intrusive_ptr.h>
|
| 6 |
+
|
| 7 |
+
namespace c10 {
|
| 8 |
+
|
| 9 |
+
struct IValue;
|
| 10 |
+
using Stack = std::vector<IValue>;
|
| 11 |
+
|
| 12 |
+
class OperatorHandle;
|
| 13 |
+
class KernelFunction;
|
| 14 |
+
|
| 15 |
+
// This kernel implements the behavior of falling through to the next available
|
| 16 |
+
// registered dispatch key. The implementation of this function is FAST; it is
|
| 17 |
+
// no overhead to fallthrough to the next key. See cpp file for some more
|
| 18 |
+
// implementation notes; notably, this does NOT actually go through the
|
| 19 |
+
// boxing/unboxing codepath.
|
| 20 |
+
TORCH_API void fallthrough_kernel(
|
| 21 |
+
OperatorKernel*,
|
| 22 |
+
const OperatorHandle&,
|
| 23 |
+
DispatchKeySet,
|
| 24 |
+
Stack*);
|
| 25 |
+
|
| 26 |
+
// Note [Ambiguity in AutogradOther kernel]
|
| 27 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 28 |
+
// This error-reporting kernel is registered to the AutogradOther entry in the
|
| 29 |
+
// dispatch table when there is both a CompositeImplicitAutograd kernel and a
|
| 30 |
+
// backend kernel for ANY backend that maps to AutogradOther. To see why
|
| 31 |
+
// this is necessary in the AutogradOther case, it's helpful to first see
|
| 32 |
+
// why everything works out fine for a backend that has a reserved Autograd
|
| 33 |
+
// entry (see rule 2.2 in [Note] DispatchTable computation):
|
| 34 |
+
//
|
| 35 |
+
// CPU AutogradCPU
|
| 36 |
+
// reg? registers with...
|
| 37 |
+
// -------------------------------------------------
|
| 38 |
+
// y Autograd registration takes precedence
|
| 39 |
+
// over CompositeImplicitAutograd.
|
| 40 |
+
// This is good, because the CPU specific backend
|
| 41 |
+
// implementation is more specialized and typically better;
|
| 42 |
+
// if we used the composite, we would bypass it.
|
| 43 |
+
// (NB: the Autograd key is guaranteed to exist because
|
| 44 |
+
// the autograd codegen requires it!)
|
| 45 |
+
//
|
| 46 |
+
// n CompositeImplicitAutograd takes precedence.
|
| 47 |
+
// This is also good, because the Autograd
|
| 48 |
+
// registration (if it exists) would try to redispatch
|
| 49 |
+
// to the (non-existent) CPU implementation; by
|
| 50 |
+
// using the composite, we ensure the operator
|
| 51 |
+
// actually works.
|
| 52 |
+
//
|
| 53 |
+
// As you can see, when we have a specific Autograd key (AutogradCPU), we can
|
| 54 |
+
// decide whether or not to use the CompositeImplicitAutograd kernel or the
|
| 55 |
+
// Autograd kernel based on whether or not the backend kernel exists.
|
| 56 |
+
//
|
| 57 |
+
// However, for AutogradOther (which is the catchall autograd kernel for
|
| 58 |
+
// everything that doesn't have a specific Autograd key), we can't do this
|
| 59 |
+
// trick because there isn't any unique backend to peek at to disambiguate;
|
| 60 |
+
// if there are some backends that have implementations they prefer Autograd,
|
| 61 |
+
// but unimplemented backends would prefer CompositeImplicitAutograd. Rather
|
| 62 |
+
// than arbitrarily pick one or the other, we just register a kernel that raises
|
| 63 |
+
// an error and let the user decide how to proceed.
|
| 64 |
+
TORCH_API void ambiguous_autogradother_kernel(
|
| 65 |
+
OperatorKernel*,
|
| 66 |
+
const OperatorHandle&,
|
| 67 |
+
DispatchKeySet,
|
| 68 |
+
Stack*);
|
| 69 |
+
|
| 70 |
+
// Note [named_not_supported_kernel]
|
| 71 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 72 |
+
// This kernel implements reporting an error message saying that named tensor is
|
| 73 |
+
// not supported. This kernel doesn't rely on the Stack, and so it is special
|
| 74 |
+
// cased in the dispatcher to be triggered before we attempt boxing (so we can
|
| 75 |
+
// give a good error message in cases when boxing is not supported). When
|
| 76 |
+
// boxing is universally supported this can be removed.
|
| 77 |
+
[[noreturn]] TORCH_API void named_not_supported_kernel(
|
| 78 |
+
OperatorKernel*,
|
| 79 |
+
const OperatorHandle&,
|
| 80 |
+
DispatchKeySet,
|
| 81 |
+
Stack*);
|
| 82 |
+
|
| 83 |
+
/**
|
| 84 |
+
* BoxedKernel is similar to a std::function storing a boxed kernel.
|
| 85 |
+
*/
|
| 86 |
+
class TORCH_API BoxedKernel final {
|
| 87 |
+
public:
|
| 88 |
+
// This is how boxed kernels are actually stored
|
| 89 |
+
//
|
| 90 |
+
// Note [Plumbing Keys Through The Dispatcher]
|
| 91 |
+
// Benchmarks have shown that it is expensive for the dispatcher to read from
|
| 92 |
+
// thread-local storage (TLS) upon every dispatch call into order to compute
|
| 93 |
+
// which kernel to dispatch to.
|
| 94 |
+
//
|
| 95 |
+
// To mitigate this, we've updated the calling convention inside the
|
| 96 |
+
// dispatcher to expect every kernel that it stores to have a first argument
|
| 97 |
+
// of type DispatchKeySet.
|
| 98 |
+
//
|
| 99 |
+
// What are the invariants of the DispatchKeySet when it gets passed to a
|
| 100 |
+
// kernel?
|
| 101 |
+
// - All keys to the left of the current dispatch key have been masked out.
|
| 102 |
+
// (e.g. a Tracing kernel that takes in the DispatchKeySet will expect the
|
| 103 |
+
// highest bit to be DispatchKey::Tracer)
|
| 104 |
+
// - All other keys that dispatcher normally would have computed through TLS +
|
| 105 |
+
// global state + op arguments
|
| 106 |
+
// are still in the set.
|
| 107 |
+
//
|
| 108 |
+
// Kernels can then opt into using this keyset to save the dispatcher from
|
| 109 |
+
// doing repeated work during redispatches: recalculating the highest-priority
|
| 110 |
+
// dispatch key, which involves reading from TLS. Instead, the kernels that
|
| 111 |
+
// opt in will calculate an updated DispatchKeySet directly from the old one,
|
| 112 |
+
// and pass the updated set directly into the dispatcher upon redispatching.
|
| 113 |
+
//
|
| 114 |
+
// This is an opt-in mechanism: Kernels can automatically opt in by setting
|
| 115 |
+
// the first argument in their signature to be of type DispatchKeySet. See the
|
| 116 |
+
// kernels in VariableTypeEverything.cpp and TraceTypeEverything.cpp for
|
| 117 |
+
// examples.
|
| 118 |
+
//
|
| 119 |
+
// The mechanism for optionally passing that DispatchKeySet into the kernel
|
| 120 |
+
// lives in make_boxed_from_unboxed_functor.h. See Note [Plumbing Keys Through
|
| 121 |
+
// The Dispatcher 2] for details.
|
| 122 |
+
using InternalBoxedKernelFunction =
|
| 123 |
+
void(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*);
|
| 124 |
+
// This is the public API for how boxed kernels are defined
|
| 125 |
+
using BoxedKernelFunction = void(const OperatorHandle&, Stack*);
|
| 126 |
+
using BoxedKernelFunction_withDispatchKeys =
|
| 127 |
+
void(const OperatorHandle&, DispatchKeySet, Stack*);
|
| 128 |
+
|
| 129 |
+
BoxedKernel();
|
| 130 |
+
|
| 131 |
+
// Fast path for dispatch to allow not touching the boxed kernel in
|
| 132 |
+
// the common case where unboxed is available.
|
| 133 |
+
bool isValid() const;
|
| 134 |
+
bool isFallthrough() const;
|
| 135 |
+
|
| 136 |
+
/**
|
| 137 |
+
* Call the function with boxed arguments.
|
| 138 |
+
*/
|
| 139 |
+
void callBoxed(
|
| 140 |
+
const OperatorHandle& opHandle,
|
| 141 |
+
DispatchKeySet dispatchKeySet,
|
| 142 |
+
Stack* stack) const;
|
| 143 |
+
|
| 144 |
+
/**
|
| 145 |
+
* Create a KernelFunction from a boxed function.
|
| 146 |
+
*
|
| 147 |
+
* Example:
|
| 148 |
+
*
|
| 149 |
+
* > void boxed_func(OperatorKernel*, Stack* stack) {...}
|
| 150 |
+
* > BoxedFunction func = BoxedKernel::makeFromFunction<&boxed_func>();
|
| 151 |
+
*/
|
| 152 |
+
template <BoxedKernelFunction* func>
|
| 153 |
+
static BoxedKernel makeFromFunction();
|
| 154 |
+
|
| 155 |
+
/**
|
| 156 |
+
* TODO: This will only be useful if we write a backend fallback that plumbs
|
| 157 |
+
* dispatch keys (currently there are none) See Note [Plumbing Keys Through
|
| 158 |
+
* The Dispatcher] for details.
|
| 159 |
+
*/
|
| 160 |
+
template <BoxedKernelFunction_withDispatchKeys* func>
|
| 161 |
+
static BoxedKernel makeFromFunction();
|
| 162 |
+
|
| 163 |
+
/**
|
| 164 |
+
* Create a KernelFunction from a boxed functor.
|
| 165 |
+
*
|
| 166 |
+
* Example:
|
| 167 |
+
*
|
| 168 |
+
* > class MyFunctor final : public c10::OperatorKernel {
|
| 169 |
+
* > public:
|
| 170 |
+
* > void operator()(const OperatorHandle&, DispatchKeySet, Stack*) {...}
|
| 171 |
+
* > };
|
| 172 |
+
* > BoxedKernel func =
|
| 173 |
+
* BoxedKernel::makeFromFunctor(std::make_unique<MyFunctor>());
|
| 174 |
+
*/
|
| 175 |
+
template <class KernelFunctor>
|
| 176 |
+
static BoxedKernel makeFromFunctor(
|
| 177 |
+
std::unique_ptr<KernelFunctor> kernelFunctor);
|
| 178 |
+
|
| 179 |
+
static BoxedKernel makeFallthrough();
|
| 180 |
+
static BoxedKernel makeAmbiguousAutogradOther();
|
| 181 |
+
static BoxedKernel makeNamedNotSupported();
|
| 182 |
+
|
| 183 |
+
private:
|
| 184 |
+
friend class KernelFunction;
|
| 185 |
+
|
| 186 |
+
template <BoxedKernelFunction* func>
|
| 187 |
+
static void make_boxed_function(
|
| 188 |
+
OperatorKernel*,
|
| 189 |
+
const OperatorHandle& opHandle,
|
| 190 |
+
DispatchKeySet,
|
| 191 |
+
Stack* stack);
|
| 192 |
+
|
| 193 |
+
template <BoxedKernelFunction_withDispatchKeys* func>
|
| 194 |
+
static void make_boxed_function(
|
| 195 |
+
OperatorKernel*,
|
| 196 |
+
const OperatorHandle& opHandle,
|
| 197 |
+
DispatchKeySet,
|
| 198 |
+
Stack* stack);
|
| 199 |
+
|
| 200 |
+
explicit BoxedKernel(
|
| 201 |
+
std::unique_ptr<OperatorKernel> functor,
|
| 202 |
+
InternalBoxedKernelFunction* boxed_kernel_func);
|
| 203 |
+
|
| 204 |
+
OperatorKernel* getFunctor() const;
|
| 205 |
+
InternalBoxedKernelFunction* getFnPtr() const;
|
| 206 |
+
|
| 207 |
+
c10::intrusive_ptr<OperatorKernel> functor_;
|
| 208 |
+
InternalBoxedKernelFunction* boxed_kernel_func_;
|
| 209 |
+
};
|
| 210 |
+
|
| 211 |
+
} // namespace c10
|
| 212 |
+
|
| 213 |
+
#include <ATen/core/boxing/BoxedKernel_impl.h>
|
phivenv/Lib/site-packages/torch/include/ATen/core/boxing/BoxedKernel_impl.h
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
namespace c10 {
|
| 4 |
+
|
| 5 |
+
inline BoxedKernel::BoxedKernel() : functor_(), boxed_kernel_func_(nullptr) {}
|
| 6 |
+
|
| 7 |
+
inline BoxedKernel::BoxedKernel(
|
| 8 |
+
std::unique_ptr<OperatorKernel> functor,
|
| 9 |
+
InternalBoxedKernelFunction* boxed_kernel_func)
|
| 10 |
+
: functor_(std::move(functor)), boxed_kernel_func_(boxed_kernel_func) {}
|
| 11 |
+
|
| 12 |
+
template <BoxedKernel::BoxedKernelFunction* func>
|
| 13 |
+
inline void BoxedKernel::make_boxed_function(
|
| 14 |
+
OperatorKernel*,
|
| 15 |
+
const OperatorHandle& opHandle,
|
| 16 |
+
DispatchKeySet,
|
| 17 |
+
Stack* stack) {
|
| 18 |
+
// Note that we're dropping the DispatchKeySet argument.
|
| 19 |
+
// See Note [Plumbing Keys Through The Dispatcher 2] for details.
|
| 20 |
+
func(opHandle, stack);
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
template <BoxedKernel::BoxedKernelFunction_withDispatchKeys* func>
|
| 24 |
+
inline void BoxedKernel::make_boxed_function(
|
| 25 |
+
OperatorKernel*,
|
| 26 |
+
const OperatorHandle& opHandle,
|
| 27 |
+
DispatchKeySet ks,
|
| 28 |
+
Stack* stack) {
|
| 29 |
+
// See Note [Plumbing Keys Through The Dispatcher 2] for details.
|
| 30 |
+
func(opHandle, ks, stack);
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
inline bool BoxedKernel::isValid() const {
|
| 34 |
+
return boxed_kernel_func_ != nullptr;
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
inline bool BoxedKernel::isFallthrough() const {
|
| 38 |
+
return boxed_kernel_func_ == &fallthrough_kernel;
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
inline void BoxedKernel::callBoxed(
|
| 42 |
+
const OperatorHandle& opHandle,
|
| 43 |
+
DispatchKeySet dispatchKeySet,
|
| 44 |
+
Stack* stack) const {
|
| 45 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
| 46 |
+
boxed_kernel_func_ != nullptr,
|
| 47 |
+
"Tried to call BoxedKernel::callBoxed() on an uninitialized BoxedKernel.");
|
| 48 |
+
(*boxed_kernel_func_)(functor_.get(), opHandle, dispatchKeySet, stack);
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
template <BoxedKernel::BoxedKernelFunction* func>
|
| 52 |
+
inline BoxedKernel BoxedKernel::makeFromFunction() {
|
| 53 |
+
return BoxedKernel(
|
| 54 |
+
nullptr, // no functor_ object
|
| 55 |
+
&make_boxed_function<func>);
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
template <BoxedKernel::BoxedKernelFunction_withDispatchKeys* func>
|
| 59 |
+
inline BoxedKernel BoxedKernel::makeFromFunction() {
|
| 60 |
+
return BoxedKernel(
|
| 61 |
+
nullptr, // no functor_ object
|
| 62 |
+
&make_boxed_function<func>);
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
inline BoxedKernel BoxedKernel::makeFallthrough() {
|
| 66 |
+
return BoxedKernel(
|
| 67 |
+
nullptr, // no functor_ object
|
| 68 |
+
&fallthrough_kernel);
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
inline BoxedKernel BoxedKernel::makeAmbiguousAutogradOther() {
|
| 72 |
+
return BoxedKernel(
|
| 73 |
+
nullptr, // no functor_ object
|
| 74 |
+
&ambiguous_autogradother_kernel);
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
inline BoxedKernel BoxedKernel::makeNamedNotSupported() {
|
| 78 |
+
return BoxedKernel(
|
| 79 |
+
nullptr, // no functor_ object
|
| 80 |
+
&named_not_supported_kernel);
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
template <class KernelFunctor>
|
| 84 |
+
inline BoxedKernel BoxedKernel::makeFromFunctor(
|
| 85 |
+
std::unique_ptr<KernelFunctor> kernelFunctor) {
|
| 86 |
+
static_assert(
|
| 87 |
+
std::is_base_of_v<OperatorKernel, KernelFunctor>,
|
| 88 |
+
"Tried to call BoxedKernel::makeFromFunctor<KernelFunctor>, but the functor doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
|
| 89 |
+
return BoxedKernel(
|
| 90 |
+
std::move(kernelFunctor),
|
| 91 |
+
[](OperatorKernel* kernel,
|
| 92 |
+
const OperatorHandle& op,
|
| 93 |
+
DispatchKeySet ks,
|
| 94 |
+
Stack* stack) {
|
| 95 |
+
(*static_cast<KernelFunctor*>(kernel))(op, ks, stack);
|
| 96 |
+
});
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
inline OperatorKernel* BoxedKernel::getFunctor() const {
|
| 100 |
+
return functor_.get();
|
| 101 |
+
}
|
| 102 |
+
inline BoxedKernel::InternalBoxedKernelFunction* BoxedKernel::getFnPtr() const {
|
| 103 |
+
return boxed_kernel_func_;
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
} // namespace c10
|
phivenv/Lib/site-packages/torch/include/ATen/core/boxing/KernelFunction.h
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/ATen_fwd.h>
|
| 4 |
+
#include <ATen/core/boxing/BoxedKernel.h>
|
| 5 |
+
#include <ATen/core/stack.h>
|
| 6 |
+
#include <c10/core/DispatchKeySet.h>
|
| 7 |
+
#include <c10/util/TypeList.h>
|
| 8 |
+
#include <c10/util/intrusive_ptr.h>
|
| 9 |
+
#include <type_traits>
|
| 10 |
+
|
| 11 |
+
namespace c10 {
|
| 12 |
+
|
| 13 |
+
using Stack = torch::jit::Stack; // TODO Instead of this, move torch::jit::Stack
|
| 14 |
+
// to the c10 namespace.
|
| 15 |
+
|
| 16 |
+
class OperatorHandle;
|
| 17 |
+
struct OperatorKernel;
|
| 18 |
+
class KernelFunction;
|
| 19 |
+
|
| 20 |
+
template <typename T>
|
| 21 |
+
using has_symint = std::disjunction<
|
| 22 |
+
std::is_same<c10::SymInt, T>,
|
| 23 |
+
std::is_same<c10::SymIntArrayRef, T>,
|
| 24 |
+
std::is_same<at::OptionalSymIntArrayRef, T>,
|
| 25 |
+
std::is_same<std::optional<c10::SymInt>, T>>;
|
| 26 |
+
|
| 27 |
+
template <typename T>
|
| 28 |
+
struct remove_symint {
|
| 29 |
+
using type = T;
|
| 30 |
+
};
|
| 31 |
+
|
| 32 |
+
template <>
|
| 33 |
+
struct remove_symint<c10::SymInt> {
|
| 34 |
+
using type = int64_t;
|
| 35 |
+
};
|
| 36 |
+
|
| 37 |
+
template <>
|
| 38 |
+
struct remove_symint<at::OptionalSymIntArrayRef> {
|
| 39 |
+
using type = OptionalIntArrayRef;
|
| 40 |
+
};
|
| 41 |
+
|
| 42 |
+
template <>
|
| 43 |
+
struct remove_symint<c10::SymIntArrayRef> {
|
| 44 |
+
using type = c10::IntArrayRef;
|
| 45 |
+
};
|
| 46 |
+
|
| 47 |
+
template <>
|
| 48 |
+
struct remove_symint<std::optional<c10::SymInt>> {
|
| 49 |
+
using type = std::optional<int64_t>;
|
| 50 |
+
};
|
| 51 |
+
|
| 52 |
+
template <bool symint, typename T>
|
| 53 |
+
struct maybe_keep_symint final {};
|
| 54 |
+
|
| 55 |
+
template <typename T>
|
| 56 |
+
struct maybe_keep_symint<true, T> {
|
| 57 |
+
using type = T;
|
| 58 |
+
};
|
| 59 |
+
|
| 60 |
+
template <typename T>
|
| 61 |
+
struct maybe_keep_symint<false, T> {
|
| 62 |
+
using type = typename remove_symint<T>::type;
|
| 63 |
+
};
|
| 64 |
+
|
| 65 |
+
template <typename T>
|
| 66 |
+
using fn_has_symint = typename guts::typelist::true_for_any_type<
|
| 67 |
+
has_symint,
|
| 68 |
+
typename guts::infer_function_traits<T>::type::parameter_types>;
|
| 69 |
+
|
| 70 |
+
template <typename T>
|
| 71 |
+
struct fn_remove_symint;
|
| 72 |
+
|
| 73 |
+
template <typename Ret, typename... Args>
|
| 74 |
+
struct fn_remove_symint<Ret(Args...)> {
|
| 75 |
+
using type = Ret(typename remove_symint<Args>::type...);
|
| 76 |
+
};
|
| 77 |
+
|
| 78 |
+
/**
|
| 79 |
+
* KernelFunction is similar to std::function but stores a kernel function.
|
| 80 |
+
* You can create a KernelFunction from a boxed or unboxed
|
| 81 |
+
* function/functor/lambda and call it in a boxed or unboxed way. If the way it
|
| 82 |
+
* was created doesn't match the way it was called, it will do boxing or
|
| 83 |
+
* unboxing as necessary.
|
| 84 |
+
*/
|
| 85 |
+
class TORCH_API KernelFunction final {
|
| 86 |
+
public:
|
| 87 |
+
using InternalBoxedKernelFunction = BoxedKernel::InternalBoxedKernelFunction;
|
| 88 |
+
using BoxedKernelFunction = BoxedKernel::BoxedKernelFunction;
|
| 89 |
+
using BoxedKernelFunction_withDispatchKeys =
|
| 90 |
+
BoxedKernel::BoxedKernelFunction_withDispatchKeys;
|
| 91 |
+
|
| 92 |
+
KernelFunction();
|
| 93 |
+
|
| 94 |
+
// Fast path for dispatch to allow not touching the boxed kernel in
|
| 95 |
+
// the common case where unboxed is available.
|
| 96 |
+
bool isValidUnboxed() const;
|
| 97 |
+
bool isValidSymUnboxed() const;
|
| 98 |
+
bool isValid() const;
|
| 99 |
+
bool isFallthrough() const;
|
| 100 |
+
|
| 101 |
+
/**
|
| 102 |
+
* Call the function in a boxed way.
|
| 103 |
+
* If the kernel function was created with an unboxed function,
|
| 104 |
+
* this will call an unboxing wrapper which then calls into that
|
| 105 |
+
* unboxed function.
|
| 106 |
+
*
|
| 107 |
+
* Example:
|
| 108 |
+
*
|
| 109 |
+
* > void boxed_func(OperatorKernel*, Stack* stack) {...}
|
| 110 |
+
* > KernelFunction func = KernelFunction::makeFromBoxedFunction(&boxed_func);
|
| 111 |
+
* > Tensor result = func.callBoxed(stack);
|
| 112 |
+
*
|
| 113 |
+
* Or, with an unboxed implementation:
|
| 114 |
+
*
|
| 115 |
+
* > KernelFunction func = KernelFunction::makeFromUnboxedLambda(
|
| 116 |
+
* > [] (Tensor a, bool b) -> Tensor {...});
|
| 117 |
+
* > Tensor result = func.callBoxed(stack);
|
| 118 |
+
*/
|
| 119 |
+
void callBoxed(
|
| 120 |
+
const OperatorHandle& opHandle,
|
| 121 |
+
DispatchKeySet dispatchKeySet,
|
| 122 |
+
Stack* stack) const;
|
| 123 |
+
|
| 124 |
+
/**
|
| 125 |
+
* Call the function in an unboxed way.
|
| 126 |
+
* If the kernel function was created with a boxed function,
|
| 127 |
+
* this will box all inputs and then call into that boxed function.
|
| 128 |
+
*
|
| 129 |
+
* Note that this doesn't work for all types yet.
|
| 130 |
+
*
|
| 131 |
+
* Example:
|
| 132 |
+
*
|
| 133 |
+
* > KernelFunction func = KernelFunction::makeFromUnboxedLambda(
|
| 134 |
+
* > [] (Tensor a, bool b) -> Tensor {...});
|
| 135 |
+
* > Tensor result = func.call<Tensor, Tensor, bool>(tensor1, true);
|
| 136 |
+
*
|
| 137 |
+
* Or, with a boxed implementation:
|
| 138 |
+
*
|
| 139 |
+
* > void boxed_func(OperatorKernel*, Stack* stack) {...}
|
| 140 |
+
* > KernelFunction func = KernelFunction::makeFromBoxedFunction(&boxed_func);
|
| 141 |
+
* > Tensor result = func.call<Tensor, Tensor, bool>(tensor1, true);
|
| 142 |
+
*/
|
| 143 |
+
template <class Return, class... Args>
|
| 144 |
+
Return call(
|
| 145 |
+
const OperatorHandle& opHandle,
|
| 146 |
+
DispatchKeySet dispatchKeySet,
|
| 147 |
+
Args... args) const;
|
| 148 |
+
|
| 149 |
+
/**
|
| 150 |
+
* Create a KernelFunction from a BoxedKernel.
|
| 151 |
+
*/
|
| 152 |
+
static KernelFunction makeFromBoxedKernel(BoxedKernel boxed_fn);
|
| 153 |
+
|
| 154 |
+
/**
|
| 155 |
+
* Create a KernelFunction from a boxed function.
|
| 156 |
+
*
|
| 157 |
+
* Example:
|
| 158 |
+
*
|
| 159 |
+
* > void boxed_func(OperatorKernel*, Stack* stack) {...}
|
| 160 |
+
* > KernelFunction func =
|
| 161 |
+
* KernelFunction::makeFromBoxedFunction<&boxed_func>();
|
| 162 |
+
*/
|
| 163 |
+
template <BoxedKernelFunction* func>
|
| 164 |
+
static KernelFunction makeFromBoxedFunction();
|
| 165 |
+
|
| 166 |
+
/**
|
| 167 |
+
* TODO: This will only be useful if we write a backend fallback that plumbs
|
| 168 |
+
* dispatch keys (currently there are none) See Note [Plumbing Keys Through
|
| 169 |
+
* The Dispatcher] for details.
|
| 170 |
+
*/
|
| 171 |
+
template <BoxedKernelFunction_withDispatchKeys* func>
|
| 172 |
+
static KernelFunction makeFromBoxedFunction();
|
| 173 |
+
|
| 174 |
+
/**
|
| 175 |
+
* Create a KernelFunction from an unboxed functor.
|
| 176 |
+
*
|
| 177 |
+
* Example:
|
| 178 |
+
*
|
| 179 |
+
* > class MyFunctor final : public c10::OperatorKernel {
|
| 180 |
+
* > public:
|
| 181 |
+
* > Tensor operator()(Tensor a, Tensor b) {...}
|
| 182 |
+
* > };
|
| 183 |
+
* > KernelFunction func =
|
| 184 |
+
* KernelFunction::makeFromUnboxedFunctor<MyFunctor>(std::make_unique<MyFunctor>());
|
| 185 |
+
*/
|
| 186 |
+
template <bool AllowLegacyTypes = false, class KernelFunctor>
|
| 187 |
+
static KernelFunction makeFromUnboxedFunctor(
|
| 188 |
+
std::unique_ptr<OperatorKernel> kernelFunctor);
|
| 189 |
+
|
| 190 |
+
/**
|
| 191 |
+
* Create a KernelFunction from a boxed functor.
|
| 192 |
+
*
|
| 193 |
+
* Example:
|
| 194 |
+
*
|
| 195 |
+
* > class MyFunctor final : public c10::OperatorKernel {
|
| 196 |
+
* > public:
|
| 197 |
+
* > void operator()(const OperatorHandle&, DispatchKeySet, Stack*) {...}
|
| 198 |
+
* > };
|
| 199 |
+
* > KernelFunction func =
|
| 200 |
+
* KernelFunction::makeFromBoxedFunctor(std::make_unique<MyFunctor>());
|
| 201 |
+
*/
|
| 202 |
+
template <class KernelFunctor>
|
| 203 |
+
static KernelFunction makeFromBoxedFunctor(
|
| 204 |
+
std::unique_ptr<KernelFunctor> kernelFunctor);
|
| 205 |
+
|
| 206 |
+
/**
|
| 207 |
+
* Create a KernelFunction from an unboxed function.
|
| 208 |
+
* This is usually better than KernelFunction::makeFromUnboxedRuntimeFunction
|
| 209 |
+
* because knowing the function pointer as a template argument (i.e. at
|
| 210 |
+
* compile time) allows the compiler to inline the function into its
|
| 211 |
+
* unboxing wrapper and yields better performance when calling the function.
|
| 212 |
+
*
|
| 213 |
+
* Example:
|
| 214 |
+
*
|
| 215 |
+
* > Tensor unboxed_func(Tensor a, Tensor b) {...}
|
| 216 |
+
* > KernelFunction func =
|
| 217 |
+
* KernelFunction::makeFromUnboxedFunction<decltype(unboxed_func),
|
| 218 |
+
* &unboxed_func>();
|
| 219 |
+
*/
|
| 220 |
+
template <class FuncPtr, bool AllowLegacyTypes = false>
|
| 221 |
+
static KernelFunction makeFromUnboxedFunction(FuncPtr);
|
| 222 |
+
|
| 223 |
+
/**
|
| 224 |
+
* Create a KernelFunction from an unboxed function.
|
| 225 |
+
* KernelFunction::makeFromUnboxedFunction is usually a better choice than
|
| 226 |
+
* this if you know the function pointer at compile time, see doc comment
|
| 227 |
+
* there for an explanation.
|
| 228 |
+
*
|
| 229 |
+
* Example:
|
| 230 |
+
*
|
| 231 |
+
* > Tensor unboxed_func(Tensor a, Tensor b) {...}
|
| 232 |
+
* > KernelFunction func =
|
| 233 |
+
* KernelFunction::makeFromUnboxedRuntimeFunction(&unboxed_func);
|
| 234 |
+
*/
|
| 235 |
+
template <bool AllowLegacyTypes = false, class FuncType>
|
| 236 |
+
static KernelFunction makeFromUnboxedRuntimeFunction(FuncType* func);
|
| 237 |
+
|
| 238 |
+
static KernelFunction makeFallthrough();
|
| 239 |
+
static KernelFunction makeAmbiguousAutogradOther();
|
| 240 |
+
static KernelFunction makeNamedNotSupported();
|
| 241 |
+
|
| 242 |
+
/**
|
| 243 |
+
* Create a KernelFunction from an unboxed lambda.
|
| 244 |
+
*
|
| 245 |
+
* Example:
|
| 246 |
+
*
|
| 247 |
+
* > KernelFunction func = KernelFunction::makeFromUnboxedLambda(
|
| 248 |
+
* > [] (Tensor a, bool b) -> Tensor {...});
|
| 249 |
+
*/
|
| 250 |
+
template <bool AllowLegacyTypes = false, class Lambda>
|
| 251 |
+
static std::enable_if_t<
|
| 252 |
+
guts::is_stateless_lambda<std::decay_t<Lambda>>::value,
|
| 253 |
+
KernelFunction>
|
| 254 |
+
makeFromUnboxedLambda(Lambda&& lambda);
|
| 255 |
+
template <bool AllowLegacyTypes = false, class Lambda>
|
| 256 |
+
static std::enable_if_t<
|
| 257 |
+
!guts::is_stateless_lambda<std::decay_t<Lambda>>::value,
|
| 258 |
+
KernelFunction>
|
| 259 |
+
makeFromUnboxedLambda(Lambda&& lambda);
|
| 260 |
+
|
| 261 |
+
std::string dumpState() const;
|
| 262 |
+
// For testing internal invariants only
|
| 263 |
+
bool _equalsBoxedAndUnboxed(const KernelFunction&) const;
|
| 264 |
+
|
| 265 |
+
private:
|
| 266 |
+
explicit KernelFunction(
|
| 267 |
+
std::unique_ptr<OperatorKernel> functor,
|
| 268 |
+
InternalBoxedKernelFunction* boxed_kernel_func,
|
| 269 |
+
void* unboxed_kernel_func,
|
| 270 |
+
void* sym_unboxed_kernel_func);
|
| 271 |
+
explicit KernelFunction(
|
| 272 |
+
BoxedKernel boxed_fn,
|
| 273 |
+
void* unboxed_kernel_func,
|
| 274 |
+
void* sym_unboxed_kernel_func);
|
| 275 |
+
|
| 276 |
+
BoxedKernel boxed_kernel_func_;
|
| 277 |
+
void* unboxed_kernel_func_;
|
| 278 |
+
void* sym_unboxed_kernel_func_;
|
| 279 |
+
};
|
| 280 |
+
|
| 281 |
+
} // namespace c10
|
| 282 |
+
|
| 283 |
+
#include <ATen/core/boxing/KernelFunction_impl.h>
|
phivenv/Lib/site-packages/torch/include/ATen/core/boxing/KernelFunction_impl.h
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/boxing/impl/WrapFunctionIntoFunctor.h>
|
| 2 |
+
#include <ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h>
|
| 3 |
+
#include <ATen/core/boxing/impl/boxing.h>
|
| 4 |
+
#include <ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h>
|
| 5 |
+
|
| 6 |
+
#include <c10/util/C++17.h>
|
| 7 |
+
#include <type_traits>
|
| 8 |
+
|
| 9 |
+
namespace c10 {
|
| 10 |
+
|
| 11 |
+
namespace detail {
|
| 12 |
+
template <typename Base, typename Child, typename... Args>
|
| 13 |
+
std::enable_if_t<
|
| 14 |
+
!std::is_array_v<Base> && !std::is_array_v<Child> &&
|
| 15 |
+
std::is_base_of_v<Base, Child>,
|
| 16 |
+
std::unique_ptr<Base>>
|
| 17 |
+
make_unique_base(Args&&... args) {
|
| 18 |
+
return std::unique_ptr<Base>(new Child(std::forward<Args>(args)...));
|
| 19 |
+
}
|
| 20 |
+
} // namespace detail
|
| 21 |
+
|
| 22 |
+
inline KernelFunction::KernelFunction()
|
| 23 |
+
: boxed_kernel_func_(),
|
| 24 |
+
unboxed_kernel_func_(nullptr),
|
| 25 |
+
sym_unboxed_kernel_func_(nullptr) {}
|
| 26 |
+
|
| 27 |
+
inline KernelFunction::KernelFunction(
|
| 28 |
+
std::unique_ptr<OperatorKernel> functor,
|
| 29 |
+
InternalBoxedKernelFunction* boxed_kernel_func,
|
| 30 |
+
void* unboxed_kernel_func,
|
| 31 |
+
void* sym_unboxed_kernel_func = nullptr)
|
| 32 |
+
: boxed_kernel_func_(std::move(functor), boxed_kernel_func),
|
| 33 |
+
unboxed_kernel_func_(unboxed_kernel_func),
|
| 34 |
+
sym_unboxed_kernel_func_(sym_unboxed_kernel_func) {}
|
| 35 |
+
|
| 36 |
+
inline KernelFunction::KernelFunction(
|
| 37 |
+
BoxedKernel boxed_fn,
|
| 38 |
+
void* unboxed_kernel_func,
|
| 39 |
+
void* sym_unboxed_kernel_func = nullptr)
|
| 40 |
+
: boxed_kernel_func_(std::move(boxed_fn)),
|
| 41 |
+
unboxed_kernel_func_(unboxed_kernel_func),
|
| 42 |
+
sym_unboxed_kernel_func_(sym_unboxed_kernel_func) {}
|
| 43 |
+
|
| 44 |
+
inline bool KernelFunction::isValidUnboxed() const {
|
| 45 |
+
return unboxed_kernel_func_ != nullptr;
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
inline bool KernelFunction::isValidSymUnboxed() const {
|
| 49 |
+
return sym_unboxed_kernel_func_ != nullptr;
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
inline bool KernelFunction::isValid() const {
|
| 53 |
+
return boxed_kernel_func_.isValid();
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
inline bool KernelFunction::isFallthrough() const {
|
| 57 |
+
return boxed_kernel_func_.isFallthrough();
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
inline void KernelFunction::callBoxed(
|
| 61 |
+
const OperatorHandle& opHandle,
|
| 62 |
+
DispatchKeySet dispatchKeySet,
|
| 63 |
+
Stack* stack) const {
|
| 64 |
+
boxed_kernel_func_.callBoxed(opHandle, dispatchKeySet, stack);
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
template <class Return, class... Args>
|
| 68 |
+
inline Return callUnboxedKernelFunction(
|
| 69 |
+
void* unboxed_kernel_func,
|
| 70 |
+
OperatorKernel* functor,
|
| 71 |
+
DispatchKeySet dispatchKeySet,
|
| 72 |
+
Args&&... args) {
|
| 73 |
+
using ActualSignature = Return(OperatorKernel*, DispatchKeySet, Args...);
|
| 74 |
+
ActualSignature* func =
|
| 75 |
+
reinterpret_cast<ActualSignature*>(unboxed_kernel_func);
|
| 76 |
+
return (*func)(functor, dispatchKeySet, std::forward<Args>(args)...);
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
// This template requires you to explicitly specify the argument you want to
|
| 80 |
+
// forward; it doesn't work if you try to deduce it
|
| 81 |
+
// NB: keep this in sync with cloneWithRealTypes in function_schema.cpp
|
| 82 |
+
|
| 83 |
+
template <typename T>
|
| 84 |
+
inline typename remove_symint<T>::type unpackSymInt(T x) {
|
| 85 |
+
return x;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
template <>
|
| 89 |
+
inline typename remove_symint<c10::SymInt>::type unpackSymInt(c10::SymInt x) {
|
| 90 |
+
return x.guard_int(__FILE__, __LINE__);
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
template <>
|
| 94 |
+
inline typename remove_symint<c10::SymIntArrayRef>::type unpackSymInt(
|
| 95 |
+
c10::SymIntArrayRef x) {
|
| 96 |
+
return C10_AS_INTARRAYREF_SLOW(x);
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
template <>
|
| 100 |
+
inline typename remove_symint<std::optional<c10::SymInt>>::type unpackSymInt(
|
| 101 |
+
std::optional<c10::SymInt> x) {
|
| 102 |
+
return x.has_value() ? std::make_optional(x->guard_int(__FILE__, __LINE__))
|
| 103 |
+
: std::nullopt;
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
template <>
|
| 107 |
+
inline typename remove_symint<at::OptionalSymIntArrayRef>::type unpackSymInt(
|
| 108 |
+
at::OptionalSymIntArrayRef x) {
|
| 109 |
+
return x.has_value() ? std::make_optional(C10_AS_INTARRAYREF_SLOW(*x))
|
| 110 |
+
: std::nullopt;
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
template <class Return, class... Args>
|
| 114 |
+
C10_ALWAYS_INLINE Return KernelFunction::call(
|
| 115 |
+
const OperatorHandle& opHandle,
|
| 116 |
+
DispatchKeySet dispatchKeySet,
|
| 117 |
+
Args... args) const {
|
| 118 |
+
// note: Args above is intentionally not Args&&. We don't want perfect
|
| 119 |
+
// forwarding, which would require Args to be deduced, but instead we
|
| 120 |
+
// want callers to explicitly specify the Args.
|
| 121 |
+
|
| 122 |
+
if constexpr (std::disjunction_v<has_symint<Args>...>) {
|
| 123 |
+
if (sym_unboxed_kernel_func_ != nullptr) {
|
| 124 |
+
auto* functor = boxed_kernel_func_.getFunctor();
|
| 125 |
+
return callUnboxedKernelFunction<Return, Args...>(
|
| 126 |
+
sym_unboxed_kernel_func_,
|
| 127 |
+
functor,
|
| 128 |
+
dispatchKeySet,
|
| 129 |
+
std::forward<Args>(args)...);
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
if (unboxed_kernel_func_ != nullptr) {
|
| 133 |
+
auto* functor = boxed_kernel_func_.getFunctor();
|
| 134 |
+
return callUnboxedKernelFunction<
|
| 135 |
+
Return,
|
| 136 |
+
typename remove_symint<Args>::type...>(
|
| 137 |
+
unboxed_kernel_func_,
|
| 138 |
+
functor,
|
| 139 |
+
dispatchKeySet,
|
| 140 |
+
unpackSymInt<Args>(args)...);
|
| 141 |
+
}
|
| 142 |
+
} else {
|
| 143 |
+
if (C10_LIKELY(unboxed_kernel_func_ != nullptr)) {
|
| 144 |
+
auto* functor = boxed_kernel_func_.getFunctor();
|
| 145 |
+
return callUnboxedKernelFunction<Return, Args...>(
|
| 146 |
+
unboxed_kernel_func_,
|
| 147 |
+
functor,
|
| 148 |
+
dispatchKeySet,
|
| 149 |
+
std::forward<Args>(args)...);
|
| 150 |
+
}
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
return impl::BoxedKernelWrapper<Return(Args...)>::call(
|
| 154 |
+
boxed_kernel_func_,
|
| 155 |
+
opHandle,
|
| 156 |
+
dispatchKeySet,
|
| 157 |
+
std::forward<Args>(args)...);
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
inline KernelFunction KernelFunction::makeFromBoxedKernel(
|
| 161 |
+
BoxedKernel boxed_fn) {
|
| 162 |
+
return KernelFunction(
|
| 163 |
+
std::move(boxed_fn), nullptr); // no unboxed function pointer
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
template <KernelFunction::BoxedKernelFunction* func>
|
| 167 |
+
inline KernelFunction KernelFunction::makeFromBoxedFunction() {
|
| 168 |
+
return KernelFunction::makeFromBoxedKernel(
|
| 169 |
+
BoxedKernel::makeFromFunction<func>());
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
template <KernelFunction::BoxedKernelFunction_withDispatchKeys* func>
|
| 173 |
+
inline KernelFunction KernelFunction::makeFromBoxedFunction() {
|
| 174 |
+
return KernelFunction::makeFromBoxedKernel(
|
| 175 |
+
BoxedKernel::makeFromFunction<func>());
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
inline KernelFunction KernelFunction::makeFallthrough() {
|
| 179 |
+
return KernelFunction::makeFromBoxedKernel(BoxedKernel::makeFallthrough());
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
inline KernelFunction KernelFunction::makeAmbiguousAutogradOther() {
|
| 183 |
+
return KernelFunction::makeFromBoxedKernel(
|
| 184 |
+
BoxedKernel::makeAmbiguousAutogradOther());
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
inline KernelFunction KernelFunction::makeNamedNotSupported() {
|
| 188 |
+
return KernelFunction::makeFromBoxedKernel(
|
| 189 |
+
BoxedKernel::makeNamedNotSupported());
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
template <bool AllowLegacyTypes, class KernelFunctor>
|
| 193 |
+
inline KernelFunction KernelFunction::makeFromUnboxedFunctor(
|
| 194 |
+
std::unique_ptr<OperatorKernel> kernelFunctor) {
|
| 195 |
+
#ifndef NDEBUG
|
| 196 |
+
// This assertion is costly for build time so it's debug-gated.
|
| 197 |
+
static_assert(
|
| 198 |
+
guts::is_functor<KernelFunctor>::value,
|
| 199 |
+
"Tried to call KernelFunction::makeFromUnboxedFunctor<KernelFunctor> but the argument is not a functor.");
|
| 200 |
+
#endif
|
| 201 |
+
static_assert(
|
| 202 |
+
std::is_base_of_v<OperatorKernel, KernelFunctor>,
|
| 203 |
+
"Tried to call KernelFunction::makeFromUnboxedFunctor<KernelFunctor>, but the functor doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
|
| 204 |
+
|
| 205 |
+
auto* unboxed_fn = &impl::wrap_kernel_functor_unboxed<KernelFunctor>::call;
|
| 206 |
+
void* void_unboxed_fn = reinterpret_cast<void*>(unboxed_fn);
|
| 207 |
+
bool is_symint = fn_has_symint<decltype(unboxed_fn)>::value;
|
| 208 |
+
return KernelFunction(
|
| 209 |
+
std::move(kernelFunctor),
|
| 210 |
+
&impl::make_boxed_from_unboxed_functor<KernelFunctor, AllowLegacyTypes>::
|
| 211 |
+
call,
|
| 212 |
+
is_symint ? nullptr : void_unboxed_fn,
|
| 213 |
+
is_symint ? void_unboxed_fn : nullptr);
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
template <class KernelFunctor>
|
| 217 |
+
inline KernelFunction KernelFunction::makeFromBoxedFunctor(
|
| 218 |
+
std::unique_ptr<KernelFunctor> kernelFunctor) {
|
| 219 |
+
return KernelFunction::makeFromBoxedKernel(
|
| 220 |
+
BoxedKernel::makeFromFunctor(std::move(kernelFunctor)));
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
template <class FuncPtr, bool AllowLegacyTypes>
|
| 224 |
+
inline KernelFunction KernelFunction::makeFromUnboxedFunction(
|
| 225 |
+
FuncPtr func_ptr) {
|
| 226 |
+
static_assert(
|
| 227 |
+
is_compile_time_function_pointer<FuncPtr>::value,
|
| 228 |
+
"Tried to call KernelFunction::makeFromUnboxedFunction with an invalid parameter. It must be a function pointer created with TORCH_FN.");
|
| 229 |
+
static_assert(
|
| 230 |
+
!std::is_same_v<typename FuncPtr::FuncType, BoxedKernelFunction>,
|
| 231 |
+
"Tried to call KernelFunction::makeFromUnboxedFunction with a boxed function pointer. Please use KernelFunction::makeFromBoxedFunction instead.");
|
| 232 |
+
#if defined(__GNUC__) && defined(__SANITIZE_ADDRESS__) && !defined(__CUDACC__)
|
| 233 |
+
TORCH_INTERNAL_ASSERT(
|
| 234 |
+
FuncPtr::func_ptr() != nullptr, "Kernel function cannot be nullptr");
|
| 235 |
+
#else
|
| 236 |
+
static_assert(
|
| 237 |
+
FuncPtr::func_ptr() != nullptr, "Kernel function cannot be nullptr");
|
| 238 |
+
#endif
|
| 239 |
+
|
| 240 |
+
#if !defined(C10_MOBILE)
|
| 241 |
+
(void)func_ptr; // Suppress unused variable warning
|
| 242 |
+
return makeFromUnboxedFunctor<
|
| 243 |
+
AllowLegacyTypes,
|
| 244 |
+
typename impl::WrapFunctionIntoFunctor<FuncPtr>::type>(
|
| 245 |
+
detail::make_unique_base<
|
| 246 |
+
OperatorKernel,
|
| 247 |
+
typename impl::WrapFunctionIntoFunctor<FuncPtr>::type>());
|
| 248 |
+
#else
|
| 249 |
+
// On mobile, we rather want to optimize for binary size than for performance,
|
| 250 |
+
// so let's not inline the kernel into the wrapper but use
|
| 251 |
+
// makeFromUnboxedRuntimeFunction instead.
|
| 252 |
+
return makeFromUnboxedRuntimeFunction(func_ptr.func_ptr());
|
| 253 |
+
#endif
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
template <bool AllowLegacyTypes, class FuncType>
|
| 257 |
+
inline KernelFunction KernelFunction::makeFromUnboxedRuntimeFunction(
|
| 258 |
+
FuncType* func) {
|
| 259 |
+
static_assert(
|
| 260 |
+
guts::is_function_type<FuncType>::value,
|
| 261 |
+
"Tried to call KernelFunction::makeFromUnboxedRuntimeFunction with a non-function type.");
|
| 262 |
+
static_assert(
|
| 263 |
+
!std::is_same_v<FuncType, BoxedKernelFunction>,
|
| 264 |
+
"Tried to call KernelFunction::makeFromUnboxedRuntimeFunction with a boxed function pointer. Please use KernelFunction::makeFromBoxedFunction instead.");
|
| 265 |
+
TORCH_INTERNAL_ASSERT(func != nullptr, "Kernel function cannot be nullptr");
|
| 266 |
+
|
| 267 |
+
return makeFromUnboxedFunctor<
|
| 268 |
+
AllowLegacyTypes,
|
| 269 |
+
impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<FuncType>>>(
|
| 270 |
+
detail::make_unique_base<
|
| 271 |
+
OperatorKernel,
|
| 272 |
+
impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<FuncType>>>(func));
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
template <bool AllowLegacyTypes, class Lambda>
|
| 276 |
+
inline std::enable_if_t<
|
| 277 |
+
guts::is_stateless_lambda<std::decay_t<Lambda>>::value,
|
| 278 |
+
KernelFunction>
|
| 279 |
+
KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) {
|
| 280 |
+
static_assert(
|
| 281 |
+
guts::is_functor<std::decay_t<Lambda>>::value,
|
| 282 |
+
"Tried to call KernelFunction::makeFromUnboxedLambda with a non-lambda type.");
|
| 283 |
+
|
| 284 |
+
#if !defined(C10_MOBILE)
|
| 285 |
+
return makeFromUnboxedFunctor<
|
| 286 |
+
AllowLegacyTypes,
|
| 287 |
+
impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>(
|
| 288 |
+
detail::make_unique_base<
|
| 289 |
+
OperatorKernel,
|
| 290 |
+
impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>(
|
| 291 |
+
std::forward<Lambda>(lambda)));
|
| 292 |
+
#else
|
| 293 |
+
// On mobile, we rather want to optimize for binary size than for performance,
|
| 294 |
+
// so let's not inline the kernel into the wrapper but use
|
| 295 |
+
// makeFromUnboxedRuntimeFunction instead.
|
| 296 |
+
using FuncType =
|
| 297 |
+
typename guts::infer_function_traits_t<std::decay_t<Lambda>>::func_type;
|
| 298 |
+
return makeFromUnboxedRuntimeFunction<AllowLegacyTypes, FuncType>(lambda);
|
| 299 |
+
#endif
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
template <bool AllowLegacyTypes, class Lambda>
|
| 303 |
+
inline std::enable_if_t<
|
| 304 |
+
!guts::is_stateless_lambda<std::decay_t<Lambda>>::value,
|
| 305 |
+
KernelFunction>
|
| 306 |
+
KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) {
|
| 307 |
+
static_assert(
|
| 308 |
+
guts::is_functor<std::decay_t<Lambda>>::value,
|
| 309 |
+
"Tried to call KernelFunction::makeFromUnboxedLambda with a non-lambda type.");
|
| 310 |
+
|
| 311 |
+
return makeFromUnboxedFunctor<
|
| 312 |
+
AllowLegacyTypes,
|
| 313 |
+
impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>(
|
| 314 |
+
detail::make_unique_base<
|
| 315 |
+
OperatorKernel,
|
| 316 |
+
impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>(
|
| 317 |
+
std::forward<Lambda>(lambda)));
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
} // namespace c10
|
phivenv/Lib/site-packages/torch/include/ATen/core/boxing/OperatorKernel.h
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <c10/util/intrusive_ptr.h>
|
| 3 |
+
|
| 4 |
+
namespace c10 {
|
| 5 |
+
|
| 6 |
+
/**
|
| 7 |
+
* Inherit from OperatorKernel to implement a c10 kernel.
|
| 8 |
+
*
|
| 9 |
+
* Example:
|
| 10 |
+
* > namespace {
|
| 11 |
+
* > class my_kernel_cpu final : public c10::OperatorKernel {
|
| 12 |
+
* > public:
|
| 13 |
+
* > Tensor operator()(Tensor a, Tensor b) {...}
|
| 14 |
+
* > };
|
| 15 |
+
* > }
|
| 16 |
+
*
|
| 17 |
+
* The kernel class is allowed to have members but these are equivalent
|
| 18 |
+
* to global variables. The kernel implementation is responsible for
|
| 19 |
+
* preventing race conditions on them.
|
| 20 |
+
*
|
| 21 |
+
* See below for how to register this kernel with PyTorch.
|
| 22 |
+
*/
|
| 23 |
+
struct TORCH_API OperatorKernel : public c10::intrusive_ptr_target {
|
| 24 |
+
~OperatorKernel() override = default;
|
| 25 |
+
};
|
| 26 |
+
|
| 27 |
+
} // namespace c10
|
phivenv/Lib/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/CompileTimeFunctionPointer.h>
|
| 4 |
+
|
| 5 |
+
namespace c10::impl {
|
| 6 |
+
namespace detail {
|
| 7 |
+
template <class FuncPtr, class ReturnType, class ParameterList>
|
| 8 |
+
class WrapFunctionIntoFunctor_ {};
|
| 9 |
+
template <class FuncPtr, class ReturnType, class... Parameters>
|
| 10 |
+
class WrapFunctionIntoFunctor_<
|
| 11 |
+
FuncPtr,
|
| 12 |
+
ReturnType,
|
| 13 |
+
guts::typelist::typelist<Parameters...>>
|
| 14 |
+
final : public c10::OperatorKernel {
|
| 15 |
+
public:
|
| 16 |
+
C10_ALWAYS_INLINE decltype(auto) operator()(Parameters... args) {
|
| 17 |
+
return (*FuncPtr::func_ptr())(std::forward<Parameters>(args)...);
|
| 18 |
+
}
|
| 19 |
+
};
|
| 20 |
+
} // namespace detail
|
| 21 |
+
|
| 22 |
+
// WrapFunctionIntoFunctor: Wraps a compile time function pointer into a kernel
|
| 23 |
+
// functor. Since it is a compile time function pointer, many compilers can
|
| 24 |
+
// inline it into the wrapper and you don't get any performance overhead for
|
| 25 |
+
// wrapping.
|
| 26 |
+
template <class FuncPtr>
|
| 27 |
+
struct WrapFunctionIntoFunctor final {
|
| 28 |
+
static_assert(
|
| 29 |
+
c10::is_compile_time_function_pointer<FuncPtr>::value,
|
| 30 |
+
"WrapFunctionIntoFunctor can only wrap functions created with TORCH_FN.");
|
| 31 |
+
using type = detail::WrapFunctionIntoFunctor_<
|
| 32 |
+
FuncPtr,
|
| 33 |
+
typename guts::function_traits<typename FuncPtr::FuncType>::return_type,
|
| 34 |
+
typename guts::function_traits<
|
| 35 |
+
typename FuncPtr::FuncType>::parameter_types>;
|
| 36 |
+
};
|
| 37 |
+
|
| 38 |
+
} // namespace c10::impl
|
phivenv/Lib/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/util/TypeTraits.h>
|
| 4 |
+
|
| 5 |
+
namespace c10::impl {
|
| 6 |
+
|
| 7 |
+
namespace detail {
|
| 8 |
+
template <class FuncType, class ReturnType, class ParameterList>
|
| 9 |
+
class WrapFunctionIntoRuntimeFunctor_ {};
|
| 10 |
+
template <class FuncType, class ReturnType, class... Parameters>
|
| 11 |
+
class WrapFunctionIntoRuntimeFunctor_<
|
| 12 |
+
FuncType,
|
| 13 |
+
ReturnType,
|
| 14 |
+
guts::typelist::typelist<Parameters...>>
|
| 15 |
+
final : public c10::OperatorKernel {
|
| 16 |
+
public:
|
| 17 |
+
template <class FuncType_>
|
| 18 |
+
explicit WrapFunctionIntoRuntimeFunctor_(FuncType_&& kernel_func)
|
| 19 |
+
: kernel_func_(std::forward<FuncType_>(kernel_func)) {}
|
| 20 |
+
|
| 21 |
+
decltype(auto) operator()(Parameters... args) {
|
| 22 |
+
return kernel_func_(std::forward<Parameters>(args)...);
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
private:
|
| 26 |
+
FuncType kernel_func_;
|
| 27 |
+
};
|
| 28 |
+
} // namespace detail
|
| 29 |
+
|
| 30 |
+
// WrapFunctionIntoRuntimeFunctor: Wraps any runtime functor into a functor that
|
| 31 |
+
// inherits from c10::OperatorKernel, so it can be used as a c10 kernel.
|
| 32 |
+
// This can, for example, be used for lambdas, functors or even function
|
| 33 |
+
// pointers. In the case of function pointers, since it is a runtime function
|
| 34 |
+
// pointer, there is an overhead for calling it whenever the kernel is invoked.
|
| 35 |
+
template <class FuncType>
|
| 36 |
+
using WrapFunctionIntoRuntimeFunctor = detail::WrapFunctionIntoRuntimeFunctor_<
|
| 37 |
+
FuncType,
|
| 38 |
+
typename guts::infer_function_traits_t<FuncType>::return_type,
|
| 39 |
+
typename guts::infer_function_traits_t<FuncType>::parameter_types>;
|
| 40 |
+
|
| 41 |
+
} // namespace c10::impl
|
phivenv/Lib/site-packages/torch/include/ATen/core/boxing/impl/boxing.h
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// This file contains boxing (not unboxing) logic,
|
| 4 |
+
// i.e. how to make a vector<IValue> from a set of concrete arguments.
|
| 5 |
+
|
| 6 |
+
#include <ATen/core/ivalue.h>
|
| 7 |
+
#include <ATen/core/stack.h>
|
| 8 |
+
#include <c10/core/TensorOptions.h>
|
| 9 |
+
|
| 10 |
+
#include <ATen/core/boxing/BoxedKernel.h>
|
| 11 |
+
|
| 12 |
+
#include <c10/util/Metaprogramming.h>
|
| 13 |
+
#include <type_traits>
|
| 14 |
+
|
| 15 |
+
namespace c10::impl {
|
| 16 |
+
|
| 17 |
+
//
|
| 18 |
+
// utils
|
| 19 |
+
//
|
| 20 |
+
|
| 21 |
+
// is_mutable_tensor_ref
|
| 22 |
+
template <class T>
|
| 23 |
+
struct is_mutable_tensor_ref : std::false_type {};
|
| 24 |
+
template <>
|
| 25 |
+
struct is_mutable_tensor_ref<at::Tensor&> : std::true_type {};
|
| 26 |
+
|
| 27 |
+
// is_tuple_of_mutable_tensor_refs
|
| 28 |
+
//
|
| 29 |
+
template <class T, class Enable = void>
|
| 30 |
+
struct is_tuple_of_mutable_tensor_refs : std::false_type {};
|
| 31 |
+
|
| 32 |
+
template <class T>
|
| 33 |
+
struct is_tuple_of_mutable_tensor_refs<
|
| 34 |
+
T,
|
| 35 |
+
std::enable_if_t<guts::is_instantiation_of<std::tuple, T>::value, void>>
|
| 36 |
+
: guts::typelist::
|
| 37 |
+
all<is_mutable_tensor_ref, guts::typelist::from_tuple_t<T>> {};
|
| 38 |
+
|
| 39 |
+
// has_ivalue_to<T> tests the presence/absence of instance method
|
| 40 |
+
// IValue::to<T>()
|
| 41 |
+
//
|
| 42 |
+
template <class T, class Enable = void>
|
| 43 |
+
struct has_ivalue_to : std::false_type {};
|
| 44 |
+
|
| 45 |
+
template <class T>
|
| 46 |
+
struct ivalue_to_helper {
|
| 47 |
+
using type = decltype(std::declval<IValue>().template to<T>());
|
| 48 |
+
};
|
| 49 |
+
template <class T>
|
| 50 |
+
using ivalue_to_helper_t = typename ivalue_to_helper<T>::type;
|
| 51 |
+
|
| 52 |
+
template <class T>
|
| 53 |
+
struct has_ivalue_to<T, std::void_t<ivalue_to_helper_t<T>>> : std::true_type {};
|
| 54 |
+
|
| 55 |
+
//
|
| 56 |
+
// boxing predicates
|
| 57 |
+
//
|
| 58 |
+
|
| 59 |
+
// A boxable arg type is one that IValue has a constructor for.
|
| 60 |
+
template <typename T>
|
| 61 |
+
using can_box = std::disjunction<
|
| 62 |
+
std::is_constructible<IValue, std::decay_t<T>>,
|
| 63 |
+
// TensorOptions are not directly constructible into IValue,
|
| 64 |
+
// but torch::jit::push knows how to handle them
|
| 65 |
+
std::is_same<TensorOptions, std::decay_t<T>>>;
|
| 66 |
+
|
| 67 |
+
template <typename... Ts>
|
| 68 |
+
using can_box_all = std::conjunction<can_box<Ts>...>;
|
| 69 |
+
|
| 70 |
+
// an unboxable result is one that can be extracted from an IValue
|
| 71 |
+
template <typename T>
|
| 72 |
+
using can_unbox = std::conjunction<
|
| 73 |
+
std::disjunction<
|
| 74 |
+
has_ivalue_to<T>,
|
| 75 |
+
// void returns are ok
|
| 76 |
+
std::is_same<void, T>>,
|
| 77 |
+
std::negation<std::is_lvalue_reference<T>>>;
|
| 78 |
+
|
| 79 |
+
//
|
| 80 |
+
// boxArgs - utility for pushing unboxed args onto IValue stack
|
| 81 |
+
//
|
| 82 |
+
template <class... Args>
|
| 83 |
+
torch::jit::Stack boxArgs(Args... args) {
|
| 84 |
+
// TODO Reuse stack vector instead of allocating?
|
| 85 |
+
torch::jit::Stack stack;
|
| 86 |
+
stack.reserve(sizeof...(Args));
|
| 87 |
+
torch::jit::push(stack, std::forward<Args>(args)...);
|
| 88 |
+
return stack;
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
template <class T>
|
| 92 |
+
inline constexpr size_t boxed_size_one() {
|
| 93 |
+
static_assert(
|
| 94 |
+
!std::is_same_v<std::decay_t<T>, c10::TensorOptions>,
|
| 95 |
+
"need to patch this path to support TensorOptions passed by reference");
|
| 96 |
+
return 1;
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
// torch::jit::push pushes 4 values for a TensorOptions; this needs to
|
| 100 |
+
// be kept in sync.
|
| 101 |
+
template <>
|
| 102 |
+
inline constexpr size_t boxed_size_one<c10::TensorOptions>() {
|
| 103 |
+
return 4;
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
// NOTE: this could probably be simplified with C++17 fold expressions.
|
| 107 |
+
template <typename...>
|
| 108 |
+
struct BoxedSize : std::integral_constant<size_t, 0> {};
|
| 109 |
+
template <class T, class... Args>
|
| 110 |
+
struct BoxedSize<T, Args...>
|
| 111 |
+
: std::integral_constant<
|
| 112 |
+
size_t,
|
| 113 |
+
boxed_size_one<T>() + BoxedSize<Args...>::value> {};
|
| 114 |
+
|
| 115 |
+
template <class... Args>
|
| 116 |
+
static inline constexpr size_t boxed_size() {
|
| 117 |
+
return BoxedSize<Args...>::value;
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
template <typename T>
|
| 121 |
+
C10_ALWAYS_INLINE_UNLESS_MOBILE void boxToStack(IValue*& dest, T& arg) {
|
| 122 |
+
new (dest++) IValue(arg);
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
C10_ALWAYS_INLINE_UNLESS_MOBILE void boxToStack(
|
| 126 |
+
IValue*& dest,
|
| 127 |
+
c10::TensorOptions options) {
|
| 128 |
+
new (dest++) IValue(c10::typeMetaToScalarType(options.dtype()));
|
| 129 |
+
new (dest++) IValue(options.layout());
|
| 130 |
+
new (dest++) IValue(options.device());
|
| 131 |
+
new (dest++) IValue(options.pinned_memory());
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
inline void boxArgsToStack(IValue*&) {}
|
| 135 |
+
|
| 136 |
+
template <typename T, typename... Args>
|
| 137 |
+
C10_ALWAYS_INLINE_UNLESS_MOBILE void boxArgsToStack(
|
| 138 |
+
IValue*& dest,
|
| 139 |
+
T& arg,
|
| 140 |
+
Args&... args) {
|
| 141 |
+
boxToStack(dest, arg);
|
| 142 |
+
boxArgsToStack(dest, args...);
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
//
|
| 146 |
+
// PopResult is a helper class whose specializations handle popping single and
|
| 147 |
+
// multiple return values, respectively.
|
| 148 |
+
//
|
| 149 |
+
template <class Result>
|
| 150 |
+
struct PopResult final {
|
| 151 |
+
static Result call(Stack& stack) {
|
| 152 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
| 153 |
+
stack.size() == 1,
|
| 154 |
+
"Boxed kernel was expected to return one value on the stack, ",
|
| 155 |
+
"but instead pushed ",
|
| 156 |
+
stack.size(),
|
| 157 |
+
" values.");
|
| 158 |
+
return std::move(stack[0]).to<Result>();
|
| 159 |
+
}
|
| 160 |
+
};
|
| 161 |
+
|
| 162 |
+
template <class... Types>
|
| 163 |
+
struct PopResult<std::tuple<Types...>> final {
|
| 164 |
+
using Result = std::tuple<Types...>;
|
| 165 |
+
|
| 166 |
+
static Result call(Stack& stack) {
|
| 167 |
+
// for tuple return types, boxed kernel has pushed multiple values onto the
|
| 168 |
+
// stack
|
| 169 |
+
constexpr int RetCount = sizeof...(Types);
|
| 170 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
| 171 |
+
stack.size() == RetCount,
|
| 172 |
+
"Boxed kernel was expected to return ",
|
| 173 |
+
RetCount,
|
| 174 |
+
" values on the stack, ",
|
| 175 |
+
"but instead pushed ",
|
| 176 |
+
stack.size(),
|
| 177 |
+
" values.");
|
| 178 |
+
return pop_to_tuple_impl(stack, std::make_index_sequence<RetCount>());
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
private:
|
| 182 |
+
// note: this has been moved into its own helper only to avoid a parse error
|
| 183 |
+
// on `indices` otherwise. I'm sure there's an incantation that slips it past
|
| 184 |
+
// the parser but eh
|
| 185 |
+
template <size_t... indices>
|
| 186 |
+
static Result pop_to_tuple_impl(
|
| 187 |
+
Stack& stack,
|
| 188 |
+
std::index_sequence<indices...>) {
|
| 189 |
+
return std::make_tuple((std::move(stack[indices]).template to<Types>())...);
|
| 190 |
+
}
|
| 191 |
+
};
|
| 192 |
+
|
| 193 |
+
//
|
| 194 |
+
// BoxedKernelWrapper
|
| 195 |
+
//
|
| 196 |
+
// For a given function type FT, BoxedKernelWrapper<FT> implements
|
| 197 |
+
// a `call` method that
|
| 198 |
+
// - takes a boxed kernel and unboxed arguments as specified by FT,
|
| 199 |
+
// - calls `boxArgs` to box the arguments
|
| 200 |
+
// - calls the boxed kernel
|
| 201 |
+
// - unboxes and returns the result
|
| 202 |
+
//
|
| 203 |
+
// The partial specializations below handle various cases: in
|
| 204 |
+
// particular, not all types appearing in op signatures are supported,
|
| 205 |
+
// and ops returning references have nonstandard wrapper implementations.
|
| 206 |
+
//
|
| 207 |
+
|
| 208 |
+
// 1. The base specialization of BoxedKernelWrapper should never be
|
| 209 |
+
// instantiated. A "no call method defined on BoxedKernelWrapper" compile error
|
| 210 |
+
// means that an op signature has failed to trigger any of the partial
|
| 211 |
+
// specializations that follow this one.
|
| 212 |
+
//
|
| 213 |
+
template <class FuncType, class Enable = void>
|
| 214 |
+
struct BoxedKernelWrapper {
|
| 215 |
+
// The reason we're not just doing straight up static_assert(false, ...) here:
|
| 216 |
+
// Basically, the way to make sure a static_assert only fires if a template
|
| 217 |
+
// is actually instantiated (rather than every time the file is parsed) is to
|
| 218 |
+
// use template parameters in the expression, e.g. FuncType here. However,
|
| 219 |
+
// since `sizeof(FuncType) != sizeof(FuncType)` is always false, this has the
|
| 220 |
+
// same effect.
|
| 221 |
+
static_assert(
|
| 222 |
+
sizeof(FuncType) != sizeof(FuncType),
|
| 223 |
+
"Function signature contains one or more unsupported parameter and/or return types. "
|
| 224 |
+
"Look for a nearby error like "
|
| 225 |
+
"\"'call' is not a member of 'c10::impl::BoxedKernelWrapper<(your function type), void>'\" "
|
| 226 |
+
"- (your function type) is the unsupported signature.");
|
| 227 |
+
};
|
| 228 |
+
|
| 229 |
+
//
|
| 230 |
+
// 2. Supported signatures, other than those involving non-const Tensor refs -
|
| 231 |
+
// i.e., "functional" ops.
|
| 232 |
+
//
|
| 233 |
+
|
| 234 |
+
template <class Result, class... Args>
|
| 235 |
+
struct BoxedKernelWrapper<
|
| 236 |
+
Result(Args...),
|
| 237 |
+
std::enable_if_t<
|
| 238 |
+
can_box_all<Args...>::value && can_unbox<Result>::value &&
|
| 239 |
+
!is_tuple_of_mutable_tensor_refs<Result>::value,
|
| 240 |
+
void>> {
|
| 241 |
+
static Result call(
|
| 242 |
+
const BoxedKernel& boxed_kernel_func,
|
| 243 |
+
const OperatorHandle& opHandle,
|
| 244 |
+
DispatchKeySet dispatchKeySet,
|
| 245 |
+
Args... args) {
|
| 246 |
+
torch::jit::Stack stack = boxArgs<Args...>(std::forward<Args>(args)...);
|
| 247 |
+
boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack);
|
| 248 |
+
|
| 249 |
+
if constexpr (!std::is_same_v<void, Result>) {
|
| 250 |
+
// op has pushed one or more values onto the stack.
|
| 251 |
+
return PopResult<Result>::call(stack);
|
| 252 |
+
} else {
|
| 253 |
+
// op returns void, boxed kernel has pushed nothing onto stack.
|
| 254 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
| 255 |
+
stack.empty(),
|
| 256 |
+
"Boxed kernel was expected to return no values on the stack, ",
|
| 257 |
+
"but instead returned ",
|
| 258 |
+
stack.size(),
|
| 259 |
+
" values.");
|
| 260 |
+
}
|
| 261 |
+
}
|
| 262 |
+
};
|
| 263 |
+
|
| 264 |
+
//
|
| 265 |
+
// 3. in-place ops take a single non-const Tensor reference
|
| 266 |
+
// as their first argument, and return it.
|
| 267 |
+
//
|
| 268 |
+
// Note: all signatures matching this pattern are assumed to be for such ops.
|
| 269 |
+
// Because of this, the generated BoxedKernelWrapper specializations simply
|
| 270 |
+
// return the in-place argument.
|
| 271 |
+
//
|
| 272 |
+
|
| 273 |
+
template <class... OtherArgs>
|
| 274 |
+
struct BoxedKernelWrapper<
|
| 275 |
+
at::Tensor&(at::Tensor&, OtherArgs...),
|
| 276 |
+
std::enable_if_t<can_box_all<OtherArgs...>::value, void>> {
|
| 277 |
+
static at::Tensor& call(
|
| 278 |
+
const BoxedKernel& boxed_kernel_func,
|
| 279 |
+
const OperatorHandle& opHandle,
|
| 280 |
+
DispatchKeySet dispatchKeySet,
|
| 281 |
+
at::Tensor& outArg,
|
| 282 |
+
OtherArgs... otherArgs) {
|
| 283 |
+
torch::jit::Stack stack = boxArgs<at::Tensor&, OtherArgs...>(
|
| 284 |
+
outArg, std::forward<OtherArgs>(otherArgs)...);
|
| 285 |
+
boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack);
|
| 286 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
| 287 |
+
stack.size() == 1,
|
| 288 |
+
"Boxed kernel was expected to return a single value on the stack, ",
|
| 289 |
+
"but instead returned ",
|
| 290 |
+
stack.size(),
|
| 291 |
+
" values.");
|
| 292 |
+
|
| 293 |
+
return outArg;
|
| 294 |
+
}
|
| 295 |
+
};
|
| 296 |
+
|
| 297 |
+
//
|
| 298 |
+
// 3.5. In-process migration to make in-place ops take and return
|
| 299 |
+
// const references instead.
|
| 300 |
+
template <class... OtherArgs>
|
| 301 |
+
struct BoxedKernelWrapper<
|
| 302 |
+
const at::Tensor&(const at::Tensor&, OtherArgs...),
|
| 303 |
+
std::enable_if_t<can_box_all<OtherArgs...>::value, void>> {
|
| 304 |
+
static const at::Tensor& call(
|
| 305 |
+
const BoxedKernel& boxed_kernel_func,
|
| 306 |
+
const OperatorHandle& opHandle,
|
| 307 |
+
DispatchKeySet dispatchKeySet,
|
| 308 |
+
const at::Tensor& outArg,
|
| 309 |
+
OtherArgs... otherArgs) {
|
| 310 |
+
torch::jit::Stack stack = boxArgs(outArg, otherArgs...);
|
| 311 |
+
boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack);
|
| 312 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
| 313 |
+
stack.size() == 1,
|
| 314 |
+
"Boxed kernel was expected to return a single value on the stack, ",
|
| 315 |
+
"but instead returned ",
|
| 316 |
+
stack.size(),
|
| 317 |
+
" values.");
|
| 318 |
+
|
| 319 |
+
return outArg;
|
| 320 |
+
}
|
| 321 |
+
};
|
| 322 |
+
|
| 323 |
+
//
|
| 324 |
+
// 4. out of place ops that take a single non-const Tensor reference as their
|
| 325 |
+
// final argument, and also return it.
|
| 326 |
+
//
|
| 327 |
+
// Note: all signatures matching this pattern are assumed to be for such ops.
|
| 328 |
+
// This assumption permits the generated BoxedKernelWrapper specializations to
|
| 329 |
+
// simply return out arguments.
|
| 330 |
+
//
|
| 331 |
+
template <class FirstArg, class... RestArgs>
|
| 332 |
+
struct BoxedKernelWrapper<
|
| 333 |
+
at::Tensor&(FirstArg, RestArgs...),
|
| 334 |
+
std::enable_if_t<
|
| 335 |
+
can_box_all<FirstArg, RestArgs...>::value
|
| 336 |
+
// this skips over in-place kernels with a non-const Tensor
|
| 337 |
+
// arg at the front, so those can unambiguously trigger the
|
| 338 |
+
// preceding specialization.
|
| 339 |
+
&& !is_mutable_tensor_ref<FirstArg>::value,
|
| 340 |
+
void>> {
|
| 341 |
+
static at::Tensor& call(
|
| 342 |
+
const BoxedKernel& boxed_kernel_func,
|
| 343 |
+
const OperatorHandle& opHandle,
|
| 344 |
+
DispatchKeySet dispatchKeySet,
|
| 345 |
+
FirstArg firstArg,
|
| 346 |
+
RestArgs... restArgs) {
|
| 347 |
+
torch::jit::Stack stack = boxArgs<FirstArg, RestArgs...>(
|
| 348 |
+
std::forward<FirstArg>(firstArg), std::forward<RestArgs>(restArgs)...);
|
| 349 |
+
boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack);
|
| 350 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
| 351 |
+
stack.size() == 1,
|
| 352 |
+
"Boxed kernel was expected to return a single value on the stack, ",
|
| 353 |
+
"but instead returned ",
|
| 354 |
+
stack.size(),
|
| 355 |
+
" values.");
|
| 356 |
+
|
| 357 |
+
// reusing restArgs after it has been forwarded here is ok because we know
|
| 358 |
+
// that the last element is of type `Tensor&`.
|
| 359 |
+
return std::get<sizeof...(RestArgs) - 1>(
|
| 360 |
+
std::tuple<RestArgs...>{restArgs...});
|
| 361 |
+
}
|
| 362 |
+
};
|
| 363 |
+
|
| 364 |
+
//
|
| 365 |
+
// 5. out of place ops that take multiple non-const Tensor references as their
|
| 366 |
+
// final arguments, and return them in a std::tuple.
|
| 367 |
+
//
|
| 368 |
+
// Note: all signatures matching this pattern are assumed to be for such ops.
|
| 369 |
+
// This assumption permits the generated BoxedKernelWrapper specializations to
|
| 370 |
+
// simply return the out arguments.
|
| 371 |
+
//
|
| 372 |
+
template <class Result, class... Args>
|
| 373 |
+
struct BoxedKernelWrapper<
|
| 374 |
+
Result(Args...),
|
| 375 |
+
std::enable_if_t<
|
| 376 |
+
can_box_all<Args...>::value &&
|
| 377 |
+
is_tuple_of_mutable_tensor_refs<Result>::value,
|
| 378 |
+
void>> {
|
| 379 |
+
static Result call(
|
| 380 |
+
const BoxedKernel& boxed_kernel_func,
|
| 381 |
+
const OperatorHandle& opHandle,
|
| 382 |
+
DispatchKeySet dispatchKeySet,
|
| 383 |
+
Args... args) {
|
| 384 |
+
using ArgTuple = std::tuple<Args...>;
|
| 385 |
+
constexpr int RetCount = std::tuple_size<Result>();
|
| 386 |
+
|
| 387 |
+
torch::jit::Stack stack = boxArgs<Args...>(std::forward<Args>(args)...);
|
| 388 |
+
boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack);
|
| 389 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
| 390 |
+
stack.size() == RetCount,
|
| 391 |
+
"Boxed kernel was expected to return ",
|
| 392 |
+
RetCount,
|
| 393 |
+
" values on the stack, ",
|
| 394 |
+
"but instead returned ",
|
| 395 |
+
stack.size(),
|
| 396 |
+
" values.");
|
| 397 |
+
|
| 398 |
+
// reusing args after it has been forwarded here is ok because we know
|
| 399 |
+
// that the last RetCount elements are of type `Tensor&`.
|
| 400 |
+
auto result = guts::tuple_take<ArgTuple, -RetCount>(
|
| 401 |
+
ArgTuple{std::forward<Args>(args)...});
|
| 402 |
+
static_assert(
|
| 403 |
+
std::is_same_v<Result, decltype(result)>,
|
| 404 |
+
"The parameter list of an op returning a tuple of Tensor references "
|
| 405 |
+
"must end with an equal number of Tensor reference parameters.");
|
| 406 |
+
return result;
|
| 407 |
+
}
|
| 408 |
+
};
|
| 409 |
+
|
| 410 |
+
} // namespace c10::impl
|
phivenv/Lib/site-packages/torch/include/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h
ADDED
|
@@ -0,0 +1,785 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/IListRef.h>
|
| 4 |
+
#include <ATen/core/boxing/OperatorKernel.h>
|
| 5 |
+
#include <ATen/core/ivalue.h>
|
| 6 |
+
#include <ATen/core/stack.h>
|
| 7 |
+
#include <c10/util/Metaprogramming.h>
|
| 8 |
+
#include <c10/util/TypeList.h>
|
| 9 |
+
#include <c10/util/intrusive_ptr.h>
|
| 10 |
+
|
| 11 |
+
#include <utility>
|
| 12 |
+
|
| 13 |
+
namespace c10 {
|
| 14 |
+
|
| 15 |
+
using Stack = torch::jit::Stack; // TODO Instead of this, move torch::jit::Stack
|
| 16 |
+
// to the c10 namespace.
|
| 17 |
+
class OperatorHandle;
|
| 18 |
+
|
| 19 |
+
/*
|
| 20 |
+
* [Note: Argument forwarding in the dispatcher]
|
| 21 |
+
*
|
| 22 |
+
* The dispatcher uses a somewhat unusual way to forward arguments through
|
| 23 |
+
* several layers of wrapper functions. This can be confusing because an
|
| 24 |
+
* experienced C++ programmer would look at this and think "oh this is supposed
|
| 25 |
+
* to be forwarding a universal reference but the && is missing. This is a
|
| 26 |
+
* bug.". It is not a bug. The common way in C++ to forward arguments is to use
|
| 27 |
+
* universal references:
|
| 28 |
+
*
|
| 29 |
+
* > template<class T> void func(T&& arg) { func2(std::forward<T>(arg)); }
|
| 30 |
+
*
|
| 31 |
+
* but that relies on inferring the correct reference type (i.e. value vs & vs
|
| 32 |
+
* &&) from the argument. In our case, we cannot rely on the argument as
|
| 33 |
+
* supplied by the caller, because that could infer a different reference type
|
| 34 |
+
* than was used in the kernel function. The correct reference type is dictated
|
| 35 |
+
* by the kernel signature and must be identical since we cast function pointers
|
| 36 |
+
* through void* pointers and mismatches would be UB. So we need a forwarding
|
| 37 |
+
* pattern that determines the reference type to use by looking at the
|
| 38 |
+
* explicitly supplied operator signature, not by looking at the argument we're
|
| 39 |
+
* calling it with.
|
| 40 |
+
*
|
| 41 |
+
* What does std::forward do, exactly?
|
| 42 |
+
* ------------------------------------
|
| 43 |
+
* std::forward<T>(t) is a way to cast t to the reference type supplied in T.
|
| 44 |
+
* Let's assume decay_t<T> == U and T is either U or some reference of U.
|
| 45 |
+
* - std::forward<T&>(t) will return U&, no matter what kind of reference t is.
|
| 46 |
+
* - std::forward<T&&>(t) will return U&&, no matter what kind of reference t
|
| 47 |
+
* is.
|
| 48 |
+
* - std::forward<T>(t) will return U&& (not U!), no matter what kind of
|
| 49 |
+
* reference t is.
|
| 50 |
+
*
|
| 51 |
+
* For universal references, that means that in the following function
|
| 52 |
+
* > template<class T> void func(T&& arg) { func2(std::forward<T>(arg)); }
|
| 53 |
+
*
|
| 54 |
+
* - when called with arg being a rvalue reference or non-reference value, T
|
| 55 |
+
* gets inferred to be a non-reference U, and std::forward<T>(t) will return
|
| 56 |
+
* U&&, correctly moving the argument.
|
| 57 |
+
* - when called with arg behind a lvalue reference, T gets inferred to be U&
|
| 58 |
+
* because that's the only way to match the signature (in C++, a type that is
|
| 59 |
+
* (T&)&& will collapse to T&). That means std::forward<T>(t) will return U& and
|
| 60 |
+
* the value will not be moved but passed on as a lvalue reference.
|
| 61 |
+
*
|
| 62 |
+
* How do we use that?
|
| 63 |
+
* ------------------------------------
|
| 64 |
+
* But std::forward can also be used outside of the common "universal
|
| 65 |
+
* forwarding" pattern to change reference types. So instead of following the
|
| 66 |
+
* common C++ pattern, we notice what std::forward<T>() actually does, and that
|
| 67 |
+
* is it takes a value and changes its reference to the type of reference passed
|
| 68 |
+
* in as T. If we don't infer T but explicitly specify it, we can use this to
|
| 69 |
+
* forward based on an explicitly specified reference type instead of the
|
| 70 |
+
* inferred argument type.
|
| 71 |
+
*
|
| 72 |
+
* This is why many of the dispatcher functions look like
|
| 73 |
+
* > template<class T> func(T t) { func2<T>(std::forward<T>(t)); }
|
| 74 |
+
* instead of the common
|
| 75 |
+
* > template<class T> func(T&& t) { func2(std::forward<T>(t)); }
|
| 76 |
+
*
|
| 77 |
+
* and are expected to be called by explicitly specifying the template
|
| 78 |
+
* parameters in a way that matches the expected operator signature at each call
|
| 79 |
+
* site.
|
| 80 |
+
*/
|
| 81 |
+
|
| 82 |
+
namespace impl {
|
| 83 |
+
// supported_primitive_arg_types defines which primitive types we allow in
|
| 84 |
+
// kernel functions as arguments or returns.
|
| 85 |
+
// Additionally, we support lists, dicts and optionals containing these types.
|
| 86 |
+
using supported_primitive_arg_types = guts::typelist::typelist<
|
| 87 |
+
int64_t,
|
| 88 |
+
double,
|
| 89 |
+
bool,
|
| 90 |
+
std::string_view,
|
| 91 |
+
at::Tensor,
|
| 92 |
+
at::Scalar,
|
| 93 |
+
c10::QScheme,
|
| 94 |
+
c10::ScalarType,
|
| 95 |
+
c10::Device,
|
| 96 |
+
c10::DeviceIndex,
|
| 97 |
+
c10::Layout,
|
| 98 |
+
c10::MemoryFormat,
|
| 99 |
+
at::Dimname>;
|
| 100 |
+
|
| 101 |
+
// We have an unboxed functor in hand that takes C++ arguments, and
|
| 102 |
+
// we're building a boxed functor wrapper for it that takes IValues.
|
| 103 |
+
// So "outside" is boxed and "inside" is unboxed.
|
| 104 |
+
//
|
| 105 |
+
// So a valid input type is one that our boxed functor wrapper can
|
| 106 |
+
// unbox from an IValue into a C++ value.
|
| 107 |
+
//
|
| 108 |
+
// Whereas a valid output type is one that our wrapper can recieve
|
| 109 |
+
// as a C++ value from the unboxed functor, and box into an IValue.
|
| 110 |
+
|
| 111 |
+
//
|
| 112 |
+
// assert_is_valid_input_type
|
| 113 |
+
// checks that T can be unboxed from an IValue into a C++ value.
|
| 114 |
+
//
|
| 115 |
+
|
| 116 |
+
template <class T, bool AllowDeprecatedTypes, class Enable = void>
|
| 117 |
+
struct assert_is_valid_input_type {
|
| 118 |
+
assert_is_valid_input_type() {
|
| 119 |
+
if constexpr (guts::typelist::contains<supported_primitive_arg_types, T>::
|
| 120 |
+
value) {
|
| 121 |
+
/* everything is ok, this is a primitive type */
|
| 122 |
+
} else {
|
| 123 |
+
/* otherwise this must be an instance of a valid custom class, since it
|
| 124 |
+
can only have been created via IValue(x), which ensures this. */
|
| 125 |
+
}
|
| 126 |
+
}
|
| 127 |
+
};
|
| 128 |
+
|
| 129 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 130 |
+
struct assert_is_valid_input_type<std::optional<T>, AllowDeprecatedTypes>
|
| 131 |
+
: assert_is_valid_input_type<T, AllowDeprecatedTypes> {};
|
| 132 |
+
|
| 133 |
+
template <bool AllowDeprecatedTypes, class... Args>
|
| 134 |
+
struct TypeCheckHelper;
|
| 135 |
+
|
| 136 |
+
template <bool AllowDeprecatedTypes>
|
| 137 |
+
struct TypeCheckHelper<AllowDeprecatedTypes> {};
|
| 138 |
+
|
| 139 |
+
template <bool AllowDeprecatedTypes, class Head, class... Rest>
|
| 140 |
+
struct TypeCheckHelper<AllowDeprecatedTypes, Head, Rest...>
|
| 141 |
+
: TypeCheckHelper<AllowDeprecatedTypes, Rest...> {
|
| 142 |
+
assert_is_valid_input_type<Head, AllowDeprecatedTypes> check;
|
| 143 |
+
};
|
| 144 |
+
|
| 145 |
+
template <class... Contained, bool AllowDeprecatedTypes>
|
| 146 |
+
struct assert_is_valid_input_type<
|
| 147 |
+
std::tuple<Contained...>,
|
| 148 |
+
AllowDeprecatedTypes>
|
| 149 |
+
: TypeCheckHelper<AllowDeprecatedTypes, Contained...> {};
|
| 150 |
+
|
| 151 |
+
template <class Key, class Value, bool AllowDeprecatedTypes>
|
| 152 |
+
struct assert_is_valid_input_type<Dict<Key, Value>, AllowDeprecatedTypes>
|
| 153 |
+
: assert_is_valid_input_type<Value, AllowDeprecatedTypes> {
|
| 154 |
+
static_assert(
|
| 155 |
+
guts::typelist::contains<impl::valid_dict_key_types, Key>::value,
|
| 156 |
+
"You tried to register a kernel with an unsupported input type: Dict<Key, Value> where Key is invalid. We only support int64_t, double, bool, and string.");
|
| 157 |
+
};
|
| 158 |
+
|
| 159 |
+
template <class Key, class Value, bool AllowDeprecatedTypes>
|
| 160 |
+
struct assert_is_valid_input_type<
|
| 161 |
+
std::unordered_map<Key, Value>,
|
| 162 |
+
AllowDeprecatedTypes>
|
| 163 |
+
: assert_is_valid_input_type<Value, AllowDeprecatedTypes> {
|
| 164 |
+
static_assert(
|
| 165 |
+
AllowDeprecatedTypes,
|
| 166 |
+
"You tried to register a kernel with an unsupported input type: std::unordered_map<Key, Value>. Please use Dict<Key, Value> instead.");
|
| 167 |
+
static_assert(
|
| 168 |
+
guts::typelist::contains<impl::valid_dict_key_types, Key>::value,
|
| 169 |
+
"You tried to register a kernel with an unsupported input type: std::unordered_map<Key, Value> where Key is invalid. We only support int64_t, double, bool, and string.");
|
| 170 |
+
};
|
| 171 |
+
|
| 172 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 173 |
+
struct assert_is_valid_input_type<List<T>, AllowDeprecatedTypes>
|
| 174 |
+
: assert_is_valid_input_type<T, AllowDeprecatedTypes> {
|
| 175 |
+
static_assert(
|
| 176 |
+
!std::is_same_v<T, at::Scalar>,
|
| 177 |
+
"You tried to register a kernel with an unsupported input type: List<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
|
| 178 |
+
};
|
| 179 |
+
|
| 180 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 181 |
+
struct assert_is_valid_input_type<c10::ArrayRef<T>, AllowDeprecatedTypes>
|
| 182 |
+
: assert_is_valid_input_type<T, AllowDeprecatedTypes> {
|
| 183 |
+
static_assert(
|
| 184 |
+
!std::is_same_v<T, at::Scalar>,
|
| 185 |
+
"You tried to register a kernel with an unsupported input type: ArrayRef<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
|
| 186 |
+
};
|
| 187 |
+
|
| 188 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 189 |
+
struct assert_is_valid_input_type<
|
| 190 |
+
c10::OptionalArrayRef<T>,
|
| 191 |
+
AllowDeprecatedTypes>
|
| 192 |
+
: assert_is_valid_input_type<T, AllowDeprecatedTypes> {
|
| 193 |
+
static_assert(
|
| 194 |
+
!std::is_same_v<T, at::Scalar>,
|
| 195 |
+
"You tried to register a kernel with an unsupported input type: OptionalArrayRef<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
|
| 196 |
+
};
|
| 197 |
+
|
| 198 |
+
template <class T, size_t N, bool AllowDeprecatedTypes>
|
| 199 |
+
struct assert_is_valid_input_type<std::array<T, N>, AllowDeprecatedTypes>
|
| 200 |
+
: assert_is_valid_input_type<T, AllowDeprecatedTypes> {
|
| 201 |
+
static_assert(
|
| 202 |
+
!std::is_same_v<T, at::Scalar>,
|
| 203 |
+
"You tried to register a kernel with an unsupported input type: std::array<Scalar, N>. Please use std::array<int64_t, N> instead.");
|
| 204 |
+
};
|
| 205 |
+
|
| 206 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 207 |
+
struct assert_is_valid_input_type<
|
| 208 |
+
T,
|
| 209 |
+
AllowDeprecatedTypes,
|
| 210 |
+
std::enable_if_t<std::is_same_v<float, T>>> {
|
| 211 |
+
// There is no reason to support float when we have double. Keep the API lean.
|
| 212 |
+
static_assert(
|
| 213 |
+
guts::false_t<T>::value,
|
| 214 |
+
"You tried to register a kernel with an unsupported input type: float. Please use double instead; you should use `double` in the C++ function signature and `float` in the schema string.");
|
| 215 |
+
};
|
| 216 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 217 |
+
struct assert_is_valid_input_type<
|
| 218 |
+
T,
|
| 219 |
+
AllowDeprecatedTypes,
|
| 220 |
+
std::enable_if_t<std::is_same_v<const char*, T>>> {
|
| 221 |
+
static_assert(
|
| 222 |
+
guts::false_t<T>::value,
|
| 223 |
+
"You tried to register a kernel with an unsupported input type: const char*. Please use std::string_view instead.");
|
| 224 |
+
};
|
| 225 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 226 |
+
struct assert_is_valid_input_type<
|
| 227 |
+
T,
|
| 228 |
+
AllowDeprecatedTypes,
|
| 229 |
+
std::enable_if_t<std::is_same_v<std::vector<bool>, T>>> {
|
| 230 |
+
static_assert(
|
| 231 |
+
guts::false_t<T>::value,
|
| 232 |
+
"You tried to register a kernel with an unsupported input type: vector<bool>. Please use List<bool> instead.");
|
| 233 |
+
};
|
| 234 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 235 |
+
struct assert_is_valid_input_type<
|
| 236 |
+
T,
|
| 237 |
+
AllowDeprecatedTypes,
|
| 238 |
+
std::enable_if_t<
|
| 239 |
+
std::is_integral_v<T> &&
|
| 240 |
+
!guts::typelist::contains<supported_primitive_arg_types, T>::value>> {
|
| 241 |
+
static_assert(
|
| 242 |
+
guts::false_t<T>::value,
|
| 243 |
+
"You tried to register a kernel with an unsupported integral input type. Please use int64_t instead; you should use `int64_t` in the C++ function signature and `int` in the schema string.");
|
| 244 |
+
};
|
| 245 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 246 |
+
struct assert_is_valid_input_type<
|
| 247 |
+
T,
|
| 248 |
+
AllowDeprecatedTypes,
|
| 249 |
+
std::enable_if_t<std::is_same_v<const c10::SymInt&, T>>> {
|
| 250 |
+
static_assert(
|
| 251 |
+
guts::false_t<T>::value,
|
| 252 |
+
"You tried to register a kernel taking c10::SymInt by reference. Please accept it by value instead.");
|
| 253 |
+
};
|
| 254 |
+
|
| 255 |
+
// TODO: it probably would be good to tighten this up quite a bit more with
|
| 256 |
+
// an explicit list for everything
|
| 257 |
+
|
| 258 |
+
//
|
| 259 |
+
// assert_is_valid_output_type
|
| 260 |
+
//
|
| 261 |
+
|
| 262 |
+
template <class T, bool AllowDeprecatedTypes, class Enable = void>
|
| 263 |
+
struct assert_is_valid_output_type {
|
| 264 |
+
assert_is_valid_output_type() {
|
| 265 |
+
if constexpr (guts::typelist::contains<supported_primitive_arg_types, T>::
|
| 266 |
+
value) {
|
| 267 |
+
/* everything is ok, this is a primitive type */
|
| 268 |
+
} else {
|
| 269 |
+
/* otherwise T is verified to be a registered custom class in the IValue
|
| 270 |
+
constructor, so no benefit in double-checking here */
|
| 271 |
+
}
|
| 272 |
+
}
|
| 273 |
+
};
|
| 274 |
+
|
| 275 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 276 |
+
struct assert_is_valid_output_type<std::optional<T>, AllowDeprecatedTypes>
|
| 277 |
+
: assert_is_valid_output_type<T, AllowDeprecatedTypes> {};
|
| 278 |
+
|
| 279 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 280 |
+
struct assert_is_valid_output_type<
|
| 281 |
+
c10::OptionalArrayRef<T>,
|
| 282 |
+
AllowDeprecatedTypes>
|
| 283 |
+
: assert_is_valid_output_type<T, AllowDeprecatedTypes> {};
|
| 284 |
+
|
| 285 |
+
template <class Key, class Value, bool AllowDeprecatedTypes>
|
| 286 |
+
struct assert_is_valid_output_type<Dict<Key, Value>, AllowDeprecatedTypes>
|
| 287 |
+
: assert_is_valid_output_type<Value, AllowDeprecatedTypes> {
|
| 288 |
+
static_assert(
|
| 289 |
+
guts::typelist::contains<impl::valid_dict_key_types, Key>::value,
|
| 290 |
+
"You tried to register a kernel with an unsupported output type: Dict<Key, Value> where Key is invalid. We only support int64_t, double, bool, and string.");
|
| 291 |
+
static_assert(
|
| 292 |
+
!std::is_same_v<Value, at::Scalar>,
|
| 293 |
+
"You tried to register a kernel with an unsupported output type: Dict<Key, Scalar>. Please use Dict<Key, int64_t> or Dict<Key, double>.");
|
| 294 |
+
};
|
| 295 |
+
|
| 296 |
+
template <class Key, class Value, bool AllowDeprecatedTypes>
|
| 297 |
+
struct assert_is_valid_output_type<
|
| 298 |
+
std::unordered_map<Key, Value>,
|
| 299 |
+
AllowDeprecatedTypes>
|
| 300 |
+
: assert_is_valid_output_type<Value, AllowDeprecatedTypes> {
|
| 301 |
+
static_assert(
|
| 302 |
+
AllowDeprecatedTypes,
|
| 303 |
+
"You tried to register a kernel with an unsupported output type: std::unordered_map<Key, Value>. Please use Dict<Key, Value> instead.");
|
| 304 |
+
static_assert(
|
| 305 |
+
guts::typelist::contains<impl::valid_dict_key_types, Key>::value,
|
| 306 |
+
"You tried to register a kernel with an unsupported output type: std::unordered_map<Key, Value> where Key is invalid. We only support int64_t, double, bool, and string.");
|
| 307 |
+
static_assert(
|
| 308 |
+
!std::is_same_v<Value, at::Scalar>,
|
| 309 |
+
"You tried to register a kernel with an unsupported output type: std::unordered_map<Key, Scalar>. Please use Dict<Key, int64_t> or Dict<Key, double>.");
|
| 310 |
+
};
|
| 311 |
+
|
| 312 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 313 |
+
struct assert_is_valid_output_type<List<T>, AllowDeprecatedTypes>
|
| 314 |
+
: assert_is_valid_output_type<T, AllowDeprecatedTypes> {
|
| 315 |
+
static_assert(
|
| 316 |
+
!std::is_same_v<T, at::Scalar>,
|
| 317 |
+
"You tried to register a kernel with an unsupported output type: List<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
|
| 318 |
+
};
|
| 319 |
+
|
| 320 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 321 |
+
struct assert_is_valid_output_type<std::vector<T>, AllowDeprecatedTypes>
|
| 322 |
+
: assert_is_valid_output_type<T, AllowDeprecatedTypes> {
|
| 323 |
+
static_assert(
|
| 324 |
+
!std::is_same_v<T, at::Scalar>,
|
| 325 |
+
"You tried to register a kernel with an unsupported output type: std::vector<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
|
| 326 |
+
// TODO static_assert(AllowDeprecatedTypes, "You tried to register a kernel
|
| 327 |
+
// with an unsupported output type: std::vector<T>. Please use List<T>
|
| 328 |
+
// instead.");
|
| 329 |
+
};
|
| 330 |
+
|
| 331 |
+
template <class T, size_t N, bool AllowDeprecatedTypes>
|
| 332 |
+
struct assert_is_valid_output_type<std::array<T, N>, AllowDeprecatedTypes>
|
| 333 |
+
: assert_is_valid_output_type<T, AllowDeprecatedTypes> {
|
| 334 |
+
static_assert(
|
| 335 |
+
!std::is_same_v<T, at::Scalar>,
|
| 336 |
+
"You tried to register a kernel with an unsupported output type: std::array<Scalar, N>. Please use std::array<int64_t, N> instead.");
|
| 337 |
+
};
|
| 338 |
+
|
| 339 |
+
// The following specialisations of assert_is_valid_output_type are technically
|
| 340 |
+
// not necessary since we would hit the base case and show an error message
|
| 341 |
+
// there if they didn't exist, but we can show a better error message
|
| 342 |
+
// in some common error scenarios.
|
| 343 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 344 |
+
struct assert_is_valid_output_type<
|
| 345 |
+
T,
|
| 346 |
+
AllowDeprecatedTypes,
|
| 347 |
+
std::enable_if_t<std::is_same_v<float, T>>> {
|
| 348 |
+
// There is no reason to support float when we have double. Keep the API lean.
|
| 349 |
+
static_assert(
|
| 350 |
+
guts::false_t<T>::value,
|
| 351 |
+
"You tried to register a kernel with an unsupported output type: float. Please use double instead; you should use `double` in the C++ function signature and `float` in the schema string.");
|
| 352 |
+
};
|
| 353 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 354 |
+
struct assert_is_valid_output_type<
|
| 355 |
+
T,
|
| 356 |
+
AllowDeprecatedTypes,
|
| 357 |
+
std::enable_if_t<std::is_same_v<const char*, T>>> {
|
| 358 |
+
static_assert(
|
| 359 |
+
guts::false_t<T>::value,
|
| 360 |
+
"You tried to register a kernel with an unsupported output type: const char*. Please use std::string_view instead.");
|
| 361 |
+
};
|
| 362 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 363 |
+
struct assert_is_valid_output_type<
|
| 364 |
+
T,
|
| 365 |
+
AllowDeprecatedTypes,
|
| 366 |
+
std::enable_if_t<std::is_same_v<std::vector<bool>, T>>> {
|
| 367 |
+
static_assert(
|
| 368 |
+
guts::false_t<T>::value,
|
| 369 |
+
"You tried to register a kernel with an unsupported output type: vector<bool>. Please use List<bool> instead.");
|
| 370 |
+
};
|
| 371 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 372 |
+
struct assert_is_valid_output_type<
|
| 373 |
+
T,
|
| 374 |
+
AllowDeprecatedTypes,
|
| 375 |
+
std::enable_if_t<
|
| 376 |
+
std::is_integral_v<T> &&
|
| 377 |
+
!guts::typelist::contains<supported_primitive_arg_types, T>::value>> {
|
| 378 |
+
static_assert(
|
| 379 |
+
guts::false_t<T>::value,
|
| 380 |
+
"You tried to register a kernel with an unsupported integral output type. Please use int64_t instead; you should use `int64_t` in the C++ function signature and `int` in the schema string.");
|
| 381 |
+
};
|
| 382 |
+
|
| 383 |
+
// ivalue_to_arg
|
| 384 |
+
|
| 385 |
+
template <class T>
|
| 386 |
+
struct decay_if_not_tensor final {
|
| 387 |
+
using type = std::decay_t<T>;
|
| 388 |
+
};
|
| 389 |
+
|
| 390 |
+
template <>
|
| 391 |
+
struct decay_if_not_tensor<at::Tensor&> final {
|
| 392 |
+
using type = at::Tensor&;
|
| 393 |
+
};
|
| 394 |
+
|
| 395 |
+
template <>
|
| 396 |
+
struct decay_if_not_tensor<const at::Tensor&> final {
|
| 397 |
+
using type = const at::Tensor&;
|
| 398 |
+
};
|
| 399 |
+
|
| 400 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 401 |
+
struct ivalue_to_arg final {
|
| 402 |
+
static decltype(auto) call(IValue& v) {
|
| 403 |
+
assert_is_valid_input_type<T, AllowDeprecatedTypes>();
|
| 404 |
+
return std::move(v).to<T>();
|
| 405 |
+
}
|
| 406 |
+
};
|
| 407 |
+
|
| 408 |
+
// The following two specializations take advantage of specialized
|
| 409 |
+
// `toTensor()` overloads on IValue to avoid copying.
|
| 410 |
+
template <bool AllowDeprecatedTypes>
|
| 411 |
+
struct ivalue_to_arg<at::Tensor&, AllowDeprecatedTypes> final {
|
| 412 |
+
// We cannot use the default implementation if they asked for a
|
| 413 |
+
// `at::Tensor&` because it moves from the IValue, so it can't get
|
| 414 |
+
// an lvalue reference.
|
| 415 |
+
static at::Tensor& call(IValue& v) {
|
| 416 |
+
// Tensor& is valid, don't bother asserting
|
| 417 |
+
return v.toTensor();
|
| 418 |
+
}
|
| 419 |
+
};
|
| 420 |
+
|
| 421 |
+
template <bool AllowDeprecatedTypes>
|
| 422 |
+
struct ivalue_to_arg<const at::Tensor&, AllowDeprecatedTypes> final {
|
| 423 |
+
// We should not use the default implementation if they asked for
|
| 424 |
+
// a `const at::Tensor&` because it moves from the IValue and they
|
| 425 |
+
// didn't ask for that.
|
| 426 |
+
static const at::Tensor& call(IValue& v) {
|
| 427 |
+
// const Tensor& is valid, don't bother asserting
|
| 428 |
+
return v.toTensor();
|
| 429 |
+
}
|
| 430 |
+
};
|
| 431 |
+
|
| 432 |
+
template <bool AllowDeprecatedTypes>
|
| 433 |
+
struct ivalue_to_arg<at::ITensorListRef, AllowDeprecatedTypes> final {
|
| 434 |
+
static List<at::Tensor> call(IValue& v) {
|
| 435 |
+
return v.toTensorList();
|
| 436 |
+
}
|
| 437 |
+
};
|
| 438 |
+
|
| 439 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 440 |
+
struct ivalue_to_arg<ArrayRef<T>, AllowDeprecatedTypes> final {
|
| 441 |
+
// If an argument is ArrayRef<T>, convert the IValue to a std::vector<T> and
|
| 442 |
+
// pass that to the operator. std::vector<T> is implicitly convertible to
|
| 443 |
+
// ArrayRef<T>.
|
| 444 |
+
static std::vector<T> call(IValue& v) {
|
| 445 |
+
return ivalue_to_arg<std::vector<T>, AllowDeprecatedTypes>::call(v);
|
| 446 |
+
}
|
| 447 |
+
};
|
| 448 |
+
template <bool AllowDeprecatedTypes>
|
| 449 |
+
struct ivalue_to_arg<c10::SymIntArrayRef, AllowDeprecatedTypes> final {
|
| 450 |
+
static std::vector<c10::SymInt> call(IValue& v) {
|
| 451 |
+
if (v.isIntList()) {
|
| 452 |
+
std::vector<c10::SymInt> r;
|
| 453 |
+
auto src = v.toIntList();
|
| 454 |
+
std::transform(
|
| 455 |
+
src.begin(), src.end(), std::back_inserter(r), [](int64_t i) {
|
| 456 |
+
return c10::SymInt(i);
|
| 457 |
+
});
|
| 458 |
+
return r;
|
| 459 |
+
} else {
|
| 460 |
+
return ivalue_to_arg<std::vector<c10::SymInt>, AllowDeprecatedTypes>::
|
| 461 |
+
call(v);
|
| 462 |
+
}
|
| 463 |
+
}
|
| 464 |
+
};
|
| 465 |
+
template <bool AllowDeprecatedTypes>
|
| 466 |
+
struct ivalue_to_arg<c10::OptionalArray<c10::SymInt>, AllowDeprecatedTypes>
|
| 467 |
+
final {
|
| 468 |
+
static OptionalArray<c10::SymInt> call(IValue& v) {
|
| 469 |
+
if (v.isIntList()) {
|
| 470 |
+
std::vector<c10::SymInt> r;
|
| 471 |
+
auto src = v.toIntList();
|
| 472 |
+
std::transform(
|
| 473 |
+
src.begin(), src.end(), std::back_inserter(r), [](int64_t i) {
|
| 474 |
+
return c10::SymInt(i);
|
| 475 |
+
});
|
| 476 |
+
return OptionalArray<c10::SymInt>(std::move(r));
|
| 477 |
+
} else {
|
| 478 |
+
return std::move(v).to<OptionalArray<c10::SymInt>>();
|
| 479 |
+
}
|
| 480 |
+
}
|
| 481 |
+
};
|
| 482 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 483 |
+
struct ivalue_to_arg<std::optional<ArrayRef<T>>, AllowDeprecatedTypes> final {
|
| 484 |
+
// If an argument is std::optional<ArrayRef<T>>, convert the IValue to an
|
| 485 |
+
// std::optional<std::vector<T>> and pass that to the operator.
|
| 486 |
+
// OptionalArray<T> is basically a std::optional<std::vector<T>> but
|
| 487 |
+
// implicitly convertible to std::optional<ArrayRef<T>>.
|
| 488 |
+
static OptionalArray<T> call(IValue& v) {
|
| 489 |
+
return ivalue_to_arg<OptionalArray<T>, AllowDeprecatedTypes>::call(v);
|
| 490 |
+
}
|
| 491 |
+
};
|
| 492 |
+
|
| 493 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 494 |
+
struct ivalue_to_arg<OptionalArrayRef<T>, AllowDeprecatedTypes> final {
|
| 495 |
+
// If an argument is OptionalArrayRef<T>, convert the IValue to an
|
| 496 |
+
// std::optional<std::vector<T>> and pass that to the operator.
|
| 497 |
+
// OptionalArray<T> is basically a std::optional<std::vector<T>> but
|
| 498 |
+
// implicitly convertible to OptionalArrayRef<T>
|
| 499 |
+
static OptionalArray<T> call(IValue& v) {
|
| 500 |
+
return ivalue_to_arg<OptionalArray<T>, AllowDeprecatedTypes>::call(v);
|
| 501 |
+
}
|
| 502 |
+
};
|
| 503 |
+
|
| 504 |
+
// return_to_ivalue
|
| 505 |
+
template <class T, bool AllowDeprecatedTypes, class Enable = void>
|
| 506 |
+
struct return_to_ivalue final {};
|
| 507 |
+
|
| 508 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 509 |
+
struct return_to_ivalue<
|
| 510 |
+
T,
|
| 511 |
+
AllowDeprecatedTypes,
|
| 512 |
+
std::enable_if_t<!std::is_same_v<at::Tensor&, T>>>
|
| 513 |
+
final {
|
| 514 |
+
static IValue call(T&& v) {
|
| 515 |
+
assert_is_valid_output_type<T, AllowDeprecatedTypes>();
|
| 516 |
+
return c10::ivalue::from(std::move(v));
|
| 517 |
+
}
|
| 518 |
+
static IValue copy(const T& v) {
|
| 519 |
+
assert_is_valid_output_type<T, AllowDeprecatedTypes>();
|
| 520 |
+
return IValue(v);
|
| 521 |
+
}
|
| 522 |
+
};
|
| 523 |
+
|
| 524 |
+
// Special case to allow kernels to return `Tensor&`.
|
| 525 |
+
// TODO Delete this once kernels don't do that anymore
|
| 526 |
+
template <bool AllowDeprecatedTypes>
|
| 527 |
+
struct return_to_ivalue<at::Tensor&, AllowDeprecatedTypes, void> final {
|
| 528 |
+
static IValue call(at::Tensor& v) {
|
| 529 |
+
return c10::ivalue::from(v);
|
| 530 |
+
}
|
| 531 |
+
static IValue copy(at::Tensor& v) {
|
| 532 |
+
return IValue(v);
|
| 533 |
+
}
|
| 534 |
+
};
|
| 535 |
+
|
| 536 |
+
// wrap_kernel_functor_unboxed_
|
| 537 |
+
|
| 538 |
+
template <class KernelFunctor, class OpSignature>
|
| 539 |
+
struct wrap_kernel_functor_unboxed_ final {};
|
| 540 |
+
|
| 541 |
+
// This specialization is for kernels with a first argument that is NOT of type
|
| 542 |
+
// DispatchKeySet This includes kernels with 0 arguments.
|
| 543 |
+
template <class KernelFunctor, class ReturnType, class... ParameterTypes>
|
| 544 |
+
struct wrap_kernel_functor_unboxed_<
|
| 545 |
+
KernelFunctor,
|
| 546 |
+
ReturnType(ParameterTypes...)>
|
| 547 |
+
final {
|
| 548 |
+
static_assert(
|
| 549 |
+
std::is_same_v<
|
| 550 |
+
ReturnType,
|
| 551 |
+
typename guts::infer_function_traits_t<KernelFunctor>::return_type>,
|
| 552 |
+
"Return type mismatch");
|
| 553 |
+
static_assert(
|
| 554 |
+
std::is_same_v<
|
| 555 |
+
guts::typelist::typelist<ParameterTypes...>,
|
| 556 |
+
typename guts::infer_function_traits_t<
|
| 557 |
+
KernelFunctor>::parameter_types>,
|
| 558 |
+
"Parameter types mismatch");
|
| 559 |
+
|
| 560 |
+
// See [Note: Argument forwarding in the dispatcher] for why ParameterTypes
|
| 561 |
+
// doesn't use &&
|
| 562 |
+
static ReturnType call(
|
| 563 |
+
OperatorKernel* functor,
|
| 564 |
+
DispatchKeySet,
|
| 565 |
+
ParameterTypes... args) {
|
| 566 |
+
KernelFunctor* functor_ = static_cast<KernelFunctor*>(functor);
|
| 567 |
+
// Note [Plumbing Keys Through The Dispatcher 2]
|
| 568 |
+
// See Note [Plumbing Keys Through The Dispatcher] for the background.
|
| 569 |
+
// This functor explicitly takes in a dispatchKeySet and drops it on the
|
| 570 |
+
// floor- it does not forward it to the registered kernel.
|
| 571 |
+
//
|
| 572 |
+
// This is due to the calling convention within the dispatcher, which
|
| 573 |
+
// expects all registered kernels to have a first argument of type
|
| 574 |
+
// DispatchKeySet.
|
| 575 |
+
// This is not the case for pretty much all manually written kernels,
|
| 576 |
+
// however- this functor serves to separate the calling convention of the
|
| 577 |
+
// dispatcher from the calling convention of manually written kernels.
|
| 578 |
+
return (*functor_)(std::forward<ParameterTypes>(args)...);
|
| 579 |
+
}
|
| 580 |
+
};
|
| 581 |
+
|
| 582 |
+
// This specialization is for kernels with a first argument of type
|
| 583 |
+
// DispatchKeySet
|
| 584 |
+
template <class KernelFunctor, class ReturnType, class... ParameterTypes>
|
| 585 |
+
struct wrap_kernel_functor_unboxed_<
|
| 586 |
+
KernelFunctor,
|
| 587 |
+
ReturnType(DispatchKeySet, ParameterTypes...)>
|
| 588 |
+
final {
|
| 589 |
+
static_assert(
|
| 590 |
+
std::is_same_v<
|
| 591 |
+
ReturnType,
|
| 592 |
+
typename guts::infer_function_traits_t<KernelFunctor>::return_type>,
|
| 593 |
+
"Return type mismatch");
|
| 594 |
+
static_assert(
|
| 595 |
+
std::is_same_v<
|
| 596 |
+
guts::typelist::typelist<DispatchKeySet, ParameterTypes...>,
|
| 597 |
+
typename guts::infer_function_traits_t<
|
| 598 |
+
KernelFunctor>::parameter_types>,
|
| 599 |
+
"Parameter types mismatch");
|
| 600 |
+
|
| 601 |
+
// See [Note: Argument forwarding in the dispatcher] for why ParameterTypes
|
| 602 |
+
// doesn't use &&
|
| 603 |
+
static ReturnType call(
|
| 604 |
+
OperatorKernel* functor,
|
| 605 |
+
DispatchKeySet dispatchKeySet,
|
| 606 |
+
ParameterTypes... args) {
|
| 607 |
+
KernelFunctor* functor_ = static_cast<KernelFunctor*>(functor);
|
| 608 |
+
// We're explicitly taking in a dispatchKeySet and forwarding it to the
|
| 609 |
+
// registered kernel. See Note [Plumbing Keys Through The Dispatcher 2] for
|
| 610 |
+
// details.
|
| 611 |
+
return (*functor_)(dispatchKeySet, std::forward<ParameterTypes>(args)...);
|
| 612 |
+
}
|
| 613 |
+
};
|
| 614 |
+
|
| 615 |
+
template <class KernelFunctor>
|
| 616 |
+
using wrap_kernel_functor_unboxed = wrap_kernel_functor_unboxed_<
|
| 617 |
+
KernelFunctor,
|
| 618 |
+
typename guts::infer_function_traits_t<KernelFunctor>::func_type>;
|
| 619 |
+
|
| 620 |
+
// call_functor_with_args_from_stack
|
| 621 |
+
|
| 622 |
+
template <
|
| 623 |
+
class Functor,
|
| 624 |
+
bool AllowDeprecatedTypes,
|
| 625 |
+
size_t... ivalue_arg_indices,
|
| 626 |
+
typename... ArgTypes>
|
| 627 |
+
std::decay_t<typename guts::infer_function_traits_t<Functor>::return_type>
|
| 628 |
+
call_functor_with_args_from_stack_(
|
| 629 |
+
OperatorKernel* functor,
|
| 630 |
+
DispatchKeySet dispatchKeySet,
|
| 631 |
+
Stack* stack,
|
| 632 |
+
std::index_sequence<ivalue_arg_indices...>,
|
| 633 |
+
guts::typelist::typelist<ArgTypes...>*) {
|
| 634 |
+
(void)(stack); // when sizeof...(ivalue_arg_indices) == 0, this argument would
|
| 635 |
+
// be unused and we have to silence the compiler warning.
|
| 636 |
+
|
| 637 |
+
// We're explicitly filtering out DispatchKeySet from the argument list.
|
| 638 |
+
// Some kernels take a DispatchKeySet as their first argument in order to
|
| 639 |
+
// plumb keys through the dispatcher. We don't want to expose the
|
| 640 |
+
// DispatchKeySet type to jit, so we don't include this argument on the stack.
|
| 641 |
+
// See Note [Plumbing Keys Through The Dispatcher] for the background.
|
| 642 |
+
return wrap_kernel_functor_unboxed<Functor>::call(
|
| 643 |
+
functor,
|
| 644 |
+
dispatchKeySet,
|
| 645 |
+
ivalue_to_arg<
|
| 646 |
+
typename decay_if_not_tensor<ArgTypes>::type,
|
| 647 |
+
AllowDeprecatedTypes>::
|
| 648 |
+
call(torch::jit::peek(
|
| 649 |
+
*stack, ivalue_arg_indices, sizeof...(ivalue_arg_indices)))...);
|
| 650 |
+
}
|
| 651 |
+
|
| 652 |
+
template <class Functor, bool AllowDeprecatedTypes>
|
| 653 |
+
std::decay_t<typename guts::infer_function_traits_t<Functor>::return_type>
|
| 654 |
+
call_functor_with_args_from_stack(
|
| 655 |
+
OperatorKernel* functor,
|
| 656 |
+
DispatchKeySet dispatchKeySet,
|
| 657 |
+
Stack* stack) {
|
| 658 |
+
// We're explicitly filtering out DispatchKeySet from the argument list.
|
| 659 |
+
// Some kernels take a DispatchKeySet as their first argument in order to
|
| 660 |
+
// plumb keys through the dispatcher. We don't want to expose the
|
| 661 |
+
// DispatchKeySet type to jit, so we don't include this argument on the stack.
|
| 662 |
+
// See Note [Plumbing Keys Through The Dispatcher] for the background.
|
| 663 |
+
using ArgTypes = typename c10::remove_DispatchKeySet_arg_from_func<
|
| 664 |
+
Functor>::parameter_types;
|
| 665 |
+
constexpr size_t num_ivalue_args = guts::typelist::size<ArgTypes>::value;
|
| 666 |
+
return call_functor_with_args_from_stack_<Functor, AllowDeprecatedTypes>(
|
| 667 |
+
functor,
|
| 668 |
+
dispatchKeySet,
|
| 669 |
+
stack,
|
| 670 |
+
std::make_index_sequence<num_ivalue_args>(),
|
| 671 |
+
static_cast<ArgTypes*>(nullptr));
|
| 672 |
+
}
|
| 673 |
+
|
| 674 |
+
// push_outputs
|
| 675 |
+
|
| 676 |
+
template <class OutputType, bool AllowDeprecatedTypes>
|
| 677 |
+
struct push_outputs final {
|
| 678 |
+
// Contrary to [Note: Argument forwarding in the dispatcher], we use
|
| 679 |
+
// OutputType&& here to avoid one extra call to the move constructor in this
|
| 680 |
+
// case. This is still not a universal reference though because OutputType is
|
| 681 |
+
// an explicitly specified class template parameter.
|
| 682 |
+
static void call(OutputType&& output, Stack* stack) {
|
| 683 |
+
torch::jit::push(
|
| 684 |
+
*stack,
|
| 685 |
+
return_to_ivalue<OutputType, AllowDeprecatedTypes>::call(
|
| 686 |
+
std::forward<OutputType>(output)));
|
| 687 |
+
}
|
| 688 |
+
static void copy(const OutputType& output, Stack* stack) {
|
| 689 |
+
torch::jit::push(
|
| 690 |
+
*stack,
|
| 691 |
+
return_to_ivalue<OutputType, AllowDeprecatedTypes>::copy(output));
|
| 692 |
+
}
|
| 693 |
+
};
|
| 694 |
+
template <class... OutputTypes, bool AllowDeprecatedTypes>
|
| 695 |
+
struct push_outputs<std::tuple<OutputTypes...>, AllowDeprecatedTypes> final {
|
| 696 |
+
static void call(std::tuple<OutputTypes...>&& output, Stack* stack) {
|
| 697 |
+
call_(
|
| 698 |
+
std::move(output),
|
| 699 |
+
stack,
|
| 700 |
+
std::make_index_sequence<sizeof...(OutputTypes)>());
|
| 701 |
+
}
|
| 702 |
+
static void copy(const std::tuple<OutputTypes...>& output, Stack* stack) {
|
| 703 |
+
copy_(output, stack, std::make_index_sequence<sizeof...(OutputTypes)>());
|
| 704 |
+
}
|
| 705 |
+
|
| 706 |
+
private:
|
| 707 |
+
template <size_t... indices>
|
| 708 |
+
static void call_(
|
| 709 |
+
std::tuple<OutputTypes...>&& output,
|
| 710 |
+
Stack* stack,
|
| 711 |
+
std::index_sequence<indices...>) {
|
| 712 |
+
torch::jit::push(
|
| 713 |
+
*stack,
|
| 714 |
+
return_to_ivalue<OutputTypes, AllowDeprecatedTypes>::call(
|
| 715 |
+
std::forward<OutputTypes>(std::get<indices>(output)))...);
|
| 716 |
+
}
|
| 717 |
+
template <size_t... indices>
|
| 718 |
+
static void copy_(
|
| 719 |
+
const std::tuple<OutputTypes...>& output,
|
| 720 |
+
Stack* stack,
|
| 721 |
+
std::index_sequence<indices...>) {
|
| 722 |
+
torch::jit::push(
|
| 723 |
+
*stack,
|
| 724 |
+
return_to_ivalue<OutputTypes, AllowDeprecatedTypes>::copy(
|
| 725 |
+
std::get<indices>(output))...);
|
| 726 |
+
}
|
| 727 |
+
};
|
| 728 |
+
template <bool AllowDeprecatedTypes>
|
| 729 |
+
struct push_outputs<void, AllowDeprecatedTypes> final {
|
| 730 |
+
static void call(int /*dummy*/, Stack* /*stack*/) {}
|
| 731 |
+
static void copy(int /*dummy*/, Stack* /*stack*/) {}
|
| 732 |
+
};
|
| 733 |
+
|
| 734 |
+
// make_boxed_from_unboxed_functor
|
| 735 |
+
|
| 736 |
+
template <class KernelFunctor, bool AllowDeprecatedTypes>
|
| 737 |
+
struct make_boxed_from_unboxed_functor final {
|
| 738 |
+
static_assert(
|
| 739 |
+
std::is_base_of_v<OperatorKernel, KernelFunctor>,
|
| 740 |
+
"Tried to register a kernel functor using the kernel<Functor>() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
|
| 741 |
+
|
| 742 |
+
static void call(
|
| 743 |
+
OperatorKernel* functor,
|
| 744 |
+
const OperatorHandle&,
|
| 745 |
+
DispatchKeySet dispatchKeySet,
|
| 746 |
+
Stack* stack) {
|
| 747 |
+
using ReturnType =
|
| 748 |
+
typename guts::infer_function_traits_t<KernelFunctor>::return_type;
|
| 749 |
+
// We're explicitly filtering out DispatchKeySet from the argument list.
|
| 750 |
+
// Some kernels take a DispatchKeySet as their first argument in order to
|
| 751 |
+
// plumb keys through the dispatcher. We don't want to expose the
|
| 752 |
+
// DispatchKeySet type to jit, so we don't include this argument on the
|
| 753 |
+
// stack. See Note [Plumbing Keys Through The Dispatcher] for the
|
| 754 |
+
// background.
|
| 755 |
+
using ArgTypes = typename c10::remove_DispatchKeySet_arg_from_func<
|
| 756 |
+
KernelFunctor>::parameter_types;
|
| 757 |
+
constexpr bool has_outputs = !std::is_same_v<void, ReturnType>;
|
| 758 |
+
constexpr size_t num_inputs = guts::typelist::size<ArgTypes>::value;
|
| 759 |
+
if constexpr (has_outputs) {
|
| 760 |
+
// Decay ReturnType to ReturnType_ so that if a reference gets returned,
|
| 761 |
+
// we actually store it by value and don't get a dangling reference. This
|
| 762 |
+
// is only required because some kernels still return `Tensor&`. [Note:
|
| 763 |
+
// VC++ and 'std': ambiguous symbol]
|
| 764 |
+
using ReturnType_ = ::std::decay_t<ReturnType>;
|
| 765 |
+
ReturnType_ output = call_functor_with_args_from_stack<
|
| 766 |
+
KernelFunctor,
|
| 767 |
+
AllowDeprecatedTypes>(functor, dispatchKeySet, stack);
|
| 768 |
+
torch::jit::drop(*stack, num_inputs);
|
| 769 |
+
// See note [ VC++ and 'std': ambiguous symbol]
|
| 770 |
+
push_outputs<ReturnType_, AllowDeprecatedTypes>::call(
|
| 771 |
+
::std::move(output), stack);
|
| 772 |
+
} else {
|
| 773 |
+
call_functor_with_args_from_stack<KernelFunctor, AllowDeprecatedTypes>(
|
| 774 |
+
functor, dispatchKeySet, stack);
|
| 775 |
+
torch::jit::drop(*stack, num_inputs);
|
| 776 |
+
}
|
| 777 |
+
}
|
| 778 |
+
};
|
| 779 |
+
} // namespace impl
|
| 780 |
+
|
| 781 |
+
} // namespace c10
|
| 782 |
+
|
| 783 |
+
namespace torch {
|
| 784 |
+
using OperatorKernel = c10::OperatorKernel;
|
| 785 |
+
}
|
phivenv/Lib/site-packages/torch/include/ATen/core/boxing/impl/test_helpers.h
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <gmock/gmock.h>
|
| 4 |
+
#include <gtest/gtest.h>
|
| 5 |
+
|
| 6 |
+
#include <ATen/core/Tensor.h>
|
| 7 |
+
#include <ATen/core/dispatch/Dispatcher.h>
|
| 8 |
+
#include <ATen/core/ivalue.h>
|
| 9 |
+
#include <c10/core/CPUAllocator.h>
|
| 10 |
+
#include <c10/util/irange.h>
|
| 11 |
+
|
| 12 |
+
template <class... Inputs>
|
| 13 |
+
inline std::vector<c10::IValue> makeStack(Inputs&&... inputs) {
|
| 14 |
+
return {std::forward<Inputs>(inputs)...};
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
inline at::Tensor dummyTensor(
|
| 18 |
+
c10::DispatchKeySet ks,
|
| 19 |
+
bool requires_grad = false) {
|
| 20 |
+
auto* allocator = c10::GetCPUAllocator();
|
| 21 |
+
int64_t nelements = 1;
|
| 22 |
+
auto dtype = caffe2::TypeMeta::Make<float>();
|
| 23 |
+
int64_t size_bytes = nelements * dtype.itemsize();
|
| 24 |
+
auto storage_impl = c10::make_intrusive<c10::StorageImpl>(
|
| 25 |
+
c10::StorageImpl::use_byte_size_t(),
|
| 26 |
+
size_bytes,
|
| 27 |
+
allocator->allocate(size_bytes),
|
| 28 |
+
allocator,
|
| 29 |
+
/*resizable=*/true);
|
| 30 |
+
at::Tensor t =
|
| 31 |
+
at::detail::make_tensor<c10::TensorImpl>(storage_impl, ks, dtype);
|
| 32 |
+
// TODO: We add this to simulate the ideal case where we only have Autograd
|
| 33 |
+
// backend keys
|
| 34 |
+
// on Tensor when it requires grad. But currently Autograd keys are
|
| 35 |
+
// added in TensorImpl constructor by default.
|
| 36 |
+
if (!requires_grad) {
|
| 37 |
+
t.unsafeGetTensorImpl()->remove_autograd_key();
|
| 38 |
+
}
|
| 39 |
+
return t;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
inline at::Tensor dummyTensor(
|
| 43 |
+
c10::DispatchKey dispatch_key,
|
| 44 |
+
bool requires_grad = false) {
|
| 45 |
+
return dummyTensor(c10::DispatchKeySet(dispatch_key), requires_grad);
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
template <class... Args>
|
| 49 |
+
inline std::vector<c10::IValue> callOp(
|
| 50 |
+
const c10::OperatorHandle& op,
|
| 51 |
+
Args... args) {
|
| 52 |
+
auto stack = makeStack(std::forward<Args>(args)...);
|
| 53 |
+
op.callBoxed(&stack);
|
| 54 |
+
return stack;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
template <class Result, class... Args>
|
| 58 |
+
inline Result callOpUnboxed(const c10::OperatorHandle& op, Args... args) {
|
| 59 |
+
return op.typed<Result(Args...)>().call(std::forward<Args>(args)...);
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
template <class Result, class... Args>
|
| 63 |
+
inline Result callOpUnboxedWithDispatchKey(
|
| 64 |
+
const c10::OperatorHandle& op,
|
| 65 |
+
c10::DispatchKey dispatchKey,
|
| 66 |
+
Args... args) {
|
| 67 |
+
return op.typed<Result(Args...)>().callWithDispatchKey(
|
| 68 |
+
dispatchKey, std::forward<Args>(args)...);
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
template <class Result, class... Args>
|
| 72 |
+
inline Result callOpUnboxedWithPrecomputedDispatchKeySet(
|
| 73 |
+
const c10::OperatorHandle& op,
|
| 74 |
+
c10::DispatchKeySet ks,
|
| 75 |
+
Args... args) {
|
| 76 |
+
return op.typed<Result(Args...)>().redispatch(
|
| 77 |
+
ks, std::forward<Args>(args)...);
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
inline void expectDoesntFindKernel(
|
| 81 |
+
const char* op_name,
|
| 82 |
+
c10::DispatchKey dispatch_key) {
|
| 83 |
+
auto op = c10::Dispatcher::singleton().findSchema({op_name, ""});
|
| 84 |
+
EXPECT_ANY_THROW(callOp(*op, dummyTensor(dispatch_key), 5););
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
inline void expectDoesntFindOperator(const char* op_name) {
|
| 88 |
+
auto op = c10::Dispatcher::singleton().findSchema({op_name, ""});
|
| 89 |
+
EXPECT_FALSE(op.has_value());
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
template <class Exception, class Functor>
|
| 93 |
+
inline void expectThrows(Functor&& functor, const char* expectMessageContains) {
|
| 94 |
+
try {
|
| 95 |
+
std::forward<Functor>(functor)();
|
| 96 |
+
} catch (const Exception& e) {
|
| 97 |
+
EXPECT_THAT(e.what(), testing::HasSubstr(expectMessageContains));
|
| 98 |
+
return;
|
| 99 |
+
}
|
| 100 |
+
ADD_FAILURE() << "Expected to throw exception containing \""
|
| 101 |
+
<< expectMessageContains << "\" but didn't throw";
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
template <class T, size_t N>
|
| 105 |
+
void expectListEquals(c10::ArrayRef<T> expected, std::array<T, N> actual) {
|
| 106 |
+
EXPECT_EQ(expected.size(), actual.size());
|
| 107 |
+
for (const auto i : c10::irange(expected.size())) {
|
| 108 |
+
EXPECT_EQ(expected[i], actual[i]);
|
| 109 |
+
}
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
template <class T>
|
| 113 |
+
void expectListEquals(c10::ArrayRef<T> expected, c10::ArrayRef<T> actual) {
|
| 114 |
+
EXPECT_EQ(expected.size(), actual.size());
|
| 115 |
+
for (const auto i : c10::irange(expected.size())) {
|
| 116 |
+
EXPECT_EQ(expected[i], actual[i]);
|
| 117 |
+
}
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
template <class T>
|
| 121 |
+
void expectListEquals(c10::ArrayRef<T> expected, c10::List<T> actual) {
|
| 122 |
+
EXPECT_EQ(expected.size(), actual.size());
|
| 123 |
+
for (const auto i : c10::irange(expected.size())) {
|
| 124 |
+
EXPECT_EQ(expected[i], actual.get(i));
|
| 125 |
+
}
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
template <class T>
|
| 129 |
+
void expectListEquals(c10::ArrayRef<T> expected, std::vector<T> actual) {
|
| 130 |
+
EXPECT_EQ(expected.size(), actual.size());
|
| 131 |
+
for (const auto i : c10::irange(expected.size())) {
|
| 132 |
+
EXPECT_EQ(expected[i], actual[i]);
|
| 133 |
+
}
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
// NB: This is not really sound, but all of the type sets constructed here
|
| 137 |
+
// are singletons so it's fine
|
| 138 |
+
static inline c10::DispatchKey extractDispatchKey(const at::Tensor& t) {
|
| 139 |
+
return legacyExtractDispatchKey(t.key_set());
|
| 140 |
+
}
|
phivenv/Lib/site-packages/torch/include/ATen/core/dispatch/CppSignature.h
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/DispatchKeySet.h>
|
| 4 |
+
#include <c10/macros/Macros.h>
|
| 5 |
+
#include <c10/util/Metaprogramming.h>
|
| 6 |
+
#include <c10/util/Type.h>
|
| 7 |
+
#include <typeindex>
|
| 8 |
+
|
| 9 |
+
namespace c10::impl {
|
| 10 |
+
|
| 11 |
+
// A CppSignature object holds RTTI information about a C++ function signature
|
| 12 |
+
// at runtime and can compare them or get a debug-printable name.
|
| 13 |
+
class TORCH_API CppSignature final {
|
| 14 |
+
public:
|
| 15 |
+
CppSignature(const CppSignature&) = default;
|
| 16 |
+
CppSignature(CppSignature&&) noexcept = default;
|
| 17 |
+
CppSignature& operator=(const CppSignature&) = default;
|
| 18 |
+
CppSignature& operator=(CppSignature&&) noexcept = default;
|
| 19 |
+
|
| 20 |
+
template <class FuncType>
|
| 21 |
+
static CppSignature make() {
|
| 22 |
+
// Normalize functors, lambdas, function pointers, etc. into the plain
|
| 23 |
+
// function type The first argument of the schema might be of type
|
| 24 |
+
// DispatchKeySet, in which case we remove it. We do this to guarantee that
|
| 25 |
+
// all CppSignature's for an operator will match, even if they're registered
|
| 26 |
+
// with different calling conventions.
|
| 27 |
+
// See Note [Plumbing Keys Through The Dispatcher]
|
| 28 |
+
using decayed_function_type =
|
| 29 |
+
typename c10::remove_DispatchKeySet_arg_from_func<
|
| 30 |
+
std::decay_t<FuncType>>::func_type;
|
| 31 |
+
|
| 32 |
+
return CppSignature(std::type_index(typeid(decayed_function_type)));
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
std::string name() const {
|
| 36 |
+
return c10::demangle(signature_.name());
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
friend bool operator==(const CppSignature& lhs, const CppSignature& rhs) {
|
| 40 |
+
if (lhs.signature_ == rhs.signature_) {
|
| 41 |
+
return true;
|
| 42 |
+
}
|
| 43 |
+
// Without RTLD_GLOBAL, the type_index comparison could yield false because
|
| 44 |
+
// they point to different instances of the RTTI data, but the types would
|
| 45 |
+
// still be the same. Let's check for that case too.
|
| 46 |
+
// Note that there still is a case where this might not work, i.e. when
|
| 47 |
+
// linking libraries of different compilers together, they might have
|
| 48 |
+
// different ways to serialize a type name. That, together with a missing
|
| 49 |
+
// RTLD_GLOBAL, would still fail this.
|
| 50 |
+
if (0 == strcmp(lhs.signature_.name(), rhs.signature_.name())) {
|
| 51 |
+
return true;
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
return false;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
private:
|
| 58 |
+
explicit CppSignature(std::type_index signature)
|
| 59 |
+
: signature_(std::move(signature)) {}
|
| 60 |
+
std::type_index signature_;
|
| 61 |
+
};
|
| 62 |
+
|
| 63 |
+
inline bool operator!=(const CppSignature& lhs, const CppSignature& rhs) {
|
| 64 |
+
return !(lhs == rhs);
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
} // namespace c10::impl
|
phivenv/Lib/site-packages/torch/include/ATen/core/dispatch/DispatchKeyExtractor.h
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Variadic.h>
|
| 4 |
+
#include <ATen/core/function_schema.h>
|
| 5 |
+
#include <ATen/core/jit_type.h>
|
| 6 |
+
#include <ATen/core/stack.h>
|
| 7 |
+
#include <c10/core/DispatchKeySet.h>
|
| 8 |
+
#include <c10/util/Bitset.h>
|
| 9 |
+
#include <c10/util/irange.h>
|
| 10 |
+
#include <cstdint>
|
| 11 |
+
|
| 12 |
+
namespace c10 {
|
| 13 |
+
|
| 14 |
+
namespace impl {
|
| 15 |
+
|
| 16 |
+
// Take a DispatchKeySet for a Tensor and determine what the actual dispatch
|
| 17 |
+
// DispatchKey should be, taking into account TLS, and skipping backends which
|
| 18 |
+
// fall through.
|
| 19 |
+
//
|
| 20 |
+
// Unlike Tensor::key_set(), the value of this on a tensor can change depending
|
| 21 |
+
// on TLS.
|
| 22 |
+
//
|
| 23 |
+
// NB: If there is no valid dispatch key, this will return Undefined
|
| 24 |
+
inline DispatchKeySet computeDispatchKeySet(
|
| 25 |
+
DispatchKeySet ks,
|
| 26 |
+
// The key mask lets us eliminate (by zero entries) keys which should not
|
| 27 |
+
// be considered for dispatch. There are two cases when we use this:
|
| 28 |
+
//
|
| 29 |
+
// - If an operator's dispatch table contains a fallthrough entry, we
|
| 30 |
+
// should bypass it entirely when finding the key
|
| 31 |
+
// - If a user invokes with redispatch, the mask lets us
|
| 32 |
+
// zero out the key the user asked us to stop.
|
| 33 |
+
//
|
| 34 |
+
// These excluded backends are NOT tracked in the TLS, but must be applied
|
| 35 |
+
// AFTER TLS (since the backend may have been introduced for consideration
|
| 36 |
+
// by the included TLS), which is why you have to pass them in to this
|
| 37 |
+
// function (as opposed to just applying it to the input 'ks').
|
| 38 |
+
DispatchKeySet key_mask) {
|
| 39 |
+
c10::impl::LocalDispatchKeySet local =
|
| 40 |
+
c10::impl::tls_local_dispatch_key_set();
|
| 41 |
+
// TODO: It's a bit irritating that we have to do logical ORs here, it would
|
| 42 |
+
// be nice to only do one. Can always_included be folded into the TLS? Well,
|
| 43 |
+
// it's a bit troublesome, because fastpath TLS access requires the type of
|
| 44 |
+
// the TLS in question to be zero-initialized, so you don't actually win
|
| 45 |
+
// anything in that case.
|
| 46 |
+
return (((ks | local.included_) - local.excluded_) & key_mask);
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
} // namespace impl
|
| 50 |
+
|
| 51 |
+
namespace detail {
|
| 52 |
+
// A small gadget to extract the DispatchKeySet from types which are known
|
| 53 |
+
// to have it. Used to extract dispatch keys from unboxed calls.
|
| 54 |
+
struct MultiDispatchKeySet : at::IterArgs<MultiDispatchKeySet> {
|
| 55 |
+
DispatchKeySet ts;
|
| 56 |
+
void operator()(const at::Tensor& x) {
|
| 57 |
+
ts = ts | x.key_set();
|
| 58 |
+
}
|
| 59 |
+
void operator()(const std::optional<at::Tensor>& x) {
|
| 60 |
+
if (x.has_value()) {
|
| 61 |
+
ts = ts | x->key_set();
|
| 62 |
+
}
|
| 63 |
+
}
|
| 64 |
+
void operator()(at::ArrayRef<at::Tensor> xs) {
|
| 65 |
+
for (const auto& x : xs) {
|
| 66 |
+
ts = ts | x.key_set();
|
| 67 |
+
}
|
| 68 |
+
}
|
| 69 |
+
// Tensor?[] translates to this case.
|
| 70 |
+
void operator()(const c10::List<std::optional<at::Tensor>>& xs) {
|
| 71 |
+
for (std::optional<at::Tensor> x : xs) {
|
| 72 |
+
if (x.has_value()) {
|
| 73 |
+
ts = ts | x.value().key_set();
|
| 74 |
+
}
|
| 75 |
+
}
|
| 76 |
+
}
|
| 77 |
+
// Structured Tensor[] translates to this case
|
| 78 |
+
void operator()(const at::ITensorListRef& xs) {
|
| 79 |
+
for (const auto& x : xs) {
|
| 80 |
+
ts = ts | x.key_set();
|
| 81 |
+
}
|
| 82 |
+
}
|
| 83 |
+
[[noreturn]] void operator()(at::ArrayRef<std::optional<at::Tensor>>) {
|
| 84 |
+
// Just checking that the handling of Tensor?[] didn't change.
|
| 85 |
+
TORCH_INTERNAL_ASSERT(false);
|
| 86 |
+
}
|
| 87 |
+
void operator()(const at::Generator& gen) {
|
| 88 |
+
if (gen.defined()) {
|
| 89 |
+
ts = ts | gen.key_set();
|
| 90 |
+
}
|
| 91 |
+
}
|
| 92 |
+
void operator()(const std::optional<at::Generator>& gen) {
|
| 93 |
+
if (gen.has_value() && gen->defined()) {
|
| 94 |
+
ts = ts | gen->key_set();
|
| 95 |
+
}
|
| 96 |
+
}
|
| 97 |
+
template <typename T>
|
| 98 |
+
void operator()(const T&) {
|
| 99 |
+
// do nothing
|
| 100 |
+
}
|
| 101 |
+
};
|
| 102 |
+
|
| 103 |
+
// NB: take by const reference (Don't do universal forwarding here! You
|
| 104 |
+
// don't want to move into this function!)
|
| 105 |
+
template <typename... Args>
|
| 106 |
+
DispatchKeySet multi_dispatch_key_set(const Args&... args) {
|
| 107 |
+
return MultiDispatchKeySet().apply(args...).ts;
|
| 108 |
+
}
|
| 109 |
+
} // namespace detail
|
| 110 |
+
|
| 111 |
+
/**
|
| 112 |
+
* An instance of DispatchKeyExtractor knows how to get a dispatch key given
|
| 113 |
+
* a list of arguments for an operator call.
|
| 114 |
+
*
|
| 115 |
+
* The instance is specific for a certain operator as:
|
| 116 |
+
* - In boxed dispatch, different operators have different ways to extract
|
| 117 |
+
* the dispatch key (e.g. different numbers of arguments), and we precompute
|
| 118 |
+
* the stack locations we should look at; and
|
| 119 |
+
* - In all dispatch, some backends should be excluded from dispatch because
|
| 120 |
+
* they have been registered as fallthrough. The set of excluded backends
|
| 121 |
+
* varies from operator, as some operators may have overridden the
|
| 122 |
+
* fallthrough with custom behavior.
|
| 123 |
+
*
|
| 124 |
+
* Note - this should maintain identical impl to the py dispatcher key
|
| 125 |
+
* extraction logic at pytorch/torch/dispatcher.py
|
| 126 |
+
*/
|
| 127 |
+
struct TORCH_API DispatchKeyExtractor final {
|
| 128 |
+
public:
|
| 129 |
+
static DispatchKeyExtractor make(const FunctionSchema& schema) {
|
| 130 |
+
return DispatchKeyExtractor(makeBitsetForDispatchArgs(schema));
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
static DispatchKeyExtractor makeUninitialized() {
|
| 134 |
+
return DispatchKeyExtractor(c10::utils::bitset());
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
void registerSchema(const FunctionSchema& schema) {
|
| 138 |
+
TORCH_INTERNAL_ASSERT(dispatch_arg_indices_reverse_.is_entirely_unset());
|
| 139 |
+
dispatch_arg_indices_reverse_ = makeBitsetForDispatchArgs(schema);
|
| 140 |
+
}
|
| 141 |
+
void deregisterSchema() {
|
| 142 |
+
dispatch_arg_indices_reverse_ = c10::utils::bitset();
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
DispatchKeySet getDispatchKeySetBoxed(const torch::jit::Stack* stack) const {
|
| 146 |
+
DispatchKeySet ks;
|
| 147 |
+
dispatch_arg_indices_reverse_.for_each_set_bit([&](size_t
|
| 148 |
+
reverse_arg_index) {
|
| 149 |
+
const auto& ivalue = torch::jit::peek(*stack, 0, reverse_arg_index + 1);
|
| 150 |
+
if (C10_LIKELY(ivalue.isTensor())) {
|
| 151 |
+
// NB: Take care not to introduce a refcount bump (there's
|
| 152 |
+
// no safe toTensorRef method, alas)
|
| 153 |
+
ks = ks | ivalue.unsafeToTensorImpl()->key_set();
|
| 154 |
+
} else if (C10_UNLIKELY(ivalue.isTensorList())) {
|
| 155 |
+
// NB: use toListRef as it doesn't induce refcount bumps
|
| 156 |
+
// (toTensorListRef is not a thing)
|
| 157 |
+
for (const auto& nv : ivalue.toListRef()) {
|
| 158 |
+
auto* tensor = nv.unsafeToTensorImpl();
|
| 159 |
+
ks = ks | tensor->key_set();
|
| 160 |
+
}
|
| 161 |
+
}
|
| 162 |
+
// Tensor?[] translates to a c10::List<IValue> so we need to peek inside
|
| 163 |
+
else if (C10_UNLIKELY(ivalue.isList())) {
|
| 164 |
+
for (const auto& elt : ivalue.toListRef()) {
|
| 165 |
+
if (elt.isTensor()) {
|
| 166 |
+
ks = ks | elt.toTensor().key_set();
|
| 167 |
+
}
|
| 168 |
+
}
|
| 169 |
+
}
|
| 170 |
+
});
|
| 171 |
+
// Keys that are fallthrough should be skipped
|
| 172 |
+
if (requiresBitsetPerBackend_) {
|
| 173 |
+
c10::impl::LocalDispatchKeySet tls =
|
| 174 |
+
c10::impl::tls_local_dispatch_key_set();
|
| 175 |
+
auto backend_idx =
|
| 176 |
+
((ks | tls.included_) - tls.excluded_).getBackendIndex();
|
| 177 |
+
return impl::computeDispatchKeySet(
|
| 178 |
+
ks, nonFallthroughKeysPerBackend_[backend_idx]);
|
| 179 |
+
} else {
|
| 180 |
+
return impl::computeDispatchKeySet(ks, nonFallthroughKeys_);
|
| 181 |
+
}
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
template <class... Args>
|
| 185 |
+
DispatchKeySet getDispatchKeySetUnboxed(const Args&... args) const {
|
| 186 |
+
auto ks = detail::multi_dispatch_key_set(args...);
|
| 187 |
+
// Keys that are fallthrough should be skipped
|
| 188 |
+
if (requiresBitsetPerBackend_) {
|
| 189 |
+
c10::impl::LocalDispatchKeySet tls =
|
| 190 |
+
c10::impl::tls_local_dispatch_key_set();
|
| 191 |
+
auto backend_idx =
|
| 192 |
+
((ks | tls.included_) - tls.excluded_).getBackendIndex();
|
| 193 |
+
return impl::computeDispatchKeySet(
|
| 194 |
+
ks, nonFallthroughKeysPerBackend_[backend_idx]);
|
| 195 |
+
} else {
|
| 196 |
+
return impl::computeDispatchKeySet(ks, nonFallthroughKeys_);
|
| 197 |
+
}
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
void setOperatorHasFallthroughForKey(DispatchKey k, bool has_fallthrough);
|
| 201 |
+
|
| 202 |
+
std::string dumpState() const;
|
| 203 |
+
void checkInvariants(const FunctionSchema& schema) const;
|
| 204 |
+
|
| 205 |
+
private:
|
| 206 |
+
static bool isDispatchType(const Type& type) {
|
| 207 |
+
// Checking isSubtypeOf on a DynamicType heap-allocates a
|
| 208 |
+
// DynamicType version of the argument if it's not a DynamicType
|
| 209 |
+
// already, and this has measurable overhead during startup.
|
| 210 |
+
#ifdef C10_MOBILE
|
| 211 |
+
struct CachedTypes {
|
| 212 |
+
DynamicTypePtr listOfTensors;
|
| 213 |
+
DynamicTypePtr listOfOptionalTensors;
|
| 214 |
+
DynamicTypePtr optionalOfTensor;
|
| 215 |
+
};
|
| 216 |
+
static const CachedTypes ct = {
|
| 217 |
+
DynamicType::create(*ListType::ofTensors()),
|
| 218 |
+
DynamicType::create(*ListType::ofOptionalTensors()),
|
| 219 |
+
DynamicType::create(*OptionalType::ofTensor())};
|
| 220 |
+
return type.isSubtypeOf(c10::TypeFactory::get<TensorType>()) ||
|
| 221 |
+
type.isSubtypeOf(ct.listOfTensors) ||
|
| 222 |
+
type.isSubtypeOf(ct.listOfOptionalTensors) ||
|
| 223 |
+
type.isSubtypeOf(ct.optionalOfTensor);
|
| 224 |
+
#else // C10_MOBILE
|
| 225 |
+
return type.isSubtypeOf(*TensorType::get()) ||
|
| 226 |
+
type.isSubtypeOf(*ListType::ofTensors()) ||
|
| 227 |
+
type.isSubtypeOf(*ListType::ofOptionalTensors()) ||
|
| 228 |
+
type.isSubtypeOf(*OptionalType::ofTensor());
|
| 229 |
+
#endif // C10_MOBILE
|
| 230 |
+
}
|
| 231 |
+
static c10::utils::bitset makeBitsetForDispatchArgs(
|
| 232 |
+
const FunctionSchema& schema) {
|
| 233 |
+
TORCH_CHECK(
|
| 234 |
+
schema.arguments().size() <= c10::utils::bitset::NUM_BITS(),
|
| 235 |
+
"The function schema has ",
|
| 236 |
+
schema.arguments().size(),
|
| 237 |
+
" arguments but this PyTorch build only supports ",
|
| 238 |
+
c10::utils::bitset::NUM_BITS());
|
| 239 |
+
c10::utils::bitset dispatch_arg_indices_reverse;
|
| 240 |
+
for (const auto index : c10::irange(schema.arguments().size())) {
|
| 241 |
+
if (isDispatchType(*schema.arguments()[index].type())) {
|
| 242 |
+
dispatch_arg_indices_reverse.set(schema.arguments().size() - 1 - index);
|
| 243 |
+
}
|
| 244 |
+
}
|
| 245 |
+
return dispatch_arg_indices_reverse;
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
explicit DispatchKeyExtractor(c10::utils::bitset dispatch_arg_indices_reverse)
|
| 249 |
+
: dispatch_arg_indices_reverse_(dispatch_arg_indices_reverse),
|
| 250 |
+
nonFallthroughKeys_(DispatchKeySet::FULL) {
|
| 251 |
+
for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size())) {
|
| 252 |
+
nonFallthroughKeysPerBackend_[i] = DispatchKeySet::FULL;
|
| 253 |
+
}
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
// this is a bitset that has ones for each argument index which has to be
|
| 257 |
+
// considered for dispatch. This avoids having to iterate over the stack
|
| 258 |
+
// to find all the tensors. The bits are stored in reverse order, i.e.
|
| 259 |
+
// dispatch_arg_indices_reverse_[i] == true, then the i-th argument from
|
| 260 |
+
// the top of the stack (i.e. the i-th last argument of the function)
|
| 261 |
+
// is relevant for dispatch.
|
| 262 |
+
// dispatch_arg_indices_reverse_ is allowed to have zero bits set; that just
|
| 263 |
+
// means you must do the fallthrough
|
| 264 |
+
c10::utils::bitset dispatch_arg_indices_reverse_;
|
| 265 |
+
|
| 266 |
+
// Set of functionality keys for which the operator does NOT have fallthrough
|
| 267 |
+
// kernel.
|
| 268 |
+
DispatchKeySet nonFallthroughKeys_;
|
| 269 |
+
// Set of functionality keys for which the operator does NOT have fallthrough
|
| 270 |
+
// kernel, defined PER BACKEND. This is only needed if we know that the
|
| 271 |
+
// operator has a different set of fallthroughs defined for some backends.
|
| 272 |
+
std::array<DispatchKeySet, num_backends> nonFallthroughKeysPerBackend_;
|
| 273 |
+
// Flag to tell us if we can use the single set of nonFallthroughKeys_ (fast
|
| 274 |
+
// path), or if we need to fall back to the slower path and check
|
| 275 |
+
// nonFallthroughKeysPerBackend_
|
| 276 |
+
bool requiresBitsetPerBackend_{false};
|
| 277 |
+
};
|
| 278 |
+
|
| 279 |
+
} // namespace c10
|