Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/BoxedKernel.h +218 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/BoxedKernel_impl.h +111 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/KernelFunction.h +346 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/KernelFunction_impl.h +395 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/OperatorKernel.h +32 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h +43 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h +46 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/impl/boxing.h +415 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h +790 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/impl/test_helpers.h +145 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/CppSignature.h +72 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/DispatchKeyExtractor.h +285 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/Dispatcher.h +955 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/ObservedOperators.h +22 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/OperatorEntry.h +342 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/OperatorOptions.h +35 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/RegistrationHandleRAII.h +41 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/op_registration/adaption.h +86 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/op_registration/infer_schema.h +162 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/op_registration/op_allowlist.h +186 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/op_registration/op_registration.h +599 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/FlushDenormal.h +19 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/Utils.h +38 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vml.h +175 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/BLASConstants.h +16 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/CUDAHooks.h +76 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/DeviceThreadHandles.h +156 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/IndexUtils.cuh +41 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/IntegerDivider.cuh +129 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/KernelUtils.h +42 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/LazyNVRTC.h +16 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/OffsetCalculator.cuh +141 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/PhiloxCudaStateRaw.cuh +48 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/TensorInfo.cuh +121 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/UnpackRaw.cuh +39 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/GemmCommon.h +705 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/GemmHipblaslt.h +692 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/GemmRocblas.h +282 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/StreamTimer.h +55 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/Tunable.h +270 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/TunableGemm.h +334 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/TunableOp.h +434 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/functorch/ADInterpreters.h +43 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchRulesHelper.h +486 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchedFallback.h +86 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchedTensorImpl.h +181 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchingMetaprogramming.h +131 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/functorch/DynamicLayer.h +129 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/functorch/FunctionalizeInterpreter.h +27 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/functorch/Interpreter.h +358 -0
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/BoxedKernel.h
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/core/boxing/OperatorKernel.h>
|
| 5 |
+
#include <c10/core/DispatchKeySet.h>
|
| 6 |
+
#include <c10/util/intrusive_ptr.h>
|
| 7 |
+
|
| 8 |
+
namespace c10 {
|
| 9 |
+
|
| 10 |
+
struct IValue;
|
| 11 |
+
using Stack = std::vector<IValue>;
|
| 12 |
+
|
| 13 |
+
class OperatorHandle;
|
| 14 |
+
class KernelFunction;
|
| 15 |
+
|
| 16 |
+
// This kernel implements the behavior of falling through to the next available
|
| 17 |
+
// registered dispatch key. The implementation of this function is FAST; it is
|
| 18 |
+
// no overhead to fallthrough to the next key. See cpp file for some more
|
| 19 |
+
// implementation notes; notably, this does NOT actually go through the
|
| 20 |
+
// boxing/unboxing codepath.
|
| 21 |
+
TORCH_API void fallthrough_kernel(
|
| 22 |
+
OperatorKernel* /*unused*/,
|
| 23 |
+
const OperatorHandle& /*unused*/,
|
| 24 |
+
DispatchKeySet /*unused*/,
|
| 25 |
+
Stack* /*unused*/);
|
| 26 |
+
|
| 27 |
+
// Note [Ambiguity in AutogradOther kernel]
|
| 28 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 29 |
+
// This error-reporting kernel is registered to the AutogradOther entry in the
|
| 30 |
+
// dispatch table when there is both a CompositeImplicitAutograd kernel and a
|
| 31 |
+
// backend kernel for ANY backend that maps to AutogradOther. To see why
|
| 32 |
+
// this is necessary in the AutogradOther case, it's helpful to first see
|
| 33 |
+
// why everything works out fine for a backend that has a reserved Autograd
|
| 34 |
+
// entry (see rule 2.2 in [Note] DispatchTable computation):
|
| 35 |
+
//
|
| 36 |
+
// CPU AutogradCPU
|
| 37 |
+
// reg? registers with...
|
| 38 |
+
// -------------------------------------------------
|
| 39 |
+
// y Autograd registration takes precedence
|
| 40 |
+
// over CompositeImplicitAutograd.
|
| 41 |
+
// This is good, because the CPU specific backend
|
| 42 |
+
// implementation is more specialized and typically better;
|
| 43 |
+
// if we used the composite, we would bypass it.
|
| 44 |
+
// (NB: the Autograd key is guaranteed to exist because
|
| 45 |
+
// the autograd codegen requires it!)
|
| 46 |
+
//
|
| 47 |
+
// n CompositeImplicitAutograd takes precedence.
|
| 48 |
+
// This is also good, because the Autograd
|
| 49 |
+
// registration (if it exists) would try to redispatch
|
| 50 |
+
// to the (non-existent) CPU implementation; by
|
| 51 |
+
// using the composite, we ensure the operator
|
| 52 |
+
// actually works.
|
| 53 |
+
//
|
| 54 |
+
// As you can see, when we have a specific Autograd key (AutogradCPU), we can
|
| 55 |
+
// decide whether or not to use the CompositeImplicitAutograd kernel or the
|
| 56 |
+
// Autograd kernel based on whether or not the backend kernel exists.
|
| 57 |
+
//
|
| 58 |
+
// However, for AutogradOther (which is the catchall autograd kernel for
|
| 59 |
+
// everything that doesn't have a specific Autograd key), we can't do this
|
| 60 |
+
// trick because there isn't any unique backend to peek at to disambiguate;
|
| 61 |
+
// if there are some backends that have implementations they prefer Autograd,
|
| 62 |
+
// but unimplemented backends would prefer CompositeImplicitAutograd. Rather
|
| 63 |
+
// than arbitrarily pick one or the other, we just register a kernel that raises
|
| 64 |
+
// an error and let the user decide how to proceed.
|
| 65 |
+
TORCH_API void ambiguous_autogradother_kernel(
|
| 66 |
+
OperatorKernel* /*unused*/,
|
| 67 |
+
const OperatorHandle& /*op*/,
|
| 68 |
+
DispatchKeySet /*unused*/,
|
| 69 |
+
Stack* /*unused*/);
|
| 70 |
+
|
| 71 |
+
// Note [named_not_supported_kernel]
|
| 72 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 73 |
+
// This kernel implements reporting an error message saying that named tensor is
|
| 74 |
+
// not supported. This kernel doesn't rely on the Stack, and so it is special
|
| 75 |
+
// cased in the dispatcher to be triggered before we attempt boxing (so we can
|
| 76 |
+
// give a good error message in cases when boxing is not supported). When
|
| 77 |
+
// boxing is universally supported this can be removed.
|
| 78 |
+
[[noreturn]] TORCH_API void named_not_supported_kernel(
|
| 79 |
+
OperatorKernel* /*unused*/,
|
| 80 |
+
const OperatorHandle& /*op*/,
|
| 81 |
+
DispatchKeySet /*unused*/,
|
| 82 |
+
Stack* /*unused*/);
|
| 83 |
+
|
| 84 |
+
/**
|
| 85 |
+
* BoxedKernel is similar to a std::function storing a boxed kernel.
|
| 86 |
+
*/
|
| 87 |
+
class TORCH_API BoxedKernel final {
|
| 88 |
+
public:
|
| 89 |
+
// This is how boxed kernels are actually stored
|
| 90 |
+
//
|
| 91 |
+
// Note [Plumbing Keys Through The Dispatcher]
|
| 92 |
+
// Benchmarks have shown that it is expensive for the dispatcher to read from
|
| 93 |
+
// thread-local storage (TLS) upon every dispatch call into order to compute
|
| 94 |
+
// which kernel to dispatch to.
|
| 95 |
+
//
|
| 96 |
+
// To mitigate this, we've updated the calling convention inside the
|
| 97 |
+
// dispatcher to expect every kernel that it stores to have a first argument
|
| 98 |
+
// of type DispatchKeySet.
|
| 99 |
+
//
|
| 100 |
+
// What are the invariants of the DispatchKeySet when it gets passed to a
|
| 101 |
+
// kernel?
|
| 102 |
+
// - All keys to the left of the current dispatch key have been masked out.
|
| 103 |
+
// (e.g. a Tracing kernel that takes in the DispatchKeySet will expect the
|
| 104 |
+
// highest bit to be DispatchKey::Tracer)
|
| 105 |
+
// - All other keys that dispatcher normally would have computed through TLS +
|
| 106 |
+
// global state + op arguments
|
| 107 |
+
// are still in the set.
|
| 108 |
+
//
|
| 109 |
+
// Kernels can then opt into using this keyset to save the dispatcher from
|
| 110 |
+
// doing repeated work during redispatches: recalculating the highest-priority
|
| 111 |
+
// dispatch key, which involves reading from TLS. Instead, the kernels that
|
| 112 |
+
// opt in will calculate an updated DispatchKeySet directly from the old one,
|
| 113 |
+
// and pass the updated set directly into the dispatcher upon redispatching.
|
| 114 |
+
//
|
| 115 |
+
// This is an opt-in mechanism: Kernels can automatically opt in by setting
|
| 116 |
+
// the first argument in their signature to be of type DispatchKeySet. See the
|
| 117 |
+
// kernels in VariableTypeEverything.cpp and TraceTypeEverything.cpp for
|
| 118 |
+
// examples.
|
| 119 |
+
//
|
| 120 |
+
// The mechanism for optionally passing that DispatchKeySet into the kernel
|
| 121 |
+
// lives in make_boxed_from_unboxed_functor.h. See Note [Plumbing Keys Through
|
| 122 |
+
// The Dispatcher 2] for details.
|
| 123 |
+
using InternalBoxedKernelFunction =
|
| 124 |
+
void(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*);
|
| 125 |
+
// This is the public API for how boxed kernels are defined
|
| 126 |
+
using BoxedKernelFunction = void(const OperatorHandle&, Stack*);
|
| 127 |
+
using BoxedKernelFunction_withDispatchKeys =
|
| 128 |
+
void(const OperatorHandle&, DispatchKeySet, Stack*);
|
| 129 |
+
|
| 130 |
+
BoxedKernel();
|
| 131 |
+
|
| 132 |
+
// Fast path for dispatch to allow not touching the boxed kernel in
|
| 133 |
+
// the common case where unboxed is available.
|
| 134 |
+
bool isValid() const;
|
| 135 |
+
bool isFallthrough() const;
|
| 136 |
+
|
| 137 |
+
/**
|
| 138 |
+
* Call the function with boxed arguments.
|
| 139 |
+
*/
|
| 140 |
+
void callBoxed(
|
| 141 |
+
const OperatorHandle& opHandle,
|
| 142 |
+
DispatchKeySet dispatchKeySet,
|
| 143 |
+
Stack* stack) const;
|
| 144 |
+
|
| 145 |
+
/**
|
| 146 |
+
* Create a KernelFunction from a boxed function.
|
| 147 |
+
*
|
| 148 |
+
* Example:
|
| 149 |
+
*
|
| 150 |
+
* > void boxed_func(OperatorKernel*, Stack* stack) {...}
|
| 151 |
+
* > BoxedFunction func = BoxedKernel::makeFromFunction<&boxed_func>();
|
| 152 |
+
*/
|
| 153 |
+
template <BoxedKernelFunction* func>
|
| 154 |
+
static BoxedKernel makeFromFunction();
|
| 155 |
+
|
| 156 |
+
/**
|
| 157 |
+
* TODO: This will only be useful if we write a backend fallback that plumbs
|
| 158 |
+
* dispatch keys (currently there are none) See Note [Plumbing Keys Through
|
| 159 |
+
* The Dispatcher] for details.
|
| 160 |
+
*/
|
| 161 |
+
template <BoxedKernelFunction_withDispatchKeys* func>
|
| 162 |
+
static BoxedKernel makeFromFunction();
|
| 163 |
+
|
| 164 |
+
/**
|
| 165 |
+
* Create a KernelFunction from a boxed functor.
|
| 166 |
+
*
|
| 167 |
+
* Example:
|
| 168 |
+
*
|
| 169 |
+
* > class MyFunctor final : public c10::OperatorKernel {
|
| 170 |
+
* > public:
|
| 171 |
+
* > void operator()(const OperatorHandle&, DispatchKeySet, Stack*) {...}
|
| 172 |
+
* > };
|
| 173 |
+
* > BoxedKernel func =
|
| 174 |
+
* BoxedKernel::makeFromFunctor(std::make_unique<MyFunctor>());
|
| 175 |
+
*/
|
| 176 |
+
template <class KernelFunctor>
|
| 177 |
+
static BoxedKernel makeFromFunctor(
|
| 178 |
+
std::unique_ptr<KernelFunctor> kernelFunctor);
|
| 179 |
+
|
| 180 |
+
static BoxedKernel makeFallthrough();
|
| 181 |
+
static BoxedKernel makeAmbiguousAutogradOther();
|
| 182 |
+
static BoxedKernel makeNamedNotSupported();
|
| 183 |
+
|
| 184 |
+
private:
|
| 185 |
+
friend class KernelFunction;
|
| 186 |
+
|
| 187 |
+
template <BoxedKernelFunction* func>
|
| 188 |
+
static void make_boxed_function(
|
| 189 |
+
OperatorKernel* /*unused*/,
|
| 190 |
+
const OperatorHandle& opHandle,
|
| 191 |
+
DispatchKeySet /*unused*/,
|
| 192 |
+
Stack* stack);
|
| 193 |
+
|
| 194 |
+
template <BoxedKernelFunction_withDispatchKeys* func>
|
| 195 |
+
static void make_boxed_function(
|
| 196 |
+
OperatorKernel* /*unused*/,
|
| 197 |
+
const OperatorHandle& opHandle,
|
| 198 |
+
DispatchKeySet /*ks*/,
|
| 199 |
+
Stack* stack);
|
| 200 |
+
|
| 201 |
+
explicit BoxedKernel(
|
| 202 |
+
std::unique_ptr<OperatorKernel> functor,
|
| 203 |
+
InternalBoxedKernelFunction* boxed_kernel_func);
|
| 204 |
+
|
| 205 |
+
OperatorKernel* getFunctor() const;
|
| 206 |
+
InternalBoxedKernelFunction* getFnPtr() const;
|
| 207 |
+
|
| 208 |
+
c10::intrusive_ptr<OperatorKernel> functor_;
|
| 209 |
+
InternalBoxedKernelFunction* boxed_kernel_func_;
|
| 210 |
+
};
|
| 211 |
+
|
| 212 |
+
} // namespace c10
|
| 213 |
+
|
| 214 |
+
#include <ATen/core/boxing/BoxedKernel_impl.h>
|
| 215 |
+
|
| 216 |
+
#else
|
| 217 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 218 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/BoxedKernel_impl.h
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
namespace c10 {
|
| 5 |
+
|
| 6 |
+
inline BoxedKernel::BoxedKernel() : boxed_kernel_func_(nullptr) {}
|
| 7 |
+
|
| 8 |
+
inline BoxedKernel::BoxedKernel(
|
| 9 |
+
std::unique_ptr<OperatorKernel> functor,
|
| 10 |
+
InternalBoxedKernelFunction* boxed_kernel_func)
|
| 11 |
+
: functor_(std::move(functor)), boxed_kernel_func_(boxed_kernel_func) {}
|
| 12 |
+
|
| 13 |
+
template <BoxedKernel::BoxedKernelFunction* func>
|
| 14 |
+
inline void BoxedKernel::make_boxed_function(
|
| 15 |
+
OperatorKernel* /*unused*/,
|
| 16 |
+
const OperatorHandle& opHandle,
|
| 17 |
+
DispatchKeySet /*unused*/,
|
| 18 |
+
Stack* stack) {
|
| 19 |
+
// Note that we're dropping the DispatchKeySet argument.
|
| 20 |
+
// See Note [Plumbing Keys Through The Dispatcher 2] for details.
|
| 21 |
+
func(opHandle, stack);
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
template <BoxedKernel::BoxedKernelFunction_withDispatchKeys* func>
|
| 25 |
+
inline void BoxedKernel::make_boxed_function(
|
| 26 |
+
OperatorKernel* /*unused*/,
|
| 27 |
+
const OperatorHandle& opHandle,
|
| 28 |
+
DispatchKeySet ks,
|
| 29 |
+
Stack* stack) {
|
| 30 |
+
// See Note [Plumbing Keys Through The Dispatcher 2] for details.
|
| 31 |
+
func(opHandle, ks, stack);
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
inline bool BoxedKernel::isValid() const {
|
| 35 |
+
return boxed_kernel_func_ != nullptr;
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
inline bool BoxedKernel::isFallthrough() const {
|
| 39 |
+
return boxed_kernel_func_ == &fallthrough_kernel;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
inline void BoxedKernel::callBoxed(
|
| 43 |
+
const OperatorHandle& opHandle,
|
| 44 |
+
DispatchKeySet dispatchKeySet,
|
| 45 |
+
Stack* stack) const {
|
| 46 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
| 47 |
+
boxed_kernel_func_ != nullptr,
|
| 48 |
+
"Tried to call BoxedKernel::callBoxed() on an uninitialized BoxedKernel.");
|
| 49 |
+
(*boxed_kernel_func_)(functor_.get(), opHandle, dispatchKeySet, stack);
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
template <BoxedKernel::BoxedKernelFunction* func>
|
| 53 |
+
inline BoxedKernel BoxedKernel::makeFromFunction() {
|
| 54 |
+
return BoxedKernel(
|
| 55 |
+
nullptr, // no functor_ object
|
| 56 |
+
&make_boxed_function<func>);
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
template <BoxedKernel::BoxedKernelFunction_withDispatchKeys* func>
|
| 60 |
+
inline BoxedKernel BoxedKernel::makeFromFunction() {
|
| 61 |
+
return BoxedKernel(
|
| 62 |
+
nullptr, // no functor_ object
|
| 63 |
+
&make_boxed_function<func>);
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
inline BoxedKernel BoxedKernel::makeFallthrough() {
|
| 67 |
+
return BoxedKernel(
|
| 68 |
+
nullptr, // no functor_ object
|
| 69 |
+
&fallthrough_kernel);
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
inline BoxedKernel BoxedKernel::makeAmbiguousAutogradOther() {
|
| 73 |
+
return BoxedKernel(
|
| 74 |
+
nullptr, // no functor_ object
|
| 75 |
+
&ambiguous_autogradother_kernel);
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
inline BoxedKernel BoxedKernel::makeNamedNotSupported() {
|
| 79 |
+
return BoxedKernel(
|
| 80 |
+
nullptr, // no functor_ object
|
| 81 |
+
&named_not_supported_kernel);
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
template <class KernelFunctor>
|
| 85 |
+
inline BoxedKernel BoxedKernel::makeFromFunctor(
|
| 86 |
+
std::unique_ptr<KernelFunctor> kernelFunctor) {
|
| 87 |
+
static_assert(
|
| 88 |
+
std::is_base_of_v<OperatorKernel, KernelFunctor>,
|
| 89 |
+
"Tried to call BoxedKernel::makeFromFunctor<KernelFunctor>, but the functor doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
|
| 90 |
+
return BoxedKernel(
|
| 91 |
+
std::move(kernelFunctor),
|
| 92 |
+
[](OperatorKernel* kernel,
|
| 93 |
+
const OperatorHandle& op,
|
| 94 |
+
DispatchKeySet ks,
|
| 95 |
+
Stack* stack) {
|
| 96 |
+
(*static_cast<KernelFunctor*>(kernel))(op, ks, stack);
|
| 97 |
+
});
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
inline OperatorKernel* BoxedKernel::getFunctor() const {
|
| 101 |
+
return functor_.get();
|
| 102 |
+
}
|
| 103 |
+
inline BoxedKernel::InternalBoxedKernelFunction* BoxedKernel::getFnPtr() const {
|
| 104 |
+
return boxed_kernel_func_;
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
} // namespace c10
|
| 108 |
+
|
| 109 |
+
#else
|
| 110 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 111 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/KernelFunction.h
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/core/ATen_fwd.h>
|
| 5 |
+
#include <ATen/core/boxing/BoxedKernel.h>
|
| 6 |
+
#include <ATen/core/stack.h>
|
| 7 |
+
#include <c10/core/DispatchKeySet.h>
|
| 8 |
+
#include <c10/util/TypeList.h>
|
| 9 |
+
#include <c10/util/intrusive_ptr.h>
|
| 10 |
+
#include <atomic>
|
| 11 |
+
#include <memory>
|
| 12 |
+
#include <type_traits>
|
| 13 |
+
|
| 14 |
+
namespace c10 {
|
| 15 |
+
|
| 16 |
+
using Stack = torch::jit::Stack; // TODO Instead of this, move torch::jit::Stack
|
| 17 |
+
// to the c10 namespace.
|
| 18 |
+
|
| 19 |
+
class OperatorHandle;
|
| 20 |
+
struct OperatorKernel;
|
| 21 |
+
class KernelFunction;
|
| 22 |
+
|
| 23 |
+
class KernelToken;
|
| 24 |
+
class SafeKernelFunction;
|
| 25 |
+
|
| 26 |
+
template <typename T>
|
| 27 |
+
using has_symint = std::disjunction<
|
| 28 |
+
std::is_same<c10::SymInt, T>,
|
| 29 |
+
std::is_same<c10::SymIntArrayRef, T>,
|
| 30 |
+
std::is_same<at::OptionalSymIntArrayRef, T>,
|
| 31 |
+
std::is_same<std::optional<c10::SymInt>, T>>;
|
| 32 |
+
|
| 33 |
+
template <typename T>
|
| 34 |
+
struct remove_symint {
|
| 35 |
+
using type = T;
|
| 36 |
+
};
|
| 37 |
+
|
| 38 |
+
template <>
|
| 39 |
+
struct remove_symint<c10::SymInt> {
|
| 40 |
+
using type = int64_t;
|
| 41 |
+
};
|
| 42 |
+
|
| 43 |
+
template <>
|
| 44 |
+
struct remove_symint<at::OptionalSymIntArrayRef> {
|
| 45 |
+
using type = OptionalIntArrayRef;
|
| 46 |
+
};
|
| 47 |
+
|
| 48 |
+
template <>
|
| 49 |
+
struct remove_symint<c10::SymIntArrayRef> {
|
| 50 |
+
using type = c10::IntArrayRef;
|
| 51 |
+
};
|
| 52 |
+
|
| 53 |
+
template <>
|
| 54 |
+
struct remove_symint<std::optional<c10::SymInt>> {
|
| 55 |
+
using type = std::optional<int64_t>;
|
| 56 |
+
};
|
| 57 |
+
|
| 58 |
+
template <bool symint, typename T>
|
| 59 |
+
struct maybe_keep_symint final {};
|
| 60 |
+
|
| 61 |
+
template <typename T>
|
| 62 |
+
struct maybe_keep_symint<true, T> {
|
| 63 |
+
using type = T;
|
| 64 |
+
};
|
| 65 |
+
|
| 66 |
+
template <typename T>
|
| 67 |
+
struct maybe_keep_symint<false, T> {
|
| 68 |
+
using type = typename remove_symint<T>::type;
|
| 69 |
+
};
|
| 70 |
+
|
| 71 |
+
template <typename T>
|
| 72 |
+
using fn_has_symint = typename guts::typelist::true_for_any_type<
|
| 73 |
+
has_symint,
|
| 74 |
+
typename guts::infer_function_traits<T>::type::parameter_types>;
|
| 75 |
+
|
| 76 |
+
template <typename T>
|
| 77 |
+
struct fn_remove_symint;
|
| 78 |
+
|
| 79 |
+
template <typename Ret, typename... Args>
|
| 80 |
+
struct fn_remove_symint<Ret(Args...)> {
|
| 81 |
+
using type = Ret(typename remove_symint<Args>::type...);
|
| 82 |
+
};
|
| 83 |
+
|
| 84 |
+
/**
|
| 85 |
+
* KernelFunction is similar to std::function but stores a kernel function.
|
| 86 |
+
* You can create a KernelFunction from a boxed or unboxed
|
| 87 |
+
* function/functor/lambda and call it in a boxed or unboxed way. If the way it
|
| 88 |
+
* was created doesn't match the way it was called, it will do boxing or
|
| 89 |
+
* unboxing as necessary.
|
| 90 |
+
*/
|
| 91 |
+
class TORCH_API KernelFunction final {
|
| 92 |
+
public:
|
| 93 |
+
using InternalBoxedKernelFunction = BoxedKernel::InternalBoxedKernelFunction;
|
| 94 |
+
using BoxedKernelFunction = BoxedKernel::BoxedKernelFunction;
|
| 95 |
+
using BoxedKernelFunction_withDispatchKeys =
|
| 96 |
+
BoxedKernel::BoxedKernelFunction_withDispatchKeys;
|
| 97 |
+
|
| 98 |
+
KernelFunction();
|
| 99 |
+
~KernelFunction();
|
| 100 |
+
|
| 101 |
+
KernelFunction(const KernelFunction& other);
|
| 102 |
+
KernelFunction& operator=(const KernelFunction& other);
|
| 103 |
+
|
| 104 |
+
KernelFunction(KernelFunction&&) noexcept = default;
|
| 105 |
+
|
| 106 |
+
// Fast path for dispatch to allow not touching the boxed kernel in
|
| 107 |
+
// the common case where unboxed is available.
|
| 108 |
+
bool isValidUnboxed() const;
|
| 109 |
+
bool isValidSymUnboxed() const;
|
| 110 |
+
bool isValid() const;
|
| 111 |
+
bool isFallthrough() const;
|
| 112 |
+
|
| 113 |
+
/**
|
| 114 |
+
* Call the function in a boxed way.
|
| 115 |
+
* If the kernel function was created with an unboxed function,
|
| 116 |
+
* this will call an unboxing wrapper which then calls into that
|
| 117 |
+
* unboxed function.
|
| 118 |
+
*
|
| 119 |
+
* Example:
|
| 120 |
+
*
|
| 121 |
+
* > void boxed_func(OperatorKernel*, Stack* stack) {...}
|
| 122 |
+
* > KernelFunction func = KernelFunction::makeFromBoxedFunction(&boxed_func);
|
| 123 |
+
* > Tensor result = func.callBoxed(stack);
|
| 124 |
+
*
|
| 125 |
+
* Or, with an unboxed implementation:
|
| 126 |
+
*
|
| 127 |
+
* > KernelFunction func = KernelFunction::makeFromUnboxedLambda(
|
| 128 |
+
* > [] (Tensor a, bool b) -> Tensor {...});
|
| 129 |
+
* > Tensor result = func.callBoxed(stack);
|
| 130 |
+
*/
|
| 131 |
+
void callBoxed(
|
| 132 |
+
const OperatorHandle& opHandle,
|
| 133 |
+
DispatchKeySet dispatchKeySet,
|
| 134 |
+
Stack* stack) const;
|
| 135 |
+
|
| 136 |
+
/**
|
| 137 |
+
* Call the function in an unboxed way.
|
| 138 |
+
* If the kernel function was created with a boxed function,
|
| 139 |
+
* this will box all inputs and then call into that boxed function.
|
| 140 |
+
*
|
| 141 |
+
* Note that this doesn't work for all types yet.
|
| 142 |
+
*
|
| 143 |
+
* Example:
|
| 144 |
+
*
|
| 145 |
+
* > KernelFunction func = KernelFunction::makeFromUnboxedLambda(
|
| 146 |
+
* > [] (Tensor a, bool b) -> Tensor {...});
|
| 147 |
+
* > Tensor result = func.call<Tensor, Tensor, bool>(tensor1, true);
|
| 148 |
+
*
|
| 149 |
+
* Or, with a boxed implementation:
|
| 150 |
+
*
|
| 151 |
+
* > void boxed_func(OperatorKernel*, Stack* stack) {...}
|
| 152 |
+
* > KernelFunction func = KernelFunction::makeFromBoxedFunction(&boxed_func);
|
| 153 |
+
* > Tensor result = func.call<Tensor, Tensor, bool>(tensor1, true);
|
| 154 |
+
*/
|
| 155 |
+
template <class Return, class... Args>
|
| 156 |
+
Return call(
|
| 157 |
+
const OperatorHandle& opHandle,
|
| 158 |
+
DispatchKeySet dispatchKeySet,
|
| 159 |
+
Args... args) const;
|
| 160 |
+
|
| 161 |
+
/**
|
| 162 |
+
* Create a KernelFunction from a BoxedKernel.
|
| 163 |
+
*/
|
| 164 |
+
static KernelFunction makeFromBoxedKernel(BoxedKernel boxed_fn);
|
| 165 |
+
|
| 166 |
+
/**
|
| 167 |
+
* Create a KernelFunction from a boxed function.
|
| 168 |
+
*
|
| 169 |
+
* Example:
|
| 170 |
+
*
|
| 171 |
+
* > void boxed_func(OperatorKernel*, Stack* stack) {...}
|
| 172 |
+
* > KernelFunction func =
|
| 173 |
+
* KernelFunction::makeFromBoxedFunction<&boxed_func>();
|
| 174 |
+
*/
|
| 175 |
+
template <BoxedKernelFunction* func>
|
| 176 |
+
static KernelFunction makeFromBoxedFunction();
|
| 177 |
+
|
| 178 |
+
/**
|
| 179 |
+
* TODO: This will only be useful if we write a backend fallback that plumbs
|
| 180 |
+
* dispatch keys (currently there are none) See Note [Plumbing Keys Through
|
| 181 |
+
* The Dispatcher] for details.
|
| 182 |
+
*/
|
| 183 |
+
template <BoxedKernelFunction_withDispatchKeys* func>
|
| 184 |
+
static KernelFunction makeFromBoxedFunction();
|
| 185 |
+
|
| 186 |
+
/**
|
| 187 |
+
* Create a KernelFunction from an unboxed functor.
|
| 188 |
+
*
|
| 189 |
+
* Example:
|
| 190 |
+
*
|
| 191 |
+
* > class MyFunctor final : public c10::OperatorKernel {
|
| 192 |
+
* > public:
|
| 193 |
+
* > Tensor operator()(Tensor a, Tensor b) {...}
|
| 194 |
+
* > };
|
| 195 |
+
* > KernelFunction func =
|
| 196 |
+
* KernelFunction::makeFromUnboxedFunctor<MyFunctor>(std::make_unique<MyFunctor>());
|
| 197 |
+
*/
|
| 198 |
+
template <bool AllowLegacyTypes = false, class KernelFunctor>
|
| 199 |
+
static KernelFunction makeFromUnboxedFunctor(
|
| 200 |
+
std::unique_ptr<OperatorKernel> kernelFunctor);
|
| 201 |
+
|
| 202 |
+
/**
|
| 203 |
+
* Create a KernelFunction from a boxed functor.
|
| 204 |
+
*
|
| 205 |
+
* Example:
|
| 206 |
+
*
|
| 207 |
+
* > class MyFunctor final : public c10::OperatorKernel {
|
| 208 |
+
* > public:
|
| 209 |
+
* > void operator()(const OperatorHandle&, DispatchKeySet, Stack*) {...}
|
| 210 |
+
* > };
|
| 211 |
+
* > KernelFunction func =
|
| 212 |
+
* KernelFunction::makeFromBoxedFunctor(std::make_unique<MyFunctor>());
|
| 213 |
+
*/
|
| 214 |
+
template <class KernelFunctor>
|
| 215 |
+
static KernelFunction makeFromBoxedFunctor(
|
| 216 |
+
std::unique_ptr<KernelFunctor> kernelFunctor);
|
| 217 |
+
|
| 218 |
+
/**
|
| 219 |
+
* Create a KernelFunction from an unboxed function.
|
| 220 |
+
* This is usually better than KernelFunction::makeFromUnboxedRuntimeFunction
|
| 221 |
+
* because knowing the function pointer as a template argument (i.e. at
|
| 222 |
+
* compile time) allows the compiler to inline the function into its
|
| 223 |
+
* unboxing wrapper and yields better performance when calling the function.
|
| 224 |
+
*
|
| 225 |
+
* Example:
|
| 226 |
+
*
|
| 227 |
+
* > Tensor unboxed_func(Tensor a, Tensor b) {...}
|
| 228 |
+
* > KernelFunction func =
|
| 229 |
+
* KernelFunction::makeFromUnboxedFunction<decltype(unboxed_func),
|
| 230 |
+
* &unboxed_func>();
|
| 231 |
+
*/
|
| 232 |
+
template <class FuncPtr, bool AllowLegacyTypes = false>
|
| 233 |
+
static KernelFunction makeFromUnboxedFunction(FuncPtr /*func_ptr*/);
|
| 234 |
+
|
| 235 |
+
/**
|
| 236 |
+
* Create a KernelFunction from an unboxed function.
|
| 237 |
+
* KernelFunction::makeFromUnboxedFunction is usually a better choice than
|
| 238 |
+
* this if you know the function pointer at compile time, see doc comment
|
| 239 |
+
* there for an explanation.
|
| 240 |
+
*
|
| 241 |
+
* Example:
|
| 242 |
+
*
|
| 243 |
+
* > Tensor unboxed_func(Tensor a, Tensor b) {...}
|
| 244 |
+
* > KernelFunction func =
|
| 245 |
+
* KernelFunction::makeFromUnboxedRuntimeFunction(&unboxed_func);
|
| 246 |
+
*/
|
| 247 |
+
template <bool AllowLegacyTypes = false, class FuncType>
|
| 248 |
+
static KernelFunction makeFromUnboxedRuntimeFunction(FuncType* func);
|
| 249 |
+
|
| 250 |
+
static KernelFunction makeFallthrough();
|
| 251 |
+
static KernelFunction makeAmbiguousAutogradOther();
|
| 252 |
+
static KernelFunction makeNamedNotSupported();
|
| 253 |
+
|
| 254 |
+
/**
|
| 255 |
+
* Create a KernelFunction from an unboxed lambda.
|
| 256 |
+
*
|
| 257 |
+
* Example:
|
| 258 |
+
*
|
| 259 |
+
* > KernelFunction func = KernelFunction::makeFromUnboxedLambda(
|
| 260 |
+
* > [] (Tensor a, bool b) -> Tensor {...});
|
| 261 |
+
*/
|
| 262 |
+
template <bool AllowLegacyTypes = false, class Lambda>
|
| 263 |
+
static std::enable_if_t<
|
| 264 |
+
guts::is_stateless_lambda<std::decay_t<Lambda>>::value,
|
| 265 |
+
KernelFunction>
|
| 266 |
+
makeFromUnboxedLambda(Lambda&& lambda);
|
| 267 |
+
template <bool AllowLegacyTypes = false, class Lambda>
|
| 268 |
+
static std::enable_if_t<
|
| 269 |
+
!guts::is_stateless_lambda<std::decay_t<Lambda>>::value,
|
| 270 |
+
KernelFunction>
|
| 271 |
+
makeFromUnboxedLambda(Lambda&& lambda);
|
| 272 |
+
|
| 273 |
+
std::string dumpState() const;
|
| 274 |
+
// For testing internal invariants only
|
| 275 |
+
bool _equalsBoxedAndUnboxed(const KernelFunction& /*other*/) const;
|
| 276 |
+
|
| 277 |
+
// Register a token to be invalidated when this KernelFunction is destroyed
|
| 278 |
+
void registerToken(std::weak_ptr<KernelToken> token) const;
|
| 279 |
+
|
| 280 |
+
private:
|
| 281 |
+
explicit KernelFunction(
|
| 282 |
+
std::unique_ptr<OperatorKernel> functor,
|
| 283 |
+
InternalBoxedKernelFunction* boxed_kernel_func,
|
| 284 |
+
void* unboxed_kernel_func,
|
| 285 |
+
void* sym_unboxed_kernel_func);
|
| 286 |
+
explicit KernelFunction(
|
| 287 |
+
BoxedKernel boxed_fn,
|
| 288 |
+
void* unboxed_kernel_func,
|
| 289 |
+
void* sym_unboxed_kernel_func);
|
| 290 |
+
|
| 291 |
+
BoxedKernel boxed_kernel_func_;
|
| 292 |
+
void* unboxed_kernel_func_;
|
| 293 |
+
void* sym_unboxed_kernel_func_;
|
| 294 |
+
// List of tokens that need to be invalidated when this KernelFunction is
|
| 295 |
+
// destroyed (lazy allocation to save memory when empty)
|
| 296 |
+
mutable std::unique_ptr<std::vector<std::weak_ptr<KernelToken>>> tokens_;
|
| 297 |
+
};
|
| 298 |
+
|
| 299 |
+
// Token held by SafeKernelFunction that gets invalidated when KernelFunction is
|
| 300 |
+
// destroyed
|
| 301 |
+
class KernelToken {
|
| 302 |
+
public:
|
| 303 |
+
bool isValid() const;
|
| 304 |
+
void invalidate();
|
| 305 |
+
|
| 306 |
+
private:
|
| 307 |
+
std::atomic<bool> invalid_{false};
|
| 308 |
+
};
|
| 309 |
+
|
| 310 |
+
class SafeKernelFunction {
|
| 311 |
+
public:
|
| 312 |
+
SafeKernelFunction(
|
| 313 |
+
const KernelFunction* kernel,
|
| 314 |
+
std::string debug,
|
| 315 |
+
std::shared_ptr<OperatorHandle> opHandle);
|
| 316 |
+
|
| 317 |
+
// Safe callBoxed - checks token validity first
|
| 318 |
+
void callBoxed(
|
| 319 |
+
const OperatorHandle& opHandle,
|
| 320 |
+
DispatchKeySet dispatchKeySet,
|
| 321 |
+
Stack* stack) const;
|
| 322 |
+
|
| 323 |
+
// Get debug information
|
| 324 |
+
const std::string& debug() const {
|
| 325 |
+
return debug_;
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
// Get the OpHandle that lives on this SafeKernelFunction
|
| 329 |
+
const OperatorHandle& opHandle() const {
|
| 330 |
+
return *opHandle_;
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
private:
|
| 334 |
+
KernelFunction kernel_;
|
| 335 |
+
std::shared_ptr<KernelToken> token_;
|
| 336 |
+
std::string debug_;
|
| 337 |
+
std::shared_ptr<OperatorHandle> opHandle_;
|
| 338 |
+
};
|
| 339 |
+
|
| 340 |
+
} // namespace c10
|
| 341 |
+
|
| 342 |
+
#include <ATen/core/boxing/KernelFunction_impl.h>
|
| 343 |
+
|
| 344 |
+
#else
|
| 345 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 346 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/KernelFunction_impl.h
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#include <ATen/core/boxing/impl/WrapFunctionIntoFunctor.h>
|
| 3 |
+
#include <ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h>
|
| 4 |
+
#include <ATen/core/boxing/impl/boxing.h>
|
| 5 |
+
#include <ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h>
|
| 6 |
+
|
| 7 |
+
#include <c10/util/C++17.h>
|
| 8 |
+
#include <type_traits>
|
| 9 |
+
|
| 10 |
+
namespace c10 {
|
| 11 |
+
|
| 12 |
+
namespace detail {
|
| 13 |
+
template <typename Base, typename Child, typename... Args>
|
| 14 |
+
std::enable_if_t<
|
| 15 |
+
!std::is_array_v<Base> && !std::is_array_v<Child> &&
|
| 16 |
+
std::is_base_of_v<Base, Child>,
|
| 17 |
+
std::unique_ptr<Base>>
|
| 18 |
+
make_unique_base(Args&&... args) {
|
| 19 |
+
return std::make_unique<Child>(std::forward<Args>(args)...);
|
| 20 |
+
}
|
| 21 |
+
} // namespace detail
|
| 22 |
+
|
| 23 |
+
inline KernelFunction::KernelFunction()
|
| 24 |
+
: unboxed_kernel_func_(nullptr), sym_unboxed_kernel_func_(nullptr) {}
|
| 25 |
+
|
| 26 |
+
inline KernelFunction::~KernelFunction() {
|
| 27 |
+
if (tokens_) {
|
| 28 |
+
for (auto& weak_token : *tokens_) {
|
| 29 |
+
if (auto token = weak_token.lock()) {
|
| 30 |
+
token->invalidate();
|
| 31 |
+
}
|
| 32 |
+
}
|
| 33 |
+
}
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
inline KernelFunction::KernelFunction(const KernelFunction& other)
|
| 37 |
+
: boxed_kernel_func_(other.boxed_kernel_func_),
|
| 38 |
+
unboxed_kernel_func_(other.unboxed_kernel_func_),
|
| 39 |
+
sym_unboxed_kernel_func_(other.sym_unboxed_kernel_func_) {
|
| 40 |
+
// tokens_ is intentionally not copied as we only care about invalidating
|
| 41 |
+
// tokens if the original KernelFunction is destroyed
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
inline KernelFunction& KernelFunction::operator=(const KernelFunction& other) {
|
| 45 |
+
if (this != &other) {
|
| 46 |
+
boxed_kernel_func_ = other.boxed_kernel_func_;
|
| 47 |
+
unboxed_kernel_func_ = other.unboxed_kernel_func_;
|
| 48 |
+
sym_unboxed_kernel_func_ = other.sym_unboxed_kernel_func_;
|
| 49 |
+
|
| 50 |
+
// tokens_ is intentionally not copied as we only care about invalidating
|
| 51 |
+
// tokens if the original KernelFunction is destroyed
|
| 52 |
+
}
|
| 53 |
+
return *this;
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
inline KernelFunction::KernelFunction(
|
| 57 |
+
std::unique_ptr<OperatorKernel> functor,
|
| 58 |
+
InternalBoxedKernelFunction* boxed_kernel_func,
|
| 59 |
+
void* unboxed_kernel_func,
|
| 60 |
+
void* sym_unboxed_kernel_func = nullptr)
|
| 61 |
+
: boxed_kernel_func_(std::move(functor), boxed_kernel_func),
|
| 62 |
+
unboxed_kernel_func_(unboxed_kernel_func),
|
| 63 |
+
sym_unboxed_kernel_func_(sym_unboxed_kernel_func) {}
|
| 64 |
+
|
| 65 |
+
inline KernelFunction::KernelFunction(
|
| 66 |
+
BoxedKernel boxed_fn,
|
| 67 |
+
void* unboxed_kernel_func,
|
| 68 |
+
void* sym_unboxed_kernel_func = nullptr)
|
| 69 |
+
: boxed_kernel_func_(std::move(boxed_fn)),
|
| 70 |
+
unboxed_kernel_func_(unboxed_kernel_func),
|
| 71 |
+
sym_unboxed_kernel_func_(sym_unboxed_kernel_func) {}
|
| 72 |
+
|
| 73 |
+
inline bool KernelFunction::isValidUnboxed() const {
|
| 74 |
+
return unboxed_kernel_func_ != nullptr;
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
inline bool KernelFunction::isValidSymUnboxed() const {
|
| 78 |
+
return sym_unboxed_kernel_func_ != nullptr;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
inline bool KernelFunction::isValid() const {
|
| 82 |
+
return boxed_kernel_func_.isValid();
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
inline bool KernelFunction::isFallthrough() const {
|
| 86 |
+
return boxed_kernel_func_.isFallthrough();
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
inline void KernelFunction::callBoxed(
|
| 90 |
+
const OperatorHandle& opHandle,
|
| 91 |
+
DispatchKeySet dispatchKeySet,
|
| 92 |
+
Stack* stack) const {
|
| 93 |
+
boxed_kernel_func_.callBoxed(opHandle, dispatchKeySet, stack);
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
template <class Return, class... Args>
|
| 97 |
+
inline Return callUnboxedKernelFunction(
|
| 98 |
+
void* unboxed_kernel_func,
|
| 99 |
+
OperatorKernel* functor,
|
| 100 |
+
DispatchKeySet dispatchKeySet,
|
| 101 |
+
Args&&... args) {
|
| 102 |
+
using ActualSignature = Return(OperatorKernel*, DispatchKeySet, Args...);
|
| 103 |
+
ActualSignature* func =
|
| 104 |
+
reinterpret_cast<ActualSignature*>(unboxed_kernel_func);
|
| 105 |
+
return (*func)(functor, dispatchKeySet, std::forward<Args>(args)...);
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
// This template requires you to explicitly specify the argument you want to
|
| 109 |
+
// forward; it doesn't work if you try to deduce it
|
| 110 |
+
// NB: keep this in sync with cloneWithRealTypes in function_schema.cpp
|
| 111 |
+
|
| 112 |
+
template <typename T>
|
| 113 |
+
inline typename remove_symint<T>::type unpackSymInt(T x) {
|
| 114 |
+
return x;
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
template <>
|
| 118 |
+
inline remove_symint<c10::SymInt>::type unpackSymInt(c10::SymInt x) {
|
| 119 |
+
return x.guard_int(__FILE__, __LINE__);
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
template <>
|
| 123 |
+
inline remove_symint<c10::SymIntArrayRef>::type unpackSymInt(
|
| 124 |
+
c10::SymIntArrayRef x) {
|
| 125 |
+
return C10_AS_INTARRAYREF_SLOW(x);
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
template <>
|
| 129 |
+
inline remove_symint<std::optional<c10::SymInt>>::type unpackSymInt(
|
| 130 |
+
std::optional<c10::SymInt> x) {
|
| 131 |
+
return x.has_value() ? std::make_optional(x->guard_int(__FILE__, __LINE__))
|
| 132 |
+
: std::nullopt;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
template <>
|
| 136 |
+
inline remove_symint<at::OptionalSymIntArrayRef>::type unpackSymInt(
|
| 137 |
+
at::OptionalSymIntArrayRef x) {
|
| 138 |
+
return x.has_value() ? std::make_optional(C10_AS_INTARRAYREF_SLOW(*x))
|
| 139 |
+
: std::nullopt;
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
template <class Return, class... Args>
|
| 143 |
+
C10_ALWAYS_INLINE Return KernelFunction::call(
|
| 144 |
+
const OperatorHandle& opHandle,
|
| 145 |
+
DispatchKeySet dispatchKeySet,
|
| 146 |
+
Args... args) const {
|
| 147 |
+
// note: Args above is intentionally not Args&&. We don't want perfect
|
| 148 |
+
// forwarding, which would require Args to be deduced, but instead we
|
| 149 |
+
// want callers to explicitly specify the Args.
|
| 150 |
+
|
| 151 |
+
if constexpr (std::disjunction_v<has_symint<Args>...>) {
|
| 152 |
+
if (sym_unboxed_kernel_func_ != nullptr) {
|
| 153 |
+
auto* functor = boxed_kernel_func_.getFunctor();
|
| 154 |
+
return callUnboxedKernelFunction<Return, Args...>(
|
| 155 |
+
sym_unboxed_kernel_func_,
|
| 156 |
+
functor,
|
| 157 |
+
dispatchKeySet,
|
| 158 |
+
std::forward<Args>(args)...);
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
if (unboxed_kernel_func_ != nullptr) {
|
| 162 |
+
auto* functor = boxed_kernel_func_.getFunctor();
|
| 163 |
+
return callUnboxedKernelFunction<
|
| 164 |
+
Return,
|
| 165 |
+
typename remove_symint<Args>::type...>(
|
| 166 |
+
unboxed_kernel_func_,
|
| 167 |
+
functor,
|
| 168 |
+
dispatchKeySet,
|
| 169 |
+
unpackSymInt<Args>(args)...);
|
| 170 |
+
}
|
| 171 |
+
} else {
|
| 172 |
+
if (C10_LIKELY(unboxed_kernel_func_ != nullptr)) {
|
| 173 |
+
auto* functor = boxed_kernel_func_.getFunctor();
|
| 174 |
+
return callUnboxedKernelFunction<Return, Args...>(
|
| 175 |
+
unboxed_kernel_func_,
|
| 176 |
+
functor,
|
| 177 |
+
dispatchKeySet,
|
| 178 |
+
std::forward<Args>(args)...);
|
| 179 |
+
}
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
return impl::BoxedKernelWrapper<Return(Args...)>::call(
|
| 183 |
+
boxed_kernel_func_,
|
| 184 |
+
opHandle,
|
| 185 |
+
dispatchKeySet,
|
| 186 |
+
std::forward<Args>(args)...);
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
inline void KernelFunction::registerToken(
|
| 190 |
+
std::weak_ptr<KernelToken> token) const {
|
| 191 |
+
if (!tokens_) {
|
| 192 |
+
tokens_ = std::make_unique<std::vector<std::weak_ptr<KernelToken>>>();
|
| 193 |
+
}
|
| 194 |
+
tokens_->push_back(std::move(token));
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
inline KernelFunction KernelFunction::makeFromBoxedKernel(
|
| 198 |
+
BoxedKernel boxed_fn) {
|
| 199 |
+
return KernelFunction(
|
| 200 |
+
std::move(boxed_fn), nullptr); // no unboxed function pointer
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
template <KernelFunction::BoxedKernelFunction* func>
|
| 204 |
+
inline KernelFunction KernelFunction::makeFromBoxedFunction() {
|
| 205 |
+
return KernelFunction::makeFromBoxedKernel(
|
| 206 |
+
BoxedKernel::makeFromFunction<func>());
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
template <KernelFunction::BoxedKernelFunction_withDispatchKeys* func>
|
| 210 |
+
inline KernelFunction KernelFunction::makeFromBoxedFunction() {
|
| 211 |
+
return KernelFunction::makeFromBoxedKernel(
|
| 212 |
+
BoxedKernel::makeFromFunction<func>());
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
inline KernelFunction KernelFunction::makeFallthrough() {
|
| 216 |
+
return KernelFunction::makeFromBoxedKernel(BoxedKernel::makeFallthrough());
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
inline KernelFunction KernelFunction::makeAmbiguousAutogradOther() {
|
| 220 |
+
return KernelFunction::makeFromBoxedKernel(
|
| 221 |
+
BoxedKernel::makeAmbiguousAutogradOther());
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
inline KernelFunction KernelFunction::makeNamedNotSupported() {
|
| 225 |
+
return KernelFunction::makeFromBoxedKernel(
|
| 226 |
+
BoxedKernel::makeNamedNotSupported());
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
template <bool AllowLegacyTypes, class KernelFunctor>
|
| 230 |
+
inline KernelFunction KernelFunction::makeFromUnboxedFunctor(
|
| 231 |
+
std::unique_ptr<OperatorKernel> kernelFunctor) {
|
| 232 |
+
#ifndef NDEBUG
|
| 233 |
+
// This assertion is costly for build time so it's debug-gated.
|
| 234 |
+
static_assert(
|
| 235 |
+
guts::is_functor<KernelFunctor>::value,
|
| 236 |
+
"Tried to call KernelFunction::makeFromUnboxedFunctor<KernelFunctor> but the argument is not a functor.");
|
| 237 |
+
#endif
|
| 238 |
+
static_assert(
|
| 239 |
+
std::is_base_of_v<OperatorKernel, KernelFunctor>,
|
| 240 |
+
"Tried to call KernelFunction::makeFromUnboxedFunctor<KernelFunctor>, but the functor doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
|
| 241 |
+
|
| 242 |
+
auto* unboxed_fn = &impl::wrap_kernel_functor_unboxed<KernelFunctor>::call;
|
| 243 |
+
void* void_unboxed_fn = reinterpret_cast<void*>(unboxed_fn);
|
| 244 |
+
bool is_symint = fn_has_symint<decltype(unboxed_fn)>::value;
|
| 245 |
+
return KernelFunction(
|
| 246 |
+
std::move(kernelFunctor),
|
| 247 |
+
&impl::make_boxed_from_unboxed_functor<KernelFunctor, AllowLegacyTypes>::
|
| 248 |
+
call,
|
| 249 |
+
is_symint ? nullptr : void_unboxed_fn,
|
| 250 |
+
is_symint ? void_unboxed_fn : nullptr);
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
template <class KernelFunctor>
|
| 254 |
+
inline KernelFunction KernelFunction::makeFromBoxedFunctor(
|
| 255 |
+
std::unique_ptr<KernelFunctor> kernelFunctor) {
|
| 256 |
+
return KernelFunction::makeFromBoxedKernel(
|
| 257 |
+
BoxedKernel::makeFromFunctor(std::move(kernelFunctor)));
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
template <class FuncPtr, bool AllowLegacyTypes>
|
| 261 |
+
inline KernelFunction KernelFunction::makeFromUnboxedFunction(
|
| 262 |
+
FuncPtr func_ptr) {
|
| 263 |
+
static_assert(
|
| 264 |
+
is_compile_time_function_pointer<FuncPtr>::value,
|
| 265 |
+
"Tried to call KernelFunction::makeFromUnboxedFunction with an invalid parameter. It must be a function pointer created with TORCH_FN.");
|
| 266 |
+
static_assert(
|
| 267 |
+
!std::is_same_v<typename FuncPtr::FuncType, BoxedKernelFunction>,
|
| 268 |
+
"Tried to call KernelFunction::makeFromUnboxedFunction with a boxed function pointer. Please use KernelFunction::makeFromBoxedFunction instead.");
|
| 269 |
+
#if defined(__GNUC__) && defined(__SANITIZE_ADDRESS__) && !defined(__CUDACC__)
|
| 270 |
+
TORCH_INTERNAL_ASSERT(
|
| 271 |
+
FuncPtr::func_ptr() != nullptr, "Kernel function cannot be nullptr");
|
| 272 |
+
#else
|
| 273 |
+
static_assert(
|
| 274 |
+
FuncPtr::func_ptr() != nullptr, "Kernel function cannot be nullptr");
|
| 275 |
+
#endif
|
| 276 |
+
|
| 277 |
+
#if !defined(C10_MOBILE)
|
| 278 |
+
(void)func_ptr; // Suppress unused variable warning
|
| 279 |
+
return makeFromUnboxedFunctor<
|
| 280 |
+
AllowLegacyTypes,
|
| 281 |
+
typename impl::WrapFunctionIntoFunctor<FuncPtr>::type>(
|
| 282 |
+
detail::make_unique_base<
|
| 283 |
+
OperatorKernel,
|
| 284 |
+
typename impl::WrapFunctionIntoFunctor<FuncPtr>::type>());
|
| 285 |
+
#else
|
| 286 |
+
// On mobile, we rather want to optimize for binary size than for performance,
|
| 287 |
+
// so let's not inline the kernel into the wrapper but use
|
| 288 |
+
// makeFromUnboxedRuntimeFunction instead.
|
| 289 |
+
return makeFromUnboxedRuntimeFunction(func_ptr.func_ptr());
|
| 290 |
+
#endif
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
template <bool AllowLegacyTypes, class FuncType>
|
| 294 |
+
inline KernelFunction KernelFunction::makeFromUnboxedRuntimeFunction(
|
| 295 |
+
FuncType* func) {
|
| 296 |
+
static_assert(
|
| 297 |
+
guts::is_function_type<FuncType>::value,
|
| 298 |
+
"Tried to call KernelFunction::makeFromUnboxedRuntimeFunction with a non-function type.");
|
| 299 |
+
static_assert(
|
| 300 |
+
!std::is_same_v<FuncType, BoxedKernelFunction>,
|
| 301 |
+
"Tried to call KernelFunction::makeFromUnboxedRuntimeFunction with a boxed function pointer. Please use KernelFunction::makeFromBoxedFunction instead.");
|
| 302 |
+
TORCH_INTERNAL_ASSERT(func != nullptr, "Kernel function cannot be nullptr");
|
| 303 |
+
|
| 304 |
+
return makeFromUnboxedFunctor<
|
| 305 |
+
AllowLegacyTypes,
|
| 306 |
+
impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<FuncType>>>(
|
| 307 |
+
detail::make_unique_base<
|
| 308 |
+
OperatorKernel,
|
| 309 |
+
impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<FuncType>>>(func));
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
template <bool AllowLegacyTypes, class Lambda>
|
| 313 |
+
inline std::enable_if_t<
|
| 314 |
+
guts::is_stateless_lambda<std::decay_t<Lambda>>::value,
|
| 315 |
+
KernelFunction>
|
| 316 |
+
KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) {
|
| 317 |
+
static_assert(
|
| 318 |
+
guts::is_functor<std::decay_t<Lambda>>::value,
|
| 319 |
+
"Tried to call KernelFunction::makeFromUnboxedLambda with a non-lambda type.");
|
| 320 |
+
|
| 321 |
+
#if !defined(C10_MOBILE)
|
| 322 |
+
return makeFromUnboxedFunctor<
|
| 323 |
+
AllowLegacyTypes,
|
| 324 |
+
impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>(
|
| 325 |
+
detail::make_unique_base<
|
| 326 |
+
OperatorKernel,
|
| 327 |
+
impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>(
|
| 328 |
+
std::forward<Lambda>(lambda)));
|
| 329 |
+
#else
|
| 330 |
+
// On mobile, we rather want to optimize for binary size than for performance,
|
| 331 |
+
// so let's not inline the kernel into the wrapper but use
|
| 332 |
+
// makeFromUnboxedRuntimeFunction instead.
|
| 333 |
+
using FuncType =
|
| 334 |
+
typename guts::infer_function_traits_t<std::decay_t<Lambda>>::func_type;
|
| 335 |
+
return makeFromUnboxedRuntimeFunction<AllowLegacyTypes, FuncType>(lambda);
|
| 336 |
+
#endif
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
template <bool AllowLegacyTypes, class Lambda>
|
| 340 |
+
inline std::enable_if_t<
|
| 341 |
+
!guts::is_stateless_lambda<std::decay_t<Lambda>>::value,
|
| 342 |
+
KernelFunction>
|
| 343 |
+
KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) {
|
| 344 |
+
static_assert(
|
| 345 |
+
guts::is_functor<std::decay_t<Lambda>>::value,
|
| 346 |
+
"Tried to call KernelFunction::makeFromUnboxedLambda with a non-lambda type.");
|
| 347 |
+
|
| 348 |
+
return makeFromUnboxedFunctor<
|
| 349 |
+
AllowLegacyTypes,
|
| 350 |
+
impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>(
|
| 351 |
+
detail::make_unique_base<
|
| 352 |
+
OperatorKernel,
|
| 353 |
+
impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>(
|
| 354 |
+
std::forward<Lambda>(lambda)));
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
inline bool KernelToken::isValid() const {
|
| 358 |
+
return !invalid_.load(std::memory_order_acquire);
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
inline void KernelToken::invalidate() {
|
| 362 |
+
invalid_.store(true, std::memory_order_release);
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
inline SafeKernelFunction::SafeKernelFunction(
|
| 366 |
+
const KernelFunction* kernel,
|
| 367 |
+
std::string debug,
|
| 368 |
+
std::shared_ptr<OperatorHandle> opHandle)
|
| 369 |
+
: kernel_(kernel ? *kernel : KernelFunction()),
|
| 370 |
+
token_(std::make_shared<KernelToken>()),
|
| 371 |
+
debug_(std::move(debug)),
|
| 372 |
+
opHandle_(std::move(opHandle)) {
|
| 373 |
+
// Register the token with the original kernel so it gets invalidated when the
|
| 374 |
+
// kernel is destroyed
|
| 375 |
+
if (kernel) {
|
| 376 |
+
kernel->registerToken(token_);
|
| 377 |
+
}
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
inline void SafeKernelFunction::callBoxed(
|
| 381 |
+
const OperatorHandle& opHandle,
|
| 382 |
+
DispatchKeySet dispatchKeySet,
|
| 383 |
+
Stack* stack) const {
|
| 384 |
+
TORCH_CHECK(
|
| 385 |
+
token_ && token_->isValid(),
|
| 386 |
+
"SafeKernelFunction has been invalidated ",
|
| 387 |
+
debug_);
|
| 388 |
+
kernel_.callBoxed(opHandle, dispatchKeySet, stack);
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
} // namespace c10
|
| 392 |
+
|
| 393 |
+
#else
|
| 394 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 395 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/OperatorKernel.h
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
#include <c10/util/intrusive_ptr.h>
|
| 4 |
+
|
| 5 |
+
namespace c10 {
|
| 6 |
+
|
| 7 |
+
/**
|
| 8 |
+
* Inherit from OperatorKernel to implement a c10 kernel.
|
| 9 |
+
*
|
| 10 |
+
* Example:
|
| 11 |
+
* > namespace {
|
| 12 |
+
* > class my_kernel_cpu final : public c10::OperatorKernel {
|
| 13 |
+
* > public:
|
| 14 |
+
* > Tensor operator()(Tensor a, Tensor b) {...}
|
| 15 |
+
* > };
|
| 16 |
+
* > }
|
| 17 |
+
*
|
| 18 |
+
* The kernel class is allowed to have members but these are equivalent
|
| 19 |
+
* to global variables. The kernel implementation is responsible for
|
| 20 |
+
* preventing race conditions on them.
|
| 21 |
+
*
|
| 22 |
+
* See below for how to register this kernel with PyTorch.
|
| 23 |
+
*/
|
| 24 |
+
struct TORCH_API OperatorKernel : public c10::intrusive_ptr_target {
|
| 25 |
+
~OperatorKernel() override = default;
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
} // namespace c10
|
| 29 |
+
|
| 30 |
+
#else
|
| 31 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 32 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <c10/core/CompileTimeFunctionPointer.h>
|
| 5 |
+
|
| 6 |
+
namespace c10::impl {
|
| 7 |
+
namespace detail {
|
| 8 |
+
template <class FuncPtr, class ReturnType, class ParameterList>
|
| 9 |
+
class WrapFunctionIntoFunctor_ {};
|
| 10 |
+
template <class FuncPtr, class ReturnType, class... Parameters>
|
| 11 |
+
class WrapFunctionIntoFunctor_<
|
| 12 |
+
FuncPtr,
|
| 13 |
+
ReturnType,
|
| 14 |
+
guts::typelist::typelist<Parameters...>>
|
| 15 |
+
final : public c10::OperatorKernel {
|
| 16 |
+
public:
|
| 17 |
+
C10_ALWAYS_INLINE decltype(auto) operator()(Parameters... args) {
|
| 18 |
+
return (*FuncPtr::func_ptr())(std::forward<Parameters>(args)...);
|
| 19 |
+
}
|
| 20 |
+
};
|
| 21 |
+
} // namespace detail
|
| 22 |
+
|
| 23 |
+
// WrapFunctionIntoFunctor: Wraps a compile time function pointer into a kernel
|
| 24 |
+
// functor. Since it is a compile time function pointer, many compilers can
|
| 25 |
+
// inline it into the wrapper and you don't get any performance overhead for
|
| 26 |
+
// wrapping.
|
| 27 |
+
template <class FuncPtr>
|
| 28 |
+
struct WrapFunctionIntoFunctor final {
|
| 29 |
+
static_assert(
|
| 30 |
+
c10::is_compile_time_function_pointer<FuncPtr>::value,
|
| 31 |
+
"WrapFunctionIntoFunctor can only wrap functions created with TORCH_FN.");
|
| 32 |
+
using type = detail::WrapFunctionIntoFunctor_<
|
| 33 |
+
FuncPtr,
|
| 34 |
+
typename guts::function_traits<typename FuncPtr::FuncType>::return_type,
|
| 35 |
+
typename guts::function_traits<
|
| 36 |
+
typename FuncPtr::FuncType>::parameter_types>;
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
} // namespace c10::impl
|
| 40 |
+
|
| 41 |
+
#else
|
| 42 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 43 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <c10/util/TypeTraits.h>
|
| 5 |
+
|
| 6 |
+
namespace c10::impl {
|
| 7 |
+
|
| 8 |
+
namespace detail {
|
| 9 |
+
template <class FuncType, class ReturnType, class ParameterList>
|
| 10 |
+
class WrapFunctionIntoRuntimeFunctor_ {};
|
| 11 |
+
template <class FuncType, class ReturnType, class... Parameters>
|
| 12 |
+
class WrapFunctionIntoRuntimeFunctor_<
|
| 13 |
+
FuncType,
|
| 14 |
+
ReturnType,
|
| 15 |
+
guts::typelist::typelist<Parameters...>>
|
| 16 |
+
final : public c10::OperatorKernel {
|
| 17 |
+
public:
|
| 18 |
+
template <class FuncType_>
|
| 19 |
+
explicit WrapFunctionIntoRuntimeFunctor_(FuncType_&& kernel_func)
|
| 20 |
+
: kernel_func_(std::forward<FuncType_>(kernel_func)) {}
|
| 21 |
+
|
| 22 |
+
decltype(auto) operator()(Parameters... args) {
|
| 23 |
+
return kernel_func_(std::forward<Parameters>(args)...);
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
private:
|
| 27 |
+
FuncType kernel_func_;
|
| 28 |
+
};
|
| 29 |
+
} // namespace detail
|
| 30 |
+
|
| 31 |
+
// WrapFunctionIntoRuntimeFunctor: Wraps any runtime functor into a functor that
|
| 32 |
+
// inherits from c10::OperatorKernel, so it can be used as a c10 kernel.
|
| 33 |
+
// This can, for example, be used for lambdas, functors or even function
|
| 34 |
+
// pointers. In the case of function pointers, since it is a runtime function
|
| 35 |
+
// pointer, there is an overhead for calling it whenever the kernel is invoked.
|
| 36 |
+
template <class FuncType>
|
| 37 |
+
using WrapFunctionIntoRuntimeFunctor = detail::WrapFunctionIntoRuntimeFunctor_<
|
| 38 |
+
FuncType,
|
| 39 |
+
typename guts::infer_function_traits_t<FuncType>::return_type,
|
| 40 |
+
typename guts::infer_function_traits_t<FuncType>::parameter_types>;
|
| 41 |
+
|
| 42 |
+
} // namespace c10::impl
|
| 43 |
+
|
| 44 |
+
#else
|
| 45 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 46 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/impl/boxing.h
ADDED
|
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
// This file contains boxing (not unboxing) logic,
|
| 5 |
+
// i.e. how to make a vector<IValue> from a set of concrete arguments.
|
| 6 |
+
|
| 7 |
+
#include <ATen/core/ivalue.h>
|
| 8 |
+
#include <ATen/core/stack.h>
|
| 9 |
+
#include <c10/core/TensorOptions.h>
|
| 10 |
+
|
| 11 |
+
#include <ATen/core/boxing/BoxedKernel.h>
|
| 12 |
+
|
| 13 |
+
#include <c10/util/Metaprogramming.h>
|
| 14 |
+
#include <type_traits>
|
| 15 |
+
|
| 16 |
+
namespace c10::impl {
|
| 17 |
+
|
| 18 |
+
//
|
| 19 |
+
// utils
|
| 20 |
+
//
|
| 21 |
+
|
| 22 |
+
// is_mutable_tensor_ref
|
| 23 |
+
template <class T>
|
| 24 |
+
struct is_mutable_tensor_ref : std::false_type {};
|
| 25 |
+
template <>
|
| 26 |
+
struct is_mutable_tensor_ref<at::Tensor&> : std::true_type {};
|
| 27 |
+
|
| 28 |
+
// is_tuple_of_mutable_tensor_refs
|
| 29 |
+
//
|
| 30 |
+
template <class T, class Enable = void>
|
| 31 |
+
struct is_tuple_of_mutable_tensor_refs : std::false_type {};
|
| 32 |
+
|
| 33 |
+
template <class T>
|
| 34 |
+
struct is_tuple_of_mutable_tensor_refs<
|
| 35 |
+
T,
|
| 36 |
+
std::enable_if_t<guts::is_instantiation_of<std::tuple, T>::value, void>>
|
| 37 |
+
: guts::typelist::
|
| 38 |
+
all<is_mutable_tensor_ref, guts::typelist::from_tuple_t<T>> {};
|
| 39 |
+
|
| 40 |
+
// has_ivalue_to<T> tests the presence/absence of instance method
|
| 41 |
+
// IValue::to<T>()
|
| 42 |
+
//
|
| 43 |
+
template <class T, class Enable = void>
|
| 44 |
+
struct has_ivalue_to : std::false_type {};
|
| 45 |
+
|
| 46 |
+
template <class T>
|
| 47 |
+
struct ivalue_to_helper {
|
| 48 |
+
using type = decltype(std::declval<IValue>().template to<T>());
|
| 49 |
+
};
|
| 50 |
+
template <class T>
|
| 51 |
+
using ivalue_to_helper_t = typename ivalue_to_helper<T>::type;
|
| 52 |
+
|
| 53 |
+
template <class T>
|
| 54 |
+
struct has_ivalue_to<T, std::void_t<ivalue_to_helper_t<T>>> : std::true_type {};
|
| 55 |
+
|
| 56 |
+
//
|
| 57 |
+
// boxing predicates
|
| 58 |
+
//
|
| 59 |
+
|
| 60 |
+
// A boxable arg type is one that IValue has a constructor for.
|
| 61 |
+
template <typename T>
|
| 62 |
+
using can_box = std::disjunction<
|
| 63 |
+
std::is_constructible<IValue, std::decay_t<T>>,
|
| 64 |
+
// TensorOptions are not directly constructible into IValue,
|
| 65 |
+
// but torch::jit::push knows how to handle them
|
| 66 |
+
std::is_same<TensorOptions, std::decay_t<T>>>;
|
| 67 |
+
|
| 68 |
+
template <typename... Ts>
|
| 69 |
+
using can_box_all = std::conjunction<can_box<Ts>...>;
|
| 70 |
+
|
| 71 |
+
// an unboxable result is one that can be extracted from an IValue
|
| 72 |
+
template <typename T>
|
| 73 |
+
using can_unbox = std::conjunction<
|
| 74 |
+
std::disjunction<
|
| 75 |
+
has_ivalue_to<T>,
|
| 76 |
+
// void returns are ok
|
| 77 |
+
std::is_same<void, T>>,
|
| 78 |
+
std::negation<std::is_lvalue_reference<T>>>;
|
| 79 |
+
|
| 80 |
+
//
|
| 81 |
+
// boxArgs - utility for pushing unboxed args onto IValue stack
|
| 82 |
+
//
|
| 83 |
+
template <class... Args>
|
| 84 |
+
torch::jit::Stack boxArgs(Args... args) {
|
| 85 |
+
// TODO Reuse stack vector instead of allocating?
|
| 86 |
+
torch::jit::Stack stack;
|
| 87 |
+
stack.reserve(sizeof...(Args));
|
| 88 |
+
torch::jit::push(stack, std::forward<Args>(args)...);
|
| 89 |
+
return stack;
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
template <class T>
|
| 93 |
+
inline constexpr size_t boxed_size_one() {
|
| 94 |
+
static_assert(
|
| 95 |
+
!std::is_same_v<std::decay_t<T>, c10::TensorOptions>,
|
| 96 |
+
"need to patch this path to support TensorOptions passed by reference");
|
| 97 |
+
return 1;
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
// torch::jit::push pushes 4 values for a TensorOptions; this needs to
|
| 101 |
+
// be kept in sync.
|
| 102 |
+
template <>
|
| 103 |
+
inline constexpr size_t boxed_size_one<c10::TensorOptions>() {
|
| 104 |
+
return 4;
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
// NOTE: this could probably be simplified with C++17 fold expressions.
|
| 108 |
+
template <typename...>
|
| 109 |
+
struct BoxedSize : std::integral_constant<size_t, 0> {};
|
| 110 |
+
template <class T, class... Args>
|
| 111 |
+
struct BoxedSize<T, Args...>
|
| 112 |
+
: std::integral_constant<
|
| 113 |
+
size_t,
|
| 114 |
+
boxed_size_one<T>() + BoxedSize<Args...>::value> {};
|
| 115 |
+
|
| 116 |
+
template <class... Args>
|
| 117 |
+
static inline constexpr size_t boxed_size() {
|
| 118 |
+
return BoxedSize<Args...>::value;
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
template <typename T>
|
| 122 |
+
C10_ALWAYS_INLINE_UNLESS_MOBILE void boxToStack(IValue*& dest, T& arg) {
|
| 123 |
+
new (dest++) IValue(arg);
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
C10_ALWAYS_INLINE_UNLESS_MOBILE void boxToStack(
|
| 127 |
+
IValue*& dest,
|
| 128 |
+
c10::TensorOptions options) {
|
| 129 |
+
new (dest++) IValue(c10::typeMetaToScalarType(options.dtype()));
|
| 130 |
+
new (dest++) IValue(options.layout());
|
| 131 |
+
new (dest++) IValue(options.device());
|
| 132 |
+
new (dest++) IValue(options.pinned_memory());
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
inline void boxArgsToStack(IValue*& /*unused*/) {}
|
| 136 |
+
|
| 137 |
+
template <typename T, typename... Args>
|
| 138 |
+
C10_ALWAYS_INLINE_UNLESS_MOBILE void boxArgsToStack(
|
| 139 |
+
IValue*& dest,
|
| 140 |
+
T& arg,
|
| 141 |
+
Args&... args) {
|
| 142 |
+
boxToStack(dest, arg);
|
| 143 |
+
boxArgsToStack(dest, args...);
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
//
|
| 147 |
+
// PopResult is a helper class whose specializations handle popping single and
|
| 148 |
+
// multiple return values, respectively.
|
| 149 |
+
//
|
| 150 |
+
template <class Result>
|
| 151 |
+
struct PopResult final {
|
| 152 |
+
static Result call(Stack& stack) {
|
| 153 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
| 154 |
+
stack.size() == 1,
|
| 155 |
+
"Boxed kernel was expected to return one value on the stack, ",
|
| 156 |
+
"but instead pushed ",
|
| 157 |
+
stack.size(),
|
| 158 |
+
" values.");
|
| 159 |
+
return std::move(stack[0]).to<Result>();
|
| 160 |
+
}
|
| 161 |
+
};
|
| 162 |
+
|
| 163 |
+
template <class... Types>
|
| 164 |
+
struct PopResult<std::tuple<Types...>> final {
|
| 165 |
+
using Result = std::tuple<Types...>;
|
| 166 |
+
|
| 167 |
+
static Result call(Stack& stack) {
|
| 168 |
+
// for tuple return types, boxed kernel has pushed multiple values onto the
|
| 169 |
+
// stack
|
| 170 |
+
constexpr int RetCount = sizeof...(Types);
|
| 171 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
| 172 |
+
stack.size() == RetCount,
|
| 173 |
+
"Boxed kernel was expected to return ",
|
| 174 |
+
RetCount,
|
| 175 |
+
" values on the stack, ",
|
| 176 |
+
"but instead pushed ",
|
| 177 |
+
stack.size(),
|
| 178 |
+
" values.");
|
| 179 |
+
return pop_to_tuple_impl(stack, std::make_index_sequence<RetCount>());
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
private:
|
| 183 |
+
// note: this has been moved into its own helper only to avoid a parse error
|
| 184 |
+
// on `indices` otherwise. I'm sure there's an incantation that slips it past
|
| 185 |
+
// the parser but eh
|
| 186 |
+
template <size_t... indices>
|
| 187 |
+
static Result pop_to_tuple_impl(
|
| 188 |
+
Stack& stack,
|
| 189 |
+
std::index_sequence<indices...> /*unused*/) {
|
| 190 |
+
return std::make_tuple((std::move(stack[indices]).template to<Types>())...);
|
| 191 |
+
}
|
| 192 |
+
};
|
| 193 |
+
|
| 194 |
+
//
|
| 195 |
+
// BoxedKernelWrapper
|
| 196 |
+
//
|
| 197 |
+
// For a given function type FT, BoxedKernelWrapper<FT> implements
|
| 198 |
+
// a `call` method that
|
| 199 |
+
// - takes a boxed kernel and unboxed arguments as specified by FT,
|
| 200 |
+
// - calls `boxArgs` to box the arguments
|
| 201 |
+
// - calls the boxed kernel
|
| 202 |
+
// - unboxes and returns the result
|
| 203 |
+
//
|
| 204 |
+
// The partial specializations below handle various cases: in
|
| 205 |
+
// particular, not all types appearing in op signatures are supported,
|
| 206 |
+
// and ops returning references have nonstandard wrapper implementations.
|
| 207 |
+
//
|
| 208 |
+
|
| 209 |
+
// 1. The base specialization of BoxedKernelWrapper should never be
|
| 210 |
+
// instantiated. A "no call method defined on BoxedKernelWrapper" compile error
|
| 211 |
+
// means that an op signature has failed to trigger any of the partial
|
| 212 |
+
// specializations that follow this one.
|
| 213 |
+
//
|
| 214 |
+
template <class FuncType, class Enable = void>
|
| 215 |
+
struct BoxedKernelWrapper {
|
| 216 |
+
// The reason we're not just doing straight up static_assert(false, ...) here:
|
| 217 |
+
// Basically, the way to make sure a static_assert only fires if a template
|
| 218 |
+
// is actually instantiated (rather than every time the file is parsed) is to
|
| 219 |
+
// use template parameters in the expression, e.g. FuncType here. However,
|
| 220 |
+
// since `sizeof(FuncType) != sizeof(FuncType)` is always false, this has the
|
| 221 |
+
// same effect.
|
| 222 |
+
static_assert(
|
| 223 |
+
sizeof(FuncType) != sizeof(FuncType),
|
| 224 |
+
"Function signature contains one or more unsupported parameter and/or return types. "
|
| 225 |
+
"Look for a nearby error like "
|
| 226 |
+
"\"'call' is not a member of 'c10::impl::BoxedKernelWrapper<(your function type), void>'\" "
|
| 227 |
+
"- (your function type) is the unsupported signature.");
|
| 228 |
+
};
|
| 229 |
+
|
| 230 |
+
//
|
| 231 |
+
// 2. Supported signatures, other than those involving non-const Tensor refs -
|
| 232 |
+
// i.e., "functional" ops.
|
| 233 |
+
//
|
| 234 |
+
|
| 235 |
+
template <class Result, class... Args>
|
| 236 |
+
struct BoxedKernelWrapper<
|
| 237 |
+
Result(Args...),
|
| 238 |
+
std::enable_if_t<
|
| 239 |
+
can_box_all<Args...>::value && can_unbox<Result>::value &&
|
| 240 |
+
!is_tuple_of_mutable_tensor_refs<Result>::value,
|
| 241 |
+
void>> {
|
| 242 |
+
static Result call(
|
| 243 |
+
const BoxedKernel& boxed_kernel_func,
|
| 244 |
+
const OperatorHandle& opHandle,
|
| 245 |
+
DispatchKeySet dispatchKeySet,
|
| 246 |
+
Args... args) {
|
| 247 |
+
torch::jit::Stack stack = boxArgs<Args...>(std::forward<Args>(args)...);
|
| 248 |
+
boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack);
|
| 249 |
+
|
| 250 |
+
if constexpr (!std::is_same_v<void, Result>) {
|
| 251 |
+
// op has pushed one or more values onto the stack.
|
| 252 |
+
return PopResult<Result>::call(stack);
|
| 253 |
+
} else {
|
| 254 |
+
// op returns void, boxed kernel has pushed nothing onto stack.
|
| 255 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
| 256 |
+
stack.empty(),
|
| 257 |
+
"Boxed kernel was expected to return no values on the stack, ",
|
| 258 |
+
"but instead returned ",
|
| 259 |
+
stack.size(),
|
| 260 |
+
" values.");
|
| 261 |
+
}
|
| 262 |
+
}
|
| 263 |
+
};
|
| 264 |
+
|
| 265 |
+
//
|
| 266 |
+
// 3. in-place ops take a single non-const Tensor reference
|
| 267 |
+
// as their first argument, and return it.
|
| 268 |
+
//
|
| 269 |
+
// Note: all signatures matching this pattern are assumed to be for such ops.
|
| 270 |
+
// Because of this, the generated BoxedKernelWrapper specializations simply
|
| 271 |
+
// return the in-place argument.
|
| 272 |
+
//
|
| 273 |
+
|
| 274 |
+
template <class... OtherArgs>
|
| 275 |
+
struct BoxedKernelWrapper<
|
| 276 |
+
at::Tensor&(at::Tensor&, OtherArgs...),
|
| 277 |
+
std::enable_if_t<can_box_all<OtherArgs...>::value, void>> {
|
| 278 |
+
static at::Tensor& call(
|
| 279 |
+
const BoxedKernel& boxed_kernel_func,
|
| 280 |
+
const OperatorHandle& opHandle,
|
| 281 |
+
DispatchKeySet dispatchKeySet,
|
| 282 |
+
at::Tensor& outArg,
|
| 283 |
+
OtherArgs... otherArgs) {
|
| 284 |
+
torch::jit::Stack stack = boxArgs<at::Tensor&, OtherArgs...>(
|
| 285 |
+
outArg, std::forward<OtherArgs>(otherArgs)...);
|
| 286 |
+
boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack);
|
| 287 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
| 288 |
+
stack.size() == 1,
|
| 289 |
+
"Boxed kernel was expected to return a single value on the stack, ",
|
| 290 |
+
"but instead returned ",
|
| 291 |
+
stack.size(),
|
| 292 |
+
" values.");
|
| 293 |
+
|
| 294 |
+
return outArg;
|
| 295 |
+
}
|
| 296 |
+
};
|
| 297 |
+
|
| 298 |
+
//
|
| 299 |
+
// 3.5. In-process migration to make in-place ops take and return
|
| 300 |
+
// const references instead.
|
| 301 |
+
template <class... OtherArgs>
|
| 302 |
+
struct BoxedKernelWrapper<
|
| 303 |
+
const at::Tensor&(const at::Tensor&, OtherArgs...),
|
| 304 |
+
std::enable_if_t<can_box_all<OtherArgs...>::value, void>> {
|
| 305 |
+
static const at::Tensor& call(
|
| 306 |
+
const BoxedKernel& boxed_kernel_func,
|
| 307 |
+
const OperatorHandle& opHandle,
|
| 308 |
+
DispatchKeySet dispatchKeySet,
|
| 309 |
+
const at::Tensor& outArg,
|
| 310 |
+
OtherArgs... otherArgs) {
|
| 311 |
+
torch::jit::Stack stack = boxArgs(outArg, otherArgs...);
|
| 312 |
+
boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack);
|
| 313 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
| 314 |
+
stack.size() == 1,
|
| 315 |
+
"Boxed kernel was expected to return a single value on the stack, ",
|
| 316 |
+
"but instead returned ",
|
| 317 |
+
stack.size(),
|
| 318 |
+
" values.");
|
| 319 |
+
|
| 320 |
+
return outArg;
|
| 321 |
+
}
|
| 322 |
+
};
|
| 323 |
+
|
| 324 |
+
//
|
| 325 |
+
// 4. out of place ops that take a single non-const Tensor reference as their
|
| 326 |
+
// final argument, and also return it.
|
| 327 |
+
//
|
| 328 |
+
// Note: all signatures matching this pattern are assumed to be for such ops.
|
| 329 |
+
// This assumption permits the generated BoxedKernelWrapper specializations to
|
| 330 |
+
// simply return out arguments.
|
| 331 |
+
//
|
| 332 |
+
template <class FirstArg, class... RestArgs>
|
| 333 |
+
struct BoxedKernelWrapper<
|
| 334 |
+
at::Tensor&(FirstArg, RestArgs...),
|
| 335 |
+
std::enable_if_t<
|
| 336 |
+
can_box_all<FirstArg, RestArgs...>::value
|
| 337 |
+
// this skips over in-place kernels with a non-const Tensor
|
| 338 |
+
// arg at the front, so those can unambiguously trigger the
|
| 339 |
+
// preceding specialization.
|
| 340 |
+
&& !is_mutable_tensor_ref<FirstArg>::value,
|
| 341 |
+
void>> {
|
| 342 |
+
static at::Tensor& call(
|
| 343 |
+
const BoxedKernel& boxed_kernel_func,
|
| 344 |
+
const OperatorHandle& opHandle,
|
| 345 |
+
DispatchKeySet dispatchKeySet,
|
| 346 |
+
FirstArg firstArg,
|
| 347 |
+
RestArgs... restArgs) {
|
| 348 |
+
torch::jit::Stack stack = boxArgs<FirstArg, RestArgs...>(
|
| 349 |
+
std::forward<FirstArg>(firstArg), std::forward<RestArgs>(restArgs)...);
|
| 350 |
+
boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack);
|
| 351 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
| 352 |
+
stack.size() == 1,
|
| 353 |
+
"Boxed kernel was expected to return a single value on the stack, ",
|
| 354 |
+
"but instead returned ",
|
| 355 |
+
stack.size(),
|
| 356 |
+
" values.");
|
| 357 |
+
|
| 358 |
+
// reusing restArgs after it has been forwarded here is ok because we know
|
| 359 |
+
// that the last element is of type `Tensor&`.
|
| 360 |
+
return std::get<sizeof...(RestArgs) - 1>(
|
| 361 |
+
std::tuple<RestArgs...>{restArgs...});
|
| 362 |
+
}
|
| 363 |
+
};
|
| 364 |
+
|
| 365 |
+
//
|
| 366 |
+
// 5. out of place ops that take multiple non-const Tensor references as their
|
| 367 |
+
// final arguments, and return them in a std::tuple.
|
| 368 |
+
//
|
| 369 |
+
// Note: all signatures matching this pattern are assumed to be for such ops.
|
| 370 |
+
// This assumption permits the generated BoxedKernelWrapper specializations to
|
| 371 |
+
// simply return the out arguments.
|
| 372 |
+
//
|
| 373 |
+
template <class Result, class... Args>
|
| 374 |
+
struct BoxedKernelWrapper<
|
| 375 |
+
Result(Args...),
|
| 376 |
+
std::enable_if_t<
|
| 377 |
+
can_box_all<Args...>::value &&
|
| 378 |
+
is_tuple_of_mutable_tensor_refs<Result>::value,
|
| 379 |
+
void>> {
|
| 380 |
+
static Result call(
|
| 381 |
+
const BoxedKernel& boxed_kernel_func,
|
| 382 |
+
const OperatorHandle& opHandle,
|
| 383 |
+
DispatchKeySet dispatchKeySet,
|
| 384 |
+
Args... args) {
|
| 385 |
+
using ArgTuple = std::tuple<Args...>;
|
| 386 |
+
constexpr int RetCount = std::tuple_size<Result>();
|
| 387 |
+
|
| 388 |
+
torch::jit::Stack stack = boxArgs<Args...>(std::forward<Args>(args)...);
|
| 389 |
+
boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack);
|
| 390 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
| 391 |
+
stack.size() == RetCount,
|
| 392 |
+
"Boxed kernel was expected to return ",
|
| 393 |
+
RetCount,
|
| 394 |
+
" values on the stack, ",
|
| 395 |
+
"but instead returned ",
|
| 396 |
+
stack.size(),
|
| 397 |
+
" values.");
|
| 398 |
+
|
| 399 |
+
// reusing args after it has been forwarded here is ok because we know
|
| 400 |
+
// that the last RetCount elements are of type `Tensor&`.
|
| 401 |
+
auto result = guts::tuple_take<ArgTuple, -RetCount>(
|
| 402 |
+
ArgTuple{std::forward<Args>(args)...});
|
| 403 |
+
static_assert(
|
| 404 |
+
std::is_same_v<Result, decltype(result)>,
|
| 405 |
+
"The parameter list of an op returning a tuple of Tensor references "
|
| 406 |
+
"must end with an equal number of Tensor reference parameters.");
|
| 407 |
+
return result;
|
| 408 |
+
}
|
| 409 |
+
};
|
| 410 |
+
|
| 411 |
+
} // namespace c10::impl
|
| 412 |
+
|
| 413 |
+
#else
|
| 414 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 415 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h
ADDED
|
@@ -0,0 +1,790 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/core/IListRef.h>
|
| 5 |
+
#include <ATen/core/boxing/OperatorKernel.h>
|
| 6 |
+
#include <ATen/core/ivalue.h>
|
| 7 |
+
#include <ATen/core/stack.h>
|
| 8 |
+
#include <c10/util/Metaprogramming.h>
|
| 9 |
+
#include <c10/util/TypeList.h>
|
| 10 |
+
#include <c10/util/intrusive_ptr.h>
|
| 11 |
+
|
| 12 |
+
#include <utility>
|
| 13 |
+
|
| 14 |
+
namespace c10 {
|
| 15 |
+
|
| 16 |
+
using Stack = torch::jit::Stack; // TODO Instead of this, move torch::jit::Stack
|
| 17 |
+
// to the c10 namespace.
|
| 18 |
+
class OperatorHandle;
|
| 19 |
+
|
| 20 |
+
/*
|
| 21 |
+
* [Note: Argument forwarding in the dispatcher]
|
| 22 |
+
*
|
| 23 |
+
* The dispatcher uses a somewhat unusual way to forward arguments through
|
| 24 |
+
* several layers of wrapper functions. This can be confusing because an
|
| 25 |
+
* experienced C++ programmer would look at this and think "oh this is supposed
|
| 26 |
+
* to be forwarding a universal reference but the && is missing. This is a
|
| 27 |
+
* bug.". It is not a bug. The common way in C++ to forward arguments is to use
|
| 28 |
+
* universal references:
|
| 29 |
+
*
|
| 30 |
+
* > template<class T> void func(T&& arg) { func2(std::forward<T>(arg)); }
|
| 31 |
+
*
|
| 32 |
+
* but that relies on inferring the correct reference type (i.e. value vs & vs
|
| 33 |
+
* &&) from the argument. In our case, we cannot rely on the argument as
|
| 34 |
+
* supplied by the caller, because that could infer a different reference type
|
| 35 |
+
* than was used in the kernel function. The correct reference type is dictated
|
| 36 |
+
* by the kernel signature and must be identical since we cast function pointers
|
| 37 |
+
* through void* pointers and mismatches would be UB. So we need a forwarding
|
| 38 |
+
* pattern that determines the reference type to use by looking at the
|
| 39 |
+
* explicitly supplied operator signature, not by looking at the argument we're
|
| 40 |
+
* calling it with.
|
| 41 |
+
*
|
| 42 |
+
* What does std::forward do, exactly?
|
| 43 |
+
* ------------------------------------
|
| 44 |
+
* std::forward<T>(t) is a way to cast t to the reference type supplied in T.
|
| 45 |
+
* Let's assume decay_t<T> == U and T is either U or some reference of U.
|
| 46 |
+
* - std::forward<T&>(t) will return U&, no matter what kind of reference t is.
|
| 47 |
+
* - std::forward<T&&>(t) will return U&&, no matter what kind of reference t
|
| 48 |
+
* is.
|
| 49 |
+
* - std::forward<T>(t) will return U&& (not U!), no matter what kind of
|
| 50 |
+
* reference t is.
|
| 51 |
+
*
|
| 52 |
+
* For universal references, that means that in the following function
|
| 53 |
+
* > template<class T> void func(T&& arg) { func2(std::forward<T>(arg)); }
|
| 54 |
+
*
|
| 55 |
+
* - when called with arg being a rvalue reference or non-reference value, T
|
| 56 |
+
* gets inferred to be a non-reference U, and std::forward<T>(t) will return
|
| 57 |
+
* U&&, correctly moving the argument.
|
| 58 |
+
* - when called with arg behind a lvalue reference, T gets inferred to be U&
|
| 59 |
+
* because that's the only way to match the signature (in C++, a type that is
|
| 60 |
+
* (T&)&& will collapse to T&). That means std::forward<T>(t) will return U& and
|
| 61 |
+
* the value will not be moved but passed on as a lvalue reference.
|
| 62 |
+
*
|
| 63 |
+
* How do we use that?
|
| 64 |
+
* ------------------------------------
|
| 65 |
+
* But std::forward can also be used outside of the common "universal
|
| 66 |
+
* forwarding" pattern to change reference types. So instead of following the
|
| 67 |
+
* common C++ pattern, we notice what std::forward<T>() actually does, and that
|
| 68 |
+
* is it takes a value and changes its reference to the type of reference passed
|
| 69 |
+
* in as T. If we don't infer T but explicitly specify it, we can use this to
|
| 70 |
+
* forward based on an explicitly specified reference type instead of the
|
| 71 |
+
* inferred argument type.
|
| 72 |
+
*
|
| 73 |
+
* This is why many of the dispatcher functions look like
|
| 74 |
+
* > template<class T> func(T t) { func2<T>(std::forward<T>(t)); }
|
| 75 |
+
* instead of the common
|
| 76 |
+
* > template<class T> func(T&& t) { func2(std::forward<T>(t)); }
|
| 77 |
+
*
|
| 78 |
+
* and are expected to be called by explicitly specifying the template
|
| 79 |
+
* parameters in a way that matches the expected operator signature at each call
|
| 80 |
+
* site.
|
| 81 |
+
*/
|
| 82 |
+
|
| 83 |
+
namespace impl {
|
| 84 |
+
// supported_primitive_arg_types defines which primitive types we allow in
|
| 85 |
+
// kernel functions as arguments or returns.
|
| 86 |
+
// Additionally, we support lists, dicts and optionals containing these types.
|
| 87 |
+
using supported_primitive_arg_types = guts::typelist::typelist<
|
| 88 |
+
int64_t,
|
| 89 |
+
double,
|
| 90 |
+
bool,
|
| 91 |
+
std::string_view,
|
| 92 |
+
at::Tensor,
|
| 93 |
+
at::Scalar,
|
| 94 |
+
c10::QScheme,
|
| 95 |
+
c10::ScalarType,
|
| 96 |
+
c10::Device,
|
| 97 |
+
c10::DeviceIndex,
|
| 98 |
+
c10::Layout,
|
| 99 |
+
c10::MemoryFormat,
|
| 100 |
+
at::Dimname>;
|
| 101 |
+
|
| 102 |
+
// We have an unboxed functor in hand that takes C++ arguments, and
|
| 103 |
+
// we're building a boxed functor wrapper for it that takes IValues.
|
| 104 |
+
// So "outside" is boxed and "inside" is unboxed.
|
| 105 |
+
//
|
| 106 |
+
// So a valid input type is one that our boxed functor wrapper can
|
| 107 |
+
// unbox from an IValue into a C++ value.
|
| 108 |
+
//
|
| 109 |
+
// Whereas a valid output type is one that our wrapper can receive
|
| 110 |
+
// as a C++ value from the unboxed functor, and box into an IValue.
|
| 111 |
+
|
| 112 |
+
//
|
| 113 |
+
// assert_is_valid_input_type
|
| 114 |
+
// checks that T can be unboxed from an IValue into a C++ value.
|
| 115 |
+
//
|
| 116 |
+
|
| 117 |
+
template <class T, bool AllowDeprecatedTypes, class Enable = void>
|
| 118 |
+
struct assert_is_valid_input_type {
|
| 119 |
+
assert_is_valid_input_type() {
|
| 120 |
+
if constexpr (guts::typelist::contains<supported_primitive_arg_types, T>::
|
| 121 |
+
value) {
|
| 122 |
+
/* everything is ok, this is a primitive type */
|
| 123 |
+
} else {
|
| 124 |
+
/* otherwise this must be an instance of a valid custom class, since it
|
| 125 |
+
can only have been created via IValue(x), which ensures this. */
|
| 126 |
+
}
|
| 127 |
+
}
|
| 128 |
+
};
|
| 129 |
+
|
| 130 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 131 |
+
struct assert_is_valid_input_type<std::optional<T>, AllowDeprecatedTypes>
|
| 132 |
+
: assert_is_valid_input_type<T, AllowDeprecatedTypes> {};
|
| 133 |
+
|
| 134 |
+
template <bool AllowDeprecatedTypes, class... Args>
|
| 135 |
+
struct TypeCheckHelper;
|
| 136 |
+
|
| 137 |
+
template <bool AllowDeprecatedTypes>
|
| 138 |
+
struct TypeCheckHelper<AllowDeprecatedTypes> {};
|
| 139 |
+
|
| 140 |
+
template <bool AllowDeprecatedTypes, class Head, class... Rest>
|
| 141 |
+
struct TypeCheckHelper<AllowDeprecatedTypes, Head, Rest...>
|
| 142 |
+
: TypeCheckHelper<AllowDeprecatedTypes, Rest...> {
|
| 143 |
+
assert_is_valid_input_type<Head, AllowDeprecatedTypes> check;
|
| 144 |
+
};
|
| 145 |
+
|
| 146 |
+
template <class... Contained, bool AllowDeprecatedTypes>
|
| 147 |
+
struct assert_is_valid_input_type<
|
| 148 |
+
std::tuple<Contained...>,
|
| 149 |
+
AllowDeprecatedTypes>
|
| 150 |
+
: TypeCheckHelper<AllowDeprecatedTypes, Contained...> {};
|
| 151 |
+
|
| 152 |
+
template <class Key, class Value, bool AllowDeprecatedTypes>
|
| 153 |
+
struct assert_is_valid_input_type<Dict<Key, Value>, AllowDeprecatedTypes>
|
| 154 |
+
: assert_is_valid_input_type<Value, AllowDeprecatedTypes> {
|
| 155 |
+
static_assert(
|
| 156 |
+
guts::typelist::contains<impl::valid_dict_key_types, Key>::value,
|
| 157 |
+
"You tried to register a kernel with an unsupported input type: Dict<Key, Value> where Key is invalid. We only support int64_t, double, bool, and string.");
|
| 158 |
+
};
|
| 159 |
+
|
| 160 |
+
template <class Key, class Value, bool AllowDeprecatedTypes>
|
| 161 |
+
struct assert_is_valid_input_type<
|
| 162 |
+
std::unordered_map<Key, Value>,
|
| 163 |
+
AllowDeprecatedTypes>
|
| 164 |
+
: assert_is_valid_input_type<Value, AllowDeprecatedTypes> {
|
| 165 |
+
static_assert(
|
| 166 |
+
AllowDeprecatedTypes,
|
| 167 |
+
"You tried to register a kernel with an unsupported input type: std::unordered_map<Key, Value>. Please use Dict<Key, Value> instead.");
|
| 168 |
+
static_assert(
|
| 169 |
+
guts::typelist::contains<impl::valid_dict_key_types, Key>::value,
|
| 170 |
+
"You tried to register a kernel with an unsupported input type: std::unordered_map<Key, Value> where Key is invalid. We only support int64_t, double, bool, and string.");
|
| 171 |
+
};
|
| 172 |
+
|
| 173 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 174 |
+
struct assert_is_valid_input_type<List<T>, AllowDeprecatedTypes>
|
| 175 |
+
: assert_is_valid_input_type<T, AllowDeprecatedTypes> {
|
| 176 |
+
static_assert(
|
| 177 |
+
!std::is_same_v<T, at::Scalar>,
|
| 178 |
+
"You tried to register a kernel with an unsupported input type: List<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
|
| 179 |
+
};
|
| 180 |
+
|
| 181 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 182 |
+
struct assert_is_valid_input_type<c10::ArrayRef<T>, AllowDeprecatedTypes>
|
| 183 |
+
: assert_is_valid_input_type<T, AllowDeprecatedTypes> {
|
| 184 |
+
static_assert(
|
| 185 |
+
!std::is_same_v<T, at::Scalar>,
|
| 186 |
+
"You tried to register a kernel with an unsupported input type: ArrayRef<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
|
| 187 |
+
};
|
| 188 |
+
|
| 189 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 190 |
+
struct assert_is_valid_input_type<
|
| 191 |
+
c10::OptionalArrayRef<T>,
|
| 192 |
+
AllowDeprecatedTypes>
|
| 193 |
+
: assert_is_valid_input_type<T, AllowDeprecatedTypes> {
|
| 194 |
+
static_assert(
|
| 195 |
+
!std::is_same_v<T, at::Scalar>,
|
| 196 |
+
"You tried to register a kernel with an unsupported input type: OptionalArrayRef<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
|
| 197 |
+
};
|
| 198 |
+
|
| 199 |
+
template <class T, size_t N, bool AllowDeprecatedTypes>
|
| 200 |
+
struct assert_is_valid_input_type<std::array<T, N>, AllowDeprecatedTypes>
|
| 201 |
+
: assert_is_valid_input_type<T, AllowDeprecatedTypes> {
|
| 202 |
+
static_assert(
|
| 203 |
+
!std::is_same_v<T, at::Scalar>,
|
| 204 |
+
"You tried to register a kernel with an unsupported input type: std::array<Scalar, N>. Please use std::array<int64_t, N> instead.");
|
| 205 |
+
};
|
| 206 |
+
|
| 207 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 208 |
+
struct assert_is_valid_input_type<
|
| 209 |
+
T,
|
| 210 |
+
AllowDeprecatedTypes,
|
| 211 |
+
std::enable_if_t<std::is_same_v<float, T>>> {
|
| 212 |
+
// There is no reason to support float when we have double. Keep the API lean.
|
| 213 |
+
static_assert(
|
| 214 |
+
guts::false_t<T>::value,
|
| 215 |
+
"You tried to register a kernel with an unsupported input type: float. Please use double instead; you should use `double` in the C++ function signature and `float` in the schema string.");
|
| 216 |
+
};
|
| 217 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 218 |
+
struct assert_is_valid_input_type<
|
| 219 |
+
T,
|
| 220 |
+
AllowDeprecatedTypes,
|
| 221 |
+
std::enable_if_t<std::is_same_v<const char*, T>>> {
|
| 222 |
+
static_assert(
|
| 223 |
+
guts::false_t<T>::value,
|
| 224 |
+
"You tried to register a kernel with an unsupported input type: const char*. Please use std::string_view instead.");
|
| 225 |
+
};
|
| 226 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 227 |
+
struct assert_is_valid_input_type<
|
| 228 |
+
T,
|
| 229 |
+
AllowDeprecatedTypes,
|
| 230 |
+
std::enable_if_t<std::is_same_v<std::vector<bool>, T>>> {
|
| 231 |
+
static_assert(
|
| 232 |
+
guts::false_t<T>::value,
|
| 233 |
+
"You tried to register a kernel with an unsupported input type: vector<bool>. Please use List<bool> instead.");
|
| 234 |
+
};
|
| 235 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 236 |
+
struct assert_is_valid_input_type<
|
| 237 |
+
T,
|
| 238 |
+
AllowDeprecatedTypes,
|
| 239 |
+
std::enable_if_t<
|
| 240 |
+
std::is_integral_v<T> &&
|
| 241 |
+
!guts::typelist::contains<supported_primitive_arg_types, T>::value>> {
|
| 242 |
+
static_assert(
|
| 243 |
+
guts::false_t<T>::value,
|
| 244 |
+
"You tried to register a kernel with an unsupported integral input type. Please use int64_t instead; you should use `int64_t` in the C++ function signature and `int` in the schema string.");
|
| 245 |
+
};
|
| 246 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 247 |
+
struct assert_is_valid_input_type<
|
| 248 |
+
T,
|
| 249 |
+
AllowDeprecatedTypes,
|
| 250 |
+
std::enable_if_t<std::is_same_v<const c10::SymInt&, T>>> {
|
| 251 |
+
static_assert(
|
| 252 |
+
guts::false_t<T>::value,
|
| 253 |
+
"You tried to register a kernel taking c10::SymInt by reference. Please accept it by value instead.");
|
| 254 |
+
};
|
| 255 |
+
|
| 256 |
+
// TODO: it probably would be good to tighten this up quite a bit more with
|
| 257 |
+
// an explicit list for everything
|
| 258 |
+
|
| 259 |
+
//
|
| 260 |
+
// assert_is_valid_output_type
|
| 261 |
+
//
|
| 262 |
+
|
| 263 |
+
template <class T, bool AllowDeprecatedTypes, class Enable = void>
|
| 264 |
+
struct assert_is_valid_output_type {
|
| 265 |
+
assert_is_valid_output_type() {
|
| 266 |
+
if constexpr (guts::typelist::contains<supported_primitive_arg_types, T>::
|
| 267 |
+
value) {
|
| 268 |
+
/* everything is ok, this is a primitive type */
|
| 269 |
+
} else {
|
| 270 |
+
/* otherwise T is verified to be a registered custom class in the IValue
|
| 271 |
+
constructor, so no benefit in double-checking here */
|
| 272 |
+
}
|
| 273 |
+
}
|
| 274 |
+
};
|
| 275 |
+
|
| 276 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 277 |
+
struct assert_is_valid_output_type<std::optional<T>, AllowDeprecatedTypes>
|
| 278 |
+
: assert_is_valid_output_type<T, AllowDeprecatedTypes> {};
|
| 279 |
+
|
| 280 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 281 |
+
struct assert_is_valid_output_type<
|
| 282 |
+
c10::OptionalArrayRef<T>,
|
| 283 |
+
AllowDeprecatedTypes>
|
| 284 |
+
: assert_is_valid_output_type<T, AllowDeprecatedTypes> {};
|
| 285 |
+
|
| 286 |
+
template <class Key, class Value, bool AllowDeprecatedTypes>
|
| 287 |
+
struct assert_is_valid_output_type<Dict<Key, Value>, AllowDeprecatedTypes>
|
| 288 |
+
: assert_is_valid_output_type<Value, AllowDeprecatedTypes> {
|
| 289 |
+
static_assert(
|
| 290 |
+
guts::typelist::contains<impl::valid_dict_key_types, Key>::value,
|
| 291 |
+
"You tried to register a kernel with an unsupported output type: Dict<Key, Value> where Key is invalid. We only support int64_t, double, bool, and string.");
|
| 292 |
+
static_assert(
|
| 293 |
+
!std::is_same_v<Value, at::Scalar>,
|
| 294 |
+
"You tried to register a kernel with an unsupported output type: Dict<Key, Scalar>. Please use Dict<Key, int64_t> or Dict<Key, double>.");
|
| 295 |
+
};
|
| 296 |
+
|
| 297 |
+
template <class Key, class Value, bool AllowDeprecatedTypes>
|
| 298 |
+
struct assert_is_valid_output_type<
|
| 299 |
+
std::unordered_map<Key, Value>,
|
| 300 |
+
AllowDeprecatedTypes>
|
| 301 |
+
: assert_is_valid_output_type<Value, AllowDeprecatedTypes> {
|
| 302 |
+
static_assert(
|
| 303 |
+
AllowDeprecatedTypes,
|
| 304 |
+
"You tried to register a kernel with an unsupported output type: std::unordered_map<Key, Value>. Please use Dict<Key, Value> instead.");
|
| 305 |
+
static_assert(
|
| 306 |
+
guts::typelist::contains<impl::valid_dict_key_types, Key>::value,
|
| 307 |
+
"You tried to register a kernel with an unsupported output type: std::unordered_map<Key, Value> where Key is invalid. We only support int64_t, double, bool, and string.");
|
| 308 |
+
static_assert(
|
| 309 |
+
!std::is_same_v<Value, at::Scalar>,
|
| 310 |
+
"You tried to register a kernel with an unsupported output type: std::unordered_map<Key, Scalar>. Please use Dict<Key, int64_t> or Dict<Key, double>.");
|
| 311 |
+
};
|
| 312 |
+
|
| 313 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 314 |
+
struct assert_is_valid_output_type<List<T>, AllowDeprecatedTypes>
|
| 315 |
+
: assert_is_valid_output_type<T, AllowDeprecatedTypes> {
|
| 316 |
+
static_assert(
|
| 317 |
+
!std::is_same_v<T, at::Scalar>,
|
| 318 |
+
"You tried to register a kernel with an unsupported output type: List<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
|
| 319 |
+
};
|
| 320 |
+
|
| 321 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 322 |
+
struct assert_is_valid_output_type<std::vector<T>, AllowDeprecatedTypes>
|
| 323 |
+
: assert_is_valid_output_type<T, AllowDeprecatedTypes> {
|
| 324 |
+
static_assert(
|
| 325 |
+
!std::is_same_v<T, at::Scalar>,
|
| 326 |
+
"You tried to register a kernel with an unsupported output type: std::vector<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
|
| 327 |
+
// TODO static_assert(AllowDeprecatedTypes, "You tried to register a kernel
|
| 328 |
+
// with an unsupported output type: std::vector<T>. Please use List<T>
|
| 329 |
+
// instead.");
|
| 330 |
+
};
|
| 331 |
+
|
| 332 |
+
template <class T, size_t N, bool AllowDeprecatedTypes>
|
| 333 |
+
struct assert_is_valid_output_type<std::array<T, N>, AllowDeprecatedTypes>
|
| 334 |
+
: assert_is_valid_output_type<T, AllowDeprecatedTypes> {
|
| 335 |
+
static_assert(
|
| 336 |
+
!std::is_same_v<T, at::Scalar>,
|
| 337 |
+
"You tried to register a kernel with an unsupported output type: std::array<Scalar, N>. Please use std::array<int64_t, N> instead.");
|
| 338 |
+
};
|
| 339 |
+
|
| 340 |
+
// The following specialisations of assert_is_valid_output_type are technically
|
| 341 |
+
// not necessary since we would hit the base case and show an error message
|
| 342 |
+
// there if they didn't exist, but we can show a better error message
|
| 343 |
+
// in some common error scenarios.
|
| 344 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 345 |
+
struct assert_is_valid_output_type<
|
| 346 |
+
T,
|
| 347 |
+
AllowDeprecatedTypes,
|
| 348 |
+
std::enable_if_t<std::is_same_v<float, T>>> {
|
| 349 |
+
// There is no reason to support float when we have double. Keep the API lean.
|
| 350 |
+
static_assert(
|
| 351 |
+
guts::false_t<T>::value,
|
| 352 |
+
"You tried to register a kernel with an unsupported output type: float. Please use double instead; you should use `double` in the C++ function signature and `float` in the schema string.");
|
| 353 |
+
};
|
| 354 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 355 |
+
struct assert_is_valid_output_type<
|
| 356 |
+
T,
|
| 357 |
+
AllowDeprecatedTypes,
|
| 358 |
+
std::enable_if_t<std::is_same_v<const char*, T>>> {
|
| 359 |
+
static_assert(
|
| 360 |
+
guts::false_t<T>::value,
|
| 361 |
+
"You tried to register a kernel with an unsupported output type: const char*. Please use std::string_view instead.");
|
| 362 |
+
};
|
| 363 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 364 |
+
struct assert_is_valid_output_type<
|
| 365 |
+
T,
|
| 366 |
+
AllowDeprecatedTypes,
|
| 367 |
+
std::enable_if_t<std::is_same_v<std::vector<bool>, T>>> {
|
| 368 |
+
static_assert(
|
| 369 |
+
guts::false_t<T>::value,
|
| 370 |
+
"You tried to register a kernel with an unsupported output type: vector<bool>. Please use List<bool> instead.");
|
| 371 |
+
};
|
| 372 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 373 |
+
struct assert_is_valid_output_type<
|
| 374 |
+
T,
|
| 375 |
+
AllowDeprecatedTypes,
|
| 376 |
+
std::enable_if_t<
|
| 377 |
+
std::is_integral_v<T> &&
|
| 378 |
+
!guts::typelist::contains<supported_primitive_arg_types, T>::value>> {
|
| 379 |
+
static_assert(
|
| 380 |
+
guts::false_t<T>::value,
|
| 381 |
+
"You tried to register a kernel with an unsupported integral output type. Please use int64_t instead; you should use `int64_t` in the C++ function signature and `int` in the schema string.");
|
| 382 |
+
};
|
| 383 |
+
|
| 384 |
+
// ivalue_to_arg
|
| 385 |
+
|
| 386 |
+
template <class T>
|
| 387 |
+
struct decay_if_not_tensor final {
|
| 388 |
+
using type = std::decay_t<T>;
|
| 389 |
+
};
|
| 390 |
+
|
| 391 |
+
template <>
|
| 392 |
+
struct decay_if_not_tensor<at::Tensor&> final {
|
| 393 |
+
using type = at::Tensor&;
|
| 394 |
+
};
|
| 395 |
+
|
| 396 |
+
template <>
|
| 397 |
+
struct decay_if_not_tensor<const at::Tensor&> final {
|
| 398 |
+
using type = const at::Tensor&;
|
| 399 |
+
};
|
| 400 |
+
|
| 401 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 402 |
+
struct ivalue_to_arg final {
|
| 403 |
+
static decltype(auto) call(IValue& v) {
|
| 404 |
+
assert_is_valid_input_type<T, AllowDeprecatedTypes>();
|
| 405 |
+
return std::move(v).to<T>();
|
| 406 |
+
}
|
| 407 |
+
};
|
| 408 |
+
|
| 409 |
+
// The following two specializations take advantage of specialized
|
| 410 |
+
// `toTensor()` overloads on IValue to avoid copying.
|
| 411 |
+
template <bool AllowDeprecatedTypes>
|
| 412 |
+
struct ivalue_to_arg<at::Tensor&, AllowDeprecatedTypes> final {
|
| 413 |
+
// We cannot use the default implementation if they asked for a
|
| 414 |
+
// `at::Tensor&` because it moves from the IValue, so it can't get
|
| 415 |
+
// an lvalue reference.
|
| 416 |
+
static at::Tensor& call(IValue& v) {
|
| 417 |
+
// Tensor& is valid, don't bother asserting
|
| 418 |
+
return v.toTensor();
|
| 419 |
+
}
|
| 420 |
+
};
|
| 421 |
+
|
| 422 |
+
template <bool AllowDeprecatedTypes>
|
| 423 |
+
struct ivalue_to_arg<const at::Tensor&, AllowDeprecatedTypes> final {
|
| 424 |
+
// We should not use the default implementation if they asked for
|
| 425 |
+
// a `const at::Tensor&` because it moves from the IValue and they
|
| 426 |
+
// didn't ask for that.
|
| 427 |
+
static const at::Tensor& call(IValue& v) {
|
| 428 |
+
// const Tensor& is valid, don't bother asserting
|
| 429 |
+
return v.toTensor();
|
| 430 |
+
}
|
| 431 |
+
};
|
| 432 |
+
|
| 433 |
+
template <bool AllowDeprecatedTypes>
|
| 434 |
+
struct ivalue_to_arg<at::ITensorListRef, AllowDeprecatedTypes> final {
|
| 435 |
+
static List<at::Tensor> call(IValue& v) {
|
| 436 |
+
return v.toTensorList();
|
| 437 |
+
}
|
| 438 |
+
};
|
| 439 |
+
|
| 440 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 441 |
+
struct ivalue_to_arg<ArrayRef<T>, AllowDeprecatedTypes> final {
|
| 442 |
+
// If an argument is ArrayRef<T>, convert the IValue to a std::vector<T> and
|
| 443 |
+
// pass that to the operator. std::vector<T> is implicitly convertible to
|
| 444 |
+
// ArrayRef<T>.
|
| 445 |
+
static std::vector<T> call(IValue& v) {
|
| 446 |
+
return ivalue_to_arg<std::vector<T>, AllowDeprecatedTypes>::call(v);
|
| 447 |
+
}
|
| 448 |
+
};
|
| 449 |
+
template <bool AllowDeprecatedTypes>
|
| 450 |
+
struct ivalue_to_arg<c10::SymIntArrayRef, AllowDeprecatedTypes> final {
|
| 451 |
+
static std::vector<c10::SymInt> call(IValue& v) {
|
| 452 |
+
if (v.isIntList()) {
|
| 453 |
+
std::vector<c10::SymInt> r;
|
| 454 |
+
auto src = v.toIntList();
|
| 455 |
+
std::transform(
|
| 456 |
+
src.begin(), src.end(), std::back_inserter(r), [](int64_t i) {
|
| 457 |
+
return c10::SymInt(i);
|
| 458 |
+
});
|
| 459 |
+
return r;
|
| 460 |
+
} else {
|
| 461 |
+
return ivalue_to_arg<std::vector<c10::SymInt>, AllowDeprecatedTypes>::
|
| 462 |
+
call(v);
|
| 463 |
+
}
|
| 464 |
+
}
|
| 465 |
+
};
|
| 466 |
+
template <bool AllowDeprecatedTypes>
|
| 467 |
+
struct ivalue_to_arg<c10::OptionalArray<c10::SymInt>, AllowDeprecatedTypes>
|
| 468 |
+
final {
|
| 469 |
+
static OptionalArray<c10::SymInt> call(IValue& v) {
|
| 470 |
+
if (v.isIntList()) {
|
| 471 |
+
std::vector<c10::SymInt> r;
|
| 472 |
+
auto src = v.toIntList();
|
| 473 |
+
std::transform(
|
| 474 |
+
src.begin(), src.end(), std::back_inserter(r), [](int64_t i) {
|
| 475 |
+
return c10::SymInt(i);
|
| 476 |
+
});
|
| 477 |
+
return OptionalArray<c10::SymInt>(std::move(r));
|
| 478 |
+
} else {
|
| 479 |
+
return std::move(v).to<OptionalArray<c10::SymInt>>();
|
| 480 |
+
}
|
| 481 |
+
}
|
| 482 |
+
};
|
| 483 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 484 |
+
struct ivalue_to_arg<std::optional<ArrayRef<T>>, AllowDeprecatedTypes> final {
|
| 485 |
+
// If an argument is std::optional<ArrayRef<T>>, convert the IValue to an
|
| 486 |
+
// std::optional<std::vector<T>> and pass that to the operator.
|
| 487 |
+
// OptionalArray<T> is basically a std::optional<std::vector<T>> but
|
| 488 |
+
// implicitly convertible to std::optional<ArrayRef<T>>.
|
| 489 |
+
static OptionalArray<T> call(IValue& v) {
|
| 490 |
+
return ivalue_to_arg<OptionalArray<T>, AllowDeprecatedTypes>::call(v);
|
| 491 |
+
}
|
| 492 |
+
};
|
| 493 |
+
|
| 494 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 495 |
+
struct ivalue_to_arg<OptionalArrayRef<T>, AllowDeprecatedTypes> final {
|
| 496 |
+
// If an argument is OptionalArrayRef<T>, convert the IValue to an
|
| 497 |
+
// std::optional<std::vector<T>> and pass that to the operator.
|
| 498 |
+
// OptionalArray<T> is basically a std::optional<std::vector<T>> but
|
| 499 |
+
// implicitly convertible to OptionalArrayRef<T>
|
| 500 |
+
static OptionalArray<T> call(IValue& v) {
|
| 501 |
+
return ivalue_to_arg<OptionalArray<T>, AllowDeprecatedTypes>::call(v);
|
| 502 |
+
}
|
| 503 |
+
};
|
| 504 |
+
|
| 505 |
+
// return_to_ivalue
|
| 506 |
+
template <class T, bool AllowDeprecatedTypes, class Enable = void>
|
| 507 |
+
struct return_to_ivalue final {};
|
| 508 |
+
|
| 509 |
+
template <class T, bool AllowDeprecatedTypes>
|
| 510 |
+
struct return_to_ivalue<
|
| 511 |
+
T,
|
| 512 |
+
AllowDeprecatedTypes,
|
| 513 |
+
std::enable_if_t<!std::is_same_v<at::Tensor&, T>>>
|
| 514 |
+
final {
|
| 515 |
+
static IValue call(T&& v) {
|
| 516 |
+
assert_is_valid_output_type<T, AllowDeprecatedTypes>();
|
| 517 |
+
return c10::ivalue::from(std::move(v));
|
| 518 |
+
}
|
| 519 |
+
static IValue copy(const T& v) {
|
| 520 |
+
assert_is_valid_output_type<T, AllowDeprecatedTypes>();
|
| 521 |
+
return IValue(v);
|
| 522 |
+
}
|
| 523 |
+
};
|
| 524 |
+
|
| 525 |
+
// Special case to allow kernels to return `Tensor&`.
|
| 526 |
+
// TODO Delete this once kernels don't do that anymore
|
| 527 |
+
template <bool AllowDeprecatedTypes>
|
| 528 |
+
struct return_to_ivalue<at::Tensor&, AllowDeprecatedTypes, void> final {
|
| 529 |
+
static IValue call(at::Tensor& v) {
|
| 530 |
+
return c10::ivalue::from(v);
|
| 531 |
+
}
|
| 532 |
+
static IValue copy(at::Tensor& v) {
|
| 533 |
+
return IValue(v);
|
| 534 |
+
}
|
| 535 |
+
};
|
| 536 |
+
|
| 537 |
+
// wrap_kernel_functor_unboxed_
|
| 538 |
+
|
| 539 |
+
template <class KernelFunctor, class OpSignature>
|
| 540 |
+
struct wrap_kernel_functor_unboxed_ final {};
|
| 541 |
+
|
| 542 |
+
// This specialization is for kernels with a first argument that is NOT of type
|
| 543 |
+
// DispatchKeySet This includes kernels with 0 arguments.
|
| 544 |
+
template <class KernelFunctor, class ReturnType, class... ParameterTypes>
|
| 545 |
+
struct wrap_kernel_functor_unboxed_<
|
| 546 |
+
KernelFunctor,
|
| 547 |
+
ReturnType(ParameterTypes...)>
|
| 548 |
+
final {
|
| 549 |
+
static_assert(
|
| 550 |
+
std::is_same_v<
|
| 551 |
+
ReturnType,
|
| 552 |
+
typename guts::infer_function_traits_t<KernelFunctor>::return_type>,
|
| 553 |
+
"Return type mismatch");
|
| 554 |
+
static_assert(
|
| 555 |
+
std::is_same_v<
|
| 556 |
+
guts::typelist::typelist<ParameterTypes...>,
|
| 557 |
+
typename guts::infer_function_traits_t<
|
| 558 |
+
KernelFunctor>::parameter_types>,
|
| 559 |
+
"Parameter types mismatch");
|
| 560 |
+
|
| 561 |
+
// See [Note: Argument forwarding in the dispatcher] for why ParameterTypes
|
| 562 |
+
// doesn't use &&
|
| 563 |
+
static ReturnType call(
|
| 564 |
+
OperatorKernel* functor,
|
| 565 |
+
DispatchKeySet /*unused*/,
|
| 566 |
+
ParameterTypes... args) {
|
| 567 |
+
KernelFunctor* functor_ = static_cast<KernelFunctor*>(functor);
|
| 568 |
+
// Note [Plumbing Keys Through The Dispatcher 2]
|
| 569 |
+
// See Note [Plumbing Keys Through The Dispatcher] for the background.
|
| 570 |
+
// This functor explicitly takes in a dispatchKeySet and drops it on the
|
| 571 |
+
// floor- it does not forward it to the registered kernel.
|
| 572 |
+
//
|
| 573 |
+
// This is due to the calling convention within the dispatcher, which
|
| 574 |
+
// expects all registered kernels to have a first argument of type
|
| 575 |
+
// DispatchKeySet.
|
| 576 |
+
// This is not the case for pretty much all manually written kernels,
|
| 577 |
+
// however- this functor serves to separate the calling convention of the
|
| 578 |
+
// dispatcher from the calling convention of manually written kernels.
|
| 579 |
+
return (*functor_)(std::forward<ParameterTypes>(args)...);
|
| 580 |
+
}
|
| 581 |
+
};
|
| 582 |
+
|
| 583 |
+
// This specialization is for kernels with a first argument of type
|
| 584 |
+
// DispatchKeySet
|
| 585 |
+
template <class KernelFunctor, class ReturnType, class... ParameterTypes>
|
| 586 |
+
struct wrap_kernel_functor_unboxed_<
|
| 587 |
+
KernelFunctor,
|
| 588 |
+
ReturnType(DispatchKeySet, ParameterTypes...)>
|
| 589 |
+
final {
|
| 590 |
+
static_assert(
|
| 591 |
+
std::is_same_v<
|
| 592 |
+
ReturnType,
|
| 593 |
+
typename guts::infer_function_traits_t<KernelFunctor>::return_type>,
|
| 594 |
+
"Return type mismatch");
|
| 595 |
+
static_assert(
|
| 596 |
+
std::is_same_v<
|
| 597 |
+
guts::typelist::typelist<DispatchKeySet, ParameterTypes...>,
|
| 598 |
+
typename guts::infer_function_traits_t<
|
| 599 |
+
KernelFunctor>::parameter_types>,
|
| 600 |
+
"Parameter types mismatch");
|
| 601 |
+
|
| 602 |
+
// See [Note: Argument forwarding in the dispatcher] for why ParameterTypes
|
| 603 |
+
// doesn't use &&
|
| 604 |
+
static ReturnType call(
|
| 605 |
+
OperatorKernel* functor,
|
| 606 |
+
DispatchKeySet dispatchKeySet,
|
| 607 |
+
ParameterTypes... args) {
|
| 608 |
+
KernelFunctor* functor_ = static_cast<KernelFunctor*>(functor);
|
| 609 |
+
// We're explicitly taking in a dispatchKeySet and forwarding it to the
|
| 610 |
+
// registered kernel. See Note [Plumbing Keys Through The Dispatcher 2] for
|
| 611 |
+
// details.
|
| 612 |
+
return (*functor_)(dispatchKeySet, std::forward<ParameterTypes>(args)...);
|
| 613 |
+
}
|
| 614 |
+
};
|
| 615 |
+
|
| 616 |
+
template <class KernelFunctor>
|
| 617 |
+
using wrap_kernel_functor_unboxed = wrap_kernel_functor_unboxed_<
|
| 618 |
+
KernelFunctor,
|
| 619 |
+
typename guts::infer_function_traits_t<KernelFunctor>::func_type>;
|
| 620 |
+
|
| 621 |
+
// call_functor_with_args_from_stack
|
| 622 |
+
|
| 623 |
+
template <
|
| 624 |
+
class Functor,
|
| 625 |
+
bool AllowDeprecatedTypes,
|
| 626 |
+
size_t... ivalue_arg_indices,
|
| 627 |
+
typename... ArgTypes>
|
| 628 |
+
std::decay_t<typename guts::infer_function_traits_t<Functor>::return_type>
|
| 629 |
+
call_functor_with_args_from_stack_(
|
| 630 |
+
OperatorKernel* functor,
|
| 631 |
+
DispatchKeySet dispatchKeySet,
|
| 632 |
+
Stack* stack,
|
| 633 |
+
std::index_sequence<ivalue_arg_indices...> /*unused*/,
|
| 634 |
+
guts::typelist::typelist<ArgTypes...>* /*unused*/) {
|
| 635 |
+
(void)stack; // when sizeof...(ivalue_arg_indices) == 0, this argument would
|
| 636 |
+
// be unused and we have to silence the compiler warning.
|
| 637 |
+
|
| 638 |
+
// We're explicitly filtering out DispatchKeySet from the argument list.
|
| 639 |
+
// Some kernels take a DispatchKeySet as their first argument in order to
|
| 640 |
+
// plumb keys through the dispatcher. We don't want to expose the
|
| 641 |
+
// DispatchKeySet type to jit, so we don't include this argument on the stack.
|
| 642 |
+
// See Note [Plumbing Keys Through The Dispatcher] for the background.
|
| 643 |
+
return wrap_kernel_functor_unboxed<Functor>::call(
|
| 644 |
+
functor,
|
| 645 |
+
dispatchKeySet,
|
| 646 |
+
ivalue_to_arg<
|
| 647 |
+
typename decay_if_not_tensor<ArgTypes>::type,
|
| 648 |
+
AllowDeprecatedTypes>::
|
| 649 |
+
call(torch::jit::peek(
|
| 650 |
+
*stack, ivalue_arg_indices, sizeof...(ivalue_arg_indices)))...);
|
| 651 |
+
}
|
| 652 |
+
|
| 653 |
+
template <class Functor, bool AllowDeprecatedTypes>
|
| 654 |
+
std::decay_t<typename guts::infer_function_traits_t<Functor>::return_type>
|
| 655 |
+
call_functor_with_args_from_stack(
|
| 656 |
+
OperatorKernel* functor,
|
| 657 |
+
DispatchKeySet dispatchKeySet,
|
| 658 |
+
Stack* stack) {
|
| 659 |
+
// We're explicitly filtering out DispatchKeySet from the argument list.
|
| 660 |
+
// Some kernels take a DispatchKeySet as their first argument in order to
|
| 661 |
+
// plumb keys through the dispatcher. We don't want to expose the
|
| 662 |
+
// DispatchKeySet type to jit, so we don't include this argument on the stack.
|
| 663 |
+
// See Note [Plumbing Keys Through The Dispatcher] for the background.
|
| 664 |
+
using ArgTypes = typename c10::remove_DispatchKeySet_arg_from_func<
|
| 665 |
+
Functor>::parameter_types;
|
| 666 |
+
constexpr size_t num_ivalue_args = guts::typelist::size<ArgTypes>::value;
|
| 667 |
+
return call_functor_with_args_from_stack_<Functor, AllowDeprecatedTypes>(
|
| 668 |
+
functor,
|
| 669 |
+
dispatchKeySet,
|
| 670 |
+
stack,
|
| 671 |
+
std::make_index_sequence<num_ivalue_args>(),
|
| 672 |
+
static_cast<ArgTypes*>(nullptr));
|
| 673 |
+
}
|
| 674 |
+
|
| 675 |
+
// push_outputs
|
| 676 |
+
|
| 677 |
+
template <class OutputType, bool AllowDeprecatedTypes>
|
| 678 |
+
struct push_outputs final {
|
| 679 |
+
// Contrary to [Note: Argument forwarding in the dispatcher], we use
|
| 680 |
+
// OutputType&& here to avoid one extra call to the move constructor in this
|
| 681 |
+
// case. This is still not a universal reference though because OutputType is
|
| 682 |
+
// an explicitly specified class template parameter.
|
| 683 |
+
static void call(OutputType&& output, Stack* stack) {
|
| 684 |
+
torch::jit::push(
|
| 685 |
+
*stack,
|
| 686 |
+
return_to_ivalue<OutputType, AllowDeprecatedTypes>::call(
|
| 687 |
+
std::forward<OutputType>(output)));
|
| 688 |
+
}
|
| 689 |
+
static void copy(const OutputType& output, Stack* stack) {
|
| 690 |
+
torch::jit::push(
|
| 691 |
+
*stack,
|
| 692 |
+
return_to_ivalue<OutputType, AllowDeprecatedTypes>::copy(output));
|
| 693 |
+
}
|
| 694 |
+
};
|
| 695 |
+
template <class... OutputTypes, bool AllowDeprecatedTypes>
|
| 696 |
+
struct push_outputs<std::tuple<OutputTypes...>, AllowDeprecatedTypes> final {
|
| 697 |
+
static void call(std::tuple<OutputTypes...>&& output, Stack* stack) {
|
| 698 |
+
call_(
|
| 699 |
+
std::move(output),
|
| 700 |
+
stack,
|
| 701 |
+
std::make_index_sequence<sizeof...(OutputTypes)>());
|
| 702 |
+
}
|
| 703 |
+
static void copy(const std::tuple<OutputTypes...>& output, Stack* stack) {
|
| 704 |
+
copy_(output, stack, std::make_index_sequence<sizeof...(OutputTypes)>());
|
| 705 |
+
}
|
| 706 |
+
|
| 707 |
+
private:
|
| 708 |
+
template <size_t... indices>
|
| 709 |
+
static void call_(
|
| 710 |
+
std::tuple<OutputTypes...>&& output,
|
| 711 |
+
Stack* stack,
|
| 712 |
+
std::index_sequence<indices...> /*unused*/) {
|
| 713 |
+
torch::jit::push(
|
| 714 |
+
*stack,
|
| 715 |
+
return_to_ivalue<OutputTypes, AllowDeprecatedTypes>::call(
|
| 716 |
+
std::forward<OutputTypes>(std::get<indices>(output)))...);
|
| 717 |
+
}
|
| 718 |
+
template <size_t... indices>
|
| 719 |
+
static void copy_(
|
| 720 |
+
const std::tuple<OutputTypes...>& output,
|
| 721 |
+
Stack* stack,
|
| 722 |
+
std::index_sequence<indices...> /*unused*/) {
|
| 723 |
+
torch::jit::push(
|
| 724 |
+
*stack,
|
| 725 |
+
return_to_ivalue<OutputTypes, AllowDeprecatedTypes>::copy(
|
| 726 |
+
std::get<indices>(output))...);
|
| 727 |
+
}
|
| 728 |
+
};
|
| 729 |
+
template <bool AllowDeprecatedTypes>
|
| 730 |
+
struct push_outputs<void, AllowDeprecatedTypes> final {
|
| 731 |
+
static void call(int /*dummy*/, Stack* /*stack*/) {}
|
| 732 |
+
static void copy(int /*dummy*/, Stack* /*stack*/) {}
|
| 733 |
+
};
|
| 734 |
+
|
| 735 |
+
// make_boxed_from_unboxed_functor
|
| 736 |
+
|
| 737 |
+
template <class KernelFunctor, bool AllowDeprecatedTypes>
|
| 738 |
+
struct make_boxed_from_unboxed_functor final {
|
| 739 |
+
static_assert(
|
| 740 |
+
std::is_base_of_v<OperatorKernel, KernelFunctor>,
|
| 741 |
+
"Tried to register a kernel functor using the kernel<Functor>() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
|
| 742 |
+
|
| 743 |
+
static void call(
|
| 744 |
+
OperatorKernel* functor,
|
| 745 |
+
const OperatorHandle& /*unused*/,
|
| 746 |
+
DispatchKeySet dispatchKeySet,
|
| 747 |
+
Stack* stack) {
|
| 748 |
+
using ReturnType =
|
| 749 |
+
typename guts::infer_function_traits_t<KernelFunctor>::return_type;
|
| 750 |
+
// We're explicitly filtering out DispatchKeySet from the argument list.
|
| 751 |
+
// Some kernels take a DispatchKeySet as their first argument in order to
|
| 752 |
+
// plumb keys through the dispatcher. We don't want to expose the
|
| 753 |
+
// DispatchKeySet type to jit, so we don't include this argument on the
|
| 754 |
+
// stack. See Note [Plumbing Keys Through The Dispatcher] for the
|
| 755 |
+
// background.
|
| 756 |
+
using ArgTypes = typename c10::remove_DispatchKeySet_arg_from_func<
|
| 757 |
+
KernelFunctor>::parameter_types;
|
| 758 |
+
constexpr bool has_outputs = !std::is_same_v<void, ReturnType>;
|
| 759 |
+
constexpr size_t num_inputs = guts::typelist::size<ArgTypes>::value;
|
| 760 |
+
if constexpr (has_outputs) {
|
| 761 |
+
// Decay ReturnType to ReturnType_ so that if a reference gets returned,
|
| 762 |
+
// we actually store it by value and don't get a dangling reference. This
|
| 763 |
+
// is only required because some kernels still return `Tensor&`. [Note:
|
| 764 |
+
// VC++ and 'std': ambiguous symbol]
|
| 765 |
+
using ReturnType_ = ::std::decay_t<ReturnType>;
|
| 766 |
+
ReturnType_ output = call_functor_with_args_from_stack<
|
| 767 |
+
KernelFunctor,
|
| 768 |
+
AllowDeprecatedTypes>(functor, dispatchKeySet, stack);
|
| 769 |
+
torch::jit::drop(*stack, num_inputs);
|
| 770 |
+
// See note [ VC++ and 'std': ambiguous symbol]
|
| 771 |
+
push_outputs<ReturnType_, AllowDeprecatedTypes>::call(
|
| 772 |
+
::std::move(output), stack);
|
| 773 |
+
} else {
|
| 774 |
+
call_functor_with_args_from_stack<KernelFunctor, AllowDeprecatedTypes>(
|
| 775 |
+
functor, dispatchKeySet, stack);
|
| 776 |
+
torch::jit::drop(*stack, num_inputs);
|
| 777 |
+
}
|
| 778 |
+
}
|
| 779 |
+
};
|
| 780 |
+
} // namespace impl
|
| 781 |
+
|
| 782 |
+
} // namespace c10
|
| 783 |
+
|
| 784 |
+
namespace torch {
|
| 785 |
+
using OperatorKernel = c10::OperatorKernel;
|
| 786 |
+
}
|
| 787 |
+
|
| 788 |
+
#else
|
| 789 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 790 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/impl/test_helpers.h
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <gmock/gmock.h>
|
| 5 |
+
#include <gtest/gtest.h>
|
| 6 |
+
|
| 7 |
+
#include <ATen/core/Tensor.h>
|
| 8 |
+
#include <ATen/core/dispatch/Dispatcher.h>
|
| 9 |
+
#include <ATen/core/ivalue.h>
|
| 10 |
+
#include <c10/core/CPUAllocator.h>
|
| 11 |
+
#include <c10/util/irange.h>
|
| 12 |
+
|
| 13 |
+
template <class... Inputs>
|
| 14 |
+
inline std::vector<c10::IValue> makeStack(Inputs&&... inputs) {
|
| 15 |
+
return {std::forward<Inputs>(inputs)...};
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
inline at::Tensor dummyTensor(
|
| 19 |
+
c10::DispatchKeySet ks,
|
| 20 |
+
bool requires_grad = false) {
|
| 21 |
+
auto* allocator = c10::GetCPUAllocator();
|
| 22 |
+
int64_t nelements = 1;
|
| 23 |
+
auto dtype = caffe2::TypeMeta::Make<float>();
|
| 24 |
+
int64_t size_bytes = nelements * dtype.itemsize();
|
| 25 |
+
auto storage_impl = c10::make_intrusive<c10::StorageImpl>(
|
| 26 |
+
c10::StorageImpl::use_byte_size_t(),
|
| 27 |
+
size_bytes,
|
| 28 |
+
allocator->allocate(size_bytes),
|
| 29 |
+
allocator,
|
| 30 |
+
/*resizable=*/true);
|
| 31 |
+
at::Tensor t =
|
| 32 |
+
at::detail::make_tensor<c10::TensorImpl>(storage_impl, ks, dtype);
|
| 33 |
+
// TODO: We add this to simulate the ideal case where we only have Autograd
|
| 34 |
+
// backend keys
|
| 35 |
+
// on Tensor when it requires grad. But currently Autograd keys are
|
| 36 |
+
// added in TensorImpl constructor by default.
|
| 37 |
+
if (!requires_grad) {
|
| 38 |
+
t.unsafeGetTensorImpl()->remove_autograd_key();
|
| 39 |
+
}
|
| 40 |
+
return t;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
inline at::Tensor dummyTensor(
|
| 44 |
+
c10::DispatchKey dispatch_key,
|
| 45 |
+
bool requires_grad = false) {
|
| 46 |
+
return dummyTensor(c10::DispatchKeySet(dispatch_key), requires_grad);
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
template <class... Args>
|
| 50 |
+
inline std::vector<c10::IValue> callOp(
|
| 51 |
+
const c10::OperatorHandle& op,
|
| 52 |
+
Args... args) {
|
| 53 |
+
auto stack = makeStack(std::forward<Args>(args)...);
|
| 54 |
+
op.callBoxed(&stack);
|
| 55 |
+
return stack;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
template <class Result, class... Args>
|
| 59 |
+
inline Result callOpUnboxed(const c10::OperatorHandle& op, Args... args) {
|
| 60 |
+
return op.typed<Result(Args...)>().call(std::forward<Args>(args)...);
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
template <class Result, class... Args>
|
| 64 |
+
inline Result callOpUnboxedWithDispatchKey(
|
| 65 |
+
const c10::OperatorHandle& op,
|
| 66 |
+
c10::DispatchKey dispatchKey,
|
| 67 |
+
Args... args) {
|
| 68 |
+
return op.typed<Result(Args...)>().callWithDispatchKey(
|
| 69 |
+
dispatchKey, std::forward<Args>(args)...);
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
template <class Result, class... Args>
|
| 73 |
+
inline Result callOpUnboxedWithPrecomputedDispatchKeySet(
|
| 74 |
+
const c10::OperatorHandle& op,
|
| 75 |
+
c10::DispatchKeySet ks,
|
| 76 |
+
Args... args) {
|
| 77 |
+
return op.typed<Result(Args...)>().redispatch(
|
| 78 |
+
ks, std::forward<Args>(args)...);
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
inline void expectDoesntFindKernel(
|
| 82 |
+
const char* op_name,
|
| 83 |
+
c10::DispatchKey dispatch_key) {
|
| 84 |
+
auto op = c10::Dispatcher::singleton().findSchema({op_name, ""});
|
| 85 |
+
EXPECT_ANY_THROW(callOp(*op, dummyTensor(dispatch_key), 5););
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
inline void expectDoesntFindOperator(const char* op_name) {
|
| 89 |
+
auto op = c10::Dispatcher::singleton().findSchema({op_name, ""});
|
| 90 |
+
EXPECT_FALSE(op.has_value());
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
template <class Exception, class Functor>
|
| 94 |
+
inline void expectThrows(Functor&& functor, const char* expectMessageContains) {
|
| 95 |
+
try {
|
| 96 |
+
std::forward<Functor>(functor)();
|
| 97 |
+
} catch (const Exception& e) {
|
| 98 |
+
EXPECT_THAT(e.what(), testing::HasSubstr(expectMessageContains));
|
| 99 |
+
return;
|
| 100 |
+
}
|
| 101 |
+
ADD_FAILURE() << "Expected to throw exception containing \""
|
| 102 |
+
<< expectMessageContains << "\" but didn't throw";
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
template <class T, size_t N>
|
| 106 |
+
void expectListEquals(c10::ArrayRef<T> expected, std::array<T, N> actual) {
|
| 107 |
+
EXPECT_EQ(expected.size(), actual.size());
|
| 108 |
+
for (const auto i : c10::irange(expected.size())) {
|
| 109 |
+
EXPECT_EQ(expected[i], actual[i]);
|
| 110 |
+
}
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
template <class T>
|
| 114 |
+
void expectListEquals(c10::ArrayRef<T> expected, c10::ArrayRef<T> actual) {
|
| 115 |
+
EXPECT_EQ(expected.size(), actual.size());
|
| 116 |
+
for (const auto i : c10::irange(expected.size())) {
|
| 117 |
+
EXPECT_EQ(expected[i], actual[i]);
|
| 118 |
+
}
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
template <class T>
|
| 122 |
+
void expectListEquals(c10::ArrayRef<T> expected, c10::List<T> actual) {
|
| 123 |
+
EXPECT_EQ(expected.size(), actual.size());
|
| 124 |
+
for (const auto i : c10::irange(expected.size())) {
|
| 125 |
+
EXPECT_EQ(expected[i], actual.get(i));
|
| 126 |
+
}
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
template <class T>
|
| 130 |
+
void expectListEquals(c10::ArrayRef<T> expected, std::vector<T> actual) {
|
| 131 |
+
EXPECT_EQ(expected.size(), actual.size());
|
| 132 |
+
for (const auto i : c10::irange(expected.size())) {
|
| 133 |
+
EXPECT_EQ(expected[i], actual[i]);
|
| 134 |
+
}
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
// NB: This is not really sound, but all of the type sets constructed here
|
| 138 |
+
// are singletons so it's fine
|
| 139 |
+
static inline c10::DispatchKey extractDispatchKey(const at::Tensor& t) {
|
| 140 |
+
return legacyExtractDispatchKey(t.key_set());
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
#else
|
| 144 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 145 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/CppSignature.h
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <c10/core/DispatchKeySet.h>
|
| 5 |
+
#include <c10/macros/Macros.h>
|
| 6 |
+
#include <c10/util/Metaprogramming.h>
|
| 7 |
+
#include <c10/util/Type.h>
|
| 8 |
+
#include <typeindex>
|
| 9 |
+
|
| 10 |
+
namespace c10::impl {
|
| 11 |
+
|
| 12 |
+
// A CppSignature object holds RTTI information about a C++ function signature
|
| 13 |
+
// at runtime and can compare them or get a debug-printable name.
|
| 14 |
+
class TORCH_API CppSignature final {
|
| 15 |
+
public:
|
| 16 |
+
CppSignature(const CppSignature&) = default;
|
| 17 |
+
CppSignature(CppSignature&&) noexcept = default;
|
| 18 |
+
CppSignature& operator=(const CppSignature&) = default;
|
| 19 |
+
CppSignature& operator=(CppSignature&&) noexcept = default;
|
| 20 |
+
|
| 21 |
+
template <class FuncType>
|
| 22 |
+
static CppSignature make() {
|
| 23 |
+
// Normalize functors, lambdas, function pointers, etc. into the plain
|
| 24 |
+
// function type The first argument of the schema might be of type
|
| 25 |
+
// DispatchKeySet, in which case we remove it. We do this to guarantee that
|
| 26 |
+
// all CppSignature's for an operator will match, even if they're registered
|
| 27 |
+
// with different calling conventions.
|
| 28 |
+
// See Note [Plumbing Keys Through The Dispatcher]
|
| 29 |
+
using decayed_function_type =
|
| 30 |
+
typename c10::remove_DispatchKeySet_arg_from_func<
|
| 31 |
+
std::decay_t<FuncType>>::func_type;
|
| 32 |
+
|
| 33 |
+
return CppSignature(std::type_index(typeid(decayed_function_type)));
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
std::string name() const {
|
| 37 |
+
return c10::demangle(signature_.name());
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
friend bool operator==(const CppSignature& lhs, const CppSignature& rhs) {
|
| 41 |
+
if (lhs.signature_ == rhs.signature_) {
|
| 42 |
+
return true;
|
| 43 |
+
}
|
| 44 |
+
// Without RTLD_GLOBAL, the type_index comparison could yield false because
|
| 45 |
+
// they point to different instances of the RTTI data, but the types would
|
| 46 |
+
// still be the same. Let's check for that case too.
|
| 47 |
+
// Note that there still is a case where this might not work, i.e. when
|
| 48 |
+
// linking libraries of different compilers together, they might have
|
| 49 |
+
// different ways to serialize a type name. That, together with a missing
|
| 50 |
+
// RTLD_GLOBAL, would still fail this.
|
| 51 |
+
if (0 == strcmp(lhs.signature_.name(), rhs.signature_.name())) {
|
| 52 |
+
return true;
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
return false;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
private:
|
| 59 |
+
explicit CppSignature(std::type_index signature)
|
| 60 |
+
: signature_(std::move(signature)) {}
|
| 61 |
+
std::type_index signature_;
|
| 62 |
+
};
|
| 63 |
+
|
| 64 |
+
inline bool operator!=(const CppSignature& lhs, const CppSignature& rhs) {
|
| 65 |
+
return !(lhs == rhs);
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
} // namespace c10::impl
|
| 69 |
+
|
| 70 |
+
#else
|
| 71 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 72 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/DispatchKeyExtractor.h
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/core/Variadic.h>
|
| 5 |
+
#include <ATen/core/function_schema.h>
|
| 6 |
+
#include <ATen/core/jit_type.h>
|
| 7 |
+
#include <ATen/core/stack.h>
|
| 8 |
+
#include <c10/core/DispatchKeySet.h>
|
| 9 |
+
#include <c10/util/Bitset.h>
|
| 10 |
+
#include <c10/util/irange.h>
|
| 11 |
+
#include <cstdint>
|
| 12 |
+
|
| 13 |
+
namespace c10 {
|
| 14 |
+
|
| 15 |
+
namespace impl {
|
| 16 |
+
|
| 17 |
+
// Take a DispatchKeySet for a Tensor and determine what the actual dispatch
|
| 18 |
+
// DispatchKey should be, taking into account TLS, and skipping backends which
|
| 19 |
+
// fall through.
|
| 20 |
+
//
|
| 21 |
+
// Unlike Tensor::key_set(), the value of this on a tensor can change depending
|
| 22 |
+
// on TLS.
|
| 23 |
+
//
|
| 24 |
+
// NB: If there is no valid dispatch key, this will return Undefined
|
| 25 |
+
inline DispatchKeySet computeDispatchKeySet(
|
| 26 |
+
DispatchKeySet ks,
|
| 27 |
+
// The key mask lets us eliminate (by zero entries) keys which should not
|
| 28 |
+
// be considered for dispatch. There are two cases when we use this:
|
| 29 |
+
//
|
| 30 |
+
// - If an operator's dispatch table contains a fallthrough entry, we
|
| 31 |
+
// should bypass it entirely when finding the key
|
| 32 |
+
// - If a user invokes with redispatch, the mask lets us
|
| 33 |
+
// zero out the key the user asked us to stop.
|
| 34 |
+
//
|
| 35 |
+
// These excluded backends are NOT tracked in the TLS, but must be applied
|
| 36 |
+
// AFTER TLS (since the backend may have been introduced for consideration
|
| 37 |
+
// by the included TLS), which is why you have to pass them in to this
|
| 38 |
+
// function (as opposed to just applying it to the input 'ks').
|
| 39 |
+
DispatchKeySet key_mask) {
|
| 40 |
+
c10::impl::LocalDispatchKeySet local =
|
| 41 |
+
c10::impl::tls_local_dispatch_key_set();
|
| 42 |
+
// TODO: It's a bit irritating that we have to do logical ORs here, it would
|
| 43 |
+
// be nice to only do one. Can always_included be folded into the TLS? Well,
|
| 44 |
+
// it's a bit troublesome, because fastpath TLS access requires the type of
|
| 45 |
+
// the TLS in question to be zero-initialized, so you don't actually win
|
| 46 |
+
// anything in that case.
|
| 47 |
+
return (((ks | local.included_) - local.excluded_) & key_mask);
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
} // namespace impl
|
| 51 |
+
|
| 52 |
+
namespace detail {
|
| 53 |
+
// A small gadget to extract the DispatchKeySet from types which are known
|
| 54 |
+
// to have it. Used to extract dispatch keys from unboxed calls.
|
| 55 |
+
struct MultiDispatchKeySet : at::IterArgs<MultiDispatchKeySet> {
|
| 56 |
+
DispatchKeySet ts;
|
| 57 |
+
void operator()(const at::Tensor& x) {
|
| 58 |
+
ts = ts | x.key_set();
|
| 59 |
+
}
|
| 60 |
+
void operator()(const std::optional<at::Tensor>& x) {
|
| 61 |
+
if (x.has_value()) {
|
| 62 |
+
ts = ts | x->key_set();
|
| 63 |
+
}
|
| 64 |
+
}
|
| 65 |
+
void operator()(at::ArrayRef<at::Tensor> xs) {
|
| 66 |
+
for (const auto& x : xs) {
|
| 67 |
+
ts = ts | x.key_set();
|
| 68 |
+
}
|
| 69 |
+
}
|
| 70 |
+
// Tensor?[] translates to this case.
|
| 71 |
+
void operator()(const c10::List<std::optional<at::Tensor>>& xs) {
|
| 72 |
+
for (std::optional<at::Tensor> x : xs) {
|
| 73 |
+
if (x.has_value()) {
|
| 74 |
+
ts = ts | x.value().key_set();
|
| 75 |
+
}
|
| 76 |
+
}
|
| 77 |
+
}
|
| 78 |
+
// Structured Tensor[] translates to this case
|
| 79 |
+
void operator()(const at::ITensorListRef& xs) {
|
| 80 |
+
for (const auto& x : xs) {
|
| 81 |
+
ts = ts | x.key_set();
|
| 82 |
+
}
|
| 83 |
+
}
|
| 84 |
+
[[noreturn]] void operator()(
|
| 85 |
+
at::ArrayRef<std::optional<at::Tensor>> /*unused*/) {
|
| 86 |
+
// Just checking that the handling of Tensor?[] didn't change.
|
| 87 |
+
TORCH_INTERNAL_ASSERT(false);
|
| 88 |
+
}
|
| 89 |
+
void operator()(const at::Generator& gen) {
|
| 90 |
+
if (gen.defined()) {
|
| 91 |
+
ts = ts | gen.key_set();
|
| 92 |
+
}
|
| 93 |
+
}
|
| 94 |
+
void operator()(const std::optional<at::Generator>& gen) {
|
| 95 |
+
if (gen.has_value() && gen->defined()) {
|
| 96 |
+
ts = ts | gen->key_set();
|
| 97 |
+
}
|
| 98 |
+
}
|
| 99 |
+
template <typename T>
|
| 100 |
+
void operator()(const T& /*unused*/) {
|
| 101 |
+
// do nothing
|
| 102 |
+
}
|
| 103 |
+
};
|
| 104 |
+
|
| 105 |
+
// NB: take by const reference (Don't do universal forwarding here! You
|
| 106 |
+
// don't want to move into this function!)
|
| 107 |
+
template <typename... Args>
|
| 108 |
+
DispatchKeySet multi_dispatch_key_set(const Args&... args) {
|
| 109 |
+
return MultiDispatchKeySet().apply(args...).ts;
|
| 110 |
+
}
|
| 111 |
+
} // namespace detail
|
| 112 |
+
|
| 113 |
+
/**
|
| 114 |
+
* An instance of DispatchKeyExtractor knows how to get a dispatch key given
|
| 115 |
+
* a list of arguments for an operator call.
|
| 116 |
+
*
|
| 117 |
+
* The instance is specific for a certain operator as:
|
| 118 |
+
* - In boxed dispatch, different operators have different ways to extract
|
| 119 |
+
* the dispatch key (e.g. different numbers of arguments), and we precompute
|
| 120 |
+
* the stack locations we should look at; and
|
| 121 |
+
* - In all dispatch, some backends should be excluded from dispatch because
|
| 122 |
+
* they have been registered as fallthrough. The set of excluded backends
|
| 123 |
+
* varies from operator, as some operators may have overridden the
|
| 124 |
+
* fallthrough with custom behavior.
|
| 125 |
+
*
|
| 126 |
+
* Note - this should maintain identical impl to the py dispatcher key
|
| 127 |
+
* extraction logic at pytorch/torch/dispatcher.py
|
| 128 |
+
*/
|
| 129 |
+
struct TORCH_API DispatchKeyExtractor final {
|
| 130 |
+
public:
|
| 131 |
+
static DispatchKeyExtractor make(const FunctionSchema& schema) {
|
| 132 |
+
return DispatchKeyExtractor(makeBitsetForDispatchArgs(schema));
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
static DispatchKeyExtractor makeUninitialized() {
|
| 136 |
+
return DispatchKeyExtractor(c10::utils::bitset());
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
void registerSchema(const FunctionSchema& schema) {
|
| 140 |
+
TORCH_INTERNAL_ASSERT(dispatch_arg_indices_reverse_.is_entirely_unset());
|
| 141 |
+
dispatch_arg_indices_reverse_ = makeBitsetForDispatchArgs(schema);
|
| 142 |
+
}
|
| 143 |
+
void deregisterSchema() {
|
| 144 |
+
dispatch_arg_indices_reverse_ = c10::utils::bitset();
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
DispatchKeySet getDispatchKeySetBoxed(const torch::jit::Stack* stack) const {
|
| 148 |
+
DispatchKeySet ks;
|
| 149 |
+
dispatch_arg_indices_reverse_.for_each_set_bit([&](size_t
|
| 150 |
+
reverse_arg_index) {
|
| 151 |
+
const auto& ivalue = torch::jit::peek(*stack, 0, reverse_arg_index + 1);
|
| 152 |
+
if (C10_LIKELY(ivalue.isTensor())) {
|
| 153 |
+
// NB: Take care not to introduce a refcount bump (there's
|
| 154 |
+
// no safe toTensorRef method, alas)
|
| 155 |
+
ks = ks | ivalue.unsafeToTensorImpl()->key_set();
|
| 156 |
+
} else if (C10_UNLIKELY(ivalue.isTensorList())) {
|
| 157 |
+
// NB: use toListRef as it doesn't induce refcount bumps
|
| 158 |
+
// (toTensorListRef is not a thing)
|
| 159 |
+
for (const auto& nv : ivalue.toListRef()) {
|
| 160 |
+
auto* tensor = nv.unsafeToTensorImpl();
|
| 161 |
+
ks = ks | tensor->key_set();
|
| 162 |
+
}
|
| 163 |
+
}
|
| 164 |
+
// Tensor?[] translates to a c10::List<IValue> so we need to peek inside
|
| 165 |
+
else if (C10_UNLIKELY(ivalue.isList())) {
|
| 166 |
+
for (const auto& elt : ivalue.toListRef()) {
|
| 167 |
+
if (elt.isTensor()) {
|
| 168 |
+
ks = ks | elt.toTensor().key_set();
|
| 169 |
+
}
|
| 170 |
+
}
|
| 171 |
+
}
|
| 172 |
+
});
|
| 173 |
+
// Keys that are fallthrough should be skipped
|
| 174 |
+
if (requiresBitsetPerBackend_) {
|
| 175 |
+
c10::impl::LocalDispatchKeySet tls =
|
| 176 |
+
c10::impl::tls_local_dispatch_key_set();
|
| 177 |
+
auto backend_idx =
|
| 178 |
+
((ks | tls.included_) - tls.excluded_).getBackendIndex();
|
| 179 |
+
return impl::computeDispatchKeySet(
|
| 180 |
+
ks, nonFallthroughKeysPerBackend_[backend_idx]);
|
| 181 |
+
} else {
|
| 182 |
+
return impl::computeDispatchKeySet(ks, nonFallthroughKeys_);
|
| 183 |
+
}
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
template <class... Args>
|
| 187 |
+
DispatchKeySet getDispatchKeySetUnboxed(const Args&... args) const {
|
| 188 |
+
auto ks = detail::multi_dispatch_key_set(args...);
|
| 189 |
+
// Keys that are fallthrough should be skipped
|
| 190 |
+
if (requiresBitsetPerBackend_) {
|
| 191 |
+
c10::impl::LocalDispatchKeySet tls =
|
| 192 |
+
c10::impl::tls_local_dispatch_key_set();
|
| 193 |
+
auto backend_idx =
|
| 194 |
+
((ks | tls.included_) - tls.excluded_).getBackendIndex();
|
| 195 |
+
return impl::computeDispatchKeySet(
|
| 196 |
+
ks, nonFallthroughKeysPerBackend_[backend_idx]);
|
| 197 |
+
} else {
|
| 198 |
+
return impl::computeDispatchKeySet(ks, nonFallthroughKeys_);
|
| 199 |
+
}
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
void setOperatorHasFallthroughForKey(DispatchKey k, bool has_fallthrough);
|
| 203 |
+
|
| 204 |
+
std::string dumpState() const;
|
| 205 |
+
void checkInvariants(const FunctionSchema& schema) const;
|
| 206 |
+
|
| 207 |
+
private:
|
| 208 |
+
static bool isDispatchType(const Type& type) {
|
| 209 |
+
// Checking isSubtypeOf on a DynamicType heap-allocates a
|
| 210 |
+
// DynamicType version of the argument if it's not a DynamicType
|
| 211 |
+
// already, and this has measurable overhead during startup.
|
| 212 |
+
#ifdef C10_MOBILE
|
| 213 |
+
struct CachedTypes {
|
| 214 |
+
DynamicTypePtr listOfTensors;
|
| 215 |
+
DynamicTypePtr listOfOptionalTensors;
|
| 216 |
+
DynamicTypePtr optionalOfTensor;
|
| 217 |
+
};
|
| 218 |
+
static const CachedTypes ct = {
|
| 219 |
+
DynamicType::create(*ListType::ofTensors()),
|
| 220 |
+
DynamicType::create(*ListType::ofOptionalTensors()),
|
| 221 |
+
DynamicType::create(*OptionalType::ofTensor())};
|
| 222 |
+
return type.isSubtypeOf(c10::TypeFactory::get<TensorType>()) ||
|
| 223 |
+
type.isSubtypeOf(ct.listOfTensors) ||
|
| 224 |
+
type.isSubtypeOf(ct.listOfOptionalTensors) ||
|
| 225 |
+
type.isSubtypeOf(ct.optionalOfTensor);
|
| 226 |
+
#else // C10_MOBILE
|
| 227 |
+
return type.isSubtypeOf(*TensorType::get()) ||
|
| 228 |
+
type.isSubtypeOf(*ListType::ofTensors()) ||
|
| 229 |
+
type.isSubtypeOf(*ListType::ofOptionalTensors()) ||
|
| 230 |
+
type.isSubtypeOf(*OptionalType::ofTensor());
|
| 231 |
+
#endif // C10_MOBILE
|
| 232 |
+
}
|
| 233 |
+
static c10::utils::bitset makeBitsetForDispatchArgs(
|
| 234 |
+
const FunctionSchema& schema) {
|
| 235 |
+
TORCH_CHECK(
|
| 236 |
+
schema.arguments().size() <= c10::utils::bitset::NUM_BITS(),
|
| 237 |
+
"The function schema has ",
|
| 238 |
+
schema.arguments().size(),
|
| 239 |
+
" arguments but this PyTorch build only supports ",
|
| 240 |
+
c10::utils::bitset::NUM_BITS());
|
| 241 |
+
c10::utils::bitset dispatch_arg_indices_reverse;
|
| 242 |
+
for (const auto index : c10::irange(schema.arguments().size())) {
|
| 243 |
+
if (isDispatchType(*schema.arguments()[index].type())) {
|
| 244 |
+
dispatch_arg_indices_reverse.set(schema.arguments().size() - 1 - index);
|
| 245 |
+
}
|
| 246 |
+
}
|
| 247 |
+
return dispatch_arg_indices_reverse;
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
explicit DispatchKeyExtractor(c10::utils::bitset dispatch_arg_indices_reverse)
|
| 251 |
+
: dispatch_arg_indices_reverse_(dispatch_arg_indices_reverse),
|
| 252 |
+
nonFallthroughKeys_(DispatchKeySet::FULL) {
|
| 253 |
+
for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size())) {
|
| 254 |
+
nonFallthroughKeysPerBackend_[i] = DispatchKeySet::FULL;
|
| 255 |
+
}
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
// this is a bitset that has ones for each argument index which has to be
|
| 259 |
+
// considered for dispatch. This avoids having to iterate over the stack
|
| 260 |
+
// to find all the tensors. The bits are stored in reverse order, i.e.
|
| 261 |
+
// dispatch_arg_indices_reverse_[i] == true, then the i-th argument from
|
| 262 |
+
// the top of the stack (i.e. the i-th last argument of the function)
|
| 263 |
+
// is relevant for dispatch.
|
| 264 |
+
// dispatch_arg_indices_reverse_ is allowed to have zero bits set; that just
|
| 265 |
+
// means you must do the fallthrough
|
| 266 |
+
c10::utils::bitset dispatch_arg_indices_reverse_;
|
| 267 |
+
|
| 268 |
+
// Set of functionality keys for which the operator does NOT have fallthrough
|
| 269 |
+
// kernel.
|
| 270 |
+
DispatchKeySet nonFallthroughKeys_;
|
| 271 |
+
// Set of functionality keys for which the operator does NOT have fallthrough
|
| 272 |
+
// kernel, defined PER BACKEND. This is only needed if we know that the
|
| 273 |
+
// operator has a different set of fallthroughs defined for some backends.
|
| 274 |
+
std::array<DispatchKeySet, num_backends> nonFallthroughKeysPerBackend_;
|
| 275 |
+
// Flag to tell us if we can use the single set of nonFallthroughKeys_ (fast
|
| 276 |
+
// path), or if we need to fall back to the slower path and check
|
| 277 |
+
// nonFallthroughKeysPerBackend_
|
| 278 |
+
bool requiresBitsetPerBackend_{false};
|
| 279 |
+
};
|
| 280 |
+
|
| 281 |
+
} // namespace c10
|
| 282 |
+
|
| 283 |
+
#else
|
| 284 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 285 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/Dispatcher.h
ADDED
|
@@ -0,0 +1,955 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/SequenceNumber.h>
|
| 5 |
+
#include <ATen/core/boxing/KernelFunction.h>
|
| 6 |
+
#include <ATen/core/boxing/impl/boxing.h>
|
| 7 |
+
#include <ATen/core/dispatch/CppSignature.h>
|
| 8 |
+
#include <ATen/core/dispatch/OperatorEntry.h>
|
| 9 |
+
#include <ATen/core/dispatch/RegistrationHandleRAII.h>
|
| 10 |
+
#include <ATen/record_function.h>
|
| 11 |
+
#include <c10/core/SafePyObject.h>
|
| 12 |
+
#include <c10/util/Exception.h>
|
| 13 |
+
#include <c10/util/LeftRight.h>
|
| 14 |
+
#include <condition_variable>
|
| 15 |
+
#include <list>
|
| 16 |
+
#include <mutex>
|
| 17 |
+
#include <type_traits>
|
| 18 |
+
|
| 19 |
+
#include <ATen/core/enum_tag.h>
|
| 20 |
+
#include <ATen/core/grad_mode.h>
|
| 21 |
+
|
| 22 |
+
#ifndef NDEBUG
|
| 23 |
+
#include <iostream>
|
| 24 |
+
#endif
|
| 25 |
+
|
| 26 |
+
namespace c10 {
|
| 27 |
+
|
| 28 |
+
TORCH_API bool show_dispatch_trace();
|
| 29 |
+
TORCH_API void dispatch_trace_nesting_incr();
|
| 30 |
+
TORCH_API void dispatch_trace_nesting_decr();
|
| 31 |
+
TORCH_API int64_t dispatch_trace_nesting_value();
|
| 32 |
+
|
| 33 |
+
struct DispatchTraceNestingGuard {
|
| 34 |
+
DispatchTraceNestingGuard() {
|
| 35 |
+
dispatch_trace_nesting_incr();
|
| 36 |
+
}
|
| 37 |
+
~DispatchTraceNestingGuard() {
|
| 38 |
+
dispatch_trace_nesting_decr();
|
| 39 |
+
}
|
| 40 |
+
};
|
| 41 |
+
|
| 42 |
+
class TORCH_API OperatorHandle;
|
| 43 |
+
template <class FuncType>
|
| 44 |
+
class TypedOperatorHandle;
|
| 45 |
+
|
| 46 |
+
/**
|
| 47 |
+
* Implement this interface and register your instance with the dispatcher
|
| 48 |
+
* to get notified when operators are registered or deregistered with
|
| 49 |
+
* the dispatcher.
|
| 50 |
+
*
|
| 51 |
+
* NB: registration events only occur when a 'def' occurs; we don't trigger
|
| 52 |
+
* on 'impl' or 'fallback' calls.
|
| 53 |
+
*/
|
| 54 |
+
class TORCH_API OpRegistrationListener {
|
| 55 |
+
public:
|
| 56 |
+
virtual ~OpRegistrationListener();
|
| 57 |
+
|
| 58 |
+
virtual void onOperatorRegistered(const OperatorHandle& op) = 0;
|
| 59 |
+
virtual void onOperatorDeregistered(const OperatorHandle& op) = 0;
|
| 60 |
+
};
|
| 61 |
+
|
| 62 |
+
namespace detail {
|
| 63 |
+
class RegistrationListenerList;
|
| 64 |
+
}
|
| 65 |
+
class SchemaRegistrationHandleRAII;
|
| 66 |
+
|
| 67 |
+
/**
|
| 68 |
+
* Top-level dispatch interface for dispatching via the dynamic dispatcher.
|
| 69 |
+
* Most end users shouldn't use this directly; if you're trying to register
|
| 70 |
+
* ops look in op_registration
|
| 71 |
+
*/
|
| 72 |
+
class TORCH_API Dispatcher final {
|
| 73 |
+
private:
|
| 74 |
+
// For direct access to backend fallback information
|
| 75 |
+
friend class impl::OperatorEntry;
|
| 76 |
+
|
| 77 |
+
struct OperatorDef final {
|
| 78 |
+
explicit OperatorDef(OperatorName&& op_name) : op(std::move(op_name)) {}
|
| 79 |
+
|
| 80 |
+
impl::OperatorEntry op;
|
| 81 |
+
|
| 82 |
+
// These refer to the number of outstanding RegistrationHandleRAII
|
| 83 |
+
// for this operator. def_count reflects only def() registrations
|
| 84 |
+
// (in the new world, this should only ever be 1, but old style
|
| 85 |
+
// registrations may register the schema multiple times, which
|
| 86 |
+
// will increase this count). def_and_impl_count reflects the number
|
| 87 |
+
// of combined def() and impl() registrations. When the last def() gets
|
| 88 |
+
// unregistered, we must immediately call the Deregistered listeners, but we
|
| 89 |
+
// must not actually delete the handle as there are other outstanding RAII
|
| 90 |
+
// destructors which will try to destruct and they had better still have a
|
| 91 |
+
// working operator handle in this case
|
| 92 |
+
size_t def_count = 0;
|
| 93 |
+
size_t def_and_impl_count = 0;
|
| 94 |
+
};
|
| 95 |
+
friend class OperatorHandle;
|
| 96 |
+
template <class>
|
| 97 |
+
friend class TypedOperatorHandle;
|
| 98 |
+
|
| 99 |
+
struct Guard final {
|
| 100 |
+
Guard() : alive(true) {}
|
| 101 |
+
std::atomic<bool> alive;
|
| 102 |
+
std::mutex mutex;
|
| 103 |
+
};
|
| 104 |
+
|
| 105 |
+
public:
|
| 106 |
+
~Dispatcher();
|
| 107 |
+
|
| 108 |
+
// Implementation note: this class abstracts over the fact that we have
|
| 109 |
+
// per-operator dispatch tables. This could be easily adjusted to have a
|
| 110 |
+
// single global hash table.
|
| 111 |
+
static Dispatcher& realSingleton();
|
| 112 |
+
|
| 113 |
+
C10_ALWAYS_INLINE static Dispatcher& singleton() {
|
| 114 |
+
#if !defined C10_MOBILE
|
| 115 |
+
// Implemented inline so that steady-state code needn't incur
|
| 116 |
+
// function-call overhead. We can't just inline `realSingleton`
|
| 117 |
+
// because the function-local static would get duplicated across
|
| 118 |
+
// all DSOs that include & use this header, leading to multiple
|
| 119 |
+
// singleton instances.
|
| 120 |
+
static Dispatcher& s = realSingleton();
|
| 121 |
+
return s;
|
| 122 |
+
#else
|
| 123 |
+
// For C10_MOBILE, we should never inline a static function that
|
| 124 |
+
// has a static member, since the generated code calls
|
| 125 |
+
// __cxa_guard_acquire and __cxa_guard_release which help
|
| 126 |
+
// implement exactly once semantics for the initialization of the
|
| 127 |
+
// static Dispatcher& s above (for the non-mobile case). That
|
| 128 |
+
// additional code when duplicated across all operator stubs
|
| 129 |
+
// for every backend results in a lot of additional code
|
| 130 |
+
// being generated by the compiler.
|
| 131 |
+
return realSingleton();
|
| 132 |
+
#endif
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
// ------------------------------------------------------------------------
|
| 136 |
+
//
|
| 137 |
+
// Accessing operators by schema
|
| 138 |
+
//
|
| 139 |
+
// ------------------------------------------------------------------------
|
| 140 |
+
|
| 141 |
+
/**
|
| 142 |
+
* Looks for an operator schema with the given name and overload name
|
| 143 |
+
* and returns it if it is registered WITH A SCHEMA.
|
| 144 |
+
* Returns nullopt otherwise.
|
| 145 |
+
*/
|
| 146 |
+
std::optional<OperatorHandle> findSchema(const OperatorName& operator_name);
|
| 147 |
+
|
| 148 |
+
/**
|
| 149 |
+
* Variant of findSchema that results in less code generated at the call site.
|
| 150 |
+
* It (1) takes const char* pointer rather than OperatorName (so we skip
|
| 151 |
+
* generating std::string constructor calls at the call site), and (2)
|
| 152 |
+
* it raises an exception if the operator is not found (so we skip
|
| 153 |
+
* generating exception raising code at the call site)
|
| 154 |
+
*
|
| 155 |
+
* Irritatingly, we still have to generate the handful of instructions
|
| 156 |
+
* for dealing with an exception being thrown during static initialization
|
| 157 |
+
* (e.g. __cxa_guard_abort). If we could annotate this method noexcept we
|
| 158 |
+
* could avoid this code too, but as the name of the function suggests,
|
| 159 |
+
* it does throw exceptions.
|
| 160 |
+
*/
|
| 161 |
+
OperatorHandle findSchemaOrThrow(const char* name, const char* overload_name);
|
| 162 |
+
|
| 163 |
+
// Like findSchema, but also returns OperatorHandle even if there is no schema
|
| 164 |
+
std::optional<OperatorHandle> findOp(const OperatorName& operator_name);
|
| 165 |
+
|
| 166 |
+
// Returns a list of all operator names present in the operatorLookupTable_
|
| 167 |
+
const std::vector<OperatorName> getAllOpNames();
|
| 168 |
+
|
| 169 |
+
// Returns a list of all operator names present in the operatorLookupTable_
|
| 170 |
+
// for a given dispatch key
|
| 171 |
+
const std::vector<OperatorName> getAllOpNamesForDispatchKey(DispatchKey k);
|
| 172 |
+
|
| 173 |
+
// ------------------------------------------------------------------------
|
| 174 |
+
//
|
| 175 |
+
// Invoking operators
|
| 176 |
+
//
|
| 177 |
+
// ------------------------------------------------------------------------
|
| 178 |
+
|
| 179 |
+
template <class Return, class... Args>
|
| 180 |
+
Return call(const TypedOperatorHandle<Return(Args...)>& op, Args... args)
|
| 181 |
+
const;
|
| 182 |
+
|
| 183 |
+
template <class Return, class... Args>
|
| 184 |
+
static Return callWithDispatchKeySlowPath(
|
| 185 |
+
const TypedOperatorHandle<Return(Args...)>& op,
|
| 186 |
+
at::StepCallbacks& stepCallbacks,
|
| 187 |
+
DispatchKeySet dispatchKeySet,
|
| 188 |
+
const KernelFunction& kernel,
|
| 189 |
+
Args... args);
|
| 190 |
+
|
| 191 |
+
// Like call, but intended for use in a redispatch in kernels that have
|
| 192 |
+
// explicitly performed the DispatchKey update calculatulation. This will take
|
| 193 |
+
// the DispatchKeySet completely as is and dispatch to the kernel of the
|
| 194 |
+
// corresponding highest priority key in the set. Note that this version of
|
| 195 |
+
// redispatch treats the inputted DispatchKeySet *as is*, and does NOT mask
|
| 196 |
+
// out the highest priority key. See Note [Plumbing Keys Through The
|
| 197 |
+
// Dispatcher]
|
| 198 |
+
template <class Return, class... Args>
|
| 199 |
+
Return redispatch(
|
| 200 |
+
const TypedOperatorHandle<Return(Args...)>& op,
|
| 201 |
+
DispatchKeySet currentDispatchKeySet,
|
| 202 |
+
Args... args) const;
|
| 203 |
+
|
| 204 |
+
// Invoke an operator via the boxed calling convention using an IValue stack
|
| 205 |
+
void callBoxed(const OperatorHandle& op, Stack* stack) const;
|
| 206 |
+
void callBoxedForDispatchKey(
|
| 207 |
+
const OperatorHandle& op,
|
| 208 |
+
DispatchKey dk,
|
| 209 |
+
Stack* stack) const;
|
| 210 |
+
|
| 211 |
+
// TODO: This will only be useful if we write a backend fallback that plumbs
|
| 212 |
+
// dispatch keys (currently there are none) See Note [Plumbing Keys Through
|
| 213 |
+
// The Dispatcher]
|
| 214 |
+
void redispatchBoxed(
|
| 215 |
+
const OperatorHandle& op,
|
| 216 |
+
DispatchKeySet dispatchKeySet,
|
| 217 |
+
Stack* stack) const;
|
| 218 |
+
|
| 219 |
+
bool hasBackendFallbackForDispatchKey(DispatchKey dk) {
|
| 220 |
+
auto dispatch_ix = getDispatchTableIndexForDispatchKey(dk);
|
| 221 |
+
if (dispatch_ix < 0)
|
| 222 |
+
return false;
|
| 223 |
+
return backendFallbackKernels_[dispatch_ix].kernel.isValid();
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
// Used by torchdeploy/multipy for multiple // codespell:ignore: multipy
|
| 227 |
+
// interpreters racing.
|
| 228 |
+
void waitForDef(const FunctionSchema& schema);
|
| 229 |
+
void waitForImpl(
|
| 230 |
+
const OperatorName& op_name,
|
| 231 |
+
std::optional<DispatchKey> dispatch_key);
|
| 232 |
+
|
| 233 |
+
// ------------------------------------------------------------------------
|
| 234 |
+
//
|
| 235 |
+
// Performing registrations (NON user public; use op_registration)
|
| 236 |
+
//
|
| 237 |
+
// ------------------------------------------------------------------------
|
| 238 |
+
|
| 239 |
+
/**
|
| 240 |
+
* Register a new operator schema.
|
| 241 |
+
*
|
| 242 |
+
* If a schema with the same operator name and overload name already exists,
|
| 243 |
+
* this function will check that both schemas are exactly identical.
|
| 244 |
+
*/
|
| 245 |
+
RegistrationHandleRAII registerDef(
|
| 246 |
+
FunctionSchema schema,
|
| 247 |
+
std::string debug,
|
| 248 |
+
std::vector<at::Tag> tags = {});
|
| 249 |
+
|
| 250 |
+
/**
|
| 251 |
+
* Register a kernel to the dispatch table for an operator.
|
| 252 |
+
* If dispatch_key is nullopt, then this registers a fallback kernel.
|
| 253 |
+
*
|
| 254 |
+
* @return A RAII object that manages the lifetime of the registration.
|
| 255 |
+
* Once that object is destructed, the kernel will be deregistered.
|
| 256 |
+
*/
|
| 257 |
+
// NB: steals the inferred function schema, as we may need to hold on to
|
| 258 |
+
// it for a bit until the real schema turns up
|
| 259 |
+
RegistrationHandleRAII registerImpl(
|
| 260 |
+
OperatorName op_name,
|
| 261 |
+
std::optional<DispatchKey> dispatch_key,
|
| 262 |
+
KernelFunction kernel,
|
| 263 |
+
std::optional<impl::CppSignature> cpp_signature,
|
| 264 |
+
std::unique_ptr<FunctionSchema> inferred_function_schema,
|
| 265 |
+
std::string debug);
|
| 266 |
+
|
| 267 |
+
/**
|
| 268 |
+
* Given an operator, tells the Dispatcher that we have implemented a fake
|
| 269 |
+
* impl for this op in the given Python module. Call this a "pystub".
|
| 270 |
+
*/
|
| 271 |
+
RegistrationHandleRAII registerPythonModule(
|
| 272 |
+
const OperatorName& op_name,
|
| 273 |
+
const char* pymodule,
|
| 274 |
+
const char* context);
|
| 275 |
+
|
| 276 |
+
/**
|
| 277 |
+
* Given an operator, throws if we have a pystub.
|
| 278 |
+
*/
|
| 279 |
+
void throwIfHasPythonModule(OperatorName op_name);
|
| 280 |
+
|
| 281 |
+
std::optional<std::pair<const char*, const char*>> getPyStub(
|
| 282 |
+
OperatorName op_name);
|
| 283 |
+
|
| 284 |
+
/**
|
| 285 |
+
* Register a new operator by name.
|
| 286 |
+
*/
|
| 287 |
+
RegistrationHandleRAII registerName(OperatorName op_name);
|
| 288 |
+
|
| 289 |
+
/**
|
| 290 |
+
* Register a fallback kernel for a backend.
|
| 291 |
+
* If an operator is called but there is no concrete kernel for the dispatch
|
| 292 |
+
* key of the given operator arguments, it will check if there is such a
|
| 293 |
+
* fallback kernel for the given dispatch key and, if yes, call that one.
|
| 294 |
+
*/
|
| 295 |
+
RegistrationHandleRAII registerFallback(
|
| 296 |
+
DispatchKey dispatch_key,
|
| 297 |
+
KernelFunction kernel,
|
| 298 |
+
std::string debug);
|
| 299 |
+
|
| 300 |
+
/**
|
| 301 |
+
* Use to register whenever we had a TORCH_LIBRARY declaration in the frontend
|
| 302 |
+
* API. These invocations are only permitted once per program, so we raise
|
| 303 |
+
* an error if this is called again for the same namespace.
|
| 304 |
+
*/
|
| 305 |
+
RegistrationHandleRAII registerLibrary(std::string ns, std::string debug);
|
| 306 |
+
|
| 307 |
+
// ------------------------------------------------------------------------
|
| 308 |
+
//
|
| 309 |
+
// Listeners on registrations
|
| 310 |
+
//
|
| 311 |
+
// ------------------------------------------------------------------------
|
| 312 |
+
|
| 313 |
+
/**
|
| 314 |
+
* Add a listener that gets called whenever a new op is registered or an
|
| 315 |
+
* existing op is deregistered. Immediately after registering, this listener
|
| 316 |
+
* gets called for all previously registered ops, so it can be used to keep
|
| 317 |
+
* track of ops registered with this dispatcher.
|
| 318 |
+
*/
|
| 319 |
+
RegistrationHandleRAII addRegistrationListener(
|
| 320 |
+
std::unique_ptr<OpRegistrationListener> listener);
|
| 321 |
+
|
| 322 |
+
void checkInvariants() const;
|
| 323 |
+
|
| 324 |
+
//
|
| 325 |
+
// ------------------------------------------------------------------------
|
| 326 |
+
//
|
| 327 |
+
// Assertions
|
| 328 |
+
//
|
| 329 |
+
// ------------------------------------------------------------------------
|
| 330 |
+
|
| 331 |
+
/**
|
| 332 |
+
* For testing purposes.
|
| 333 |
+
* Returns a list of all operators that were created through calls to
|
| 334 |
+
* registerImpl(), without any corresponding calls to registerDef(). After
|
| 335 |
+
* static initialization is done this is almost certainly a bug, as the
|
| 336 |
+
* created OperatorHandle won't have any schema associated with it and users
|
| 337 |
+
* calling the op through the dispatcher won't be able to access it
|
| 338 |
+
*
|
| 339 |
+
* Note that we cannot enforce this invariant "as we go" during static
|
| 340 |
+
* initialization, due to undefined static initialization order- we have no
|
| 341 |
+
* guarantees over the order in which .def() and .impl() calls are registered
|
| 342 |
+
* in the dispatcher at static initialization time. So this function should
|
| 343 |
+
* only be called after static initialization.
|
| 344 |
+
*/
|
| 345 |
+
std::vector<OperatorHandle> findDanglingImpls() const;
|
| 346 |
+
|
| 347 |
+
/**
|
| 348 |
+
* Useful for inspecting global Dispatcher registration state.
|
| 349 |
+
* Returns the names of all operators with a kernel registered for the
|
| 350 |
+
* specified DispatchKey. If no DispatchKey is specified, it returns all
|
| 351 |
+
* registered operators.
|
| 352 |
+
*/
|
| 353 |
+
std::vector<OperatorName> getRegistrationsForDispatchKey(
|
| 354 |
+
std::optional<DispatchKey> k) const;
|
| 355 |
+
|
| 356 |
+
private:
|
| 357 |
+
Dispatcher();
|
| 358 |
+
|
| 359 |
+
static int64_t sequenceNumberForRunningRecordFunction(
|
| 360 |
+
DispatchKey dispatchKey,
|
| 361 |
+
DispatchKeySet dispatchKeySet);
|
| 362 |
+
static void runRecordFunction(
|
| 363 |
+
at::RecordFunction& guard,
|
| 364 |
+
at::RecordFunction::schema_ref_t schema_ref,
|
| 365 |
+
DispatchKey dispatchKey,
|
| 366 |
+
DispatchKeySet dispatchKeySet);
|
| 367 |
+
static void runRecordFunction(
|
| 368 |
+
at::RecordFunction& guard,
|
| 369 |
+
at::RecordFunction::schema_ref_t schema_ref,
|
| 370 |
+
DispatchKey dispatchKey,
|
| 371 |
+
DispatchKeySet dispatchKeySet,
|
| 372 |
+
c10::ArrayRef<const c10::IValue> args);
|
| 373 |
+
|
| 374 |
+
#ifdef FBCODE_CAFFE2
|
| 375 |
+
static bool profilingOperatorEvents();
|
| 376 |
+
static void fireOpStartUSDT(
|
| 377 |
+
at::RecordFunction::schema_ref_t schema_ref,
|
| 378 |
+
std::vector<void*>& argsAddresses,
|
| 379 |
+
std::vector<const char*>& argsTypes);
|
| 380 |
+
static void fireOpEndUSDT(at::RecordFunction::schema_ref_t schema_ref);
|
| 381 |
+
#endif // FBCODE_CAFFE2
|
| 382 |
+
|
| 383 |
+
OperatorHandle findOrRegisterSchema_(FunctionSchema&& schema);
|
| 384 |
+
OperatorHandle findOrRegisterName_(const OperatorName& op_name);
|
| 385 |
+
|
| 386 |
+
void deregisterDef_(const OperatorHandle& op, const OperatorName& op_name);
|
| 387 |
+
void deregisterImpl_(
|
| 388 |
+
const OperatorHandle& op,
|
| 389 |
+
const OperatorName& op_name,
|
| 390 |
+
std::optional<DispatchKey> dispatch_key,
|
| 391 |
+
impl::OperatorEntry::AnnotatedKernelContainerIterator kernel_handle);
|
| 392 |
+
void deregisterName_(const OperatorHandle& op, const OperatorName& op_name);
|
| 393 |
+
void deregisterFallback_(DispatchKey dispatchKey);
|
| 394 |
+
void deregisterLibrary_(const std::string& ns);
|
| 395 |
+
void cleanup(const OperatorHandle& op, const OperatorName& op_name);
|
| 396 |
+
void checkSchemaCompatibility(
|
| 397 |
+
const OperatorHandle& op,
|
| 398 |
+
const FunctionSchema& schema,
|
| 399 |
+
const std::string& debug);
|
| 400 |
+
|
| 401 |
+
std::list<OperatorDef> operators_;
|
| 402 |
+
#if !defined(C10_MOBILE)
|
| 403 |
+
LeftRight<ska::flat_hash_map<OperatorName, OperatorHandle>>
|
| 404 |
+
operatorLookupTable_;
|
| 405 |
+
#else
|
| 406 |
+
RWSafeLeftRightWrapper<ska::flat_hash_map<OperatorName, OperatorHandle>>
|
| 407 |
+
operatorLookupTable_;
|
| 408 |
+
#endif
|
| 409 |
+
// Map from namespace to debug string (saying, e.g., where the library was
|
| 410 |
+
// defined)
|
| 411 |
+
ska::flat_hash_map<std::string, std::string> libraries_;
|
| 412 |
+
|
| 413 |
+
std::array<impl::AnnotatedKernel, num_runtime_entries>
|
| 414 |
+
backendFallbackKernels_;
|
| 415 |
+
|
| 416 |
+
std::unique_ptr<detail::RegistrationListenerList> listeners_;
|
| 417 |
+
|
| 418 |
+
// This condition variable gets notified whenever we add a new def/impl to the
|
| 419 |
+
// dispatch table. This is primarily used by multiply/torchdeploy, when
|
| 420 |
+
// we have multiple interpreters trying to register to the dispatch table.
|
| 421 |
+
// In this situation, whenever the non-primary interpreter would have tried
|
| 422 |
+
// to register to the dispatch table, instead it will check to see if the
|
| 423 |
+
// expected registration has already been made, and if it hasn't, wait on
|
| 424 |
+
// this condition variable to see if it was just racing with the primary
|
| 425 |
+
// interpreter.
|
| 426 |
+
//
|
| 427 |
+
// We expect it to be rare for there to be any waiters on this condition
|
| 428 |
+
// variable. This is mostly just to help give better diagnostics if
|
| 429 |
+
// something goes horribly wrong
|
| 430 |
+
std::condition_variable cond_var_;
|
| 431 |
+
|
| 432 |
+
// Protect concurrent access to the dispatcher. We store this in a
|
| 433 |
+
// `shared_ptr` as we return callbacks that call back into dispatcher methods,
|
| 434 |
+
// and we need to be able to handle and guard against the event when the
|
| 435 |
+
// `Dispatcher` has been destroyed before the callbacks fire.
|
| 436 |
+
std::shared_ptr<Guard> guard_;
|
| 437 |
+
};
|
| 438 |
+
|
| 439 |
+
/**
|
| 440 |
+
* This is a handle to an operator schema registered with the dispatcher.
|
| 441 |
+
* This handle can be used to register kernels with the dispatcher or
|
| 442 |
+
* to lookup a kernel for a certain set of arguments.
|
| 443 |
+
*/
|
| 444 |
+
class TORCH_API OperatorHandle {
|
| 445 |
+
template <typename T>
|
| 446 |
+
friend struct std::hash;
|
| 447 |
+
|
| 448 |
+
public:
|
| 449 |
+
OperatorHandle(OperatorHandle&&) noexcept = default;
|
| 450 |
+
OperatorHandle& operator=(OperatorHandle&&) noexcept = default;
|
| 451 |
+
OperatorHandle(const OperatorHandle&) = default;
|
| 452 |
+
OperatorHandle& operator=(const OperatorHandle&) = default;
|
| 453 |
+
// NOLINTNEXTLINE(performance-trivially-destructible)
|
| 454 |
+
~OperatorHandle();
|
| 455 |
+
|
| 456 |
+
const OperatorName& operator_name() const {
|
| 457 |
+
return operatorDef_->op.operator_name();
|
| 458 |
+
}
|
| 459 |
+
|
| 460 |
+
bool hasSchema() const {
|
| 461 |
+
return operatorDef_->op.hasSchema();
|
| 462 |
+
}
|
| 463 |
+
|
| 464 |
+
const FunctionSchema& schema() const {
|
| 465 |
+
return operatorDef_->op.schema();
|
| 466 |
+
}
|
| 467 |
+
|
| 468 |
+
const std::string& debug() const {
|
| 469 |
+
return operatorDef_->op.debug();
|
| 470 |
+
}
|
| 471 |
+
|
| 472 |
+
std::string dumpState() const {
|
| 473 |
+
return operatorDef_->op.dumpState();
|
| 474 |
+
}
|
| 475 |
+
|
| 476 |
+
bool hasKernelForDispatchKey(DispatchKey k) const {
|
| 477 |
+
return operatorDef_->op.hasKernelForDispatchKey(k);
|
| 478 |
+
}
|
| 479 |
+
|
| 480 |
+
bool isKernelFallthroughKernel(DispatchKey k) const {
|
| 481 |
+
return operatorDef_->op.kernelForDispatchKey(k).isFallthrough();
|
| 482 |
+
}
|
| 483 |
+
|
| 484 |
+
bool hasKernelForAnyDispatchKey(DispatchKeySet k) const {
|
| 485 |
+
return operatorDef_->op.hasKernelForAnyDispatchKey(k);
|
| 486 |
+
}
|
| 487 |
+
|
| 488 |
+
bool hasComputedKernelForDispatchKey(DispatchKey k) const {
|
| 489 |
+
return operatorDef_->op.hasComputedKernelForDispatchKey(k);
|
| 490 |
+
}
|
| 491 |
+
|
| 492 |
+
SafeKernelFunction getComputedKernelForDispatchKey(DispatchKey k) const {
|
| 493 |
+
return operatorDef_->op.getComputedKernelForDispatchKey(k);
|
| 494 |
+
}
|
| 495 |
+
|
| 496 |
+
std::string dumpComputedTable() const {
|
| 497 |
+
return operatorDef_->op.dumpComputedTable();
|
| 498 |
+
}
|
| 499 |
+
|
| 500 |
+
void checkInvariants() const {
|
| 501 |
+
operatorDef_->op.checkInvariants();
|
| 502 |
+
}
|
| 503 |
+
|
| 504 |
+
c10::ArrayRef<at::Tag> getTags() const {
|
| 505 |
+
return operatorDef_->op.getTags();
|
| 506 |
+
}
|
| 507 |
+
|
| 508 |
+
void setReportErrorCallback_(std::unique_ptr<c10::SafePyObject> callback) {
|
| 509 |
+
operatorDef_->op.setReportErrorCallback_(std::move(callback));
|
| 510 |
+
}
|
| 511 |
+
|
| 512 |
+
bool hasTag(const at::Tag& tag) const {
|
| 513 |
+
for (const auto& tag_ : getTags()) {
|
| 514 |
+
if (tag == tag_) {
|
| 515 |
+
return true;
|
| 516 |
+
}
|
| 517 |
+
}
|
| 518 |
+
return false;
|
| 519 |
+
}
|
| 520 |
+
|
| 521 |
+
template <class FuncType>
|
| 522 |
+
TypedOperatorHandle<FuncType> typed() const {
|
| 523 |
+
// NB: This assert is not 100% sound: you can retrieve a typed() operator
|
| 524 |
+
// handle prior to ANY C++ signature being registered on the operator
|
| 525 |
+
// and the check will say everything is OK (at which point you can then
|
| 526 |
+
// smuggle in a kernel that is typed incorrectly). For everything
|
| 527 |
+
// in core library this won't happen, because all the static registrations
|
| 528 |
+
// will be done by the time a typed() handle is acquired.
|
| 529 |
+
#if !defined C10_MOBILE
|
| 530 |
+
operatorDef_->op.assertSignatureIsCorrect<FuncType>();
|
| 531 |
+
if (fn_has_symint<FuncType>::value) {
|
| 532 |
+
operatorDef_->op.assertSignatureIsCorrect<
|
| 533 |
+
typename fn_remove_symint<FuncType>::type>();
|
| 534 |
+
}
|
| 535 |
+
#endif
|
| 536 |
+
return TypedOperatorHandle<FuncType>(operatorIterator_);
|
| 537 |
+
}
|
| 538 |
+
|
| 539 |
+
void callBoxed(Stack* stack) const {
|
| 540 |
+
c10::Dispatcher::singleton().callBoxed(*this, stack);
|
| 541 |
+
}
|
| 542 |
+
|
| 543 |
+
void callBoxed(Stack& stack) const {
|
| 544 |
+
callBoxed(&stack);
|
| 545 |
+
}
|
| 546 |
+
|
| 547 |
+
void callBoxedForDispatchKey(DispatchKey dk, Stack& stack) const {
|
| 548 |
+
c10::Dispatcher::singleton().callBoxedForDispatchKey(*this, dk, &stack);
|
| 549 |
+
}
|
| 550 |
+
|
| 551 |
+
void redispatchBoxed(DispatchKeySet ks, Stack* stack) const {
|
| 552 |
+
c10::Dispatcher::singleton().redispatchBoxed(*this, ks, stack);
|
| 553 |
+
}
|
| 554 |
+
|
| 555 |
+
template <typename F>
|
| 556 |
+
PyObject* getPythonOp(
|
| 557 |
+
c10::impl::PyInterpreter* self_interpreter,
|
| 558 |
+
F slow_accessor) const {
|
| 559 |
+
return operatorDef_->op.getPythonOp(self_interpreter, slow_accessor);
|
| 560 |
+
}
|
| 561 |
+
|
| 562 |
+
bool operator==(const OperatorHandle& other) const {
|
| 563 |
+
return operatorDef_ == other.operatorDef_;
|
| 564 |
+
}
|
| 565 |
+
|
| 566 |
+
bool operator!=(const OperatorHandle& other) const {
|
| 567 |
+
return operatorDef_ != other.operatorDef_;
|
| 568 |
+
}
|
| 569 |
+
|
| 570 |
+
private:
|
| 571 |
+
explicit OperatorHandle(
|
| 572 |
+
std::list<Dispatcher::OperatorDef>::iterator operatorIterator)
|
| 573 |
+
: operatorDef_(&*operatorIterator), operatorIterator_(operatorIterator) {}
|
| 574 |
+
friend class Dispatcher;
|
| 575 |
+
template <class>
|
| 576 |
+
friend class TypedOperatorHandle;
|
| 577 |
+
|
| 578 |
+
// Storing a direct pointer to the OperatorDef even though we
|
| 579 |
+
// already have the iterator saves an instruction in the critical
|
| 580 |
+
// dispatch path. The iterator is effectively a
|
| 581 |
+
// pointer-to-std::list-node, and (at least in libstdc++'s
|
| 582 |
+
// implementation) the element is at an offset 16 bytes from that,
|
| 583 |
+
// because the prev/next pointers come first in the list node
|
| 584 |
+
// struct. So, an add instruction would be necessary to convert from the
|
| 585 |
+
// iterator to an OperatorDef*.
|
| 586 |
+
Dispatcher::OperatorDef* operatorDef_;
|
| 587 |
+
|
| 588 |
+
// We need to store this iterator in order to make
|
| 589 |
+
// Dispatcher::cleanup() fast -- it runs a lot on program
|
| 590 |
+
// termination (and presumably library unloading).
|
| 591 |
+
std::list<Dispatcher::OperatorDef>::iterator operatorIterator_;
|
| 592 |
+
};
|
| 593 |
+
|
| 594 |
+
/**
|
| 595 |
+
* This is a handle to an operator schema registered with the dispatcher.
|
| 596 |
+
* It holds the same information as an OperatorHandle, but it is templated
|
| 597 |
+
* on the operator arguments and allows calling the operator in an
|
| 598 |
+
* unboxed way.
|
| 599 |
+
*/
|
| 600 |
+
template <class FuncType>
|
| 601 |
+
class TypedOperatorHandle final {
|
| 602 |
+
static_assert(
|
| 603 |
+
guts::false_t<FuncType>(),
|
| 604 |
+
"FuncType in OperatorHandle::typed<FuncType> was not a valid function type");
|
| 605 |
+
};
|
| 606 |
+
template <class Return, class... Args>
|
| 607 |
+
class TypedOperatorHandle<Return(Args...)> final : public OperatorHandle {
|
| 608 |
+
public:
|
| 609 |
+
TypedOperatorHandle(TypedOperatorHandle&&) noexcept = default;
|
| 610 |
+
TypedOperatorHandle& operator=(TypedOperatorHandle&&) noexcept = default;
|
| 611 |
+
TypedOperatorHandle(const TypedOperatorHandle&) = default;
|
| 612 |
+
TypedOperatorHandle& operator=(const TypedOperatorHandle&) = default;
|
| 613 |
+
|
| 614 |
+
// See [Note: Argument forwarding in the dispatcher] for why Args doesn't use
|
| 615 |
+
// &&
|
| 616 |
+
C10_ALWAYS_INLINE Return call(Args... args) const {
|
| 617 |
+
return c10::Dispatcher::singleton().call<Return, Args...>(
|
| 618 |
+
*this, std::forward<Args>(args)...);
|
| 619 |
+
}
|
| 620 |
+
|
| 621 |
+
// See [Note: Argument forwarding in the dispatcher] for why Args doesn't use
|
| 622 |
+
// &&
|
| 623 |
+
C10_ALWAYS_INLINE Return
|
| 624 |
+
redispatch(DispatchKeySet currentDispatchKeySet, Args... args) const {
|
| 625 |
+
return c10::Dispatcher::singleton().redispatch<Return, Args...>(
|
| 626 |
+
*this, currentDispatchKeySet, std::forward<Args>(args)...);
|
| 627 |
+
}
|
| 628 |
+
|
| 629 |
+
private:
|
| 630 |
+
explicit TypedOperatorHandle(
|
| 631 |
+
std::list<Dispatcher::OperatorDef>::iterator operatorIterator)
|
| 632 |
+
: OperatorHandle(operatorIterator) {}
|
| 633 |
+
friend class OperatorHandle;
|
| 634 |
+
};
|
| 635 |
+
|
| 636 |
+
namespace detail {
|
| 637 |
+
template <class... Args>
|
| 638 |
+
inline void unused_arg_(const Args&... /*unused*/) {}
|
| 639 |
+
|
| 640 |
+
// CaptureKernelCall is intended to capture return values from Dispatcher
|
| 641 |
+
// unboxed kernel calls. A record function may request to get outputs from the
|
| 642 |
+
// kernel calls. For boxed kernels, it's straightforward, the returned values
|
| 643 |
+
// are in the stack object. The stack can be passed to record functions. For
|
| 644 |
+
// unboxed kernels, we need to handle different kinds of return values, cache
|
| 645 |
+
// them temporarily, then release the values for the actual function call
|
| 646 |
+
// return.
|
| 647 |
+
template <typename ReturnType>
|
| 648 |
+
struct CaptureKernelCall {
|
| 649 |
+
template <typename F, typename... Args>
|
| 650 |
+
CaptureKernelCall(
|
| 651 |
+
const F& kernel,
|
| 652 |
+
const TypedOperatorHandle<ReturnType(Args...)>& op,
|
| 653 |
+
const DispatchKeySet& dispatchKeySet,
|
| 654 |
+
Args&&... args)
|
| 655 |
+
// Calls the kernel and capture the result in output_.
|
| 656 |
+
: output_{kernel.template call<ReturnType, Args...>(
|
| 657 |
+
op,
|
| 658 |
+
dispatchKeySet,
|
| 659 |
+
std::forward<Args>(args)...)} {}
|
| 660 |
+
// Wraps the return values in a Stack.
|
| 661 |
+
Stack getOutputs() {
|
| 662 |
+
Stack stack;
|
| 663 |
+
impl::push_outputs<ReturnType, false>::copy(output_, &stack);
|
| 664 |
+
return stack;
|
| 665 |
+
}
|
| 666 |
+
// Since we are returning the output_, we don't expect the output_ to be used
|
| 667 |
+
// afterward. Copy elision and RVO do not apply to class data members. Using
|
| 668 |
+
// move semantic to avoid copies when possible.
|
| 669 |
+
ReturnType release() && {
|
| 670 |
+
return std::move(output_);
|
| 671 |
+
}
|
| 672 |
+
|
| 673 |
+
private:
|
| 674 |
+
ReturnType output_;
|
| 675 |
+
};
|
| 676 |
+
|
| 677 |
+
// Handle the lvalue reference differently since it should not be moved.
|
| 678 |
+
template <>
|
| 679 |
+
inline at::Tensor& CaptureKernelCall<at::Tensor&>::release() && {
|
| 680 |
+
return output_;
|
| 681 |
+
}
|
| 682 |
+
|
| 683 |
+
// Handle case where the kernel returns void.
|
| 684 |
+
template <>
|
| 685 |
+
struct CaptureKernelCall<void> {
|
| 686 |
+
template <typename F, typename... Args>
|
| 687 |
+
CaptureKernelCall(
|
| 688 |
+
const F& kernel,
|
| 689 |
+
const TypedOperatorHandle<void(Args...)>& op,
|
| 690 |
+
const DispatchKeySet& dispatchKeySet,
|
| 691 |
+
Args&&... args) {
|
| 692 |
+
// Calling the kernel and no need to capture void.
|
| 693 |
+
kernel.template call<void, Args...>(
|
| 694 |
+
op, dispatchKeySet, std::forward<Args>(args)...);
|
| 695 |
+
}
|
| 696 |
+
Stack getOutputs() {
|
| 697 |
+
return Stack();
|
| 698 |
+
}
|
| 699 |
+
void release() && {}
|
| 700 |
+
};
|
| 701 |
+
|
| 702 |
+
TORCH_API void _print_dispatch_trace(
|
| 703 |
+
const std::string& label,
|
| 704 |
+
const std::string& op_name,
|
| 705 |
+
const DispatchKeySet& dispatchKeySet);
|
| 706 |
+
|
| 707 |
+
} // namespace detail
|
| 708 |
+
|
| 709 |
+
// See [Note: Argument forwarding in the dispatcher] for why Args doesn't use &&
|
| 710 |
+
template <class Return, class... Args>
|
| 711 |
+
inline Return Dispatcher::callWithDispatchKeySlowPath(
|
| 712 |
+
const TypedOperatorHandle<Return(Args...)>& op,
|
| 713 |
+
at::StepCallbacks& stepCallbacks,
|
| 714 |
+
DispatchKeySet dispatchKeySet,
|
| 715 |
+
const KernelFunction& kernel,
|
| 716 |
+
Args... args) {
|
| 717 |
+
// If callbacks need inputs, we box the arguments and pass them to the guard.
|
| 718 |
+
// Note: For perf reasons we wouldn't want to prematurely box the arguments.
|
| 719 |
+
at::RecordFunction guard(std::move(stepCallbacks));
|
| 720 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(op.operatorDef_->op.isObserved());
|
| 721 |
+
auto dispatchKey = dispatchKeySet.highestPriorityTypeId();
|
| 722 |
+
auto& schema = op.schema();
|
| 723 |
+
auto schema_ref = std::reference_wrapper<const FunctionSchema>(schema);
|
| 724 |
+
constexpr auto num_boxed_args = impl::boxed_size<Args...>();
|
| 725 |
+
if constexpr (num_boxed_args != 0) {
|
| 726 |
+
if (guard.needsInputs()) {
|
| 727 |
+
// If we used std::array<IValue, num_boxed_args> here, we would
|
| 728 |
+
// have to spend time default constructing the IValues in
|
| 729 |
+
// boxedArgs. aligned_storage has no such requirement.
|
| 730 |
+
// NOLINTNEXTLINE(*array*)
|
| 731 |
+
alignas(IValue) std::byte boxedArgs[num_boxed_args * sizeof(IValue)];
|
| 732 |
+
// For debugging only; could be removed (but the compiler will do
|
| 733 |
+
// that for us and it's nice to have the extra assurance of
|
| 734 |
+
// correctness from our debug builds).
|
| 735 |
+
IValue* boxedArgsPtr = reinterpret_cast<IValue*>(boxedArgs);
|
| 736 |
+
impl::boxArgsToStack(boxedArgsPtr, args...);
|
| 737 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
| 738 |
+
reinterpret_cast<std::byte*>(boxedArgsPtr) ==
|
| 739 |
+
boxedArgs + num_boxed_args * sizeof(IValue));
|
| 740 |
+
// I don't *think* we need std::launder here, because IValue has
|
| 741 |
+
// no subclasses and no const or reference fields.
|
| 742 |
+
runRecordFunction(
|
| 743 |
+
guard,
|
| 744 |
+
schema_ref,
|
| 745 |
+
dispatchKey,
|
| 746 |
+
dispatchKeySet,
|
| 747 |
+
c10::ArrayRef<const c10::IValue>(
|
| 748 |
+
reinterpret_cast<IValue*>(boxedArgs), num_boxed_args));
|
| 749 |
+
boxedArgsPtr = reinterpret_cast<IValue*>(boxedArgs);
|
| 750 |
+
for (size_t ii = 0; ii < num_boxed_args; ++ii) {
|
| 751 |
+
(boxedArgsPtr + ii)->~IValue();
|
| 752 |
+
}
|
| 753 |
+
} else {
|
| 754 |
+
runRecordFunction(guard, schema_ref, dispatchKey, dispatchKeySet);
|
| 755 |
+
}
|
| 756 |
+
} else {
|
| 757 |
+
runRecordFunction(guard, schema_ref, dispatchKey, dispatchKeySet);
|
| 758 |
+
}
|
| 759 |
+
|
| 760 |
+
if (C10_UNLIKELY(guard.needsOutputs())) {
|
| 761 |
+
// Calls the kernel and capture the output temporarily to pass to
|
| 762 |
+
// RecordFunction.
|
| 763 |
+
detail::CaptureKernelCall<Return> captureKernelCall(
|
| 764 |
+
kernel, op, dispatchKeySet, std::forward<Args>(args)...);
|
| 765 |
+
guard.setOutputs(captureKernelCall.getOutputs());
|
| 766 |
+
// Releases the captured output to return to caller.
|
| 767 |
+
return std::move(captureKernelCall).release();
|
| 768 |
+
}
|
| 769 |
+
|
| 770 |
+
// keeping the guard alive while executing the kernel
|
| 771 |
+
return kernel.template call<Return, Args...>(
|
| 772 |
+
op, dispatchKeySet, std::forward<Args>(args)...);
|
| 773 |
+
}
|
| 774 |
+
|
| 775 |
+
// See [Note: Argument forwarding in the dispatcher] for why Args doesn't use &&
|
| 776 |
+
template <class Return, class... Args>
|
| 777 |
+
C10_ALWAYS_INLINE_UNLESS_MOBILE Return Dispatcher::call(
|
| 778 |
+
const TypedOperatorHandle<Return(Args...)>& op,
|
| 779 |
+
Args... args) const {
|
| 780 |
+
auto dispatchKeySet =
|
| 781 |
+
op.operatorDef_->op.dispatchKeyExtractor()
|
| 782 |
+
.template getDispatchKeySetUnboxed<Args...>(args...);
|
| 783 |
+
#if defined(HAS_TORCH_SHOW_DISPATCH_TRACE) || !defined(NDEBUG)
|
| 784 |
+
DispatchTraceNestingGuard debug_guard;
|
| 785 |
+
if (show_dispatch_trace()) {
|
| 786 |
+
detail::_print_dispatch_trace(
|
| 787 |
+
"[call]", toString(op.operator_name()), dispatchKeySet);
|
| 788 |
+
}
|
| 789 |
+
#endif
|
| 790 |
+
const KernelFunction& kernel = op.operatorDef_->op.lookup(dispatchKeySet);
|
| 791 |
+
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
|
| 792 |
+
auto step_callbacks =
|
| 793 |
+
at::getStepCallbacksUnlessEmpty(at::RecordScope::FUNCTION);
|
| 794 |
+
if (C10_UNLIKELY(
|
| 795 |
+
step_callbacks.has_value() && op.operatorDef_->op.isObserved())) {
|
| 796 |
+
return callWithDispatchKeySlowPath<Return, Args...>(
|
| 797 |
+
op,
|
| 798 |
+
*step_callbacks,
|
| 799 |
+
dispatchKeySet,
|
| 800 |
+
kernel,
|
| 801 |
+
std::forward<Args>(args)...);
|
| 802 |
+
}
|
| 803 |
+
#endif // PYTORCH_DISABLE_PER_OP_PROFILING
|
| 804 |
+
|
| 805 |
+
#ifdef FBCODE_CAFFE2
|
| 806 |
+
if (profilingOperatorEvents()) {
|
| 807 |
+
std::vector<void*> argsAddresses = {(void*)(&args)...};
|
| 808 |
+
std::vector<const char*> argsTypes = {(typeid(args).name())...};
|
| 809 |
+
struct FireOpRAII {
|
| 810 |
+
FireOpRAII(
|
| 811 |
+
at::RecordFunction::schema_ref_t schema_ref,
|
| 812 |
+
std::vector<void*>& argsAddresses,
|
| 813 |
+
std::vector<const char*>& argsTypes)
|
| 814 |
+
: schema_ref_(schema_ref) {
|
| 815 |
+
fireOpStartUSDT(schema_ref, argsAddresses, argsTypes);
|
| 816 |
+
}
|
| 817 |
+
~FireOpRAII() {
|
| 818 |
+
fireOpEndUSDT(schema_ref_);
|
| 819 |
+
}
|
| 820 |
+
at::RecordFunction::schema_ref_t schema_ref_;
|
| 821 |
+
} event(op.schema(), argsAddresses, argsTypes);
|
| 822 |
+
return kernel.template call<Return, Args...>(
|
| 823 |
+
op, dispatchKeySet, std::forward<Args>(args)...);
|
| 824 |
+
} else {
|
| 825 |
+
return kernel.template call<Return, Args...>(
|
| 826 |
+
op, dispatchKeySet, std::forward<Args>(args)...);
|
| 827 |
+
}
|
| 828 |
+
#else
|
| 829 |
+
return kernel.template call<Return, Args...>(
|
| 830 |
+
op, dispatchKeySet, std::forward<Args>(args)...);
|
| 831 |
+
#endif // FBCODE_CAFFE2
|
| 832 |
+
}
|
| 833 |
+
|
| 834 |
+
// See [Note: Argument forwarding in the dispatcher] for why Args doesn't use &&
|
| 835 |
+
template <class Return, class... Args>
|
| 836 |
+
inline Return Dispatcher::redispatch(
|
| 837 |
+
const TypedOperatorHandle<Return(Args...)>& op,
|
| 838 |
+
DispatchKeySet currentDispatchKeySet,
|
| 839 |
+
Args... args) const {
|
| 840 |
+
// do not use RecordFunction on redispatch
|
| 841 |
+
#if defined(HAS_TORCH_SHOW_DISPATCH_TRACE) || !defined(NDEBUG)
|
| 842 |
+
DispatchTraceNestingGuard debug_guard;
|
| 843 |
+
if (show_dispatch_trace()) {
|
| 844 |
+
detail::_print_dispatch_trace(
|
| 845 |
+
"[redispatch]", toString(op.operator_name()), currentDispatchKeySet);
|
| 846 |
+
}
|
| 847 |
+
#endif
|
| 848 |
+
const KernelFunction& kernel =
|
| 849 |
+
op.operatorDef_->op.lookup(currentDispatchKeySet);
|
| 850 |
+
return kernel.template call<Return, Args...>(
|
| 851 |
+
op, currentDispatchKeySet, std::forward<Args>(args)...);
|
| 852 |
+
}
|
| 853 |
+
|
| 854 |
+
inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack)
|
| 855 |
+
const {
|
| 856 |
+
// note: this doesn't need the mutex because write operations on the list keep
|
| 857 |
+
// iterators intact.
|
| 858 |
+
const auto& entry = op.operatorDef_->op;
|
| 859 |
+
auto dispatchKeySet =
|
| 860 |
+
entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack);
|
| 861 |
+
#if defined(HAS_TORCH_SHOW_DISPATCH_TRACE) || !defined(NDEBUG)
|
| 862 |
+
DispatchTraceNestingGuard debug_guard;
|
| 863 |
+
if (show_dispatch_trace()) {
|
| 864 |
+
detail::_print_dispatch_trace(
|
| 865 |
+
"[callBoxed]", toString(op.operator_name()), dispatchKeySet);
|
| 866 |
+
}
|
| 867 |
+
#endif
|
| 868 |
+
const auto& kernel = entry.lookup(dispatchKeySet);
|
| 869 |
+
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
|
| 870 |
+
auto step_callbacks =
|
| 871 |
+
at::getStepCallbacksUnlessEmpty(at::RecordScope::FUNCTION);
|
| 872 |
+
if (C10_UNLIKELY(step_callbacks.has_value() && entry.isObserved())) {
|
| 873 |
+
at::RecordFunction guard(std::move(*step_callbacks));
|
| 874 |
+
auto dispatchKey = dispatchKeySet.highestPriorityTypeId();
|
| 875 |
+
auto& schema = op.schema();
|
| 876 |
+
auto schema_ref = std::reference_wrapper<const FunctionSchema>(schema);
|
| 877 |
+
guard.needsInputs()
|
| 878 |
+
? runRecordFunction(
|
| 879 |
+
guard,
|
| 880 |
+
schema_ref,
|
| 881 |
+
dispatchKey,
|
| 882 |
+
dispatchKeySet,
|
| 883 |
+
c10::ArrayRef<const c10::IValue>(stack->data(), stack->size()))
|
| 884 |
+
: runRecordFunction(guard, schema_ref, dispatchKey, dispatchKeySet);
|
| 885 |
+
|
| 886 |
+
// keeping the guard alive while executing the kernel
|
| 887 |
+
kernel.callBoxed(op, dispatchKeySet, stack);
|
| 888 |
+
|
| 889 |
+
if (C10_UNLIKELY(guard.needsOutputs())) {
|
| 890 |
+
guard.setOutputs(*stack);
|
| 891 |
+
}
|
| 892 |
+
return;
|
| 893 |
+
}
|
| 894 |
+
#endif // PYTORCH_DISABLE_PER_OP_PROFILING
|
| 895 |
+
kernel.callBoxed(op, dispatchKeySet, stack);
|
| 896 |
+
}
|
| 897 |
+
|
| 898 |
+
// NB: this doesn't count as a "true" dispatcher jump, so no instrumentation
|
| 899 |
+
inline void Dispatcher::callBoxedForDispatchKey(
|
| 900 |
+
const OperatorHandle& op,
|
| 901 |
+
DispatchKey dk,
|
| 902 |
+
Stack* stack) const {
|
| 903 |
+
// note: this doesn't need the mutex because write operations on the list keep
|
| 904 |
+
// iterators intact.
|
| 905 |
+
const auto& entry = op.operatorDef_->op;
|
| 906 |
+
// We still compute this as we're obligated to pass it on to the internal
|
| 907 |
+
// kernel, if it is a boxed fallback
|
| 908 |
+
auto dispatchKeySet =
|
| 909 |
+
entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack);
|
| 910 |
+
const auto& kernel = ([&]() {
|
| 911 |
+
if (op.hasKernelForDispatchKey(dk)) {
|
| 912 |
+
return entry.kernelForDispatchKey(dk);
|
| 913 |
+
} else {
|
| 914 |
+
auto idx = getDispatchTableIndexForDispatchKey(dk);
|
| 915 |
+
TORCH_INTERNAL_ASSERT(idx >= 0);
|
| 916 |
+
return backendFallbackKernels_[idx].kernel;
|
| 917 |
+
}
|
| 918 |
+
})();
|
| 919 |
+
kernel.callBoxed(op, dispatchKeySet, stack);
|
| 920 |
+
}
|
| 921 |
+
|
| 922 |
+
inline void Dispatcher::redispatchBoxed(
|
| 923 |
+
const OperatorHandle& op,
|
| 924 |
+
DispatchKeySet dispatchKeySet,
|
| 925 |
+
Stack* stack) const {
|
| 926 |
+
// note: this doesn't need the mutex because write operations on the list keep
|
| 927 |
+
// iterators intact.
|
| 928 |
+
const auto& entry = op.operatorDef_->op;
|
| 929 |
+
#if defined(HAS_TORCH_SHOW_DISPATCH_TRACE) || !defined(NDEBUG)
|
| 930 |
+
DispatchTraceNestingGuard debug_guard;
|
| 931 |
+
if (show_dispatch_trace()) {
|
| 932 |
+
detail::_print_dispatch_trace(
|
| 933 |
+
"[redispatchBoxed]", toString(op.operator_name()), dispatchKeySet);
|
| 934 |
+
}
|
| 935 |
+
#endif
|
| 936 |
+
const auto& kernel = entry.lookup(dispatchKeySet);
|
| 937 |
+
kernel.callBoxed(op, dispatchKeySet, stack);
|
| 938 |
+
}
|
| 939 |
+
|
| 940 |
+
} // namespace c10
|
| 941 |
+
|
| 942 |
+
namespace std {
|
| 943 |
+
|
| 944 |
+
template <>
|
| 945 |
+
struct hash<c10::OperatorHandle> {
|
| 946 |
+
size_t operator()(const c10::OperatorHandle& op) const noexcept {
|
| 947 |
+
return std::hash<void*>{}(static_cast<void*>(op.operatorDef_));
|
| 948 |
+
}
|
| 949 |
+
};
|
| 950 |
+
|
| 951 |
+
} // namespace std
|
| 952 |
+
|
| 953 |
+
#else
|
| 954 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 955 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/ObservedOperators.h
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/core/operator_name.h>
|
| 5 |
+
#include <string>
|
| 6 |
+
#include <unordered_set>
|
| 7 |
+
|
| 8 |
+
namespace c10 {
|
| 9 |
+
|
| 10 |
+
struct TORCH_API ObservedOperators {
|
| 11 |
+
ObservedOperators() = delete;
|
| 12 |
+
|
| 13 |
+
static bool isObserved(const OperatorName& name);
|
| 14 |
+
|
| 15 |
+
static std::unordered_set<std::string>& getUnobservedOperatorList();
|
| 16 |
+
};
|
| 17 |
+
|
| 18 |
+
} // namespace c10
|
| 19 |
+
|
| 20 |
+
#else
|
| 21 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 22 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/OperatorEntry.h
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/core/boxing/KernelFunction.h>
|
| 5 |
+
#include <ATen/core/dispatch/DispatchKeyExtractor.h>
|
| 6 |
+
#include <ATen/core/function_schema.h>
|
| 7 |
+
#include <ATen/core/ivalue.h>
|
| 8 |
+
#include <c10/core/DispatchKey.h>
|
| 9 |
+
#include <c10/core/PyHandleCache.h>
|
| 10 |
+
#include <c10/core/SafePyObject.h>
|
| 11 |
+
#include <c10/util/Metaprogramming.h>
|
| 12 |
+
#include <c10/util/flat_hash_map.h>
|
| 13 |
+
|
| 14 |
+
#include <ATen/core/dispatch/CppSignature.h>
|
| 15 |
+
#include <ATen/core/dispatch/OperatorOptions.h>
|
| 16 |
+
#include <ATen/core/dispatch/RegistrationHandleRAII.h>
|
| 17 |
+
#include <ATen/core/enum_tag.h>
|
| 18 |
+
|
| 19 |
+
#include <array>
|
| 20 |
+
#include <list>
|
| 21 |
+
#include <optional>
|
| 22 |
+
|
| 23 |
+
#ifdef C10_MOBILE
|
| 24 |
+
#define C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
|
| 25 |
+
#endif
|
| 26 |
+
|
| 27 |
+
namespace c10 {
|
| 28 |
+
|
| 29 |
+
class Dispatcher;
|
| 30 |
+
|
| 31 |
+
namespace impl {
|
| 32 |
+
|
| 33 |
+
// This data structure represents a kernel that was registered to us from a
|
| 34 |
+
// user. Unlike KernelFunction, AnnotatedKernel contains some extra metadata
|
| 35 |
+
// about the kernel that isn't necessary for actual dispatching (this is why
|
| 36 |
+
// we don't put AnnotatedKernel in the actual DispatchTable), but is useful for
|
| 37 |
+
// giving good error messages.
|
| 38 |
+
struct AnnotatedKernel final {
|
| 39 |
+
AnnotatedKernel(
|
| 40 |
+
KernelFunction k,
|
| 41 |
+
std::unique_ptr<FunctionSchema> s,
|
| 42 |
+
std::string d)
|
| 43 |
+
: kernel(std::move(k)),
|
| 44 |
+
inferred_function_schema(std::move(s)),
|
| 45 |
+
debug(std::move(d)) {}
|
| 46 |
+
AnnotatedKernel() = default;
|
| 47 |
+
KernelFunction kernel;
|
| 48 |
+
std::unique_ptr<FunctionSchema> inferred_function_schema;
|
| 49 |
+
// A little debug string to help us identify the kernel in question.
|
| 50 |
+
// Most importantly it records the TORCH_LIBRARY block that did the
|
| 51 |
+
// registration.
|
| 52 |
+
std::string debug;
|
| 53 |
+
};
|
| 54 |
+
|
| 55 |
+
// This data structure represents operator schema, with metadata specifying
|
| 56 |
+
// where the registration of this schema occurred
|
| 57 |
+
struct AnnotatedSchema final {
|
| 58 |
+
AnnotatedSchema(FunctionSchema s, std::string d)
|
| 59 |
+
: schema(std::move(s)), debug(std::move(d)) {}
|
| 60 |
+
FunctionSchema schema;
|
| 61 |
+
std::string debug;
|
| 62 |
+
};
|
| 63 |
+
|
| 64 |
+
// Internal data structure that records information about a specific operator.
|
| 65 |
+
// It's not part of the public API; typically, users will interact with
|
| 66 |
+
// OperatorHandle instead.
|
| 67 |
+
//
|
| 68 |
+
// Concurrent writes to OperatorEntry are protected by the GLOBAL Dispatcher
|
| 69 |
+
// lock (this is important because some methods in OperatorEntry access
|
| 70 |
+
// dispatcher state)
|
| 71 |
+
class TORCH_API OperatorEntry final {
|
| 72 |
+
public:
|
| 73 |
+
explicit OperatorEntry(OperatorName&& operator_name);
|
| 74 |
+
|
| 75 |
+
OperatorEntry(const OperatorEntry&) = delete;
|
| 76 |
+
OperatorEntry(OperatorEntry&&) noexcept = delete;
|
| 77 |
+
OperatorEntry& operator=(const OperatorEntry&) = delete;
|
| 78 |
+
OperatorEntry& operator=(OperatorEntry&&) noexcept = delete;
|
| 79 |
+
|
| 80 |
+
const FunctionSchema& schema() const {
|
| 81 |
+
TORCH_INTERNAL_ASSERT(
|
| 82 |
+
schema_.has_value(),
|
| 83 |
+
"Tried to access the schema for ",
|
| 84 |
+
name_,
|
| 85 |
+
" which doesn't have a schema registered yet");
|
| 86 |
+
return schema_->schema;
|
| 87 |
+
}
|
| 88 |
+
const std::string& debug() const {
|
| 89 |
+
TORCH_INTERNAL_ASSERT(schema_.has_value());
|
| 90 |
+
return schema_->debug;
|
| 91 |
+
}
|
| 92 |
+
bool hasSchema() const {
|
| 93 |
+
return schema_.has_value();
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
bool isObserved() const {
|
| 97 |
+
return is_observed_;
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
// We may allocate an OperatorEntry for an operator even when we don't
|
| 101 |
+
// have a schema. When we receive the schema registration, we post
|
| 102 |
+
// facto register a schema.
|
| 103 |
+
//
|
| 104 |
+
// NB: registerSchema/deregisterSchema are not idempotent; if you
|
| 105 |
+
// attempt to register a schema when one is already present or vice
|
| 106 |
+
// versa that is an error. (Refcounting for the registrations is
|
| 107 |
+
// handled in the OperatorHandle in Dispatcher)
|
| 108 |
+
void registerSchema(
|
| 109 |
+
FunctionSchema&& /*schema*/,
|
| 110 |
+
std::string&& debug,
|
| 111 |
+
std::vector<at::Tag> tags = {});
|
| 112 |
+
void deregisterSchema();
|
| 113 |
+
|
| 114 |
+
const OperatorName& operator_name() const {
|
| 115 |
+
return name_;
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
#ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
|
| 119 |
+
using AnnotatedKernelContainer = std::array<AnnotatedKernel, 1>;
|
| 120 |
+
#else
|
| 121 |
+
using AnnotatedKernelContainer = std::list<AnnotatedKernel>;
|
| 122 |
+
#endif
|
| 123 |
+
using AnnotatedKernelContainerIterator = AnnotatedKernelContainer::iterator;
|
| 124 |
+
|
| 125 |
+
// Why are kernels and fallback asymmetric? It has to do with ownership.
|
| 126 |
+
// Kernels and the computed dispatch tables for them are canonically
|
| 127 |
+
// owned by OperatorEntry, but backend fallbacks are specified once
|
| 128 |
+
// and apply for all operators, so they should be owned by Dispatcher.
|
| 129 |
+
// However, the registration of a backend fallback affects the
|
| 130 |
+
// state of the computed dispatch table, so when a backend fallback
|
| 131 |
+
// is updated, we need to update the operator tables too. Thus,
|
| 132 |
+
// registerKernel is the mechanism by which we give kernels to
|
| 133 |
+
// operator entry to own (and update dispatch table), but we only
|
| 134 |
+
// need a non-owning mechanism to update fallback.
|
| 135 |
+
|
| 136 |
+
// Precondition: Dispatcher::mutex_ is held
|
| 137 |
+
// Postcondition: caller is responsible for disposing of the kernel
|
| 138 |
+
AnnotatedKernelContainerIterator registerKernel(
|
| 139 |
+
const Dispatcher& dispatcher,
|
| 140 |
+
std::optional<DispatchKey> dispatch_key,
|
| 141 |
+
KernelFunction kernel,
|
| 142 |
+
std::optional<CppSignature> cpp_signature,
|
| 143 |
+
std::unique_ptr<FunctionSchema> inferred_function_schema,
|
| 144 |
+
std::string debug);
|
| 145 |
+
|
| 146 |
+
// Precondition: Dispatcher::mutex_ is held
|
| 147 |
+
void deregisterKernel_(
|
| 148 |
+
const Dispatcher& dispatcher,
|
| 149 |
+
std::optional<DispatchKey> dispatch_key,
|
| 150 |
+
AnnotatedKernelContainerIterator kernel);
|
| 151 |
+
|
| 152 |
+
// Precondition: Dispatcher::mutex_ is held
|
| 153 |
+
void updateFallback(const Dispatcher& dispatcher, DispatchKey dispatch_key);
|
| 154 |
+
|
| 155 |
+
// Precondition: Dispatcher::mutex_ is held
|
| 156 |
+
void updateSchemaAliasAnalysis(AliasAnalysisKind a) {
|
| 157 |
+
TORCH_INTERNAL_ASSERT(schema_.has_value());
|
| 158 |
+
schema_->schema.setAliasAnalysis(a);
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
std::string dumpComputedTable() const;
|
| 162 |
+
std::string dumpState() const;
|
| 163 |
+
void checkInvariants() const;
|
| 164 |
+
|
| 165 |
+
const DispatchKeyExtractor& dispatchKeyExtractor() const {
|
| 166 |
+
return dispatchKeyExtractor_;
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
// Asserts that the given FuncType is correct for calling this operator in an
|
| 170 |
+
// unboxed way.
|
| 171 |
+
template <class FuncType>
|
| 172 |
+
inline void assertSignatureIsCorrect() {
|
| 173 |
+
assertSignatureIsCorrect(
|
| 174 |
+
CppSignature::make<FuncType>(), fn_has_symint<FuncType>::value);
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
void assertSignatureIsCorrect(
|
| 178 |
+
const CppSignature& call_signature,
|
| 179 |
+
bool has_symint) const;
|
| 180 |
+
|
| 181 |
+
[[noreturn]] void reportError(DispatchKey dispatchKey) const;
|
| 182 |
+
|
| 183 |
+
const KernelFunction& lookup(DispatchKeySet ks) const {
|
| 184 |
+
const auto idx = ks.getDispatchTableIndexForDispatchKeySet();
|
| 185 |
+
if (C10_UNLIKELY(idx == -1)) {
|
| 186 |
+
reportError(ks.highestPriorityTypeId());
|
| 187 |
+
}
|
| 188 |
+
const auto& kernel = dispatchTable_[idx];
|
| 189 |
+
// A valid kernel *always* has a boxed kernel and *may* have an
|
| 190 |
+
// unboxed kernel. However, we typically do unboxed calls in at::
|
| 191 |
+
// APIs, where the kernel 1) will very likely be valid and 2)
|
| 192 |
+
// should have an unboxed kernel. Checking the unboxed kernel
|
| 193 |
+
// first will allow us to avoid touching the boxed kernel at all
|
| 194 |
+
// in the common case.
|
| 195 |
+
if (C10_UNLIKELY(!kernel.isValidUnboxed())) {
|
| 196 |
+
if (!kernel.isValid()) {
|
| 197 |
+
reportError(ks.highestPriorityTypeId());
|
| 198 |
+
}
|
| 199 |
+
}
|
| 200 |
+
return kernel;
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
std::string listAllDispatchKeys() const;
|
| 204 |
+
|
| 205 |
+
// Returns true if kernel_ has entry for any key in ks.
|
| 206 |
+
//
|
| 207 |
+
// Invariant: There are no alias keys in the passed-in dispatch key set.
|
| 208 |
+
// Note [No Alias Keys in DispatchKeySet]
|
| 209 |
+
// Alias keys should be checked using `hasKernelForDispatchKey`
|
| 210 |
+
// Alias keys shouldn't go inside of a DispatchKeySet, since they can
|
| 211 |
+
// technically have a value > 63 (causing overflow).
|
| 212 |
+
bool hasKernelForAnyDispatchKey(DispatchKeySet ks) const;
|
| 213 |
+
// Returns true if kernel_ has entry for a particular key.
|
| 214 |
+
bool hasKernelForDispatchKey(DispatchKey k) const;
|
| 215 |
+
// Retrieves the kernel entry at a particular key. Symmetric with
|
| 216 |
+
// hasKernelForDispatchKey. To get the AnnotatedKernel, see
|
| 217 |
+
// getKernelForDispatchKey (private)
|
| 218 |
+
const KernelFunction& kernelForDispatchKey(DispatchKey k) const;
|
| 219 |
+
// Returns true if the "computed table" has an entry for a particular key.
|
| 220 |
+
bool hasComputedKernelForDispatchKey(DispatchKey k) const;
|
| 221 |
+
// Returns a KernelFunction corresponding to the kernel in dispatchTable
|
| 222 |
+
SafeKernelFunction getComputedKernelForDispatchKey(DispatchKey k) const;
|
| 223 |
+
// Returns all the operator tags added at the time of registration
|
| 224 |
+
const std::vector<at::Tag>& getTags() const;
|
| 225 |
+
void setReportErrorCallback_(std::unique_ptr<c10::SafePyObject> callback);
|
| 226 |
+
|
| 227 |
+
template <typename F>
|
| 228 |
+
PyObject* getPythonOp(PyInterpreter* self_interpreter, F slow_accessor)
|
| 229 |
+
const {
|
| 230 |
+
return py_cache_.ptr_or(self_interpreter, slow_accessor);
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
private:
|
| 234 |
+
OperatorName name_;
|
| 235 |
+
std::optional<AnnotatedSchema> schema_;
|
| 236 |
+
#ifndef C10_MOBILE
|
| 237 |
+
std::vector<at::Tag> tags_;
|
| 238 |
+
#endif
|
| 239 |
+
std::array<KernelFunction, c10::num_runtime_entries> dispatchTable_;
|
| 240 |
+
DispatchKeyExtractor dispatchKeyExtractor_;
|
| 241 |
+
// Pointer to the torch.ops.ns.op.overload object for speed
|
| 242 |
+
c10::PyHandleCache py_cache_;
|
| 243 |
+
|
| 244 |
+
// kernels_ stores all registered kernels for the corresponding dispatch key
|
| 245 |
+
// and catchAllKernels_ stores the catch-all kernels.
|
| 246 |
+
// If an operator library gets loaded that overwrites an already existing
|
| 247 |
+
// kernel, both kernels will be in that list but only the newer one will be in
|
| 248 |
+
// dispatchTable. If any of the kernels go away (say the library gets
|
| 249 |
+
// unloaded), we remove the kernel from this list and update the
|
| 250 |
+
// dispatchTable if necessary.
|
| 251 |
+
// Kernels in the list are ordered by registration time descendingly,
|
| 252 |
+
// newer registrations are before older registrations.
|
| 253 |
+
// We do not combine dispatchTable and kernels into one hash map because
|
| 254 |
+
// kernels is a larger data structure and accessed quite infrequently
|
| 255 |
+
// while dispatchTable is accessed often and should be kept small to fit
|
| 256 |
+
// into CPU caches.
|
| 257 |
+
// Invariants:
|
| 258 |
+
// - dispatchTable[dispatch_key] == kernels_[dispatch_key].front()
|
| 259 |
+
// - dispatchTable[dispatch_key] does not exist if and only if
|
| 260 |
+
// kernels_[dispatch_key] does not exist
|
| 261 |
+
// - If kernels_[dispatch_key] exists, then it has elements.
|
| 262 |
+
// It is never an empty list.
|
| 263 |
+
//
|
| 264 |
+
// Why do we do that?
|
| 265 |
+
// -----
|
| 266 |
+
// We mostly do this to enable Jupyter notebooks where a cell registering
|
| 267 |
+
// a kernel could be executed multiple times and the later execution
|
| 268 |
+
// should overwrite the earlier one. Note that this still fails when the
|
| 269 |
+
// function schema changed between the executions, but it works as long
|
| 270 |
+
// as the function schema didn't change. A better solution would be to
|
| 271 |
+
// unload the old extension library from the Jupyter cell when the cell is
|
| 272 |
+
// re-executed and then only allow one kernel here, i.e. error if a kernel
|
| 273 |
+
// is already registered, but that's a lot of effort to implement and
|
| 274 |
+
// currently not high-pri.
|
| 275 |
+
ska::flat_hash_map<
|
| 276 |
+
DispatchKey,
|
| 277 |
+
#ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
|
| 278 |
+
// On mobile, we needn't worry about Jupyter notebooks.
|
| 279 |
+
std::array<AnnotatedKernel, 1>
|
| 280 |
+
#else
|
| 281 |
+
std::list<AnnotatedKernel>
|
| 282 |
+
#endif
|
| 283 |
+
>
|
| 284 |
+
kernels_;
|
| 285 |
+
|
| 286 |
+
const AnnotatedKernel& missingKernel() const;
|
| 287 |
+
const AnnotatedKernel& ambiguousAutogradOtherKernel() const;
|
| 288 |
+
|
| 289 |
+
// cpp_signature_ stores function signature if any of
|
| 290 |
+
// the kernels was created in a way that allowed us to know the function
|
| 291 |
+
// signature (i.e. by supplying an unboxed C++ kernel function).
|
| 292 |
+
// If this is set, it will be used to check that future kernel
|
| 293 |
+
// registrations match and it will be used in unboxed function calls
|
| 294 |
+
// to verify their arguments against the known function signature.
|
| 295 |
+
struct CppSignatureWithDebug {
|
| 296 |
+
CppSignature signature;
|
| 297 |
+
std::string debug;
|
| 298 |
+
std::optional<DispatchKey> dispatch_key;
|
| 299 |
+
};
|
| 300 |
+
std::optional<CppSignatureWithDebug> cpp_signature_;
|
| 301 |
+
std::optional<CppSignatureWithDebug> sym_cpp_signature_;
|
| 302 |
+
|
| 303 |
+
// A Python custom error handler for OperatorEntry::reportError
|
| 304 |
+
std::unique_ptr<c10::SafePyObject> report_error_callback_;
|
| 305 |
+
|
| 306 |
+
// Whether this operator needs to be observed with RecordFunction
|
| 307 |
+
const bool is_observed_;
|
| 308 |
+
|
| 309 |
+
[[noreturn]] void reportSignatureError(
|
| 310 |
+
const CppSignature& call_signature,
|
| 311 |
+
const CppSignatureWithDebug& saved_signature) const;
|
| 312 |
+
const KernelFunction& computeDispatchTableEntry(
|
| 313 |
+
const c10::Dispatcher& dispatcher,
|
| 314 |
+
DispatchKey dispatch_key) const;
|
| 315 |
+
std::pair<const AnnotatedKernel&, const char*>
|
| 316 |
+
computeDispatchTableEntryWithDebug(
|
| 317 |
+
const c10::Dispatcher& dispatcher,
|
| 318 |
+
DispatchKey dispatch_key) const;
|
| 319 |
+
// This function re-establishes the invariant that dispatchTable
|
| 320 |
+
// contains the front element from the kernels list for a given runtime
|
| 321 |
+
// dispatch key.
|
| 322 |
+
void updateDispatchTableEntry_(
|
| 323 |
+
const c10::Dispatcher& dispatcher,
|
| 324 |
+
DispatchKey dispatch_key);
|
| 325 |
+
// Like above, but also handles alias dispatch keys.
|
| 326 |
+
void updateDispatchTable_(
|
| 327 |
+
const c10::Dispatcher& dispatcher,
|
| 328 |
+
DispatchKey dispatch_key);
|
| 329 |
+
// Like above, but for ALL entries in the dispatch table.
|
| 330 |
+
void updateDispatchTableFull_(const c10::Dispatcher& dispatcher);
|
| 331 |
+
// Retrieves a pointer to AnnotatedKernel at
|
| 332 |
+
// kernels_.at(dispatch_key).front().
|
| 333 |
+
const AnnotatedKernel* getKernelForDispatchKey(
|
| 334 |
+
DispatchKey dispatch_key) const;
|
| 335 |
+
};
|
| 336 |
+
|
| 337 |
+
} // namespace impl
|
| 338 |
+
} // namespace c10
|
| 339 |
+
|
| 340 |
+
#else
|
| 341 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 342 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/OperatorOptions.h
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <cstdint>
|
| 5 |
+
|
| 6 |
+
namespace c10 {
|
| 7 |
+
|
| 8 |
+
enum class AliasAnalysisKind : uint8_t {
|
| 9 |
+
INTERNAL_SPECIAL_CASE,
|
| 10 |
+
CONSERVATIVE, // The most conservative alias analysis type, assumes
|
| 11 |
+
// side-effects. This is the default analysis.
|
| 12 |
+
FROM_SCHEMA,
|
| 13 |
+
PURE_FUNCTION
|
| 14 |
+
};
|
| 15 |
+
|
| 16 |
+
#if !defined(_MSC_VER)
|
| 17 |
+
constexpr // Our current MSVC version has a bug that doesn't allow this to be
|
| 18 |
+
// constexpr.
|
| 19 |
+
#endif
|
| 20 |
+
inline const char*
|
| 21 |
+
toString(AliasAnalysisKind aliasAnalysisKind) {
|
| 22 |
+
return (aliasAnalysisKind == AliasAnalysisKind::CONSERVATIVE) ? "CONSERVATIVE"
|
| 23 |
+
: (aliasAnalysisKind == AliasAnalysisKind::FROM_SCHEMA) ? "FROM_SCHEMA"
|
| 24 |
+
: (aliasAnalysisKind == AliasAnalysisKind::PURE_FUNCTION)
|
| 25 |
+
? "PURE_FUNCTION"
|
| 26 |
+
: (aliasAnalysisKind == AliasAnalysisKind::INTERNAL_SPECIAL_CASE)
|
| 27 |
+
? "INTERNAL_SPECIAL_CASE"
|
| 28 |
+
: "UNKNOWN";
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
} // namespace c10
|
| 32 |
+
|
| 33 |
+
#else
|
| 34 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 35 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/RegistrationHandleRAII.h
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <functional>
|
| 5 |
+
|
| 6 |
+
namespace c10 {
|
| 7 |
+
|
| 8 |
+
class RegistrationHandleRAII final {
|
| 9 |
+
public:
|
| 10 |
+
explicit RegistrationHandleRAII(std::function<void()> onDestruction)
|
| 11 |
+
: onDestruction_(std::move(onDestruction)) {}
|
| 12 |
+
|
| 13 |
+
~RegistrationHandleRAII() {
|
| 14 |
+
if (onDestruction_) {
|
| 15 |
+
onDestruction_();
|
| 16 |
+
}
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
RegistrationHandleRAII(const RegistrationHandleRAII&) = delete;
|
| 20 |
+
RegistrationHandleRAII& operator=(const RegistrationHandleRAII&) = delete;
|
| 21 |
+
|
| 22 |
+
RegistrationHandleRAII(RegistrationHandleRAII&& rhs) noexcept
|
| 23 |
+
: onDestruction_(std::move(rhs.onDestruction_)) {
|
| 24 |
+
rhs.onDestruction_ = nullptr;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
RegistrationHandleRAII& operator=(RegistrationHandleRAII&& rhs) noexcept {
|
| 28 |
+
onDestruction_ = std::move(rhs.onDestruction_);
|
| 29 |
+
rhs.onDestruction_ = nullptr;
|
| 30 |
+
return *this;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
private:
|
| 34 |
+
std::function<void()> onDestruction_;
|
| 35 |
+
};
|
| 36 |
+
|
| 37 |
+
} // namespace c10
|
| 38 |
+
|
| 39 |
+
#else
|
| 40 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 41 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/op_registration/adaption.h
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/Tensor.h>
|
| 5 |
+
#include <ATen/TensorUtils.h>
|
| 6 |
+
#include <ATen/core/List.h>
|
| 7 |
+
#include <c10/core/TensorOptions.h>
|
| 8 |
+
|
| 9 |
+
/*
|
| 10 |
+
* [Note: hacky wrapper removal for optional tensor]
|
| 11 |
+
*
|
| 12 |
+
* The kernel implementation takes an optional tensor marked in the schema as
|
| 13 |
+
* Tensor? but the C++ function takes Tensor instead of the std::optional<Tensor>
|
| 14 |
+
* expected by the dispatcher.
|
| 15 |
+
*
|
| 16 |
+
* To remove the hacky wrapper, the C++ function is changed to take
|
| 17 |
+
* std::optional<Tensor> and unwrap the Tensor value at the beginning of
|
| 18 |
+
* the function, e.g.:
|
| 19 |
+
* > c10::MaybeOwned<Tensor> weight_maybe_owned =
|
| 20 |
+
* > at::borrow_from_optional_tensor(weight_opt);
|
| 21 |
+
* > const Tensor& weight = *weight_maybe_owned;
|
| 22 |
+
*
|
| 23 |
+
* We may want to make the kernel handle optional directly without
|
| 24 |
+
* going through the creation of a default-constructed Tensor in
|
| 25 |
+
* at::borrow_from_optional_tensor.
|
| 26 |
+
*/
|
| 27 |
+
|
| 28 |
+
/*
|
| 29 |
+
* [Note: hacky wrapper removal for TensorOptions]
|
| 30 |
+
*
|
| 31 |
+
* The kernel implementation takes a TensorOptions argument but the dispatcher
|
| 32 |
+
* expects separate arguments for dtype, layout, device, pin_memory.
|
| 33 |
+
*
|
| 34 |
+
* To remove the hacky wrapper, the kernel implementation is changed to take
|
| 35 |
+
* the 4 arguments (dtype, layout, device, pin_memory), and assemble the
|
| 36 |
+
* TensorOptions value at the beginning of the function, e.g.:
|
| 37 |
+
* > TensorOptions options = TensorOptions().dtype(dtype).layout(layout)
|
| 38 |
+
* > .device(device).pinned_memory(pin_memory);
|
| 39 |
+
*
|
| 40 |
+
* We may want make the kernel handle these parameters directly without going
|
| 41 |
+
* through the creation of a TensorOptions value.
|
| 42 |
+
*/
|
| 43 |
+
|
| 44 |
+
namespace c10::impl {
|
| 45 |
+
|
| 46 |
+
TORCH_API void common_device_check_failure(Device common_device, const at::Tensor& tensor, at::CheckedFrom methodName, at::CheckedFrom argName);
|
| 47 |
+
|
| 48 |
+
inline void check_and_update_common_device(std::optional<Device>& common_device, const at::Tensor& tensor, at::CheckedFrom methodName, at::CheckedFrom argName) {
|
| 49 |
+
// TODO: Remove this once the following issue is addressed:
|
| 50 |
+
// https://github.com/pytorch/pytorch/issues/57380
|
| 51 |
+
if (!tensor.defined()) {
|
| 52 |
+
return;
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
if (!common_device.has_value()) {
|
| 56 |
+
common_device = tensor.device();
|
| 57 |
+
return;
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
if (C10_UNLIKELY(common_device != tensor.device())) {
|
| 61 |
+
common_device_check_failure(*common_device, tensor, methodName, argName);
|
| 62 |
+
}
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
inline void check_and_update_common_device(std::optional<Device>& common_device, const std::optional<at::Tensor>& tensor, at::CheckedFrom methodName, at::CheckedFrom argName) {
|
| 66 |
+
if (tensor.has_value()) {
|
| 67 |
+
check_and_update_common_device(common_device, tensor.value(), methodName, argName);
|
| 68 |
+
}
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
inline void check_and_update_common_device(std::optional<Device>& common_device, at::ITensorListRef tensors, at::CheckedFrom methodName, at::CheckedFrom argName) {
|
| 72 |
+
for (const auto& tensor : tensors) {
|
| 73 |
+
check_and_update_common_device(common_device, tensor, methodName, argName);
|
| 74 |
+
}
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
inline void check_and_update_common_device(std::optional<Device>& common_device, const List<std::optional<at::Tensor>>& tensors, at::CheckedFrom methodName, at::CheckedFrom argName) {
|
| 78 |
+
for (const auto& tensor : tensors) {
|
| 79 |
+
check_and_update_common_device(common_device, tensor, methodName, argName);
|
| 80 |
+
}
|
| 81 |
+
}
|
| 82 |
+
} // namespace c10::impl
|
| 83 |
+
|
| 84 |
+
#else
|
| 85 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 86 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/op_registration/infer_schema.h
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
/**
|
| 5 |
+
* This file contains functionality to take a C++ function and infer its
|
| 6 |
+
* c10::FunctionSchema.
|
| 7 |
+
*/
|
| 8 |
+
|
| 9 |
+
#include <ATen/core/function_schema.h>
|
| 10 |
+
#include <c10/util/Metaprogramming.h>
|
| 11 |
+
|
| 12 |
+
namespace c10 {
|
| 13 |
+
namespace detail::infer_schema {
|
| 14 |
+
|
| 15 |
+
/// The templated inference code creates `ArgumentDef` instead of `Argument`,
|
| 16 |
+
/// because that can be constructed at compile time and has a much smaller
|
| 17 |
+
/// binary size than having calls to `Argument` constructors in the template.
|
| 18 |
+
/// Creating `Argument` objects from `ArgumentDef` can then be done at
|
| 19 |
+
/// runtime in a non-templated way.
|
| 20 |
+
struct ArgumentDef final {
|
| 21 |
+
using GetTypeFn = TypePtr();
|
| 22 |
+
GetTypeFn* getTypeFn;
|
| 23 |
+
GetTypeFn* getFakeTypeFn;
|
| 24 |
+
constexpr ArgumentDef(): getTypeFn(nullptr), getFakeTypeFn(nullptr) {}
|
| 25 |
+
explicit constexpr ArgumentDef(GetTypeFn *getTypeFn, GetTypeFn *getFakeTypeFn): getTypeFn(getTypeFn), getFakeTypeFn(getFakeTypeFn) {}
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
template<bool V>
|
| 29 |
+
struct bool_t {};
|
| 30 |
+
template<> struct bool_t<true> : std::true_type {};
|
| 31 |
+
template<> struct bool_t<false> : std::false_type {};
|
| 32 |
+
|
| 33 |
+
/// Checks the static C++ types `Types` for correctness to catch common error cases.
|
| 34 |
+
template <class... Types>
|
| 35 |
+
constexpr int checkStaticTypes() {
|
| 36 |
+
// Give nice error messages for some of the common error cases.
|
| 37 |
+
// Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT
|
| 38 |
+
static_assert(std::conjunction_v<
|
| 39 |
+
bool_t<!std::is_integral_v<Types> || std::is_same_v<Types, int8_t> || std::is_same_v<Types, int64_t> || std::is_same_v<Types, bool>>...
|
| 40 |
+
>, "INVALID TYPE: Only int8_t, int64_t and bool are supported as an integral argument type");
|
| 41 |
+
static_assert(std::conjunction_v<
|
| 42 |
+
bool_t<!std::is_same_v<Types, float>>...
|
| 43 |
+
>, "INVALID TYPE: float is not supported as an argument type, use double instead");
|
| 44 |
+
return 0;
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
template <typename... Ts, size_t... Is>
|
| 48 |
+
constexpr std::array<ArgumentDef, sizeof...(Ts)> createArgumentVectorFromTypes(std::index_sequence<Is...> /*unused*/) {
|
| 49 |
+
return (
|
| 50 |
+
// Check types for common errors
|
| 51 |
+
checkStaticTypes<Ts...>(),
|
| 52 |
+
|
| 53 |
+
// Create the return value
|
| 54 |
+
std::array<ArgumentDef, sizeof...(Ts)>{
|
| 55 |
+
ArgumentDef(&getTypePtrCopy<std::decay_t<Ts>>, &getFakeTypePtrCopy<std::decay_t<Ts>>)...}
|
| 56 |
+
);
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
/// Creates a vector of `ArgumentDef` from a list of C++ types that are specified
|
| 60 |
+
/// as template arguments.
|
| 61 |
+
template<class ParameterTypes> struct createArguments final {};
|
| 62 |
+
template<class... ParameterTypes>
|
| 63 |
+
struct createArguments<guts::typelist::typelist<ParameterTypes...>> final {
|
| 64 |
+
static constexpr std::array<ArgumentDef, sizeof...(ParameterTypes)> call() {
|
| 65 |
+
return createArgumentVectorFromTypes<ParameterTypes...>(
|
| 66 |
+
std::make_index_sequence<sizeof...(ParameterTypes)>()
|
| 67 |
+
);
|
| 68 |
+
}
|
| 69 |
+
};
|
| 70 |
+
|
| 71 |
+
/// Creates a vector of `ArgumentDef` from a list of C++ types that are specified
|
| 72 |
+
/// as a tuple (i.e. in the way c10 kernels return values).
|
| 73 |
+
/// It can be a tuple<A, B, C> if there's three output arguments with types A, B, C.
|
| 74 |
+
/// It can be an empty tuple<>, or void for kernels that don't return anything.
|
| 75 |
+
/// It can be a single type A (i.e. no tuple) for the case where a kernel just
|
| 76 |
+
/// returns one value.
|
| 77 |
+
template<class ReturnTypeTuple, class Enable = void> struct createReturns final {};
|
| 78 |
+
|
| 79 |
+
template<class... ReturnTypes>
|
| 80 |
+
struct createReturns<std::tuple<ReturnTypes...>, void> final {
|
| 81 |
+
static constexpr std::array<ArgumentDef, sizeof...(ReturnTypes)> call() {
|
| 82 |
+
return createArgumentVectorFromTypes<ReturnTypes...>(
|
| 83 |
+
std::make_index_sequence<sizeof...(ReturnTypes)>()
|
| 84 |
+
);
|
| 85 |
+
}
|
| 86 |
+
};
|
| 87 |
+
|
| 88 |
+
template<class ReturnType>
|
| 89 |
+
struct createReturns<ReturnType, std::enable_if_t<!std::is_same_v<void, ReturnType> && !guts::is_instantiation_of<std::tuple, ReturnType>::value>> final {
|
| 90 |
+
static constexpr std::array<ArgumentDef, 1> call() {
|
| 91 |
+
return createReturns<std::tuple<ReturnType>>::call();
|
| 92 |
+
}
|
| 93 |
+
};
|
| 94 |
+
|
| 95 |
+
template<>
|
| 96 |
+
struct createReturns<void, void> final {
|
| 97 |
+
static constexpr std::array<ArgumentDef, 0> call() {
|
| 98 |
+
return createReturns<std::tuple<>>::call();
|
| 99 |
+
}
|
| 100 |
+
};
|
| 101 |
+
|
| 102 |
+
template <typename ReturnType>
|
| 103 |
+
struct createSingleReturn {
|
| 104 |
+
static constexpr std::array<ArgumentDef, 1> call() {
|
| 105 |
+
return createArgumentVectorFromTypes<ReturnType>(std::make_index_sequence<1>());
|
| 106 |
+
}
|
| 107 |
+
};
|
| 108 |
+
|
| 109 |
+
TORCH_API FunctionSchema make_function_schema(std::string&& name, std::string&& overload_name, c10::ArrayRef<ArgumentDef> arguments, c10::ArrayRef<ArgumentDef> returns);
|
| 110 |
+
TORCH_API FunctionSchema make_function_schema(c10::ArrayRef<ArgumentDef> arguments, c10::ArrayRef<ArgumentDef> returns);
|
| 111 |
+
|
| 112 |
+
/// Creates a `FunctionSchema` object from a `FunctionTraits` type for a
|
| 113 |
+
/// function. Flattens std::tuple returns into multiple return types
|
| 114 |
+
template <typename FunctionTraits>
|
| 115 |
+
FunctionSchema createFunctionSchemaFromTraitsFlattenedReturns() {
|
| 116 |
+
using ReturnType = typename FunctionTraits::return_type;
|
| 117 |
+
using ParameterTypes = typename FunctionTraits::parameter_types;
|
| 118 |
+
|
| 119 |
+
// arguments and returns are computed into a std::array at compile time and embedded into the binary.
|
| 120 |
+
// The only code executed at runtime here is the one that creates a std::vector
|
| 121 |
+
// of the arguments/returns from the std::array.
|
| 122 |
+
constexpr auto arguments = createArguments<ParameterTypes>::call();
|
| 123 |
+
constexpr auto returns = createReturns<ReturnType>::call();
|
| 124 |
+
|
| 125 |
+
return make_function_schema(arguments, returns);
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
/// Creates a `FunctionSchema` object from a `FunctionTraits` type for a
|
| 129 |
+
/// function. Preserves std::tuple returns as a Tuple return type
|
| 130 |
+
template <typename FunctionTraits>
|
| 131 |
+
FunctionSchema createFunctionSchemaFromTraitsSingleReturn(std::string&& name, std::string&& overload_name) {
|
| 132 |
+
using ReturnType = typename FunctionTraits::return_type;
|
| 133 |
+
using ParameterTypes = typename FunctionTraits::parameter_types;
|
| 134 |
+
|
| 135 |
+
// arguments and returns are computed into a std::array at compile time and embedded into the binary.
|
| 136 |
+
// The only code executed at runtime here is the one that creates a std::vector
|
| 137 |
+
// of the arguments/returns from the std::array.
|
| 138 |
+
constexpr auto arguments = createArguments<ParameterTypes>::call();
|
| 139 |
+
constexpr auto returns = createSingleReturn<ReturnType>::call();
|
| 140 |
+
|
| 141 |
+
return make_function_schema(std::move(name), std::move(overload_name), arguments, returns);
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
template<class FuncType>
|
| 147 |
+
FunctionSchema inferFunctionSchemaFlattenedReturns() {
|
| 148 |
+
return detail::infer_schema::createFunctionSchemaFromTraitsFlattenedReturns<guts::infer_function_traits_t<FuncType>>();
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
template<class FuncType>
|
| 152 |
+
FunctionSchema inferFunctionSchemaSingleReturn(std::string&& name, std::string&& overload_name) {
|
| 153 |
+
return detail::infer_schema::createFunctionSchemaFromTraitsSingleReturn<guts::infer_function_traits_t<FuncType>>(std::move(name), std::move(overload_name));
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
TORCH_API std::optional<std::string> findSchemaDifferences(const FunctionSchema& inferred, const FunctionSchema& specified);
|
| 157 |
+
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
#else
|
| 161 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 162 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/op_registration/op_allowlist.h
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
// TODO: unify to C10_MOBILE. In theory this header could be used in OSS.
|
| 5 |
+
#ifdef TEMPLATE_SELECTIVE_BUILD
|
| 6 |
+
#include <ATen/selected_mobile_ops.h>
|
| 7 |
+
#endif
|
| 8 |
+
|
| 9 |
+
/**
|
| 10 |
+
* This header implements functionality to build PyTorch with only a certain
|
| 11 |
+
* set of operators (+ dependencies) included.
|
| 12 |
+
*
|
| 13 |
+
* - Build with -DTORCH_OPERATOR_WHITELIST="aten::add;aten::sub" and only these
|
| 14 |
+
* two ops will be included in your build. The allowlist records operators
|
| 15 |
+
* only, no overloads; if you include aten::add, all overloads of aten::add
|
| 16 |
+
* will be included.
|
| 17 |
+
*
|
| 18 |
+
* Internally, this is done by removing the operator registration calls
|
| 19 |
+
* using compile time programming, and the linker will then prune all
|
| 20 |
+
* operator functions that weren't registered.
|
| 21 |
+
* See Note [Selective build] for more details
|
| 22 |
+
*
|
| 23 |
+
* WARNING: The allowlist mechanism doesn't work for all ways you could go about
|
| 24 |
+
* registering an operator. If the dispatch key / operator name is not
|
| 25 |
+
* sufficiently obvious at compile time, then the allowlisting mechanism
|
| 26 |
+
* will fail (and the operator will be included in the binary anyway).
|
| 27 |
+
*/
|
| 28 |
+
|
| 29 |
+
#include <string_view>
|
| 30 |
+
#include <c10/core/DispatchKey.h>
|
| 31 |
+
#include <c10/macros/Macros.h>
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
#if defined(ENABLE_RECORD_KERNEL_FUNCTION_DTYPE)
|
| 35 |
+
#include <ATen/record_function.h>
|
| 36 |
+
#endif
|
| 37 |
+
|
| 38 |
+
namespace c10::impl {
|
| 39 |
+
|
| 40 |
+
constexpr bool allowlist_contains(std::string_view allowlist, std::string_view item); // Forward Declare
|
| 41 |
+
|
| 42 |
+
/**
|
| 43 |
+
* In selective build mode returns true/false depending on whether a build
|
| 44 |
+
* feature is available or not.
|
| 45 |
+
*
|
| 46 |
+
* In instrumenting mode (tracing mode), always returns true, and doesn't
|
| 47 |
+
* trigger any side effects.
|
| 48 |
+
*/
|
| 49 |
+
constexpr bool is_build_feature_available(const char* name) {
|
| 50 |
+
#if !defined(ENABLE_RECORD_KERNEL_FUNCTION_DTYPE)
|
| 51 |
+
// Selective Build mode.
|
| 52 |
+
#if !defined(TORCH_BUILD_FEATURE_ALLOWLIST)
|
| 53 |
+
(void)name;
|
| 54 |
+
return true;
|
| 55 |
+
#else
|
| 56 |
+
return allowlist_contains(
|
| 57 |
+
C10_STRINGIZE(TORCH_BUILD_FEATURE_ALLOWLIST),
|
| 58 |
+
name);
|
| 59 |
+
#endif
|
| 60 |
+
|
| 61 |
+
#else
|
| 62 |
+
// Instrumenting mode.
|
| 63 |
+
(void)name;
|
| 64 |
+
return true;
|
| 65 |
+
#endif
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
[[noreturn]] void build_feature_required_feature_not_available(const char* feature);
|
| 69 |
+
|
| 70 |
+
/**
|
| 71 |
+
* Use BUILD_FEATURE_REQUIRED macro in user-code.
|
| 72 |
+
*
|
| 73 |
+
* In selective build mode becomes a no-op if the build feature passed
|
| 74 |
+
* in is available. If not available, throws an exception (c10::Error).
|
| 75 |
+
* The compiler is able to perform dead code elimination for code
|
| 76 |
+
* following this method if the build feature is not available.
|
| 77 |
+
*
|
| 78 |
+
* In instrumenting mode (tracing mode), registers (as a side effect)
|
| 79 |
+
* the presence of this specific build feature being triggered.
|
| 80 |
+
*/
|
| 81 |
+
#if !defined(ENABLE_RECORD_KERNEL_FUNCTION_DTYPE) // selective build mode
|
| 82 |
+
|
| 83 |
+
#if defined(TORCH_BUILD_FEATURE_ALLOWLIST)
|
| 84 |
+
#define BUILD_FEATURE_REQUIRED(NAME) \
|
| 85 |
+
if (!c10::impl::is_build_feature_available(NAME)) { \
|
| 86 |
+
::c10::impl::build_feature_required_feature_not_available(NAME); \
|
| 87 |
+
}
|
| 88 |
+
#else // Everything trivially selected
|
| 89 |
+
#define BUILD_FEATURE_REQUIRED(NAME)
|
| 90 |
+
|
| 91 |
+
#endif
|
| 92 |
+
|
| 93 |
+
#else // trace mode
|
| 94 |
+
#define BUILD_FEATURE_REQUIRED(NAME) \
|
| 95 |
+
RECORD_FUNCTION_WITH_SCOPE( \
|
| 96 |
+
at::RecordScope::BUILD_FEATURE, \
|
| 97 |
+
std::string(NAME), \
|
| 98 |
+
{});
|
| 99 |
+
#endif
|
| 100 |
+
|
| 101 |
+
// Use this macro, and not is_build_feature_available
|
| 102 |
+
#define BUILD_FEATURE_AVAILABLE(NAME) ::c10::impl::is_build_feature_available(NAME)
|
| 103 |
+
|
| 104 |
+
// returns true iff allowlist contains item
|
| 105 |
+
// allowlist_contains("a;bc;d", "bc") == true
|
| 106 |
+
constexpr bool allowlist_contains(std::string_view allowlist, std::string_view item) {
|
| 107 |
+
//Choose a really big value for next so that if something goes wrong
|
| 108 |
+
//this code will blow up in a hopefully detectable way.
|
| 109 |
+
size_t next = std::numeric_limits<size_t>::max();
|
| 110 |
+
for (size_t cur = 0; cur <= allowlist.size(); cur = next) {
|
| 111 |
+
next = allowlist.find(';', cur);
|
| 112 |
+
if (next != std::string_view::npos) {
|
| 113 |
+
if (allowlist.substr(cur, next - cur) == item) {
|
| 114 |
+
return true;
|
| 115 |
+
}
|
| 116 |
+
next++;
|
| 117 |
+
} else {
|
| 118 |
+
if (allowlist.substr(cur) == item) {
|
| 119 |
+
return true;
|
| 120 |
+
}
|
| 121 |
+
break;
|
| 122 |
+
}
|
| 123 |
+
}
|
| 124 |
+
return false;
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
// Returns true iff the given op name is on the allowlist
|
| 128 |
+
// and should be registered
|
| 129 |
+
constexpr bool op_allowlist_check(std::string_view op_name [[maybe_unused]]) {
|
| 130 |
+
assert(op_name.find("::") != std::string_view::npos);
|
| 131 |
+
// Use assert() instead of throw() due to a gcc bug. See:
|
| 132 |
+
// https://stackoverflow.com/questions/34280729/throw-in-constexpr-function
|
| 133 |
+
// https://github.com/fmtlib/fmt/issues/682
|
| 134 |
+
assert(op_name.find('(') == std::string_view::npos);
|
| 135 |
+
#if !defined(TORCH_OPERATOR_WHITELIST)
|
| 136 |
+
// If the TORCH_OPERATOR_WHITELIST parameter is not defined,
|
| 137 |
+
// all ops are to be registered
|
| 138 |
+
return true;
|
| 139 |
+
#else
|
| 140 |
+
return allowlist_contains(
|
| 141 |
+
C10_STRINGIZE(TORCH_OPERATOR_WHITELIST),
|
| 142 |
+
// This function is majorly used for mobile selective build with
|
| 143 |
+
// root operators, where the overload is included in the allowlist.
|
| 144 |
+
op_name);
|
| 145 |
+
// // Strip overload name (as allowlist doesn't contain overloads)
|
| 146 |
+
// // Another function based on this may be added when there's usage
|
| 147 |
+
// // on op names without overload.
|
| 148 |
+
// OperatorNameView::parse(op_name).name);
|
| 149 |
+
#endif
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
// Returns true iff the given schema string is on the allowlist
|
| 153 |
+
// and should be registered
|
| 154 |
+
constexpr bool schema_allowlist_check(std::string_view schema) {
|
| 155 |
+
#if defined(TORCH_FORCE_SCHEMA_REGISTRATION)
|
| 156 |
+
return true;
|
| 157 |
+
#else
|
| 158 |
+
return op_allowlist_check(schema.substr(0, schema.find('(')));
|
| 159 |
+
#endif
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
// Returns true iff the given custom class name is on the allowlist
|
| 163 |
+
// and should be registered
|
| 164 |
+
constexpr bool custom_class_allowlist_check(std::string_view custom_class_name [[maybe_unused]]) {
|
| 165 |
+
#if !defined(TORCH_CUSTOM_CLASS_ALLOWLIST)
|
| 166 |
+
// If the TORCH_CUSTOM_CLASS_ALLOWLIST parameter is not defined,
|
| 167 |
+
// all custom classes are to be registered
|
| 168 |
+
return true;
|
| 169 |
+
#else
|
| 170 |
+
return allowlist_contains(
|
| 171 |
+
C10_STRINGIZE(TORCH_CUSTOM_CLASS_ALLOWLIST),
|
| 172 |
+
custom_class_name);
|
| 173 |
+
#endif
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
// schema_allowlist_check() implicitly depends on a macro, TORCH_OPERATOR_WHITELIST.
|
| 177 |
+
// Add this API to pass arbitrary allowlist.
|
| 178 |
+
constexpr bool op_allowlist_contains_name_in_schema(std::string_view allowlist, std::string_view schema) {
|
| 179 |
+
return allowlist_contains(allowlist, schema.substr(0, schema.find('(')));
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
} // namespace c10::impl
|
| 183 |
+
|
| 184 |
+
#else
|
| 185 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 186 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/op_registration/op_registration.h
ADDED
|
@@ -0,0 +1,599 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
/**
|
| 5 |
+
* Include this file if you want to register operators. It includes all
|
| 6 |
+
* functionality needed to do so for you.
|
| 7 |
+
*/
|
| 8 |
+
|
| 9 |
+
#include <c10/core/DispatchKey.h>
|
| 10 |
+
#include <c10/core/DispatchKeySet.h>
|
| 11 |
+
#include <c10/core/CompileTimeFunctionPointer.h>
|
| 12 |
+
#include <ATen/core/boxing/KernelFunction.h>
|
| 13 |
+
#include <ATen/core/dispatch/CppSignature.h>
|
| 14 |
+
#include <ATen/core/dispatch/RegistrationHandleRAII.h>
|
| 15 |
+
#include <ATen/core/op_registration/infer_schema.h>
|
| 16 |
+
#if defined(EXPOSE_C2_OPS) || !defined(CAFFE2_IS_XPLAT_BUILD)
|
| 17 |
+
#include <torch/csrc/jit/frontend/function_schema_parser.h>
|
| 18 |
+
#endif
|
| 19 |
+
#include <ATen/core/ATenOpList.h>
|
| 20 |
+
|
| 21 |
+
namespace c10 {
|
| 22 |
+
|
| 23 |
+
namespace detail {
|
| 24 |
+
// The first argument of the schema might be of type DispatchKeySet, in which case we remove it.
|
| 25 |
+
// We do this because every argument in a function schema is expected to be convertible
|
| 26 |
+
// to an ivalue, but DispatchKeySet is not a type we want the jit to be aware of.
|
| 27 |
+
// See Note [Plumbing Keys Through The Dispatcher]
|
| 28 |
+
template<class KernelFunctor>
|
| 29 |
+
std::unique_ptr<FunctionSchema> inferFunctionSchemaFromFunctor() {
|
| 30 |
+
using func_type = typename c10::remove_DispatchKeySet_arg_from_func<KernelFunctor>::func_type;
|
| 31 |
+
return std::make_unique<FunctionSchema>(inferFunctionSchemaFlattenedReturns<func_type>());
|
| 32 |
+
}
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
/**
|
| 36 |
+
* An instance of this class handles the registration for one or more operators.
|
| 37 |
+
* Make sure you keep the RegisterOperators instance around since it will
|
| 38 |
+
* deregister the operator it's responsible for in its destructor.
|
| 39 |
+
*
|
| 40 |
+
* Example:
|
| 41 |
+
*
|
| 42 |
+
* > namespace {
|
| 43 |
+
* > class my_kernel_cpu final : public c10::OperatorKernel {
|
| 44 |
+
* > public:
|
| 45 |
+
* > Tensor operator()(Tensor a, Tensor b) {...}
|
| 46 |
+
* > };
|
| 47 |
+
* > }
|
| 48 |
+
* >
|
| 49 |
+
* > static auto registry = c10::RegisterOperators()
|
| 50 |
+
* > .op(c10::RegisterOperators::options()
|
| 51 |
+
* > .schema("my_op")
|
| 52 |
+
* > .kernel<my_kernel_cpu>(DispatchKey::CPU));
|
| 53 |
+
*/
|
| 54 |
+
class TORCH_API RegisterOperators final {
|
| 55 |
+
public:
|
| 56 |
+
RegisterOperators() = default;
|
| 57 |
+
~RegisterOperators() = default;
|
| 58 |
+
|
| 59 |
+
RegisterOperators(const RegisterOperators&) = delete;
|
| 60 |
+
RegisterOperators& operator=(const RegisterOperators&) = delete;
|
| 61 |
+
RegisterOperators(RegisterOperators&&) noexcept = default;
|
| 62 |
+
RegisterOperators& operator=(RegisterOperators&&) noexcept = default;
|
| 63 |
+
|
| 64 |
+
class TORCH_API Options final {
|
| 65 |
+
public:
|
| 66 |
+
Options(const Options&) = delete;
|
| 67 |
+
Options(Options&&) noexcept = delete;
|
| 68 |
+
Options& operator=(const Options&) = delete;
|
| 69 |
+
Options& operator=(Options&&) noexcept = delete;
|
| 70 |
+
|
| 71 |
+
// internal-only for registering stack based kernels
|
| 72 |
+
template<KernelFunction::BoxedKernelFunction* kernel_func>
|
| 73 |
+
Options&& kernel(DispatchKey dispatch_key) && {
|
| 74 |
+
return std::move(*this).kernel(dispatch_key, KernelFunction::makeFromBoxedFunction<kernel_func>(), std::nullopt, nullptr);
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
// internal-only for registering stack based catch-all kernels
|
| 78 |
+
template<KernelFunction::BoxedKernelFunction* kernel_func>
|
| 79 |
+
Options&& catchAllKernel() && {
|
| 80 |
+
return std::move(*this).kernel(std::nullopt, KernelFunction::makeFromBoxedFunction<kernel_func>(), std::nullopt, nullptr);
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
// internal only for registering caffe2 ops
|
| 84 |
+
Options&& schema(FunctionSchema&& schema) {
|
| 85 |
+
TORCH_CHECK(!schemaOrName_.has_value(), "You can only specify the schema once per operator registration.");
|
| 86 |
+
schemaOrName_ = FunctionSchema(std::move(schema));
|
| 87 |
+
return std::move(*this);
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
/**
|
| 91 |
+
* Use this to specify the schema for an operator. You can also specify
|
| 92 |
+
* the operator name only to have the function signature part of the
|
| 93 |
+
* schema be inferred from the kernel function.
|
| 94 |
+
*
|
| 95 |
+
* Example:
|
| 96 |
+
*
|
| 97 |
+
* > // Infer function signature from my_kernel_cpu
|
| 98 |
+
* > static auto registry = c10::RegisterOperators()
|
| 99 |
+
* > .op(c10::RegisterOperators::options()
|
| 100 |
+
* > .schema("my_op")
|
| 101 |
+
* > .kernel<my_kernel_cpu>(DispatchKey::CPU));
|
| 102 |
+
* >
|
| 103 |
+
* >
|
| 104 |
+
* > // Explicitly specify full schema
|
| 105 |
+
* > static auto registry = c10::RegisterOperators()
|
| 106 |
+
* > .op(c10::RegisterOperators::options()
|
| 107 |
+
* > .schema("my_op(Tensor a) -> Tensor")
|
| 108 |
+
* > .kernel<my_kernel_cpu>(DispatchKey::CPU));
|
| 109 |
+
*/
|
| 110 |
+
Options&& schema(const std::string& schemaOrName) {
|
| 111 |
+
TORCH_CHECK(!schemaOrName_.has_value(), "Tried to register operator ", schemaOrName," but specified schema multiple times. You can only specify the schema once per operator registration.");
|
| 112 |
+
|
| 113 |
+
#if !defined(EXPOSE_C2_OPS) && defined(CAFFE2_IS_XPLAT_BUILD)
|
| 114 |
+
throw std::logic_error("Tried to register operator " + schemaOrName + ". We don't support registering c10 ops on mobile yet because the function schema parser isn't present in the mobile build.");
|
| 115 |
+
#else
|
| 116 |
+
schemaOrName_ = torch::jit::parseSchemaOrName(schemaOrName);
|
| 117 |
+
#endif
|
| 118 |
+
|
| 119 |
+
return std::move(*this);
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
/**
|
| 123 |
+
* Use this to register an operator whose kernel is implemented as a functor.
|
| 124 |
+
* The kernel is only called for inputs matching the given dispatch key.
|
| 125 |
+
* You can register multiple kernels for different dispatch keys.
|
| 126 |
+
*
|
| 127 |
+
* Example:
|
| 128 |
+
*
|
| 129 |
+
* > namespace {
|
| 130 |
+
* > class my_kernel_cpu final : public c10::OperatorKernel {
|
| 131 |
+
* > public:
|
| 132 |
+
* > Tensor operator()(Tensor a, Tensor b) {...}
|
| 133 |
+
* > };
|
| 134 |
+
* > }
|
| 135 |
+
* >
|
| 136 |
+
* > static auto registry = c10::RegisterOperators()
|
| 137 |
+
* > .op(c10::RegisterOperators::options()
|
| 138 |
+
* > .schema("my_op")
|
| 139 |
+
* > .kernel<my_kernel_cpu>(DispatchKey::CPU));
|
| 140 |
+
*
|
| 141 |
+
* The functor constructor can take arguments to configure the kernel.
|
| 142 |
+
* The arguments are defined in the kernel registration.
|
| 143 |
+
* Example:
|
| 144 |
+
*
|
| 145 |
+
* > namespace {
|
| 146 |
+
* > class my_kernel_cpu final : public c10::OperatorKernel {
|
| 147 |
+
* > public:
|
| 148 |
+
* > explicit my_kernel_cpu(std::string some_configuration, int a, bool b)
|
| 149 |
+
* > : ... {...}
|
| 150 |
+
* >
|
| 151 |
+
* > Tensor operator()(Tensor a, Tensor b) {...}
|
| 152 |
+
* > };
|
| 153 |
+
* > }
|
| 154 |
+
* >
|
| 155 |
+
* > static auto registry = c10::RegisterOperators()
|
| 156 |
+
* > .op(c10::RegisterOperators::options()
|
| 157 |
+
* > .schema("my_op")
|
| 158 |
+
* > .kernel<my_kernel_cpu>(DispatchKey::CPU, "some_configuration", 3, true));
|
| 159 |
+
*/
|
| 160 |
+
template<class KernelFunctor, class... ConstructorParameters>
|
| 161 |
+
// enable_if: only enable it if KernelFunctor is actually a functor
|
| 162 |
+
std::enable_if_t<guts::is_functor<KernelFunctor>::value, Options&&> kernel(DispatchKey dispatch_key, ConstructorParameters&&... constructorParameters) && {
|
| 163 |
+
static_assert(std::is_base_of_v<OperatorKernel, KernelFunctor>, "Tried to register a kernel functor using the kernel<Functor>() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
|
| 164 |
+
static_assert(std::is_constructible_v<KernelFunctor, ConstructorParameters...>, "Wrong argument list for constructor of kernel functor. The arguments to kernel<Functor>(arguments...) must match one of the constructors of Functor.");
|
| 165 |
+
|
| 166 |
+
return std::move(*this).kernel(
|
| 167 |
+
dispatch_key,
|
| 168 |
+
KernelFunction::makeFromUnboxedFunctor<false, KernelFunctor>(std::make_unique<KernelFunctor>(std::forward<ConstructorParameters>(constructorParameters)...)),
|
| 169 |
+
impl::CppSignature::make<KernelFunctor>(),
|
| 170 |
+
detail::inferFunctionSchemaFromFunctor<KernelFunctor>()
|
| 171 |
+
);
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
/**
|
| 175 |
+
* Use this to register an operator whose kernel is implemented as a functor.
|
| 176 |
+
* The kernel is a catch-all kernel, meaning it's called independent from
|
| 177 |
+
* the input. Dispatch is disabled for this operator.
|
| 178 |
+
*
|
| 179 |
+
* Example:
|
| 180 |
+
*
|
| 181 |
+
* > namespace {
|
| 182 |
+
* > class my_kernel_cpu final : public c10::OperatorKernel {
|
| 183 |
+
* > public:
|
| 184 |
+
* > Tensor operator()(Tensor a, Tensor b) {...}
|
| 185 |
+
* > };
|
| 186 |
+
* > }
|
| 187 |
+
* >
|
| 188 |
+
* > static auto registry = c10::RegisterOperators()
|
| 189 |
+
* > .op(c10::RegisterOperators::options()
|
| 190 |
+
* > .schema("my_op")
|
| 191 |
+
* > .catchAllKernel<my_kernel_cpu>());
|
| 192 |
+
*
|
| 193 |
+
* The functor constructor can take arguments to configure the kernel.
|
| 194 |
+
* The arguments are defined in the kernel registration.
|
| 195 |
+
* Example:
|
| 196 |
+
*
|
| 197 |
+
* > namespace {
|
| 198 |
+
* > class my_kernel_cpu final : public c10::OperatorKernel {
|
| 199 |
+
* > public:
|
| 200 |
+
* > explicit my_kernel_cpu(std::string some_configuration, int a, bool b)
|
| 201 |
+
* > : ... {...}
|
| 202 |
+
* >
|
| 203 |
+
* > Tensor operator()(Tensor a, Tensor b) {...}
|
| 204 |
+
* > };
|
| 205 |
+
* > }
|
| 206 |
+
* >
|
| 207 |
+
* > static auto registry = c10::RegisterOperators()
|
| 208 |
+
* > .op(c10::RegisterOperators::options()
|
| 209 |
+
* > .schema("my_op")
|
| 210 |
+
* > .catchAllKernel<my_kernel_cpu>("some_configuration", 3, true));
|
| 211 |
+
*/
|
| 212 |
+
template<class KernelFunctor, class... ConstructorParameters>
|
| 213 |
+
// enable_if: only enable it if KernelFunctor is actually a functor
|
| 214 |
+
std::enable_if_t<guts::is_functor<KernelFunctor>::value, Options&&> catchAllKernel(ConstructorParameters&&... constructorParameters) && {
|
| 215 |
+
static_assert(std::is_base_of_v<OperatorKernel, KernelFunctor>, "Tried to register a kernel functor using the kernel<Functor>() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
|
| 216 |
+
static_assert(std::is_constructible_v<KernelFunctor, ConstructorParameters...>, "Wrong argument list for constructor of kernel functor. The arguments to kernel<Functor>(arguments...) must match one of the constructors of Functor.");
|
| 217 |
+
|
| 218 |
+
return std::move(*this).kernel(
|
| 219 |
+
std::nullopt,
|
| 220 |
+
KernelFunction::makeFromUnboxedFunctor<false, KernelFunctor>(std::make_unique<KernelFunctor>(std::forward<ConstructorParameters>(constructorParameters)...)),
|
| 221 |
+
impl::CppSignature::make<KernelFunctor>(),
|
| 222 |
+
detail::inferFunctionSchemaFromFunctor<KernelFunctor>()
|
| 223 |
+
);
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
/**
|
| 227 |
+
* Use this to register an operator whose kernel is implemented by a function.
|
| 228 |
+
* The kernel is only called for inputs matching the given dispatch key.
|
| 229 |
+
* You can register multiple kernels for different dispatch keys.
|
| 230 |
+
*
|
| 231 |
+
* Example:
|
| 232 |
+
*
|
| 233 |
+
* > namespace { Tensor my_kernel_cpu(Tensor a, Tensor b) {...} }
|
| 234 |
+
* >
|
| 235 |
+
* > static auto registry = c10::RegisterOperators()
|
| 236 |
+
* > .op(c10::RegisterOperators::options()
|
| 237 |
+
* > .schema("my_op")
|
| 238 |
+
* > .kernel<decltype(my_kernel_cpu), &my_kernel_cpu>(DispatchKey::CPU));
|
| 239 |
+
*/
|
| 240 |
+
template<class FuncType, FuncType* kernel_func>
|
| 241 |
+
// enable_if: only enable it if FuncType is actually a function
|
| 242 |
+
std::enable_if_t<guts::is_function_type<FuncType>::value, Options&&> kernel(DispatchKey dispatch_key) && {
|
| 243 |
+
static_assert(!std::is_same_v<FuncType, KernelFunction::BoxedKernelFunction>, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API.");
|
| 244 |
+
static_assert(kernel_func != nullptr, "Kernel function cannot be nullptr");
|
| 245 |
+
|
| 246 |
+
return std::move(*this).kernel(
|
| 247 |
+
dispatch_key,
|
| 248 |
+
KernelFunction::makeFromUnboxedFunction(TORCH_FN(kernel_func)),
|
| 249 |
+
impl::CppSignature::make<FuncType>(),
|
| 250 |
+
// TODO Do schema inference without relying on WrapFunctionIntoFunctor
|
| 251 |
+
detail::inferFunctionSchemaFromFunctor<typename impl::WrapFunctionIntoFunctor<CompileTimeFunctionPointer<FuncType, kernel_func>>::type>()
|
| 252 |
+
);
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
/**
|
| 256 |
+
* Use this to register an operator whose kernel is implemented by a function.
|
| 257 |
+
* The kernel is a catch-all kernel, meaning it's called independent from
|
| 258 |
+
* the input. Dispatch is disabled for this operator.
|
| 259 |
+
*
|
| 260 |
+
* Example:
|
| 261 |
+
*
|
| 262 |
+
* > namespace { Tensor my_kernel_cpu(Tensor a, Tensor b) {...} }
|
| 263 |
+
* >
|
| 264 |
+
* > static auto registry = c10::RegisterOperators()
|
| 265 |
+
* > .op(c10::RegisterOperators::options()
|
| 266 |
+
* > .schema("my_op")
|
| 267 |
+
* > .catchAllKernel<decltype(my_kernel_cpu), &my_kernel_cpu>());
|
| 268 |
+
*/
|
| 269 |
+
template<class FuncType, FuncType* kernel_func>
|
| 270 |
+
// enable_if: only enable it if FuncType is actually a function
|
| 271 |
+
std::enable_if_t<guts::is_function_type<FuncType>::value, Options&&> catchAllKernel() && {
|
| 272 |
+
static_assert(!std::is_same_v<FuncType, KernelFunction::BoxedKernelFunction>, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API.");
|
| 273 |
+
static_assert(kernel_func != nullptr, "Kernel function cannot be nullptr");
|
| 274 |
+
|
| 275 |
+
return std::move(*this).kernel(
|
| 276 |
+
std::nullopt,
|
| 277 |
+
KernelFunction::makeFromUnboxedFunction(TORCH_FN(kernel_func)),
|
| 278 |
+
impl::CppSignature::make<FuncType>(),
|
| 279 |
+
// TODO Do schema inference without relying on WrapFunctionIntoFunctor
|
| 280 |
+
detail::inferFunctionSchemaFromFunctor<typename impl::WrapFunctionIntoFunctor<CompileTimeFunctionPointer<FuncType, kernel_func>>::type>()
|
| 281 |
+
);
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
template<class FuncType>
|
| 285 |
+
// enable_if: only enable it if FuncType is actually a function
|
| 286 |
+
std::enable_if_t<guts::is_function_type<FuncType>::value, Options&&> kernel(DispatchKey dispatch_key, FuncType* kernel_func) && {
|
| 287 |
+
static_assert(!std::is_same_v<FuncType, KernelFunction::BoxedKernelFunction>, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API.");
|
| 288 |
+
TORCH_INTERNAL_ASSERT(kernel_func != nullptr, "Kernel function cannot be nullptr");
|
| 289 |
+
|
| 290 |
+
return std::move(*this).kernel(
|
| 291 |
+
dispatch_key,
|
| 292 |
+
KernelFunction::makeFromUnboxedRuntimeFunction(kernel_func),
|
| 293 |
+
impl::CppSignature::make<FuncType>(),
|
| 294 |
+
// TODO Do schema inference without relying on WrapFunctionIntoFunctor
|
| 295 |
+
detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<FuncType>>>()
|
| 296 |
+
);
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
template<class FuncType>
|
| 300 |
+
// enable_if: only enable it if FuncType is actually a function
|
| 301 |
+
std::enable_if_t<guts::is_function_type<FuncType>::value, Options&&> catchAllKernel(FuncType* kernel_func) && {
|
| 302 |
+
static_assert(!std::is_same_v<FuncType, KernelFunction::BoxedKernelFunction>, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API.");
|
| 303 |
+
TORCH_INTERNAL_ASSERT(kernel_func != nullptr, "Kernel function cannot be nullptr");
|
| 304 |
+
|
| 305 |
+
return std::move(*this).kernel(
|
| 306 |
+
std::nullopt,
|
| 307 |
+
KernelFunction::makeFromUnboxedRuntimeFunction(kernel_func),
|
| 308 |
+
impl::CppSignature::make<FuncType>(),
|
| 309 |
+
// TODO Do schema inference without relying on WrapFunctionIntoFunctor
|
| 310 |
+
detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<FuncType>>>()
|
| 311 |
+
);
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
/**
|
| 315 |
+
* Use this to register an operator whose kernel is implemented as a lambda.
|
| 316 |
+
* The kernel is only called for inputs matching the given dispatch key.
|
| 317 |
+
* You can register multiple kernels for different dispatch keys.
|
| 318 |
+
*
|
| 319 |
+
* The lambda must be stateless, i.e. not have a capture. If your kernel
|
| 320 |
+
* needs to store some configuration parameters, write the kernel as a
|
| 321 |
+
* functor instead.
|
| 322 |
+
*
|
| 323 |
+
* Example:
|
| 324 |
+
*
|
| 325 |
+
* > static auto registry = c10::RegisterOperators()
|
| 326 |
+
* > .op(c10::RegisterOperators::options()
|
| 327 |
+
* > .schema("my_op")
|
| 328 |
+
* > .kernel(DispatchKey::CPU, [] (Tensor a) -> Tensor {...}));
|
| 329 |
+
*/
|
| 330 |
+
template<class Lambda>
|
| 331 |
+
// enable_if: only enable it if Lambda is a functor (note: lambdas are functors)
|
| 332 |
+
std::enable_if_t<
|
| 333 |
+
guts::is_functor<std::decay_t<Lambda>>::value
|
| 334 |
+
&& !std::is_same_v<typename guts::infer_function_traits_t<std::decay_t<Lambda>>::func_type, KernelFunction::BoxedKernelFunction>,
|
| 335 |
+
Options&&> kernel(DispatchKey dispatch_key, Lambda&& functor) && {
|
| 336 |
+
static_assert(!std::is_base_of_v<OperatorKernel, std::decay_t<Lambda>>, "The kernel(x) API for registering a kernel is only meant to be used with lambdas. Your kernel is a functor. Please use the kernel<Functor>() API instead.");
|
| 337 |
+
|
| 338 |
+
// We don't support stateful lambdas (i.e. lambdas with a capture), because their
|
| 339 |
+
// behavior would be nonobvious. A functor kernel with cache gets a new instance of
|
| 340 |
+
// its cache each time the kernel is looked up from the dispatch table.
|
| 341 |
+
// A lambda with a capture would be global and share its capture between all kernel lookups.
|
| 342 |
+
// So, instead of making users having to think about it (including the thread-safety
|
| 343 |
+
// issues this causes), let's just forbid stateful lambdas altogether.
|
| 344 |
+
static_assert(guts::is_stateless_lambda<std::decay_t<Lambda>>::value, "The kernel(x) API for registering a kernel only works for stateless lambdas (i.e. lambdas without captures). If you need a cache, please use the functor based API kernel<Functor>() instead.");
|
| 345 |
+
|
| 346 |
+
return std::move(*this).kernel(
|
| 347 |
+
dispatch_key,
|
| 348 |
+
KernelFunction::makeFromUnboxedLambda(std::forward<Lambda>(functor)),
|
| 349 |
+
impl::CppSignature::make<Lambda>(),
|
| 350 |
+
// TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor
|
| 351 |
+
detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>()
|
| 352 |
+
);
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
/**
|
| 356 |
+
* Use this to register an operator whose kernel is implemented as a lambda.
|
| 357 |
+
* The kernel is a catch-all kernel, meaning it's called independent from
|
| 358 |
+
* the input. Dispatch is disabled for this operator.
|
| 359 |
+
*
|
| 360 |
+
* The lambda must be stateless, i.e. not have a capture. If your kernel
|
| 361 |
+
* needs to store some configuration parameters, write the kernel as a
|
| 362 |
+
* functor instead.
|
| 363 |
+
*
|
| 364 |
+
* Example:
|
| 365 |
+
*
|
| 366 |
+
* > static auto registry = c10::RegisterOperators()
|
| 367 |
+
* > .op(c10::RegisterOperators::options()
|
| 368 |
+
* > .schema("my_op")
|
| 369 |
+
* > .catchAllKernel([] (Tensor a) -> Tensor {...}));
|
| 370 |
+
*/
|
| 371 |
+
template<class Lambda>
|
| 372 |
+
// enable_if: only enable it if Lambda is a functor (note: lambdas are functors)
|
| 373 |
+
std::enable_if_t<
|
| 374 |
+
guts::is_functor<std::decay_t<Lambda>>::value
|
| 375 |
+
&& !std::is_same_v<typename guts::infer_function_traits_t<std::decay_t<Lambda>>::func_type, KernelFunction::BoxedKernelFunction>,
|
| 376 |
+
Options&&> catchAllKernel(Lambda&& lambda) && {
|
| 377 |
+
static_assert(!std::is_base_of_v<OperatorKernel, std::decay_t<Lambda>>, "The kernel(x) API for registering a kernel is only meant to be used with lambdas. Your kernel is a functor. Please use the kernel<Functor>() API instead.");
|
| 378 |
+
|
| 379 |
+
// We don't support stateful lambdas (i.e. lambdas with a capture), because their
|
| 380 |
+
// behavior would be nonobvious.
|
| 381 |
+
// A lambda with a capture would be global and share its capture between all kernel lookups.
|
| 382 |
+
// This would be a likely source for unexpected race conditions, so we forbid it.
|
| 383 |
+
// If a kernel really needs global state, they can just have regular global state
|
| 384 |
+
// in their .cpp file next to the kernel lambda.
|
| 385 |
+
static_assert(guts::is_stateless_lambda<std::decay_t<Lambda>>::value, "The kernel(x) API for registering a kernel only works for stateless lambdas (i.e. lambdas without captures). If you need a cache, please use the functor based API kernel<Functor>() instead.");
|
| 386 |
+
|
| 387 |
+
return std::move(*this).kernel(
|
| 388 |
+
std::nullopt,
|
| 389 |
+
KernelFunction::makeFromUnboxedLambda(std::forward<Lambda>(lambda)),
|
| 390 |
+
impl::CppSignature::make<Lambda>(),
|
| 391 |
+
// TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor
|
| 392 |
+
detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>()
|
| 393 |
+
);
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
Options&& aliasAnalysis(AliasAnalysisKind aliasAnalysisKind) && {
|
| 397 |
+
TORCH_CHECK(!aliasAnalysisKind_.has_value(), "You can only call aliasAnalysis() once per operator registration.");
|
| 398 |
+
aliasAnalysisKind_ = aliasAnalysisKind;
|
| 399 |
+
return std::move(*this);
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
private:
|
| 403 |
+
Options&& kernel(std::optional<DispatchKey> dispatch_key, KernelFunction&& func, std::optional<impl::CppSignature> cpp_signature, std::unique_ptr<FunctionSchema>&& inferred_function_schema) && {
|
| 404 |
+
KernelRegistrationConfig config;
|
| 405 |
+
config.dispatch_key = dispatch_key;
|
| 406 |
+
config.func = std::move(func);
|
| 407 |
+
config.cpp_signature = cpp_signature;
|
| 408 |
+
config.inferred_function_schema = std::move(inferred_function_schema);
|
| 409 |
+
kernels.push_back(std::move(config));
|
| 410 |
+
return std::move(*this);
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
Options()
|
| 414 |
+
: schemaOrName_(std::nullopt)
|
| 415 |
+
, aliasAnalysisKind_(std::nullopt)
|
| 416 |
+
{}
|
| 417 |
+
|
| 418 |
+
// KernelRegistrationConfig accumulates all information from the config
|
| 419 |
+
// parameters passed to a RegisterOperators::op() call into one object.
|
| 420 |
+
struct KernelRegistrationConfig final {
|
| 421 |
+
KernelRegistrationConfig()
|
| 422 |
+
: dispatch_key(std::nullopt)
|
| 423 |
+
, cpp_signature(std::nullopt)
|
| 424 |
+
, inferred_function_schema(nullptr)
|
| 425 |
+
{}
|
| 426 |
+
|
| 427 |
+
std::optional<DispatchKey> dispatch_key;
|
| 428 |
+
KernelFunction func;
|
| 429 |
+
std::optional<impl::CppSignature> cpp_signature;
|
| 430 |
+
std::unique_ptr<FunctionSchema> inferred_function_schema;
|
| 431 |
+
};
|
| 432 |
+
|
| 433 |
+
std::optional<std::variant<OperatorName, FunctionSchema>> schemaOrName_;
|
| 434 |
+
|
| 435 |
+
std::vector<KernelRegistrationConfig> kernels;
|
| 436 |
+
std::optional<AliasAnalysisKind> aliasAnalysisKind_;
|
| 437 |
+
friend class RegisterOperators;
|
| 438 |
+
friend class Library;
|
| 439 |
+
};
|
| 440 |
+
|
| 441 |
+
/**
|
| 442 |
+
* Call this to get an instance of registration options, which
|
| 443 |
+
* can be passed to a call to RegisterOperators::op() to specify
|
| 444 |
+
* these options for the operator registration.
|
| 445 |
+
* See class doc comment for examples.
|
| 446 |
+
*/
|
| 447 |
+
static Options options() {
|
| 448 |
+
return {};
|
| 449 |
+
}
|
| 450 |
+
|
| 451 |
+
/**
|
| 452 |
+
* Call this to register an operator. See class doc comment for examples.
|
| 453 |
+
*/
|
| 454 |
+
RegisterOperators&& op(Options&& options) && {
|
| 455 |
+
checkSchemaAndRegisterOp_(std::move(options));
|
| 456 |
+
return std::move(*this);
|
| 457 |
+
}
|
| 458 |
+
|
| 459 |
+
// Regular mutator version of the && version above
|
| 460 |
+
RegisterOperators& op(Options&& options) & {
|
| 461 |
+
checkSchemaAndRegisterOp_(std::move(options));
|
| 462 |
+
return *this;
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
/**
|
| 466 |
+
* This is a shorthand for RegisterOperators::op(Options) where you can
|
| 467 |
+
* specify the operator schema outside of the options parameter.
|
| 468 |
+
* See class doc comment for examples.
|
| 469 |
+
*/
|
| 470 |
+
RegisterOperators&& op(const std::string& schemaOrName, Options&& options = RegisterOperators::options()) && {
|
| 471 |
+
return std::move(*this).op(std::move(options).schema(schemaOrName));
|
| 472 |
+
}
|
| 473 |
+
|
| 474 |
+
// internal only for registering caffe2 ops
|
| 475 |
+
RegisterOperators&& op(FunctionSchema schema, Options&& options) && {
|
| 476 |
+
return std::move(*this).op(std::move(options).schema(std::move(schema)));
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
template<class FuncType>
|
| 480 |
+
explicit RegisterOperators(const std::string& schemaOrName, FuncType&& func, Options&& options = RegisterOperators::options())
|
| 481 |
+
: RegisterOperators() {
|
| 482 |
+
std::move(*this).op(schemaOrName, std::forward<FuncType>(func), std::move(options));
|
| 483 |
+
}
|
| 484 |
+
|
| 485 |
+
/**
|
| 486 |
+
* This API registers an operator based on a kernel function pointer.
|
| 487 |
+
*
|
| 488 |
+
* Given a kernel
|
| 489 |
+
*
|
| 490 |
+
* > namespace { Tensor my_kernel_cpu(Tensor a, Tensor b) {...} }
|
| 491 |
+
*
|
| 492 |
+
* This API looks like:
|
| 493 |
+
*
|
| 494 |
+
* > static auto registry = c10::RegisterOperators()
|
| 495 |
+
* > .op("my_op", &my_kernel_cpu);
|
| 496 |
+
*
|
| 497 |
+
* If your kernel is small and the overhead of calling it matters,
|
| 498 |
+
* then this API might be the wrong choice since the following API
|
| 499 |
+
* has a slightly lower overhead for calling into the kernel:
|
| 500 |
+
*
|
| 501 |
+
* > static auto registry = c10::RegisterOperators()
|
| 502 |
+
* > .op("my_op", c10::RegisterOperators::options()
|
| 503 |
+
* > .kernel<decltype(my_kernel_cpu), &my_kernel_cpu>());
|
| 504 |
+
*
|
| 505 |
+
* Or, alternatively, write your kernel as a functor:
|
| 506 |
+
*
|
| 507 |
+
* > namespace {
|
| 508 |
+
* > class my_kernel_cpu final : public c10::OperatorKernel {
|
| 509 |
+
* > public:
|
| 510 |
+
* > Tensor operator()(Tensor a, Tensor b) {...}
|
| 511 |
+
* > };
|
| 512 |
+
* > }
|
| 513 |
+
* >
|
| 514 |
+
* > static auto registry = c10::RegisterOperators()
|
| 515 |
+
* > .op("my_op", c10::RegisterOperators::options()
|
| 516 |
+
* > .kernel<my_kernel_cpu>());
|
| 517 |
+
*/
|
| 518 |
+
template<class FuncType>
|
| 519 |
+
// enable_if: only enable it if FuncType is actually a function, but not a stack based BoxedKernelFunction.
|
| 520 |
+
std::enable_if_t<guts::is_function_type<FuncType>::value && !std::is_same_v<FuncType, KernelFunction::BoxedKernelFunction>, RegisterOperators&&>
|
| 521 |
+
op(const std::string& schemaOrName, FuncType* func, Options&& options = RegisterOperators::options()) && {
|
| 522 |
+
constexpr bool AllowLegacyTypes = true;
|
| 523 |
+
return std::move(*this).op(std::move(options).schema(schemaOrName).kernel(
|
| 524 |
+
std::nullopt,
|
| 525 |
+
KernelFunction::makeFromUnboxedRuntimeFunction<AllowLegacyTypes>(func),
|
| 526 |
+
impl::CppSignature::make<FuncType>(),
|
| 527 |
+
// TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor
|
| 528 |
+
detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<FuncType>>>()
|
| 529 |
+
));
|
| 530 |
+
}
|
| 531 |
+
|
| 532 |
+
/**
|
| 533 |
+
* This API registers an operator based on a kernel lambda.
|
| 534 |
+
*
|
| 535 |
+
* This API looks like:
|
| 536 |
+
*
|
| 537 |
+
* > static auto registry = c10::RegisterOperators()
|
| 538 |
+
* > .op("my_op", [] (Tensor a, Tensor b) {...});
|
| 539 |
+
*
|
| 540 |
+
* This is equivalent to:
|
| 541 |
+
*
|
| 542 |
+
* > static auto registry = c10::RegisterOperators()
|
| 543 |
+
* > .op("my_op", c10::RegisterOperators::options()
|
| 544 |
+
* > .catchAllKernel([] (Tensor a, Tensor b) {...}));
|
| 545 |
+
*
|
| 546 |
+
*/
|
| 547 |
+
template<class Lambda>
|
| 548 |
+
// enable_if: only enable it if Lambda is actually a stateless lambda
|
| 549 |
+
std::enable_if_t<guts::is_functor<Lambda>::value && guts::is_stateless_lambda<std::decay_t<Lambda>>::value, RegisterOperators&&>
|
| 550 |
+
op(const std::string& schemaOrName, Lambda&& lambda, Options&& options = RegisterOperators::options()) && {
|
| 551 |
+
static_assert(!std::is_base_of_v<OperatorKernel, Lambda>, "c10::OperatorKernel is part of the new kernel registration API and shouldn't be used together with the deprecated registration API. Please use the new RegisterOperators::options().kernel() based API instead.");
|
| 552 |
+
|
| 553 |
+
constexpr bool AllowLegacyTypes = true;
|
| 554 |
+
return std::move(*this).op(std::move(options).schema(schemaOrName).kernel(
|
| 555 |
+
std::nullopt,
|
| 556 |
+
KernelFunction::makeFromUnboxedLambda<AllowLegacyTypes>(std::forward<Lambda>(lambda)),
|
| 557 |
+
impl::CppSignature::make<Lambda>(),
|
| 558 |
+
// TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor
|
| 559 |
+
detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>()
|
| 560 |
+
));
|
| 561 |
+
}
|
| 562 |
+
|
| 563 |
+
template<class Lambda>
|
| 564 |
+
C10_DEPRECATED_MESSAGE("Registering operator kernels with stateful lambdas (i.e. lambdas with a capture) has non-obvious behavior. This is deprecated. Please use a lambda without a capture or a functor class instead.")
|
| 565 |
+
// enable_if: only enable it if Lambda is actually a functor but not a stateless lambda
|
| 566 |
+
std::enable_if_t<guts::is_functor<Lambda>::value && !guts::is_stateless_lambda<std::decay_t<Lambda>>::value, RegisterOperators&&>
|
| 567 |
+
op(const std::string& schemaOrName, Lambda&& lambda, Options&& options = RegisterOperators::options()) && {
|
| 568 |
+
static_assert(!std::is_base_of_v<OperatorKernel, Lambda>, "c10::OperatorKernel is part of the new kernel registration API and shouldn't be used together with the deprecated registration API. Please use the new RegisterOperators::options().kernel() based API instead.");
|
| 569 |
+
|
| 570 |
+
constexpr bool AllowLegacyTypes = true;
|
| 571 |
+
return std::move(*this).op(std::move(options).schema(schemaOrName).kernel(
|
| 572 |
+
std::nullopt,
|
| 573 |
+
KernelFunction::makeFromUnboxedLambda<AllowLegacyTypes>(std::forward<Lambda>(lambda)),
|
| 574 |
+
impl::CppSignature::make<Lambda>(),
|
| 575 |
+
// TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor
|
| 576 |
+
detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>()
|
| 577 |
+
));
|
| 578 |
+
}
|
| 579 |
+
|
| 580 |
+
private:
|
| 581 |
+
void checkSchemaAndRegisterOp_(Options&& config);
|
| 582 |
+
|
| 583 |
+
static c10::FunctionSchema inferSchemaFromKernels_(const OperatorName& opNameStr, const Options& options);
|
| 584 |
+
void checkNoDuplicateKernels_(const Options& options);
|
| 585 |
+
void registerOp_(Options&& options);
|
| 586 |
+
|
| 587 |
+
std::vector<RegistrationHandleRAII> registrars_;
|
| 588 |
+
};
|
| 589 |
+
|
| 590 |
+
} // namespace c10
|
| 591 |
+
|
| 592 |
+
namespace torch {
|
| 593 |
+
// Old-style API
|
| 594 |
+
using RegisterOperators = c10::RegisterOperators;
|
| 595 |
+
}
|
| 596 |
+
|
| 597 |
+
#else
|
| 598 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 599 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/FlushDenormal.h
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
/// Flush-To-Zero and Denormals-Are-Zero mode
|
| 3 |
+
///
|
| 4 |
+
/// Flush-To-Zero (FTZ) and Denormals-Are-Zero (DAZ) are modes that bypass
|
| 5 |
+
/// IEEE 754 methods of dealing with denormal floating-point numbers on x86-64
|
| 6 |
+
/// and some x86 CPUs. They result in reduced precision for values near zero,
|
| 7 |
+
/// but increased performance.
|
| 8 |
+
///
|
| 9 |
+
/// See https://software.intel.com/en-us/articles/x87-and-sse-floating-point-assists-in-ia-32-flush-to-zero-ftz-and-denormals-are-zero-daz
|
| 10 |
+
|
| 11 |
+
namespace at::cpu {
|
| 12 |
+
|
| 13 |
+
bool set_flush_denormal(bool on);
|
| 14 |
+
|
| 15 |
+
} // namespace at::cpu
|
| 16 |
+
|
| 17 |
+
#else
|
| 18 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 19 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/Utils.h
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <cstdint>
|
| 5 |
+
|
| 6 |
+
#include <c10/macros/Export.h>
|
| 7 |
+
|
| 8 |
+
namespace at::cpu {
|
| 9 |
+
|
| 10 |
+
TORCH_API bool is_avx2_supported();
|
| 11 |
+
TORCH_API bool is_avx512_supported();
|
| 12 |
+
|
| 13 |
+
// Detect if CPU support Vector Neural Network Instruction.
|
| 14 |
+
TORCH_API bool is_avx512_vnni_supported();
|
| 15 |
+
|
| 16 |
+
// Detect if CPU supports AVX512_BF16 ISA
|
| 17 |
+
TORCH_API bool is_avx512_bf16_supported();
|
| 18 |
+
|
| 19 |
+
// Detect if CPU support Advanced Matrix Extension.
|
| 20 |
+
TORCH_API bool is_amx_tile_supported();
|
| 21 |
+
|
| 22 |
+
// Detect if CPU support Advanced Matrix Extension for fp16.
|
| 23 |
+
TORCH_API bool is_amx_fp16_supported();
|
| 24 |
+
|
| 25 |
+
// Enable the system to use AMX instructions.
|
| 26 |
+
TORCH_API bool init_amx();
|
| 27 |
+
|
| 28 |
+
// Get the L1 cache size per core in Byte
|
| 29 |
+
TORCH_API uint32_t L1d_cache_size();
|
| 30 |
+
|
| 31 |
+
// Get the L2 cache size per core in Byte
|
| 32 |
+
TORCH_API uint32_t L2_cache_size();
|
| 33 |
+
|
| 34 |
+
} // namespace at::cpu
|
| 35 |
+
|
| 36 |
+
#else
|
| 37 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 38 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vml.h
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/Config.h>
|
| 5 |
+
#include <ATen/Parallel.h>
|
| 6 |
+
#include <ATen/OpMathType.h>
|
| 7 |
+
#include <ATen/cpu/vec/functional.h>
|
| 8 |
+
#include <ATen/cpu/vec/vec.h>
|
| 9 |
+
#include <c10/util/complex.h>
|
| 10 |
+
|
| 11 |
+
// This header implements various unary operations using a MKL VML style
|
| 12 |
+
// interface.
|
| 13 |
+
|
| 14 |
+
// It implements various functions with a simple interface
|
| 15 |
+
// For example it enables the user to call vsin(float* out, const float* in,
|
| 16 |
+
// size) This functions takes a pointer to a continuous output array of floats and
|
| 17 |
+
// a constant input array. It will then apply sin to each value in the input
|
| 18 |
+
// array and write the result into the output array. out and in may point to the
|
| 19 |
+
// same memory, i.e. this fully supports in-place operations. These functions
|
| 20 |
+
// also implement their own parallelization, so take precautions when calling
|
| 21 |
+
// these from threaded functions.
|
| 22 |
+
|
| 23 |
+
// When MKL is available it will call into MKL's VML library similar to NumPy
|
| 24 |
+
// If MKL is not available it will use SLEEF.
|
| 25 |
+
|
| 26 |
+
// This file might be compiled under AVX or AVX2 when called from e.g.
|
| 27 |
+
// UnaryOpsKernel.cpp
|
| 28 |
+
|
| 29 |
+
#include <algorithm>
|
| 30 |
+
#include <cstddef>
|
| 31 |
+
#include <cstdint>
|
| 32 |
+
#include <cstring>
|
| 33 |
+
#include <type_traits>
|
| 34 |
+
|
| 35 |
+
#if AT_MKL_ENABLED() && !defined(__APPLE__)
|
| 36 |
+
#include <mkl.h>
|
| 37 |
+
#endif
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
namespace at::vml {
|
| 41 |
+
inline namespace CPU_CAPABILITY {
|
| 42 |
+
|
| 43 |
+
using namespace vec;
|
| 44 |
+
|
| 45 |
+
template <typename scalar_t>
|
| 46 |
+
inline void vrsqrt(scalar_t* out, scalar_t* in, int64_t size) {
|
| 47 |
+
parallel_for(0, size, 2048, [out, in](int64_t begin, int64_t end) {
|
| 48 |
+
map(
|
| 49 |
+
[](const Vectorized<scalar_t>& x) {
|
| 50 |
+
return Vectorized<scalar_t>((scalar_t)1) / x.sqrt();
|
| 51 |
+
},
|
| 52 |
+
out + begin,
|
| 53 |
+
in + begin,
|
| 54 |
+
end - begin);
|
| 55 |
+
});
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
// NB: We ignore numerical errors by convention and leave them to the user
|
| 59 |
+
|
| 60 |
+
#define IMPLEMENT_VML(op) \
|
| 61 |
+
template <typename scalar_t> \
|
| 62 |
+
inline void v##op(scalar_t* out, const scalar_t* in, int64_t size) { \
|
| 63 |
+
using vec_t = Vectorized<vec_scalar_t<scalar_t>>; \
|
| 64 |
+
vec::map([](vec_t x) { return x.op(); }, out, in, size); \
|
| 65 |
+
} \
|
| 66 |
+
|
| 67 |
+
IMPLEMENT_VML(abs)
|
| 68 |
+
IMPLEMENT_VML(acos)
|
| 69 |
+
IMPLEMENT_VML(asin)
|
| 70 |
+
IMPLEMENT_VML(atan)
|
| 71 |
+
IMPLEMENT_VML(atanh)
|
| 72 |
+
IMPLEMENT_VML(ceil)
|
| 73 |
+
IMPLEMENT_VML(cos)
|
| 74 |
+
// IMPLEMENT_VML(cosh)
|
| 75 |
+
IMPLEMENT_VML(erf)
|
| 76 |
+
IMPLEMENT_VML(erfc)
|
| 77 |
+
IMPLEMENT_VML(erfinv)
|
| 78 |
+
IMPLEMENT_VML(exp)
|
| 79 |
+
IMPLEMENT_VML(expm1)
|
| 80 |
+
IMPLEMENT_VML(floor)
|
| 81 |
+
IMPLEMENT_VML(i0)
|
| 82 |
+
IMPLEMENT_VML(i0e)
|
| 83 |
+
IMPLEMENT_VML(digamma)
|
| 84 |
+
IMPLEMENT_VML(reciprocal)
|
| 85 |
+
IMPLEMENT_VML(log)
|
| 86 |
+
IMPLEMENT_VML(log10)
|
| 87 |
+
IMPLEMENT_VML(log1p)
|
| 88 |
+
IMPLEMENT_VML(log2)
|
| 89 |
+
IMPLEMENT_VML(neg)
|
| 90 |
+
IMPLEMENT_VML(sin)
|
| 91 |
+
// IMPLEMENT_VML(sinh)
|
| 92 |
+
IMPLEMENT_VML(sqrt)
|
| 93 |
+
IMPLEMENT_VML(round)
|
| 94 |
+
IMPLEMENT_VML(rsqrt)
|
| 95 |
+
IMPLEMENT_VML(tan)
|
| 96 |
+
IMPLEMENT_VML(tanh)
|
| 97 |
+
IMPLEMENT_VML(trunc)
|
| 98 |
+
IMPLEMENT_VML(lgamma)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
#if AT_MKL_ENABLED() && !defined(__APPLE__)
|
| 102 |
+
|
| 103 |
+
// NB: LP64 MKL is the most commonly used and thus we assume it here. That means
|
| 104 |
+
// we need to expect MKL_INT to be of type int, which implies int32_t or int64_t in most
|
| 105 |
+
// cases.
|
| 106 |
+
static_assert(
|
| 107 |
+
std::is_same_v<MKL_INT, int32_t> || std::is_same_v<MKL_INT, int64_t>,
|
| 108 |
+
"MKL_INT is assumed to be int32_t or int64_t");
|
| 109 |
+
#define IMPLEMENT_VML_MKL_STUB(op, mklop, type, mkltype) \
|
| 110 |
+
template <> \
|
| 111 |
+
inline void v##op(type * out, const type * in, int64_t size) { \
|
| 112 |
+
auto constexpr max_mkl_ind = std::numeric_limits<MKL_INT>::max(); \
|
| 113 |
+
if (size <= static_cast<int64_t>(max_mkl_ind)) { \
|
| 114 |
+
vm##mkltype##mklop( \
|
| 115 |
+
size, in, out, VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \
|
| 116 |
+
} else { \
|
| 117 |
+
int64_t ind = 0; \
|
| 118 |
+
int64_t chunks = size / max_mkl_ind; \
|
| 119 |
+
int64_t rest = size % max_mkl_ind; \
|
| 120 |
+
for (; ind < chunks; ind++) { \
|
| 121 |
+
vm##mkltype##mklop( \
|
| 122 |
+
max_mkl_ind, \
|
| 123 |
+
in + ind * max_mkl_ind, \
|
| 124 |
+
out + ind * max_mkl_ind, \
|
| 125 |
+
VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \
|
| 126 |
+
} \
|
| 127 |
+
vm##mkltype##mklop( \
|
| 128 |
+
rest, \
|
| 129 |
+
in + ind * max_mkl_ind, \
|
| 130 |
+
out + ind * max_mkl_ind, \
|
| 131 |
+
VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \
|
| 132 |
+
} \
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
#define IMPLEMENT_VML_MKL(op, mklop) \
|
| 136 |
+
IMPLEMENT_VML_MKL_STUB(op, mklop, float, s) \
|
| 137 |
+
IMPLEMENT_VML_MKL_STUB(op, mklop, double, d)
|
| 138 |
+
|
| 139 |
+
// NB: abs, cosh and sinh were temporarily disabled due to issues with Apple
|
| 140 |
+
// NB: expm1 is disabled because on some configs it produces expm1(nan)=-1
|
| 141 |
+
IMPLEMENT_VML_MKL(acos, Acos)
|
| 142 |
+
IMPLEMENT_VML_MKL(asin, Asin)
|
| 143 |
+
IMPLEMENT_VML_MKL(atan, Atan)
|
| 144 |
+
IMPLEMENT_VML_MKL(cos, Cos)
|
| 145 |
+
// IMPLEMENT_VML_MKL(cosh, Cosh)
|
| 146 |
+
IMPLEMENT_VML_MKL(erf, Erf)
|
| 147 |
+
IMPLEMENT_VML_MKL(erfc, Erfc)
|
| 148 |
+
IMPLEMENT_VML_MKL(erfinv, ErfInv)
|
| 149 |
+
IMPLEMENT_VML_MKL(exp, Exp)
|
| 150 |
+
// IMPLEMENT_VML_MKL(expm1, Expm1)
|
| 151 |
+
IMPLEMENT_VML_MKL(log, Ln)
|
| 152 |
+
IMPLEMENT_VML_MKL(log10, Log10)
|
| 153 |
+
IMPLEMENT_VML_MKL(sin, Sin)
|
| 154 |
+
// IMPLEMENT_VML_MKL(sinh, Sinh)
|
| 155 |
+
IMPLEMENT_VML_MKL(sqrt, Sqrt)
|
| 156 |
+
IMPLEMENT_VML_MKL(tan, Tan)
|
| 157 |
+
IMPLEMENT_VML_MKL(tanh, Tanh)
|
| 158 |
+
IMPLEMENT_VML_MKL(trunc, Trunc)
|
| 159 |
+
|
| 160 |
+
// Not vectorized in MKL version tested
|
| 161 |
+
// IMPLEMENT_VML_MKL(abs, Abs)
|
| 162 |
+
// IMPLEMENT_VML_MKL(log1p, Log1p)
|
| 163 |
+
|
| 164 |
+
#if INTEL_MKL_VERSION >= 20180406
|
| 165 |
+
IMPLEMENT_VML_MKL(log2, Log2)
|
| 166 |
+
#endif
|
| 167 |
+
|
| 168 |
+
#endif
|
| 169 |
+
|
| 170 |
+
} // namespace
|
| 171 |
+
} // namespace at::vml
|
| 172 |
+
|
| 173 |
+
#else
|
| 174 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 175 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/BLASConstants.h
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/core/TensorBase.h>
|
| 5 |
+
|
| 6 |
+
namespace at::cuda::detail {
|
| 7 |
+
|
| 8 |
+
float *get_cublas_device_one();
|
| 9 |
+
float *get_cublas_device_zero();
|
| 10 |
+
float *get_user_alpha_ptr();
|
| 11 |
+
|
| 12 |
+
} // namespace at::cuda::detail
|
| 13 |
+
|
| 14 |
+
#else
|
| 15 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 16 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/CUDAHooks.h
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/detail/CUDAHooksInterface.h>
|
| 5 |
+
|
| 6 |
+
#include <ATen/Generator.h>
|
| 7 |
+
|
| 8 |
+
// TODO: No need to have this whole header, we can just put it all in
|
| 9 |
+
// the cpp file
|
| 10 |
+
|
| 11 |
+
namespace at::cuda::detail {
|
| 12 |
+
|
| 13 |
+
// Set the callback to initialize Magma, which is set by
|
| 14 |
+
// torch_cuda_cu. This indirection is required so magma_init is called
|
| 15 |
+
// in the same library where Magma will be used.
|
| 16 |
+
TORCH_CUDA_CPP_API void set_magma_init_fn(void (*magma_init_fn)());
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
// The real implementation of CUDAHooksInterface
|
| 20 |
+
struct CUDAHooks : public at::CUDAHooksInterface {
|
| 21 |
+
CUDAHooks(at::CUDAHooksArgs /*unused*/) {}
|
| 22 |
+
void init() const override;
|
| 23 |
+
Device getDeviceFromPtr(void* data) const override;
|
| 24 |
+
bool isPinnedPtr(const void* data) const override;
|
| 25 |
+
const Generator& getDefaultGenerator(
|
| 26 |
+
DeviceIndex device_index = -1) const override;
|
| 27 |
+
Generator getNewGenerator(
|
| 28 |
+
DeviceIndex device_index = -1) const override;
|
| 29 |
+
bool hasCUDA() const override;
|
| 30 |
+
bool hasMAGMA() const override;
|
| 31 |
+
bool hasCuDNN() const override;
|
| 32 |
+
bool hasCuSOLVER() const override;
|
| 33 |
+
bool hasCuBLASLt() const override;
|
| 34 |
+
bool hasROCM() const override;
|
| 35 |
+
bool hasCKSDPA() const override;
|
| 36 |
+
bool hasCKGEMM() const override;
|
| 37 |
+
const at::cuda::NVRTC& nvrtc() const override;
|
| 38 |
+
DeviceIndex current_device() const override;
|
| 39 |
+
bool isBuilt() const override {return true;}
|
| 40 |
+
bool isAvailable() const override {return hasCUDA();}
|
| 41 |
+
bool hasPrimaryContext(DeviceIndex device_index) const override;
|
| 42 |
+
Allocator* getCUDADeviceAllocator() const override;
|
| 43 |
+
Allocator* getPinnedMemoryAllocator() const override;
|
| 44 |
+
bool compiledWithCuDNN() const override;
|
| 45 |
+
bool compiledWithMIOpen() const override;
|
| 46 |
+
bool supportsDilatedConvolutionWithCuDNN() const override;
|
| 47 |
+
bool supportsDepthwiseConvolutionWithCuDNN() const override;
|
| 48 |
+
bool supportsBFloat16ConvolutionWithCuDNNv8() const override;
|
| 49 |
+
bool supportsBFloat16RNNWithCuDNN() const override;
|
| 50 |
+
bool hasCUDART() const override;
|
| 51 |
+
long versionCUDART() const override;
|
| 52 |
+
long versionCuDNN() const override;
|
| 53 |
+
long versionRuntimeCuDNN() const override;
|
| 54 |
+
long versionCuDNNFrontend() const override;
|
| 55 |
+
long versionMIOpen() const override;
|
| 56 |
+
std::string showConfig() const override;
|
| 57 |
+
double batchnormMinEpsilonCuDNN() const override;
|
| 58 |
+
int64_t cuFFTGetPlanCacheMaxSize(DeviceIndex device_index) const override;
|
| 59 |
+
void cuFFTSetPlanCacheMaxSize(DeviceIndex device_index, int64_t max_size) const override;
|
| 60 |
+
int64_t cuFFTGetPlanCacheSize(DeviceIndex device_index) const override;
|
| 61 |
+
void cuFFTClearPlanCache(DeviceIndex device_index) const override;
|
| 62 |
+
int getNumGPUs() const override;
|
| 63 |
+
DeviceIndex deviceCount() const override;
|
| 64 |
+
DeviceIndex getCurrentDevice() const override;
|
| 65 |
+
|
| 66 |
+
#ifdef USE_ROCM
|
| 67 |
+
bool isGPUArch(const std::vector<std::string>& archs, DeviceIndex device_index = -1) const override;
|
| 68 |
+
#endif
|
| 69 |
+
void deviceSynchronize(DeviceIndex device_index) const override;
|
| 70 |
+
};
|
| 71 |
+
|
| 72 |
+
} // at::cuda::detail
|
| 73 |
+
|
| 74 |
+
#else
|
| 75 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 76 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/DeviceThreadHandles.h
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
// Some stateful GPU libraries, such as cuDNN, cuBLAS, use handles to store states.
|
| 3 |
+
// These handles are tied to device, and these libraries requires/recommends not to
|
| 4 |
+
// share handles across host threads.
|
| 5 |
+
//
|
| 6 |
+
// These libraries recommend using one handle per host thread. We may not want to do
|
| 7 |
+
// this because threads are relatively light-weight, but creating and destroying
|
| 8 |
+
// handles is expensive (destroying the handle causes synchronizations). DataParallel,
|
| 9 |
+
// for example, creates new threads for each forward pass.
|
| 10 |
+
//
|
| 11 |
+
// This file implements a handle pool mechanism. The handle pool returns handles on
|
| 12 |
+
// demand as threads request them. If all existing handles in the pool are in use,
|
| 13 |
+
// it creates a new one. As threads terminate, they release handles back into the pool.
|
| 14 |
+
// In this way, the handle pool never creates more handles than the high-water mark of
|
| 15 |
+
// active threads, so it's efficient with DataParallel.
|
| 16 |
+
|
| 17 |
+
#pragma once
|
| 18 |
+
|
| 19 |
+
#include <unordered_map>
|
| 20 |
+
#include <vector>
|
| 21 |
+
#include <utility>
|
| 22 |
+
#include <mutex>
|
| 23 |
+
#include <memory>
|
| 24 |
+
|
| 25 |
+
#include <c10/util/Exception.h>
|
| 26 |
+
|
| 27 |
+
namespace at::cuda { namespace {
|
| 28 |
+
|
| 29 |
+
template <typename Handle_t, void Create(Handle_t *), void Destroy(Handle_t)>
|
| 30 |
+
struct DeviceThreadHandlePool : public std::enable_shared_from_this<DeviceThreadHandlePool<Handle_t, Create, Destroy>> {
|
| 31 |
+
|
| 32 |
+
struct Handle {
|
| 33 |
+
Handle_t handle;
|
| 34 |
+
Handle(bool create = false) : handle(nullptr)
|
| 35 |
+
{
|
| 36 |
+
if(create) Create(&handle);
|
| 37 |
+
}
|
| 38 |
+
// std::vector.emplace() and push_back() may route through temporaries and call
|
| 39 |
+
// copy/move constructors along the way. If this is the case, we don't want
|
| 40 |
+
// the destructors of temporaries to call cudnnDestroy on the handle.
|
| 41 |
+
// We can achieve safety (for the narrow case of stashing within std::vectors)
|
| 42 |
+
// by making Handle moveable but not copyable, and transferring handle ownership
|
| 43 |
+
// to the latest constructed object. This is not a substitute for full-blown
|
| 44 |
+
// reference counting, but reference counting may be overkill here.
|
| 45 |
+
// Another alternative is to wrap the saved Handles in unique_ptrs, i.e.,
|
| 46 |
+
// unordered_map<int, vector<unique_ptr<Handle>>> created_handles;
|
| 47 |
+
Handle(const Handle& rhs) = delete;
|
| 48 |
+
// Following https://stackoverflow.com/questions/3279543/what-is-the-copy-and-swap-idiom
|
| 49 |
+
Handle(Handle&& rhs) noexcept : Handle() { std::swap(handle, rhs.handle); }
|
| 50 |
+
// operator= takes argument by value
|
| 51 |
+
Handle& operator=(Handle rhs) { std::swap(handle, rhs.handle); return *this; }
|
| 52 |
+
~Handle() {
|
| 53 |
+
if(handle) Destroy(handle);
|
| 54 |
+
}
|
| 55 |
+
};
|
| 56 |
+
|
| 57 |
+
std::mutex mutex;
|
| 58 |
+
|
| 59 |
+
// Handles are lazily created as different threads request them,
|
| 60 |
+
// but are never destroyed until the end of the process.
|
| 61 |
+
// The maximum number of handles this process will create for each device is equal
|
| 62 |
+
// to the high-water mark of the number of concurrently active threads that request
|
| 63 |
+
// handles for that device.
|
| 64 |
+
// When threads terminate, they release their handles back into the pool for reuse.
|
| 65 |
+
// Otherwise, new handles would be created every time new threads were spawned,
|
| 66 |
+
// resulting in poor performance for Python modules that repeatedly or frequently
|
| 67 |
+
// spawned new sets of threads (like DataParallel, which creates a new set of threads
|
| 68 |
+
// for each forward pass).
|
| 69 |
+
//
|
| 70 |
+
// To prevent potential deadlocks, we explicitly choose not to cap the number
|
| 71 |
+
// of handles that are created per device.
|
| 72 |
+
// Example of danger: If we cap the max handles at 4, and 5 threads are sharing a device,
|
| 73 |
+
// only 4 can make forward progress at any time. The other 4 will not release their
|
| 74 |
+
// handles until they exit, so the fifth cannot make progress until then. This is
|
| 75 |
+
// not a problem...UNLESS all 5 threads attempt some sort of synchronization at an
|
| 76 |
+
// intermediate point (ie, before any of them have exited). We have no way to anticipate
|
| 77 |
+
// or enforce that user threads will not attempt such intermediate synchronization.
|
| 78 |
+
// The only way to ensure safety is to avoid imposing a cap on the number of handles.
|
| 79 |
+
std::unordered_map<int, std::vector<Handle>> created_handles;
|
| 80 |
+
std::unordered_map<int, std::vector<Handle_t>> available_handles;
|
| 81 |
+
|
| 82 |
+
// PoolWindow lazily creates and caches the handles that a particular thread is using,
|
| 83 |
+
// so in the common case handle access doesn't incur either handle creation or a mutex lock.
|
| 84 |
+
class PoolWindow
|
| 85 |
+
{
|
| 86 |
+
public:
|
| 87 |
+
PoolWindow(std::shared_ptr<DeviceThreadHandlePool> parent): weak_parent(std::move(parent)) {}
|
| 88 |
+
~PoolWindow(){ release(); }
|
| 89 |
+
|
| 90 |
+
Handle_t reserve(int device)
|
| 91 |
+
{
|
| 92 |
+
// If this thread already has a handle for this device, return it
|
| 93 |
+
if(my_handles.find(device) != my_handles.end())
|
| 94 |
+
return my_handles[device];
|
| 95 |
+
|
| 96 |
+
// otherwise, either grab a handle from the pool if one is available,
|
| 97 |
+
// or if not, create a new one.
|
| 98 |
+
auto parent = weak_parent.lock();
|
| 99 |
+
TORCH_CHECK(parent, "Cannot create handle during program termination");
|
| 100 |
+
std::lock_guard<std::mutex> guard(parent->mutex);
|
| 101 |
+
|
| 102 |
+
if(parent->available_handles[device].size() > 0)
|
| 103 |
+
{
|
| 104 |
+
my_handles[device] = parent->available_handles[device].back();
|
| 105 |
+
parent->available_handles[device].pop_back();
|
| 106 |
+
}
|
| 107 |
+
else
|
| 108 |
+
{
|
| 109 |
+
// In local testing, I do observe that emplace_back sometimes routes through temporaries
|
| 110 |
+
// that incur move-constructor and destructor calls. See comments in Handle above.
|
| 111 |
+
parent->created_handles[device].emplace_back(true /*create*/);
|
| 112 |
+
my_handles[device] = parent->created_handles[device].back().handle;
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
return my_handles[device];
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
private:
|
| 119 |
+
// Stores the per-device handles currently owned by this thread
|
| 120 |
+
std::unordered_map<int, Handle_t> my_handles;
|
| 121 |
+
|
| 122 |
+
std::weak_ptr<DeviceThreadHandlePool> weak_parent;
|
| 123 |
+
|
| 124 |
+
// Called by the destructor. Releases this thread's handles back into the pool.
|
| 125 |
+
void release() {
|
| 126 |
+
if(!my_handles.empty()) {
|
| 127 |
+
auto parent = weak_parent.lock();
|
| 128 |
+
if (!parent) {
|
| 129 |
+
// If this thread exits after atexit handlers have completed, the
|
| 130 |
+
// cuda context itself may be invalid, so we must leak the handles.
|
| 131 |
+
return;
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
std::lock_guard<std::mutex> guard(parent->mutex);
|
| 135 |
+
for(auto d_h : my_handles)
|
| 136 |
+
parent->available_handles[d_h.first].push_back(d_h.second);
|
| 137 |
+
}
|
| 138 |
+
}
|
| 139 |
+
};
|
| 140 |
+
|
| 141 |
+
// Warning:
|
| 142 |
+
// If you want to change this function, be aware that this function will be called
|
| 143 |
+
// by multiple threads and there is no mutex guarding the call of this function, so
|
| 144 |
+
// make sure your implementation is thread-safe.
|
| 145 |
+
PoolWindow *newPoolWindow() {
|
| 146 |
+
// The returned pointer will be owned by a thread local variable
|
| 147 |
+
// so that different threads does not share the same PoolWindow.
|
| 148 |
+
return new PoolWindow(this->shared_from_this());
|
| 149 |
+
}
|
| 150 |
+
};
|
| 151 |
+
|
| 152 |
+
}} // namespace at::cuda::detail::<anonymous>
|
| 153 |
+
|
| 154 |
+
#else
|
| 155 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 156 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/IndexUtils.cuh
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/core/TensorBase.h>
|
| 5 |
+
#include <ATen/cuda/detail/TensorInfo.cuh>
|
| 6 |
+
#include <ATen/native/CanUse32BitIndexMath.h>
|
| 7 |
+
|
| 8 |
+
namespace at::cuda::detail {
|
| 9 |
+
|
| 10 |
+
TORCH_CUDA_CU_API bool maybeOverlappingIndices(const at::TensorBase &t);
|
| 11 |
+
using at::native::canUse32BitIndexMath;
|
| 12 |
+
|
| 13 |
+
template <typename scalar, typename IndexType>
|
| 14 |
+
TensorInfo<scalar, IndexType>
|
| 15 |
+
getTensorInfo(const at::TensorBase &t) {
|
| 16 |
+
IndexType sz[MAX_TENSORINFO_DIMS];
|
| 17 |
+
IndexType st[MAX_TENSORINFO_DIMS];
|
| 18 |
+
|
| 19 |
+
int dims = t.dim();
|
| 20 |
+
for (int i = 0; i < dims; ++i) {
|
| 21 |
+
sz[i] = t.size(i);
|
| 22 |
+
st[i] = t.stride(i);
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
scalar* data_ptr = nullptr;
|
| 26 |
+
|
| 27 |
+
if constexpr (std::is_const_v<scalar>) {
|
| 28 |
+
data_ptr = t.const_data_ptr<scalar>();
|
| 29 |
+
} else {
|
| 30 |
+
data_ptr = t.mutable_data_ptr<scalar>();
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
return TensorInfo<scalar, IndexType>(
|
| 34 |
+
data_ptr, dims, sz, st);
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
} // namespace at::cuda::detail
|
| 38 |
+
|
| 39 |
+
#else
|
| 40 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 41 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/IntegerDivider.cuh
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <assert.h>
|
| 5 |
+
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
|
| 6 |
+
#include <cuda_runtime.h>
|
| 7 |
+
#endif
|
| 8 |
+
|
| 9 |
+
namespace at::cuda::detail {
|
| 10 |
+
|
| 11 |
+
// A utility class to implement integer division by multiplication, given a fixed
|
| 12 |
+
// divisor.
|
| 13 |
+
//
|
| 14 |
+
// WARNING: The fast divider algorithm is only implemented for unsigned int;
|
| 15 |
+
// otherwise we default to plain integer division. For unsigned int,
|
| 16 |
+
// we further assume that the dividend is at most INT32_MAX. Thus,
|
| 17 |
+
// IntDivider must NOT be used for general integer division.
|
| 18 |
+
//
|
| 19 |
+
// This reduced range is enough for our purpose, and it allows us to
|
| 20 |
+
// slightly simplify the computation.
|
| 21 |
+
//
|
| 22 |
+
// (NOTE: Below, "2^k" denotes exponentiation, i.e., 1<<k.)
|
| 23 |
+
//
|
| 24 |
+
// For any N-bit unsigned integer d (> 0), we can find a "magic number" m (2^N
|
| 25 |
+
// <= m < 2^(N+1)) and shift s such that:
|
| 26 |
+
//
|
| 27 |
+
// \floor(n / d) = \floor((m * n) / 2^(N+s)).
|
| 28 |
+
//
|
| 29 |
+
// Given such m and s, the integer division can be then implemented as:
|
| 30 |
+
//
|
| 31 |
+
// let m' = m - 2^N // 0 <= m' < 2^N
|
| 32 |
+
//
|
| 33 |
+
// fast_integer_division(n):
|
| 34 |
+
// // Multiply two N-bit unsigned integers: the result is a 2N-bit unsigned
|
| 35 |
+
// // integer. Then take the higher N bits.
|
| 36 |
+
// t = (m' * n) >> N
|
| 37 |
+
//
|
| 38 |
+
// // Here we use the fact that n is less than 2^(N-1): otherwise the value
|
| 39 |
+
// // of (t + n) may not fit in an N-bit integer.
|
| 40 |
+
// return (t + n) >> s
|
| 41 |
+
//
|
| 42 |
+
// Finding such a magic number is surprisingly easy:
|
| 43 |
+
//
|
| 44 |
+
// s = \ceil(\log_2 d)
|
| 45 |
+
// m' = \floor(2^N * (2^s - d) / d) + 1 // Need 2N-bit integer arithmetic.
|
| 46 |
+
//
|
| 47 |
+
// See also:
|
| 48 |
+
// - Division by Invariant Integers Using Multiplication,
|
| 49 |
+
// Torbjörn Granlund and Peter L. Montgomery, 1994.
|
| 50 |
+
//
|
| 51 |
+
// - http://www.hackersdelight.org/magic.htm
|
| 52 |
+
//
|
| 53 |
+
// - http://ridiculousfish.com/blog/posts/labor-of-division-episode-i.html
|
| 54 |
+
|
| 55 |
+
// Result of div/mod operation stored together.
|
| 56 |
+
template <typename Value>
|
| 57 |
+
struct DivMod {
|
| 58 |
+
Value div, mod;
|
| 59 |
+
|
| 60 |
+
C10_HOST_DEVICE DivMod(Value div, Value mod) : div(div), mod(mod) { }
|
| 61 |
+
};
|
| 62 |
+
|
| 63 |
+
// Base case: we only have an implementation for uint32_t for now. For
|
| 64 |
+
// everything else, we use plain division.
|
| 65 |
+
template <typename Value>
|
| 66 |
+
struct IntDivider {
|
| 67 |
+
IntDivider() = default;
|
| 68 |
+
IntDivider(Value d) : divisor(d) { }
|
| 69 |
+
|
| 70 |
+
C10_HOST_DEVICE inline Value div(Value n) const { return n / divisor; }
|
| 71 |
+
C10_HOST_DEVICE inline Value mod(Value n) const { return n % divisor; }
|
| 72 |
+
C10_HOST_DEVICE inline DivMod<Value> divmod(Value n) const {
|
| 73 |
+
return DivMod<Value>(n / divisor, n % divisor);
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
Value divisor;
|
| 77 |
+
};
|
| 78 |
+
|
| 79 |
+
// Implement fast integer division.
|
| 80 |
+
template <>
|
| 81 |
+
struct IntDivider<unsigned int> {
|
| 82 |
+
static_assert(sizeof(unsigned int) == 4, "Assumes 32-bit unsigned int.");
|
| 83 |
+
|
| 84 |
+
IntDivider() = default;
|
| 85 |
+
|
| 86 |
+
IntDivider(unsigned int d) : divisor(d) {
|
| 87 |
+
assert(divisor >= 1 && divisor <= INT32_MAX);
|
| 88 |
+
|
| 89 |
+
// TODO: gcc/clang has __builtin_clz() but it's not portable.
|
| 90 |
+
for (shift = 0; shift < 32; shift++) if ((1U << shift) >= divisor) break;
|
| 91 |
+
|
| 92 |
+
uint64_t one = 1;
|
| 93 |
+
uint64_t magic = ((one << 32) * ((one << shift) - divisor)) / divisor + 1;
|
| 94 |
+
m1 = magic;
|
| 95 |
+
assert(m1 > 0 && m1 == magic); // m1 must fit in 32 bits.
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
C10_HOST_DEVICE inline unsigned int div(unsigned int n) const {
|
| 99 |
+
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
|
| 100 |
+
// 't' is the higher 32-bits of unsigned 32-bit multiplication of 'n' and
|
| 101 |
+
// 'm1'.
|
| 102 |
+
unsigned int t = __umulhi(n, m1);
|
| 103 |
+
return (t + n) >> shift;
|
| 104 |
+
#else
|
| 105 |
+
// Using uint64_t so that the addition does not overflow.
|
| 106 |
+
uint64_t t = ((uint64_t) n * m1) >> 32;
|
| 107 |
+
return (t + n) >> shift;
|
| 108 |
+
#endif
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
C10_HOST_DEVICE inline unsigned int mod(unsigned int n) const {
|
| 112 |
+
return n - div(n) * divisor;
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
C10_HOST_DEVICE inline DivMod<unsigned int> divmod(unsigned int n) const {
|
| 116 |
+
unsigned int q = div(n);
|
| 117 |
+
return DivMod<unsigned int>(q, n - q * divisor);
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
unsigned int divisor; // d above.
|
| 121 |
+
unsigned int m1; // Magic number: m' above.
|
| 122 |
+
unsigned int shift; // Shift amounts.
|
| 123 |
+
};
|
| 124 |
+
|
| 125 |
+
} // namespace at::cuda::detail
|
| 126 |
+
|
| 127 |
+
#else
|
| 128 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 129 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/KernelUtils.h
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <limits>
|
| 5 |
+
#include <c10/util/Exception.h>
|
| 6 |
+
|
| 7 |
+
namespace at::cuda::detail {
|
| 8 |
+
|
| 9 |
+
// CUDA: grid stride looping
|
| 10 |
+
//
|
| 11 |
+
// int64_t _i_n_d_e_x specifically prevents overflow in the loop increment.
|
| 12 |
+
// If input.numel() < INT_MAX, _i_n_d_e_x < INT_MAX, except after the final
|
| 13 |
+
// iteration of the loop where _i_n_d_e_x += blockDim.x * gridDim.x can be
|
| 14 |
+
// greater than INT_MAX. But in that case _i_n_d_e_x >= n, so there are no
|
| 15 |
+
// further iterations and the overflowed value in i=_i_n_d_e_x is not used.
|
| 16 |
+
#define CUDA_KERNEL_LOOP_TYPE(i, n, index_type) \
|
| 17 |
+
int64_t _i_n_d_e_x = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x; \
|
| 18 |
+
for (index_type i=_i_n_d_e_x; _i_n_d_e_x < (n); _i_n_d_e_x+=blockDim.x * gridDim.x, i=_i_n_d_e_x)
|
| 19 |
+
|
| 20 |
+
#define CUDA_KERNEL_LOOP(i, n) CUDA_KERNEL_LOOP_TYPE(i, n, int)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
// Use 1024 threads per block, which requires cuda sm_2x or above
|
| 24 |
+
constexpr int CUDA_NUM_THREADS = 1024;
|
| 25 |
+
|
| 26 |
+
// CUDA: number of blocks for threads.
|
| 27 |
+
inline int GET_BLOCKS(const int64_t N, const int64_t max_threads_per_block=CUDA_NUM_THREADS) {
|
| 28 |
+
TORCH_INTERNAL_ASSERT(N > 0, "CUDA kernel launch blocks must be positive, but got N=", N);
|
| 29 |
+
constexpr int64_t max_int = std::numeric_limits<int>::max();
|
| 30 |
+
|
| 31 |
+
// Round up division for positive number that cannot cause integer overflow
|
| 32 |
+
auto block_num = (N - 1) / max_threads_per_block + 1;
|
| 33 |
+
TORCH_INTERNAL_ASSERT(block_num <= max_int, "Can't schedule too many blocks on CUDA device");
|
| 34 |
+
|
| 35 |
+
return static_cast<int>(block_num);
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
} // namespace at::cuda::detail
|
| 39 |
+
|
| 40 |
+
#else
|
| 41 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 42 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/LazyNVRTC.h
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
#include <ATen/detail/CUDAHooksInterface.h>
|
| 4 |
+
namespace at::cuda {
|
| 5 |
+
// Forward-declares at::cuda::NVRTC
|
| 6 |
+
struct NVRTC;
|
| 7 |
+
|
| 8 |
+
namespace detail {
|
| 9 |
+
extern NVRTC lazyNVRTC;
|
| 10 |
+
} // namespace detail
|
| 11 |
+
|
| 12 |
+
} // namespace at::cuda
|
| 13 |
+
|
| 14 |
+
#else
|
| 15 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 16 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/OffsetCalculator.cuh
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <array>
|
| 5 |
+
#include <cstdint>
|
| 6 |
+
#include <type_traits>
|
| 7 |
+
#include <c10/macros/Macros.h>
|
| 8 |
+
#include <ATen/native/TensorIterator.h>
|
| 9 |
+
#include <ATen/cuda/detail/IntegerDivider.cuh>
|
| 10 |
+
|
| 11 |
+
// If element_sizes is nullptr, then the strides will be in bytes, otherwise
|
| 12 |
+
// the strides will be in # of elements.
|
| 13 |
+
// Operands that share the same shape, but may have different strides.
|
| 14 |
+
// OffsetCalculator iterates the tensor in a column-major order
|
| 15 |
+
|
| 16 |
+
#if defined(USE_ROCM)
|
| 17 |
+
constexpr int MAX_DIMS = 16;
|
| 18 |
+
#else
|
| 19 |
+
constexpr int MAX_DIMS = 25;
|
| 20 |
+
#endif
|
| 21 |
+
|
| 22 |
+
template <int NARGS, typename index_t = uint32_t, bool signed_strides = false>
|
| 23 |
+
struct OffsetCalculator {
|
| 24 |
+
// We allow having negative strides to implement some operations like torch.flip
|
| 25 |
+
using stride_t = std::conditional_t<signed_strides,
|
| 26 |
+
std::make_signed_t<index_t>,
|
| 27 |
+
index_t>;
|
| 28 |
+
// The offset for each argument. Wrapper around fixed-size array.
|
| 29 |
+
// On CUDA, zero sized array is not allowed, so when we are handling nullary
|
| 30 |
+
// operators, we need to create a size 1 offset to avoid compiler failure.
|
| 31 |
+
// This size 1 offset is just a placeholder, and we will not use it.
|
| 32 |
+
using offset_type = std::array<stride_t, std::max<int>(NARGS, 1)>;
|
| 33 |
+
|
| 34 |
+
// if element_sizes is nullptr, then the strides will be in bytes, otherwise
|
| 35 |
+
// the strides will be in # of elements.
|
| 36 |
+
OffsetCalculator(int dims, const int64_t* sizes, const int64_t* const* strides, const int64_t* element_sizes=nullptr) : dims(dims) {
|
| 37 |
+
TORCH_CHECK(dims <= MAX_DIMS, "tensor has too many (>", MAX_DIMS, ") dims");
|
| 38 |
+
for (int i=0; i < dims; i++){
|
| 39 |
+
sizes_[i] = at::cuda::detail::IntDivider<index_t>(sizes[i]);
|
| 40 |
+
for (int arg = 0; arg < NARGS; arg++) {
|
| 41 |
+
int64_t element_size = (element_sizes == nullptr ? 1LL : element_sizes[arg]);
|
| 42 |
+
strides_[i][arg] = strides[arg][i] / element_size;
|
| 43 |
+
}
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
C10_HOST_DEVICE offset_type get(index_t linear_idx) const {
|
| 48 |
+
offset_type offsets;
|
| 49 |
+
|
| 50 |
+
#if defined(USE_ROCM)
|
| 51 |
+
if ((dims > 0) && (dims <= 2)) {
|
| 52 |
+
auto divmod = sizes_[0].divmod(linear_idx);
|
| 53 |
+
#pragma unroll
|
| 54 |
+
for (int arg = 0; arg < NARGS; arg++)
|
| 55 |
+
offsets[arg] = divmod.mod * strides_[0][arg];
|
| 56 |
+
if (dims >= 2) {
|
| 57 |
+
divmod = sizes_[1].divmod(divmod.div);
|
| 58 |
+
#pragma unroll
|
| 59 |
+
for (int arg = 0; arg < NARGS; arg++)
|
| 60 |
+
offsets[arg] += divmod.mod * strides_[1][arg];
|
| 61 |
+
}
|
| 62 |
+
// [...]
|
| 63 |
+
return offsets;
|
| 64 |
+
}
|
| 65 |
+
#endif
|
| 66 |
+
|
| 67 |
+
#pragma unroll
|
| 68 |
+
for (int arg = 0; arg < NARGS; arg++) {
|
| 69 |
+
offsets[arg] = 0;
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
#pragma unroll
|
| 73 |
+
for (int dim = 0; dim < MAX_DIMS; ++dim) {
|
| 74 |
+
if (dim == dims) {
|
| 75 |
+
break;
|
| 76 |
+
}
|
| 77 |
+
auto divmod = sizes_[dim].divmod(linear_idx);
|
| 78 |
+
linear_idx = divmod.div;
|
| 79 |
+
|
| 80 |
+
#pragma unroll
|
| 81 |
+
for (int arg = 0; arg < NARGS; arg++) {
|
| 82 |
+
offsets[arg] += divmod.mod * strides_[dim][arg];
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
}
|
| 86 |
+
return offsets;
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
int dims;
|
| 90 |
+
at::cuda::detail::IntDivider<index_t> sizes_[MAX_DIMS];
|
| 91 |
+
stride_t strides_[MAX_DIMS][std::max<int>(NARGS, 1)];
|
| 92 |
+
};
|
| 93 |
+
|
| 94 |
+
template <int NARGS, typename index_t = uint32_t>
|
| 95 |
+
struct TrivialOffsetCalculator {
|
| 96 |
+
// The offset for each argument. Wrapper around fixed-size array.
|
| 97 |
+
// The offsets are in # of elements, not in bytes.
|
| 98 |
+
// On CUDA, zero sized array is not allowed, so when we are handling nullary
|
| 99 |
+
// operators, we need to create a size 1 offset to avoid compiler failure.
|
| 100 |
+
// This size 1 offset is just a placeholder, and we will not use it.
|
| 101 |
+
using offset_type = std::array<index_t, std::max<int>(NARGS, 1)>;
|
| 102 |
+
|
| 103 |
+
C10_HOST_DEVICE offset_type get(index_t linear_idx) const {
|
| 104 |
+
offset_type offsets;
|
| 105 |
+
#pragma unroll
|
| 106 |
+
for (int arg = 0; arg < NARGS; arg++) {
|
| 107 |
+
offsets[arg] = linear_idx;
|
| 108 |
+
}
|
| 109 |
+
return offsets;
|
| 110 |
+
}
|
| 111 |
+
};
|
| 112 |
+
|
| 113 |
+
// Make an OffsetCalculator with byte offsets
|
| 114 |
+
template<int N, bool signed_strides = false>
|
| 115 |
+
static OffsetCalculator<N, uint32_t, signed_strides> make_offset_calculator(const at::TensorIteratorBase& iter) {
|
| 116 |
+
TORCH_INTERNAL_ASSERT(N <= iter.ntensors());
|
| 117 |
+
std::array<const int64_t*, N> strides;
|
| 118 |
+
for (int i = 0; i < N; i++) {
|
| 119 |
+
strides[i] = iter.strides(i).data();
|
| 120 |
+
}
|
| 121 |
+
return OffsetCalculator<N, uint32_t, signed_strides>(iter.ndim(), iter.shape().data(), strides.data());
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
// Make an OffsetCalculator with element offsets
|
| 125 |
+
template<int N, bool signed_strides = false>
|
| 126 |
+
static OffsetCalculator<N, uint32_t, signed_strides> make_element_offset_calculator(
|
| 127 |
+
const at::TensorIteratorBase& iter) {
|
| 128 |
+
TORCH_INTERNAL_ASSERT(N <= iter.ntensors());
|
| 129 |
+
std::array<const int64_t*, N> strides;
|
| 130 |
+
std::array<int64_t, N> element_sizes;
|
| 131 |
+
for (int i = 0; i < N; i++) {
|
| 132 |
+
strides[i] = iter.strides(i).data();
|
| 133 |
+
element_sizes[i] = iter.element_size(i);
|
| 134 |
+
}
|
| 135 |
+
return OffsetCalculator<N, uint32_t, signed_strides>(
|
| 136 |
+
iter.ndim(), iter.shape().data(), strides.data(), element_sizes.data());
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
#else
|
| 140 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 141 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/PhiloxCudaStateRaw.cuh
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
// No "#pragma once" because this is a raw definition that can be copied by jit codegen.
|
| 3 |
+
// Eager mode clients should not include this file directly, instead,
|
| 4 |
+
// they should #include <ATen/cuda/PhiloxCudaState.h>, which has a #pragma once.
|
| 5 |
+
|
| 6 |
+
// Stores RNG state values. Passed as a kernel argument.
|
| 7 |
+
// See Note [CUDA Graph-safe RNG states].
|
| 8 |
+
//
|
| 9 |
+
// The raw definition lives in its own file so jit codegen can easily copy it.
|
| 10 |
+
namespace at {
|
| 11 |
+
|
| 12 |
+
struct PhiloxCudaState {
|
| 13 |
+
PhiloxCudaState() = default;
|
| 14 |
+
// Called if graph capture is not underway
|
| 15 |
+
PhiloxCudaState(uint64_t seed,
|
| 16 |
+
uint64_t offset) {
|
| 17 |
+
seed_.val = seed;
|
| 18 |
+
offset_.val = offset;
|
| 19 |
+
}
|
| 20 |
+
// Called if graph capture is underway
|
| 21 |
+
PhiloxCudaState(int64_t* seed,
|
| 22 |
+
int64_t* offset_extragraph,
|
| 23 |
+
uint64_t offset_intragraph) {
|
| 24 |
+
seed_.ptr = seed;
|
| 25 |
+
offset_.ptr = offset_extragraph;
|
| 26 |
+
offset_intragraph_ = offset_intragraph;
|
| 27 |
+
captured_ = true;
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
// Public members, directly accessible by at::cuda::philox::unpack.
|
| 31 |
+
// If we made them private with getters/setters, the getters/setters
|
| 32 |
+
// would have to be __device__, and we can't declare __device__ in ATen.
|
| 33 |
+
union Payload {
|
| 34 |
+
uint64_t val;
|
| 35 |
+
int64_t* ptr;
|
| 36 |
+
};
|
| 37 |
+
|
| 38 |
+
Payload seed_{};
|
| 39 |
+
Payload offset_{};
|
| 40 |
+
uint64_t offset_intragraph_ = 0;
|
| 41 |
+
bool captured_ = false;
|
| 42 |
+
};
|
| 43 |
+
|
| 44 |
+
} // namespace at
|
| 45 |
+
|
| 46 |
+
#else
|
| 47 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 48 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/TensorInfo.cuh
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/CollapseDims.h>
|
| 5 |
+
|
| 6 |
+
namespace at::cuda::detail {
|
| 7 |
+
|
| 8 |
+
#define MAX_TENSORINFO_DIMS 25
|
| 9 |
+
|
| 10 |
+
// CUDA kernel argument that defines tensor layout
|
| 11 |
+
template <typename T, typename IndexType>
|
| 12 |
+
struct TensorInfo {
|
| 13 |
+
TensorInfo();
|
| 14 |
+
TensorInfo(T* p,
|
| 15 |
+
int dim,
|
| 16 |
+
IndexType sz[MAX_TENSORINFO_DIMS],
|
| 17 |
+
IndexType st[MAX_TENSORINFO_DIMS]);
|
| 18 |
+
|
| 19 |
+
// Set the size of the given dimension to 1, as if it were a
|
| 20 |
+
// reduction dim (allows you to calculate offsets of the reduction
|
| 21 |
+
// slice)
|
| 22 |
+
void reduceDim(int dim);
|
| 23 |
+
|
| 24 |
+
// See note on [collapse dims].
|
| 25 |
+
int collapseDims(const int excludeDim = -1);
|
| 26 |
+
|
| 27 |
+
// Contiguous tensors of more than one dimension are collapsed down
|
| 28 |
+
// to one tensor
|
| 29 |
+
__host__ __device__ inline bool isContiguous() const {
|
| 30 |
+
return (dims == 1 && strides[0] == 1);
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
T* data;
|
| 34 |
+
IndexType sizes[MAX_TENSORINFO_DIMS];
|
| 35 |
+
IndexType strides[MAX_TENSORINFO_DIMS];
|
| 36 |
+
int dims;
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
template <typename T, typename IndexType>
|
| 40 |
+
TensorInfo<T, IndexType>::TensorInfo() {
|
| 41 |
+
data = nullptr;
|
| 42 |
+
dims = 0;
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
template <typename T, typename IndexType>
|
| 46 |
+
TensorInfo<T, IndexType>::TensorInfo(T* p,
|
| 47 |
+
int dim,
|
| 48 |
+
IndexType sz[MAX_TENSORINFO_DIMS],
|
| 49 |
+
IndexType st[MAX_TENSORINFO_DIMS]) {
|
| 50 |
+
data = p;
|
| 51 |
+
dims = dim;
|
| 52 |
+
TORCH_CHECK(dims < MAX_TENSORINFO_DIMS, "CUDA Tensors cannot have more than 25 dimensions");
|
| 53 |
+
|
| 54 |
+
for (int i = 0; i < dim; ++i) {
|
| 55 |
+
sizes[i] = sz[i];
|
| 56 |
+
strides[i] = st[i];
|
| 57 |
+
}
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
template <typename T, typename IndexType>
|
| 61 |
+
void
|
| 62 |
+
TensorInfo<T, IndexType>::reduceDim(int dim) {
|
| 63 |
+
TORCH_CHECK(dim < dims && dim >= 0, "expected dim between 0 and dims - 1");
|
| 64 |
+
sizes[dim] = 1;
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
template <typename T, typename IndexType>
|
| 68 |
+
int
|
| 69 |
+
TensorInfo<T, IndexType>::collapseDims(const int excludeDim) {
|
| 70 |
+
auto result = at::collapse_dims(sizes, strides, dims, excludeDim);
|
| 71 |
+
dims = std::get<1>(result);
|
| 72 |
+
return std::get<0>(result);
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
// Translate a linear index for the apply to a T* offset;
|
| 76 |
+
// specialized on `Dims` to reduce nvcc compilation time
|
| 77 |
+
template <typename T, typename IndexType, int Dims>
|
| 78 |
+
struct IndexToOffset {
|
| 79 |
+
static __host__ __device__ IndexType get(
|
| 80 |
+
IndexType linearId,
|
| 81 |
+
const TensorInfo<T, IndexType>& info) {
|
| 82 |
+
|
| 83 |
+
IndexType offset = 0;
|
| 84 |
+
|
| 85 |
+
// Uses static dims
|
| 86 |
+
for (int i = Dims - 1; i > 0; --i) {
|
| 87 |
+
IndexType curDimIndex = linearId % info.sizes[i];
|
| 88 |
+
IndexType curDimOffset = curDimIndex * info.strides[i];
|
| 89 |
+
offset += curDimOffset;
|
| 90 |
+
linearId /= info.sizes[i];
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
return offset + linearId * info.strides[0];
|
| 94 |
+
}
|
| 95 |
+
};
|
| 96 |
+
|
| 97 |
+
// Uses dynamic (runtime) instead of static (compile time) dims
|
| 98 |
+
template <typename T, typename IndexType>
|
| 99 |
+
struct IndexToOffset<T, IndexType, -1> {
|
| 100 |
+
static inline __host__ __device__ IndexType get(
|
| 101 |
+
IndexType linearId,
|
| 102 |
+
const TensorInfo<T, IndexType>& info) {
|
| 103 |
+
|
| 104 |
+
IndexType offset = 0;
|
| 105 |
+
|
| 106 |
+
for (int i = info.dims - 1; i > 0; --i) {
|
| 107 |
+
IndexType curDimIndex = linearId % info.sizes[i];
|
| 108 |
+
IndexType curDimOffset = curDimIndex * info.strides[i];
|
| 109 |
+
offset += curDimOffset;
|
| 110 |
+
linearId /= info.sizes[i];
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
return offset + linearId * info.strides[0];
|
| 114 |
+
}
|
| 115 |
+
};
|
| 116 |
+
|
| 117 |
+
} // namespace at::cuda::detail
|
| 118 |
+
|
| 119 |
+
#else
|
| 120 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 121 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/UnpackRaw.cuh
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
// No "#pragma once" because this is a raw definition that can be copied by jit codegen.
|
| 3 |
+
// Eager mode clients should not include this file directly, instead,
|
| 4 |
+
// they should #include <ATen/cuda/PhiloxUtils.cuh>, which has a #pragma once.
|
| 5 |
+
|
| 6 |
+
namespace at::cuda::philox {
|
| 7 |
+
|
| 8 |
+
// In-kernel call to retrieve philox seed and offset from a PhiloxCudaState instance whether
|
| 9 |
+
// that instance was created with graph capture underway or not.
|
| 10 |
+
// See Note [CUDA Graph-safe RNG states].
|
| 11 |
+
//
|
| 12 |
+
// We can't write a __device__ function in CUDAGeneratorImpl.h, because it's in ATen.
|
| 13 |
+
// Also, whatever call unpacks PhiloxCudaState in consumer kernels must be inlineable.
|
| 14 |
+
// Easiest thing that comes to mind is, define a __device__ unpack helper here, in ATen/cuda.
|
| 15 |
+
//
|
| 16 |
+
// The raw definition lives in its own file so jit codegen can easily copy it.
|
| 17 |
+
__host__ __device__ __forceinline__ std::tuple<uint64_t, uint64_t>
|
| 18 |
+
unpack(at::PhiloxCudaState arg) {
|
| 19 |
+
if (arg.captured_) {
|
| 20 |
+
// static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long".
|
| 21 |
+
// *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel.
|
| 22 |
+
// For most threads' reads it will hit in cache, so it shouldn't hurt performance.
|
| 23 |
+
return std::make_tuple(static_cast<uint64_t>(*arg.seed_.ptr), static_cast<uint64_t>(*(arg.offset_.ptr) + arg.offset_intragraph_));
|
| 24 |
+
} else {
|
| 25 |
+
return std::make_tuple(arg.seed_.val, arg.offset_.val);
|
| 26 |
+
}
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
// Adapted from TE
|
| 30 |
+
// extract seed and offset from PhiloxCudaState
|
| 31 |
+
__global__ void unpack_cudnn(at::PhiloxCudaState arg, int64_t* seed_ptr, int64_t* offset_ptr);
|
| 32 |
+
|
| 33 |
+
void unpack_cudnn_wrapper(at::PhiloxCudaState arg, int64_t* seed_ptr, int64_t* offset_ptr, cudaStream_t stream);
|
| 34 |
+
|
| 35 |
+
} // namespace at::cuda::philox
|
| 36 |
+
|
| 37 |
+
#else
|
| 38 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 39 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/GemmCommon.h
ADDED
|
@@ -0,0 +1,705 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
// Original TunableOp is from onnxruntime.
|
| 3 |
+
// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
|
| 4 |
+
// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
|
| 5 |
+
// Copyright (c) Microsoft Corporation.
|
| 6 |
+
// Licensed under the MIT license.
|
| 7 |
+
//
|
| 8 |
+
// Adapting TunableOp into PyTorch
|
| 9 |
+
// Copyright (c) Advanced Micro Devices, Inc.
|
| 10 |
+
//
|
| 11 |
+
#pragma once
|
| 12 |
+
|
| 13 |
+
#include <string>
|
| 14 |
+
#include <c10/core/ScalarType.h>
|
| 15 |
+
|
| 16 |
+
#include <ATen/cuda/tunable/TunableOp.h>
|
| 17 |
+
#include <ATen/cuda/tunable/Tunable.h>
|
| 18 |
+
#include <ATen/cuda/CUDABlas.h>
|
| 19 |
+
#include <ATen/cuda/Exceptions.h>
|
| 20 |
+
#include <c10/util/StringUtil.h>
|
| 21 |
+
|
| 22 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 23 |
+
#include <ATen/Functions.h>
|
| 24 |
+
#include <ATen/NativeFunctions.h>
|
| 25 |
+
#else
|
| 26 |
+
#include <ATen/ops/allclose.h>
|
| 27 |
+
#include <ATen/ops/from_blob.h>
|
| 28 |
+
#endif
|
| 29 |
+
#include <ATen/OpMathType.h>
|
| 30 |
+
#include <fmt/printf.h>
|
| 31 |
+
|
| 32 |
+
namespace at::cuda::tunable {
|
| 33 |
+
|
| 34 |
+
using at::blas::ScalingType;
|
| 35 |
+
|
| 36 |
+
enum class BlasOp {
|
| 37 |
+
N = 0,
|
| 38 |
+
T = 1
|
| 39 |
+
};
|
| 40 |
+
|
| 41 |
+
inline char BlasOpToString(BlasOp op) {
|
| 42 |
+
switch (op) {
|
| 43 |
+
case BlasOp::N:
|
| 44 |
+
return 'N';
|
| 45 |
+
case BlasOp::T:
|
| 46 |
+
return 'T';
|
| 47 |
+
}
|
| 48 |
+
TORCH_CHECK(false, "unrecognized BlasOp");
|
| 49 |
+
return 'N';
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
template <typename T>
|
| 53 |
+
inline const char* BLASTypeName(T v) {
|
| 54 |
+
return "unknown";
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
template <>
|
| 58 |
+
inline const char* BLASTypeName(float v) {
|
| 59 |
+
return "f32_r";
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
template <>
|
| 63 |
+
inline const char* BLASTypeName(double v) {
|
| 64 |
+
return "f64_r";
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
template <>
|
| 68 |
+
inline const char* BLASTypeName(BFloat16 v) {
|
| 69 |
+
return "bf16_r";
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
template <>
|
| 73 |
+
inline const char* BLASTypeName(Half v) {
|
| 74 |
+
return "f16_r";
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
//https://github.com/ROCm/hipBLASLt/blob/develop/library/src/include/auxiliary.hpp#L175
|
| 78 |
+
template <>
|
| 79 |
+
inline const char* BLASTypeName(Float8_e4m3fn v) {
|
| 80 |
+
return "f8_r";
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
template <>
|
| 84 |
+
inline const char* BLASTypeName(Float8_e5m2 v) {
|
| 85 |
+
return "bf8_r";
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
template <>
|
| 89 |
+
inline const char* BLASTypeName(Float8_e4m3fnuz v) {
|
| 90 |
+
return "f8_fnuz_r";
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
template <>
|
| 94 |
+
inline const char* BLASTypeName(Float8_e5m2fnuz v) {
|
| 95 |
+
return "bf8_fnuz_r";
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
template <>
|
| 99 |
+
inline const char* BLASTypeName(c10::complex<double> v) {
|
| 100 |
+
return "f64_r";
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
template <>
|
| 104 |
+
inline const char* BLASTypeName(c10::complex<float> v) {
|
| 105 |
+
return "f32_r";
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
inline std::string ScalarTypeToBLASType(c10::ScalarType scalar_type) {
|
| 109 |
+
std::string BLASType;
|
| 110 |
+
switch (scalar_type) {
|
| 111 |
+
case c10::ScalarType::Float:{
|
| 112 |
+
BLASType = "f32_r";
|
| 113 |
+
break;
|
| 114 |
+
}
|
| 115 |
+
case c10::ScalarType::Double:{
|
| 116 |
+
BLASType = "f64_r";
|
| 117 |
+
break;
|
| 118 |
+
}
|
| 119 |
+
case c10::ScalarType::BFloat16:{
|
| 120 |
+
BLASType = "bf16_r";
|
| 121 |
+
break;
|
| 122 |
+
}
|
| 123 |
+
case c10::ScalarType::Half: {
|
| 124 |
+
BLASType = "f16_r";
|
| 125 |
+
break;
|
| 126 |
+
}
|
| 127 |
+
case c10::ScalarType::Float8_e4m3fn: {
|
| 128 |
+
BLASType = "f8_r";
|
| 129 |
+
break;
|
| 130 |
+
}
|
| 131 |
+
case c10::ScalarType::Float8_e5m2: {
|
| 132 |
+
BLASType = "bf8_r";
|
| 133 |
+
break;
|
| 134 |
+
}
|
| 135 |
+
case c10::ScalarType::Float8_e4m3fnuz: {
|
| 136 |
+
BLASType = "f8_fnuz_r";
|
| 137 |
+
break;
|
| 138 |
+
}
|
| 139 |
+
case c10::ScalarType::Float8_e5m2fnuz: {
|
| 140 |
+
BLASType = "bf8_fnuz_r";
|
| 141 |
+
break;
|
| 142 |
+
}
|
| 143 |
+
case c10::ScalarType::ComplexFloat:{
|
| 144 |
+
BLASType = "f32_c";
|
| 145 |
+
break;
|
| 146 |
+
}
|
| 147 |
+
case c10::ScalarType::ComplexDouble:{
|
| 148 |
+
BLASType = "f64_c";
|
| 149 |
+
break;
|
| 150 |
+
}
|
| 151 |
+
default:
|
| 152 |
+
BLASType = "unknown";
|
| 153 |
+
}
|
| 154 |
+
return BLASType;
|
| 155 |
+
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
// Similar to Compute Type in GemmRocblas.h
|
| 159 |
+
template <typename T>
|
| 160 |
+
inline std::string ComputeTypeFor() {
|
| 161 |
+
return "Unknown ComputeType";
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
// This is a union of the compute types for
|
| 165 |
+
// ROCBLAS and hipBLASLt.
|
| 166 |
+
template <>
|
| 167 |
+
inline std::string ComputeTypeFor<float>() {
|
| 168 |
+
if (at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) != at::Float32Precision::TF32) {
|
| 169 |
+
return "f32_r";
|
| 170 |
+
} else {
|
| 171 |
+
return "xf32_r";
|
| 172 |
+
}
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
template <>
|
| 176 |
+
inline std::string ComputeTypeFor<double>() {
|
| 177 |
+
return "f64_r";
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
template <>
|
| 181 |
+
inline std::string ComputeTypeFor<Half>() {
|
| 182 |
+
return "f32_r";
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
template <>
|
| 186 |
+
inline std::string ComputeTypeFor<BFloat16>() {
|
| 187 |
+
return "f32_r";
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
template <>
|
| 191 |
+
inline std::string ComputeTypeFor<c10::complex<float>>() {
|
| 192 |
+
return "f32_c";
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
template <>
|
| 196 |
+
inline std::string ComputeTypeFor<c10::complex<double>>() {
|
| 197 |
+
return "f64_c";
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
template <>
|
| 201 |
+
inline std::string ComputeTypeFor<Float8_e4m3fn>() {
|
| 202 |
+
return "f32_r";
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
template <>
|
| 206 |
+
inline std::string ComputeTypeFor<Float8_e5m2>() {
|
| 207 |
+
return "f32_r";
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
template <>
|
| 211 |
+
inline std::string ComputeTypeFor<Float8_e4m3fnuz>() {
|
| 212 |
+
return "f32_r";
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
template <>
|
| 216 |
+
inline std::string ComputeTypeFor<Float8_e5m2fnuz>() {
|
| 217 |
+
return "f32_r";
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
// Convert opmath_type<T> to string
|
| 221 |
+
template <typename T>
|
| 222 |
+
inline std::string to_string_opmath(const at::opmath_type<T>& value) {
|
| 223 |
+
if constexpr (std::is_same_v<at::opmath_type<T>, c10::complex<float>> ||
|
| 224 |
+
std::is_same_v<at::opmath_type<T>, c10::complex<double>>) {
|
| 225 |
+
return fmt::format("({:.4f}, {:.4f})", value.real(), value.imag());
|
| 226 |
+
} else {
|
| 227 |
+
return fmt::format("{:.4f}", value);
|
| 228 |
+
}
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
// convert activation epilogue to string
|
| 232 |
+
inline std::string to_string_epilogue(const at::cuda::blas::GEMMAndBiasActivationEpilogue& value) {
|
| 233 |
+
switch (value) {
|
| 234 |
+
case at::cuda::blas::GEMMAndBiasActivationEpilogue::None:
|
| 235 |
+
return std::string("None");
|
| 236 |
+
break;
|
| 237 |
+
case at::cuda::blas::GEMMAndBiasActivationEpilogue::RELU:
|
| 238 |
+
return std::string("RELU");
|
| 239 |
+
break;
|
| 240 |
+
case cuda::blas::GEMMAndBiasActivationEpilogue::GELU:
|
| 241 |
+
return std::string("GELU");
|
| 242 |
+
break;
|
| 243 |
+
default:
|
| 244 |
+
return std::string("unknown");
|
| 245 |
+
}
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
namespace detail {
|
| 249 |
+
|
| 250 |
+
static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size, const NumericalCheckConfig& config) {
|
| 251 |
+
|
| 252 |
+
if (!config.enabled) {
|
| 253 |
+
return true; // skip when disabled
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA);
|
| 257 |
+
at::Tensor ref = at::from_blob(c, {size}, options);
|
| 258 |
+
at::Tensor oth = at::from_blob(other_c, {size}, options);
|
| 259 |
+
at::Tensor ref_float = ref.to(at::kFloat);
|
| 260 |
+
at::Tensor oth_float = oth.to(at::kFloat);
|
| 261 |
+
|
| 262 |
+
const bool ok = at::allclose(ref_float, oth_float, config.rtol, config.atol);
|
| 263 |
+
if (ok) {
|
| 264 |
+
TUNABLE_LOG3("├──verify numerics: PASSED with atol=", config.atol, ", rtol=", config.rtol);
|
| 265 |
+
} else {
|
| 266 |
+
TUNABLE_LOG3("├──verify numerics: FAILED with atol=", config.atol, ", rtol=", config.rtol);
|
| 267 |
+
}
|
| 268 |
+
return ok;
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
// Note on GetSizeA et al.
|
| 274 |
+
// Tensors can be dense or arbitrarily strided. We only need our copies to be large enough.
|
| 275 |
+
// Our copies must be at least as large as the m n k shapes dictate, but could be larger
|
| 276 |
+
// depending on the lda ldb ldc values. Similarly for the batched case.
|
| 277 |
+
|
| 278 |
+
template <typename T>
|
| 279 |
+
struct GemmParams : OpParams {
|
| 280 |
+
GemmParams() = default;
|
| 281 |
+
|
| 282 |
+
std::string BLASSignature() const override {
|
| 283 |
+
std::string alpha_str = to_string_opmath<T>(alpha);
|
| 284 |
+
std::string beta_str = to_string_opmath<T>(beta);
|
| 285 |
+
return fmt::sprintf("- { function: matmul, M: %ld, N: %ld, K: %ld, lda: %ld, ldb: %ld, ldc: %ld, ldd: %ld, stride_a: 0, stride_b: 0, stride_c: 0, stride_d: 0, "
|
| 286 |
+
"alpha: %s, beta: %s, transA: %c, transB: %c, batch_count: 1, a_type: %s, b_type: %s, c_type: %s, d_type: %s, scale_type: %s, bias_type: %s, compute_type: %s }",
|
| 287 |
+
m, n, k, lda, ldb, ldc, ldc, alpha_str, beta_str, transa, transb,
|
| 288 |
+
BLASTypeName<T>(T{}), BLASTypeName<T>(T{}), BLASTypeName<T>(T{}), BLASTypeName<T>(T{}), ComputeTypeFor<T>(), ComputeTypeFor<T>(), ComputeTypeFor<T>());
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
std::string Signature() const override {
|
| 292 |
+
return fmt::sprintf("%c%c_%ld_%ld_%ld_ld_%ld_%ld_%ld", transa, transb, m, n, k, lda, ldb, ldc);
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
size_t GetSizeA() const {
|
| 296 |
+
size_t size_stride = lda * ((transa == 'n' || transa == 'N') ? k : m);
|
| 297 |
+
size_t size_dense = m * k;
|
| 298 |
+
return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense);
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
size_t GetSizeB() const {
|
| 302 |
+
size_t size_stride = ldb * ((transb == 'n' || transb == 'N') ? n : k);
|
| 303 |
+
size_t size_dense = k * n;
|
| 304 |
+
return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense);
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
size_t GetSizeC() const {
|
| 308 |
+
size_t size_stride = ldc * n;
|
| 309 |
+
size_t size_dense = m * n;
|
| 310 |
+
return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense);
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
size_t GetSize(bool duplicate_inputs) const {
|
| 314 |
+
size_t size = GetSizeC();
|
| 315 |
+
if (duplicate_inputs) {
|
| 316 |
+
size += GetSizeA();
|
| 317 |
+
size += GetSizeB();
|
| 318 |
+
}
|
| 319 |
+
return size;
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
GemmParams* DeepCopy(bool duplicate_inputs) const {
|
| 323 |
+
GemmParams* copy = new GemmParams;
|
| 324 |
+
*copy = *this;
|
| 325 |
+
c10::DeviceIndex device = 0;
|
| 326 |
+
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
|
| 327 |
+
size_t c_size = GetSizeC();
|
| 328 |
+
copy->c = static_cast<T*>(c10::cuda::CUDACachingAllocator::raw_alloc(c_size));
|
| 329 |
+
AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
|
| 330 |
+
copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
|
| 331 |
+
if (duplicate_inputs) {
|
| 332 |
+
size_t a_size = GetSizeA();
|
| 333 |
+
size_t b_size = GetSizeB();
|
| 334 |
+
copy->a = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(a_size));
|
| 335 |
+
copy->b = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(b_size));
|
| 336 |
+
copy->duplicate_inputs_ = true;
|
| 337 |
+
}
|
| 338 |
+
return copy;
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
// only call on object returned by DeepCopy
|
| 342 |
+
void Delete() {
|
| 343 |
+
c10::cuda::CUDACachingAllocator::raw_delete(c);
|
| 344 |
+
if (duplicate_inputs_) {
|
| 345 |
+
// NOLINTNEXTLINE(*const-cast*)
|
| 346 |
+
c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(a));
|
| 347 |
+
// NOLINTNEXTLINE(*const-cast*)
|
| 348 |
+
c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(b));
|
| 349 |
+
}
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
TuningStatus NumericalCheck(GemmParams<T> *other) {
|
| 353 |
+
auto* ctx = getTuningContext();
|
| 354 |
+
auto cfg = ctx->GetNumericalCheckConfig();
|
| 355 |
+
auto c_dtype = c10::CppTypeToScalarType<T>::value;
|
| 356 |
+
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
char transa{};
|
| 360 |
+
char transb{};
|
| 361 |
+
int64_t m{};
|
| 362 |
+
int64_t n{};
|
| 363 |
+
int64_t k{};
|
| 364 |
+
at::opmath_type<T> alpha;
|
| 365 |
+
const T* a{};
|
| 366 |
+
int64_t lda{};
|
| 367 |
+
const T* b{};
|
| 368 |
+
int64_t ldb{};
|
| 369 |
+
at::opmath_type<T> beta;
|
| 370 |
+
T* c{};
|
| 371 |
+
int64_t ldc{};
|
| 372 |
+
private:
|
| 373 |
+
bool duplicate_inputs_{false};
|
| 374 |
+
};
|
| 375 |
+
|
| 376 |
+
template <typename T>
|
| 377 |
+
struct GemmAndBiasParams : OpParams {
|
| 378 |
+
std::string BLASSignature() const override {
|
| 379 |
+
std::string alpha_str = to_string_opmath<T>(alpha);
|
| 380 |
+
std::string activation_str = to_string_epilogue(activation);
|
| 381 |
+
return fmt::sprintf("- { function: matmul, M: %ld, N: %ld, K: %ld, lda: %ld, ldb: %ld, ldc: %ld, ldd: %ld, stride_a: 0, stride_b: 0, stride_c: 0, stride_d: 0, "
|
| 382 |
+
"alpha: %s, transA: %c, transB: %c, batch_count: 1, a_type: %s, b_type: %s, c_type: %s, d_type: %s, activation: %s, bias_type: %s, scale_type: %s, compute_type: %s }",
|
| 383 |
+
m, n, k, lda, ldb, ldc, ldc, alpha_str, transa, transb,
|
| 384 |
+
BLASTypeName<T>(T{}), BLASTypeName<T>(T{}), BLASTypeName<T>(T{}), BLASTypeName<T>(T{}), activation_str, BLASTypeName<T>(T{}), ComputeTypeFor<T>(), ComputeTypeFor<T>(), ComputeTypeFor<T>());
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
std::string Signature() const override {
|
| 388 |
+
return fmt::sprintf("%c%c_%ld_%ld_%ld_ld_%ld_%ld_%ld", transa, transb, m, n, k, lda, ldb, ldc);
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
size_t GetSizeA() const {
|
| 392 |
+
size_t size_stride = lda * ((transa == 'n' || transa == 'N') ? k : m);
|
| 393 |
+
size_t size_dense = m * k;
|
| 394 |
+
return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense);
|
| 395 |
+
}
|
| 396 |
+
|
| 397 |
+
size_t GetSizeB() const {
|
| 398 |
+
size_t size_stride = ldb * ((transb == 'n' || transb == 'N') ? n : k);
|
| 399 |
+
size_t size_dense = k * n;
|
| 400 |
+
return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense);
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
size_t GetSizeC() const {
|
| 404 |
+
size_t size_stride = ldc * n;
|
| 405 |
+
size_t size_dense = m * n;
|
| 406 |
+
return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense);
|
| 407 |
+
}
|
| 408 |
+
|
| 409 |
+
size_t GetSize(bool duplicate_inputs) const {
|
| 410 |
+
size_t size = GetSizeC();
|
| 411 |
+
if (duplicate_inputs) {
|
| 412 |
+
size += GetSizeA();
|
| 413 |
+
size += GetSizeB();
|
| 414 |
+
}
|
| 415 |
+
return size;
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
GemmAndBiasParams* DeepCopy(bool duplicate_inputs) const {
|
| 419 |
+
GemmAndBiasParams* copy = new GemmAndBiasParams;
|
| 420 |
+
*copy = *this;
|
| 421 |
+
c10::DeviceIndex device = 0;
|
| 422 |
+
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
|
| 423 |
+
size_t c_size = GetSizeC();
|
| 424 |
+
copy->c = static_cast<T*>(c10::cuda::CUDACachingAllocator::raw_alloc(c_size));
|
| 425 |
+
AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
|
| 426 |
+
copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
|
| 427 |
+
if (duplicate_inputs) {
|
| 428 |
+
size_t a_size = GetSizeA();
|
| 429 |
+
size_t b_size = GetSizeB();
|
| 430 |
+
copy->a = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(a_size));
|
| 431 |
+
copy->b = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(b_size));
|
| 432 |
+
copy->duplicate_inputs_ = true;
|
| 433 |
+
}
|
| 434 |
+
return copy;
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
+
// only call on object returned by DeepCopy
|
| 438 |
+
void Delete() {
|
| 439 |
+
c10::cuda::CUDACachingAllocator::raw_delete(c);
|
| 440 |
+
if (duplicate_inputs_) {
|
| 441 |
+
// NOLINTNEXTLINE(*const-cast)
|
| 442 |
+
c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(a));
|
| 443 |
+
// NOLINTNEXTLINE(*const-cast)
|
| 444 |
+
c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(b));
|
| 445 |
+
}
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
TuningStatus NumericalCheck(GemmAndBiasParams<T> *other) {
|
| 449 |
+
auto* ctx = getTuningContext();
|
| 450 |
+
auto cfg = ctx->GetNumericalCheckConfig();
|
| 451 |
+
auto c_dtype = c10::CppTypeToScalarType<T>::value;
|
| 452 |
+
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
|
| 453 |
+
}
|
| 454 |
+
|
| 455 |
+
char transa{};
|
| 456 |
+
char transb{};
|
| 457 |
+
int64_t m{};
|
| 458 |
+
int64_t n{};
|
| 459 |
+
int64_t k{};
|
| 460 |
+
at::opmath_type<T> alpha{};
|
| 461 |
+
const T* a{};
|
| 462 |
+
int64_t lda{};
|
| 463 |
+
const T* b{};
|
| 464 |
+
int64_t ldb{};
|
| 465 |
+
T* c{};
|
| 466 |
+
int64_t ldc{};
|
| 467 |
+
const T* bias{};
|
| 468 |
+
at::cuda::blas::GEMMAndBiasActivationEpilogue activation{};
|
| 469 |
+
private:
|
| 470 |
+
bool duplicate_inputs_{false};
|
| 471 |
+
};
|
| 472 |
+
|
| 473 |
+
template <typename T, typename C_Dtype = T>
|
| 474 |
+
struct GemmStridedBatchedParams : OpParams {
|
| 475 |
+
std::string BLASSignature() const override {
|
| 476 |
+
std::string alpha_str = to_string_opmath<T>(alpha);
|
| 477 |
+
std::string beta_str = to_string_opmath<T>(beta);
|
| 478 |
+
return fmt::sprintf("- { function: matmul, M: %ld, N: %ld, K: %ld, lda: %ld, ldb: %ld, ldc: %ld, ldd: %ld, stride_a: %ld, stride_b: %ld, stride_c: %ld, stride_d: %ld, "
|
| 479 |
+
"alpha: %s, beta: %s, transA: %c, transB: %c, batch_count: %ld, a_type: %s, b_type: %s, c_type: %s, d_type: %s, scale_type: %s, compute_type: %s }",
|
| 480 |
+
m, n, k, lda, ldb, ldc, ldc, stride_a, stride_b, stride_c, stride_c, alpha_str, beta_str, transa, transb, batch,
|
| 481 |
+
BLASTypeName<T>(T{}), BLASTypeName<T>(T{}), BLASTypeName<C_Dtype>(C_Dtype{}), BLASTypeName<T>(T{}), ComputeTypeFor<T>(), ComputeTypeFor<T>());
|
| 482 |
+
}
|
| 483 |
+
|
| 484 |
+
std::string Signature() const override {
|
| 485 |
+
return fmt::sprintf("%c%c_%ld_%ld_%ld_B_%ld_ld_%ld_%ld_%ld", transa, transb, m, n, k, batch, lda, ldb, ldc);
|
| 486 |
+
}
|
| 487 |
+
|
| 488 |
+
size_t GetSizeA() const {
|
| 489 |
+
size_t size_stride = stride_a * batch;
|
| 490 |
+
size_t size_dense = m * k * batch;
|
| 491 |
+
return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense);
|
| 492 |
+
}
|
| 493 |
+
|
| 494 |
+
size_t GetSizeB() const {
|
| 495 |
+
size_t size_stride = stride_b * batch;
|
| 496 |
+
size_t size_dense = k * n * batch;
|
| 497 |
+
return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense);
|
| 498 |
+
}
|
| 499 |
+
|
| 500 |
+
size_t GetSizeC() const {
|
| 501 |
+
size_t size_stride = stride_c * batch;
|
| 502 |
+
size_t size_dense = m * n * batch;
|
| 503 |
+
return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense);
|
| 504 |
+
}
|
| 505 |
+
|
| 506 |
+
size_t GetSize(bool duplicate_inputs) const {
|
| 507 |
+
size_t size = GetSizeC();
|
| 508 |
+
if (duplicate_inputs) {
|
| 509 |
+
size += GetSizeA();
|
| 510 |
+
size += GetSizeB();
|
| 511 |
+
}
|
| 512 |
+
return size;
|
| 513 |
+
}
|
| 514 |
+
|
| 515 |
+
GemmStridedBatchedParams* DeepCopy(bool duplicate_inputs) const {
|
| 516 |
+
GemmStridedBatchedParams* copy = new GemmStridedBatchedParams;
|
| 517 |
+
*copy = *this;
|
| 518 |
+
c10::DeviceIndex device = 0;
|
| 519 |
+
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
|
| 520 |
+
size_t c_size = GetSizeC();
|
| 521 |
+
copy->c = static_cast<C_Dtype*>(c10::cuda::CUDACachingAllocator::raw_alloc(c_size));
|
| 522 |
+
AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
|
| 523 |
+
copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
|
| 524 |
+
if (duplicate_inputs) {
|
| 525 |
+
size_t a_size = GetSizeA();
|
| 526 |
+
size_t b_size = GetSizeB();
|
| 527 |
+
// NOLINTNEXTLINE(*const-cast*)
|
| 528 |
+
copy->a = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(a_size));
|
| 529 |
+
// NOLINTNEXTLINE(*const-cast*)
|
| 530 |
+
copy->b = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(b_size));
|
| 531 |
+
copy->duplicate_inputs_ = true;
|
| 532 |
+
}
|
| 533 |
+
return copy;
|
| 534 |
+
}
|
| 535 |
+
|
| 536 |
+
// only call on object returned by DeepCopy
|
| 537 |
+
void Delete() {
|
| 538 |
+
c10::cuda::CUDACachingAllocator::raw_delete(c);
|
| 539 |
+
if (duplicate_inputs_) {
|
| 540 |
+
// NOLINTNEXTLINE(*const-cast*)
|
| 541 |
+
c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(a));
|
| 542 |
+
// NOLINTNEXTLINE(*const-cast*)
|
| 543 |
+
c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(b));
|
| 544 |
+
}
|
| 545 |
+
}
|
| 546 |
+
|
| 547 |
+
TuningStatus NumericalCheck(GemmStridedBatchedParams<T> *other) {
|
| 548 |
+
auto* ctx = getTuningContext();
|
| 549 |
+
auto cfg = ctx->GetNumericalCheckConfig();
|
| 550 |
+
auto c_dtype = c10::CppTypeToScalarType<C_Dtype>::value;
|
| 551 |
+
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
|
| 552 |
+
}
|
| 553 |
+
|
| 554 |
+
char transa{};
|
| 555 |
+
char transb{};
|
| 556 |
+
int64_t m{};
|
| 557 |
+
int64_t n{};
|
| 558 |
+
int64_t k{};
|
| 559 |
+
at::opmath_type<T> alpha{};
|
| 560 |
+
const T* a{};
|
| 561 |
+
int64_t lda{};
|
| 562 |
+
int64_t stride_a{};
|
| 563 |
+
const T* b{};
|
| 564 |
+
int64_t ldb{};
|
| 565 |
+
int64_t stride_b{};
|
| 566 |
+
at::opmath_type<T> beta;
|
| 567 |
+
C_Dtype* c{};
|
| 568 |
+
int64_t ldc{};
|
| 569 |
+
int64_t stride_c{};
|
| 570 |
+
int64_t batch{};
|
| 571 |
+
private:
|
| 572 |
+
bool duplicate_inputs_{false};
|
| 573 |
+
};
|
| 574 |
+
|
| 575 |
+
template <typename T>
|
| 576 |
+
struct ScaledGemmParams : OpParams {
|
| 577 |
+
ScaledGemmParams() = default;
|
| 578 |
+
|
| 579 |
+
std::string BLASSignature() const override {
|
| 580 |
+
// Excluding use_fast_accum and use_rowise booleans for now
|
| 581 |
+
if (bias_ptr == nullptr) {
|
| 582 |
+
return fmt::sprintf("- { function: matmul, M: %ld, N: %ld, K: %ld, lda: %ld, ldb: %ld, ldc: %ld, ldd: %ld, stride_a: 0, stride_b: 0, stride_c: 0, stride_d: 0, "
|
| 583 |
+
"transA: %c, transB: %c, batch_count: 1, scaleA: f32_r, scaleB: f32_r, a_type: %s, b_type: %s, c_type: %s, d_type: %s, scale_type: %s, compute_type: %s }",
|
| 584 |
+
m, n, k, lda, ldb, ldc, ldc, transa, transb,
|
| 585 |
+
ScalarTypeToBLASType(a_dtype), ScalarTypeToBLASType(b_dtype), ScalarTypeToBLASType(c_dtype), ScalarTypeToBLASType(c_dtype),
|
| 586 |
+
ComputeTypeFor<T>(), ComputeTypeFor<T>());
|
| 587 |
+
}
|
| 588 |
+
else {
|
| 589 |
+
return fmt::sprintf("- { function: matmul, M: %ld, N: %ld, K: %ld, lda: %ld, ldb: %ld, ldc: %ld, ldd: %ld, stride_a: 0, stride_b: 0, stride_c: 0, stride_d: 0, "
|
| 590 |
+
"transA: %c, transB: %c, batch_count: 1, scaleA: f32_r, scaleB: f32_r, a_type: %s, b_type: %s, c_type: %s, d_type: %s, bias_type: %s, scale_type: %s, compute_type: %s }",
|
| 591 |
+
m, n, k, lda, ldb, ldc, ldc, transa, transb,
|
| 592 |
+
ScalarTypeToBLASType(a_dtype), ScalarTypeToBLASType(b_dtype), ScalarTypeToBLASType(c_dtype), ScalarTypeToBLASType(c_dtype), ScalarTypeToBLASType(bias_dtype),
|
| 593 |
+
ComputeTypeFor<T>(), ComputeTypeFor<T>());
|
| 594 |
+
}
|
| 595 |
+
}
|
| 596 |
+
|
| 597 |
+
std::string Signature() const override {
|
| 598 |
+
// In Blas.cpp, code defaults to a bias_dtype of Half even when there is no bias vector.
|
| 599 |
+
// Search for this line::
|
| 600 |
+
// params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_;
|
| 601 |
+
//
|
| 602 |
+
// In TunableOp, we must distinguish in param signature these two cases: with and without a bias vector.
|
| 603 |
+
return fmt::sprintf("%c%c_%ld_%ld_%ld_ld_%ld_%ld_%ld_rw_%d_bias_%s",
|
| 604 |
+
transa, transb, m, n, k, lda, ldb, ldc,
|
| 605 |
+
a_scaling_type == ScalingType::RowWise && b_scaling_type == ScalingType::RowWise,
|
| 606 |
+
bias_ptr == nullptr ? "None" : at::toString(bias_dtype));
|
| 607 |
+
}
|
| 608 |
+
|
| 609 |
+
size_t GetSizeA() const {
|
| 610 |
+
size_t size_stride = lda * ((transa == 'n' || transa == 'N') ? k : m);
|
| 611 |
+
size_t size_dense = m * k;
|
| 612 |
+
return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense);
|
| 613 |
+
}
|
| 614 |
+
|
| 615 |
+
size_t GetSizeB() const {
|
| 616 |
+
size_t size_stride = ldb * ((transb == 'n' || transb == 'N') ? n : k);
|
| 617 |
+
size_t size_dense = k * n;
|
| 618 |
+
return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense);
|
| 619 |
+
}
|
| 620 |
+
|
| 621 |
+
size_t GetSizeC() const {
|
| 622 |
+
size_t size_stride = ldc * n;
|
| 623 |
+
size_t size_dense = m * n;
|
| 624 |
+
return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense);
|
| 625 |
+
}
|
| 626 |
+
|
| 627 |
+
size_t GetSize(bool duplicate_inputs) const {
|
| 628 |
+
size_t size = GetSizeC();
|
| 629 |
+
if (duplicate_inputs) {
|
| 630 |
+
size += GetSizeA();
|
| 631 |
+
size += GetSizeB();
|
| 632 |
+
}
|
| 633 |
+
return size;
|
| 634 |
+
}
|
| 635 |
+
|
| 636 |
+
ScaledGemmParams* DeepCopy(bool duplicate_inputs) const {
|
| 637 |
+
ScaledGemmParams* copy = new ScaledGemmParams;
|
| 638 |
+
*copy = *this;
|
| 639 |
+
c10::DeviceIndex device = 0;
|
| 640 |
+
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
|
| 641 |
+
size_t c_size = GetSizeC();
|
| 642 |
+
copy->c = c10::cuda::CUDACachingAllocator::raw_alloc(c_size);
|
| 643 |
+
AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
|
| 644 |
+
copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
|
| 645 |
+
if (duplicate_inputs) {
|
| 646 |
+
size_t a_size = GetSizeA();
|
| 647 |
+
size_t b_size = GetSizeB();
|
| 648 |
+
copy->a = c10::cuda::CUDACachingAllocator::raw_alloc(a_size);
|
| 649 |
+
copy->b = c10::cuda::CUDACachingAllocator::raw_alloc(b_size);
|
| 650 |
+
copy->duplicate_inputs_ = true;
|
| 651 |
+
}
|
| 652 |
+
return copy;
|
| 653 |
+
}
|
| 654 |
+
|
| 655 |
+
// only call on object returned by DeepCopy
|
| 656 |
+
void Delete() {
|
| 657 |
+
c10::cuda::CUDACachingAllocator::raw_delete(c);
|
| 658 |
+
if (duplicate_inputs_) {
|
| 659 |
+
// NOLINTNEXTLINE(*const-cast*)
|
| 660 |
+
c10::cuda::CUDACachingAllocator::raw_delete(const_cast<void*>(a));
|
| 661 |
+
// NOLINTNEXTLINE(*const-cast*)
|
| 662 |
+
c10::cuda::CUDACachingAllocator::raw_delete(const_cast<void*>(b));
|
| 663 |
+
}
|
| 664 |
+
}
|
| 665 |
+
|
| 666 |
+
TuningStatus NumericalCheck(ScaledGemmParams<T> *other) {
|
| 667 |
+
auto* ctx = getTuningContext();
|
| 668 |
+
auto cfg = ctx->GetNumericalCheckConfig();
|
| 669 |
+
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
|
| 670 |
+
}
|
| 671 |
+
|
| 672 |
+
char transa{};
|
| 673 |
+
char transb{};
|
| 674 |
+
int64_t m{};
|
| 675 |
+
int64_t n{};
|
| 676 |
+
int64_t k{};
|
| 677 |
+
const void* a{};
|
| 678 |
+
const void* a_scale_ptr{};
|
| 679 |
+
int64_t lda{};
|
| 680 |
+
ScalarType a_dtype{};
|
| 681 |
+
ScalarType a_scale_dtype{};
|
| 682 |
+
ScalingType a_scaling_type{};
|
| 683 |
+
const void* b{};
|
| 684 |
+
const void* b_scale_ptr{};
|
| 685 |
+
int64_t ldb{};
|
| 686 |
+
ScalarType b_dtype{};
|
| 687 |
+
ScalarType b_scale_dtype{};
|
| 688 |
+
ScalingType b_scaling_type{};
|
| 689 |
+
const void* bias_ptr{};
|
| 690 |
+
ScalarType bias_dtype{};
|
| 691 |
+
void* c{};
|
| 692 |
+
const void* c_scale_ptr{};
|
| 693 |
+
int64_t ldc{};
|
| 694 |
+
ScalarType c_dtype{};
|
| 695 |
+
void* amax_ptr{};
|
| 696 |
+
bool use_fast_accum{};
|
| 697 |
+
private:
|
| 698 |
+
bool duplicate_inputs_{false};
|
| 699 |
+
};
|
| 700 |
+
|
| 701 |
+
} // namespace at::cuda::tunable
|
| 702 |
+
|
| 703 |
+
#else
|
| 704 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 705 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/GemmHipblaslt.h
ADDED
|
@@ -0,0 +1,692 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 3 |
+
// Licensed under the MIT License.
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 8 |
+
#include <ATen/cuda/CUDADataType.h>
|
| 9 |
+
#include <ATen/cuda/tunable/TunableOp.h>
|
| 10 |
+
#include <ATen/cuda/tunable/GemmCommon.h>
|
| 11 |
+
#include <c10/cuda/CUDACachingAllocator.h>
|
| 12 |
+
#include <c10/util/StringUtil.h>
|
| 13 |
+
#include <fmt/printf.h>
|
| 14 |
+
|
| 15 |
+
#include <hipblaslt/hipblaslt.h>
|
| 16 |
+
#include <hipblaslt/hipblaslt-ext.hpp>
|
| 17 |
+
|
| 18 |
+
#define TORCH_HIPBLASLT_CHECK(EXPR) \
|
| 19 |
+
do { \
|
| 20 |
+
hipblasStatus_t __err = EXPR; \
|
| 21 |
+
TORCH_CHECK(__err == HIPBLAS_STATUS_SUCCESS, \
|
| 22 |
+
"hipblaslt error: ", \
|
| 23 |
+
hipblasStatusToString(__err), \
|
| 24 |
+
" when calling `" #EXPR "`"); \
|
| 25 |
+
} while (0)
|
| 26 |
+
|
| 27 |
+
namespace at::cuda::tunable {
|
| 28 |
+
|
| 29 |
+
template <typename T>
|
| 30 |
+
constexpr hipDataType HipDataTypeFor();
|
| 31 |
+
|
| 32 |
+
template <>
|
| 33 |
+
constexpr hipDataType HipDataTypeFor<float>() {
|
| 34 |
+
return HIP_R_32F;
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
template <>
|
| 38 |
+
constexpr hipDataType HipDataTypeFor<Half>() {
|
| 39 |
+
return HIP_R_16F;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
template <>
|
| 43 |
+
constexpr hipDataType HipDataTypeFor<BFloat16>() {
|
| 44 |
+
return HIP_R_16BF;
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
template <>
|
| 48 |
+
constexpr hipDataType HipDataTypeFor<double>() {
|
| 49 |
+
return HIP_R_64F;
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
template <>
|
| 53 |
+
constexpr hipDataType HipDataTypeFor<c10::Float8_e4m3fnuz>() {
|
| 54 |
+
return HIP_R_8F_E4M3_FNUZ;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
template <>
|
| 58 |
+
constexpr hipDataType HipDataTypeFor<c10::Float8_e5m2fnuz>() {
|
| 59 |
+
return HIP_R_8F_E5M2_FNUZ;
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
// This code is instantiated regardless of ROCm version.
|
| 63 |
+
// Prior to ROCm 6.3, we hard-code the known enum values.
|
| 64 |
+
template <>
|
| 65 |
+
constexpr hipDataType HipDataTypeFor<c10::Float8_e4m3fn>() {
|
| 66 |
+
#if ROCM_VERSION >= 60300
|
| 67 |
+
return HIP_R_8F_E4M3;
|
| 68 |
+
#else
|
| 69 |
+
return static_cast<hipDataType>(28);
|
| 70 |
+
#endif
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
template <>
|
| 74 |
+
constexpr hipDataType HipDataTypeFor<c10::Float8_e5m2>() {
|
| 75 |
+
#if ROCM_VERSION >= 60300
|
| 76 |
+
return HIP_R_8F_E5M2;
|
| 77 |
+
#else
|
| 78 |
+
return static_cast<hipDataType>(29);
|
| 79 |
+
#endif
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
// This type is not intended for matrix types but rather a scale factor.
|
| 83 |
+
// Return a dummy value to satisfy linker.
|
| 84 |
+
template <>
|
| 85 |
+
constexpr hipDataType HipDataTypeFor<c10::Float8_e8m0fnu>() {
|
| 86 |
+
return static_cast<hipDataType>(500);
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
template <>
|
| 90 |
+
constexpr hipDataType HipDataTypeFor<c10::Float4_e2m1fn_x2>() {
|
| 91 |
+
#if ROCM_VERSION >= 70000
|
| 92 |
+
return HIP_R_4F_E2M1;
|
| 93 |
+
#else
|
| 94 |
+
return static_cast<hipDataType>(33);
|
| 95 |
+
#endif
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
template <typename T>
|
| 99 |
+
int GetBatchFromParams(const GemmParams<T>* params) {
|
| 100 |
+
return 1;
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
template <typename T>
|
| 104 |
+
int GetBatchFromParams(const GemmAndBiasParams<T>* params) {
|
| 105 |
+
return 1;
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
template <typename T>
|
| 109 |
+
int GetBatchFromParams(const GemmStridedBatchedParams<T>* params) {
|
| 110 |
+
return params->batch;
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
template <typename T>
|
| 114 |
+
int GetBatchFromParams(const ScaledGemmParams<T>* params) {
|
| 115 |
+
return 1;
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
template <typename T>
|
| 119 |
+
int GetStrideAFromParams(const GemmParams<T>* params) {
|
| 120 |
+
return 1;
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
template <typename T>
|
| 124 |
+
int GetStrideAFromParams(const GemmAndBiasParams<T>* params) {
|
| 125 |
+
return 1;
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
template <typename T>
|
| 129 |
+
int GetStrideAFromParams(const GemmStridedBatchedParams<T>* params) {
|
| 130 |
+
return params->stride_a;
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
template <typename T>
|
| 134 |
+
int GetStrideAFromParams(const ScaledGemmParams<T>* params) {
|
| 135 |
+
return 1;
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
template <typename T>
|
| 139 |
+
int GetStrideBFromParams(const GemmParams<T>* params) {
|
| 140 |
+
return 1;
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
template <typename T>
|
| 144 |
+
int GetStrideBFromParams(const GemmAndBiasParams<T>* params) {
|
| 145 |
+
return 1;
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
template <typename T>
|
| 149 |
+
int GetStrideBFromParams(const GemmStridedBatchedParams<T>* params) {
|
| 150 |
+
return params->stride_b;
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
template <typename T>
|
| 154 |
+
int GetStrideBFromParams(const ScaledGemmParams<T>* params) {
|
| 155 |
+
return 1;
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
template <typename T>
|
| 159 |
+
int GetStrideCFromParams(const GemmParams<T>* params) {
|
| 160 |
+
return 1;
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
template <typename T>
|
| 164 |
+
int GetStrideCFromParams(const GemmAndBiasParams<T>* params) {
|
| 165 |
+
return 1;
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
template <typename T>
|
| 169 |
+
int GetStrideCFromParams(const GemmStridedBatchedParams<T>* params) {
|
| 170 |
+
return params->stride_c;
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
template <typename T>
|
| 174 |
+
int GetStrideCFromParams(const ScaledGemmParams<T>* params) {
|
| 175 |
+
return 1;
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
template <typename T>
|
| 179 |
+
float GetAlphaFromParams(const GemmParams<T>* params) {
|
| 180 |
+
return params->alpha;
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
template <typename T>
|
| 184 |
+
float GetAlphaFromParams(const GemmAndBiasParams<T>* params) {
|
| 185 |
+
return params->alpha;
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
template <typename T>
|
| 189 |
+
float GetAlphaFromParams(const GemmStridedBatchedParams<T>* params) {
|
| 190 |
+
return params->alpha;
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
template <typename T>
|
| 194 |
+
float GetAlphaFromParams(const ScaledGemmParams<T>* params) {
|
| 195 |
+
return 1.0;
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
template <typename T>
|
| 199 |
+
float GetBetaFromParams(const GemmParams<T>* params) {
|
| 200 |
+
return params->beta;
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
template <typename T>
|
| 204 |
+
float GetBetaFromParams(const GemmAndBiasParams<T>* params) {
|
| 205 |
+
return 0.0;
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
template <typename T>
|
| 209 |
+
float GetBetaFromParams(const GemmStridedBatchedParams<T>* params) {
|
| 210 |
+
return params->beta;
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
template <typename T>
|
| 214 |
+
float GetBetaFromParams(const ScaledGemmParams<T>* params) {
|
| 215 |
+
return 0.0;
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
template <typename T>
|
| 219 |
+
ScalingType GetAScalingTypeFromParams(const GemmParams<T>* params) {
|
| 220 |
+
return ScalingType::TensorWise;
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
template <typename T>
|
| 224 |
+
ScalingType GetBScalingTypeFromParams(const GemmParams<T>* params) {
|
| 225 |
+
return ScalingType::TensorWise;
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
template <typename T>
|
| 229 |
+
ScalingType GetAScalingTypeFromParams(const GemmAndBiasParams<T>* params) {
|
| 230 |
+
return ScalingType::TensorWise;
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
template <typename T>
|
| 234 |
+
ScalingType GetBScalingTypeFromParams(const GemmAndBiasParams<T>* params) {
|
| 235 |
+
return ScalingType::TensorWise;
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
template <typename T>
|
| 239 |
+
ScalingType GetAScalingTypeFromParams(const GemmStridedBatchedParams<T>* params) {
|
| 240 |
+
return ScalingType::TensorWise;
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
template <typename T>
|
| 244 |
+
ScalingType GetBScalingTypeFromParams(const GemmStridedBatchedParams<T>* params) {
|
| 245 |
+
return ScalingType::TensorWise;
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
template <typename T>
|
| 249 |
+
ScalingType GetAScalingTypeFromParams(const ScaledGemmParams<T>* params) {
|
| 250 |
+
return params->a_scaling_type;
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
template <typename T>
|
| 254 |
+
ScalingType GetBScalingTypeFromParams(const ScaledGemmParams<T>* params) {
|
| 255 |
+
return params->b_scaling_type;
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
template <typename T>
|
| 259 |
+
const void* GetAScalePointerFromParams(const GemmParams<T>* params) {
|
| 260 |
+
return nullptr;
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
template <typename T>
|
| 264 |
+
const void* GetAScalePointerFromParams(const GemmAndBiasParams<T>* params) {
|
| 265 |
+
return nullptr;
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
template <typename T>
|
| 269 |
+
const void* GetAScalePointerFromParams(const GemmStridedBatchedParams<T>* params) {
|
| 270 |
+
return nullptr;
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
template <typename T>
|
| 274 |
+
const void* GetAScalePointerFromParams(const ScaledGemmParams<T>* params) {
|
| 275 |
+
return params->a_scale_ptr;
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
template <typename T>
|
| 279 |
+
const void* GetBScalePointerFromParams(const GemmParams<T>* params) {
|
| 280 |
+
return nullptr;
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
template <typename T>
|
| 284 |
+
const void* GetBScalePointerFromParams(const GemmAndBiasParams<T>* params) {
|
| 285 |
+
return nullptr;
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
template <typename T>
|
| 289 |
+
const void* GetBScalePointerFromParams(const GemmStridedBatchedParams<T>* params) {
|
| 290 |
+
return nullptr;
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
template <typename T>
|
| 294 |
+
const void* GetBScalePointerFromParams(const ScaledGemmParams<T>* params) {
|
| 295 |
+
return params->b_scale_ptr;
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
template <typename T>
|
| 299 |
+
const void* GetDScalePointerFromParams(const GemmParams<T>* params) {
|
| 300 |
+
return nullptr;
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
template <typename T>
|
| 304 |
+
const void* GetDScalePointerFromParams(const GemmAndBiasParams<T>* params) {
|
| 305 |
+
return nullptr;
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
template <typename T>
|
| 309 |
+
const void* GetDScalePointerFromParams(const GemmStridedBatchedParams<T>* params) {
|
| 310 |
+
return nullptr;
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
template <typename T>
|
| 314 |
+
const void* GetDScalePointerFromParams(const ScaledGemmParams<T>* params) {
|
| 315 |
+
return params->c_scale_ptr;
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
template <typename T>
|
| 319 |
+
const void* GetBiasPointerFromParams(const GemmParams<T>* params) {
|
| 320 |
+
return nullptr;
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
template <typename T>
|
| 324 |
+
const void* GetBiasPointerFromParams(const GemmAndBiasParams<T>* params) {
|
| 325 |
+
return params->bias;
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
template <typename T>
|
| 329 |
+
const void* GetBiasPointerFromParams(const GemmStridedBatchedParams<T>* params) {
|
| 330 |
+
return nullptr;
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
template <typename T>
|
| 334 |
+
const void* GetBiasPointerFromParams(const ScaledGemmParams<T>* params) {
|
| 335 |
+
return params->bias_ptr;
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
template <typename T>
|
| 339 |
+
hipDataType GetBiasTypeFromParams(const GemmParams<T>* params) {
|
| 340 |
+
return HIP_R_32F;
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
template <typename T>
|
| 344 |
+
hipDataType GetBiasTypeFromParams(const GemmAndBiasParams<T>* params) {
|
| 345 |
+
return HipDataTypeFor<T>();
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
template <typename T>
|
| 349 |
+
hipDataType GetBiasTypeFromParams(const GemmStridedBatchedParams<T>* params) {
|
| 350 |
+
return HIP_R_32F;
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
template <typename T>
|
| 354 |
+
hipDataType GetBiasTypeFromParams(const ScaledGemmParams<T>* params) {
|
| 355 |
+
return at::cuda::ScalarTypeToCudaDataType(params->bias_dtype);
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
template <typename T>
|
| 359 |
+
at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmParams<T>* params) {
|
| 360 |
+
return at::cuda::blas::GEMMAndBiasActivationEpilogue::None;
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
template <typename T>
|
| 364 |
+
at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmAndBiasParams<T>* params) {
|
| 365 |
+
return params->activation;
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
template <typename T>
|
| 369 |
+
at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmStridedBatchedParams<T>* params) {
|
| 370 |
+
return at::cuda::blas::GEMMAndBiasActivationEpilogue::None;
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
template <typename T>
|
| 374 |
+
at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const ScaledGemmParams<T>* params) {
|
| 375 |
+
return at::cuda::blas::GEMMAndBiasActivationEpilogue::None;
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
static hipblasOperation_t _hipblasOpFromChar(char op) {
|
| 379 |
+
switch (op) {
|
| 380 |
+
case 'n':
|
| 381 |
+
case 'N':
|
| 382 |
+
return HIPBLAS_OP_N;
|
| 383 |
+
case 't':
|
| 384 |
+
case 'T':
|
| 385 |
+
return HIPBLAS_OP_T;
|
| 386 |
+
case 'c':
|
| 387 |
+
case 'C':
|
| 388 |
+
return HIPBLAS_OP_C;
|
| 389 |
+
}
|
| 390 |
+
TORCH_CHECK(false,
|
| 391 |
+
"_hipblasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`");
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
static char _charFromhipblasOp(hipblasOperation_t op) {
|
| 395 |
+
switch (op) {
|
| 396 |
+
case HIPBLAS_OP_N:
|
| 397 |
+
return 'N';
|
| 398 |
+
case HIPBLAS_OP_T:
|
| 399 |
+
return 'T';
|
| 400 |
+
case HIPBLAS_OP_C:
|
| 401 |
+
return 'C';
|
| 402 |
+
}
|
| 403 |
+
TORCH_CHECK(false,
|
| 404 |
+
"_charFromhipblasOp input should be HIPBLAS_OP_N/T/C but got `", op, "`");
|
| 405 |
+
}
|
| 406 |
+
|
| 407 |
+
static hipblasOperation_t MapLayoutToHipBlasLt(BlasOp layout) {
|
| 408 |
+
if (layout == BlasOp::N) {
|
| 409 |
+
return HIPBLAS_OP_N;
|
| 410 |
+
}
|
| 411 |
+
return HIPBLAS_OP_T;
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
template <typename T, cublasStatus_t (*destructor)(T*)>
|
| 415 |
+
struct HipBlasLtDeleter {
|
| 416 |
+
void operator()(T* x) {
|
| 417 |
+
if (x != nullptr) {
|
| 418 |
+
TORCH_CUDABLAS_CHECK(destructor(x));
|
| 419 |
+
}
|
| 420 |
+
}
|
| 421 |
+
};
|
| 422 |
+
|
| 423 |
+
template <typename T, hipblasStatus_t (*destructor)(T*)>
|
| 424 |
+
class HipBlasLtDescriptor {
|
| 425 |
+
public:
|
| 426 |
+
T* descriptor() const {
|
| 427 |
+
return descriptor_.get();
|
| 428 |
+
}
|
| 429 |
+
T* descriptor() {
|
| 430 |
+
return descriptor_.get();
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
protected:
|
| 434 |
+
std::unique_ptr<T, HipBlasLtDeleter<T, destructor>> descriptor_;
|
| 435 |
+
};
|
| 436 |
+
|
| 437 |
+
class HipBlasLtMatmulDescriptor : public HipBlasLtDescriptor<
|
| 438 |
+
hipblasLtMatmulDescOpaque_t,
|
| 439 |
+
&hipblasLtMatmulDescDestroy> {
|
| 440 |
+
public:
|
| 441 |
+
HipBlasLtMatmulDescriptor(
|
| 442 |
+
hipblasComputeType_t compute_type,
|
| 443 |
+
hipDataType scale_type) {
|
| 444 |
+
hipblasLtMatmulDesc_t raw_descriptor = nullptr;
|
| 445 |
+
TORCH_HIPBLASLT_CHECK(
|
| 446 |
+
hipblasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type));
|
| 447 |
+
descriptor_.reset(raw_descriptor);
|
| 448 |
+
}
|
| 449 |
+
template <typename T>
|
| 450 |
+
inline void setAttribute(hipblasLtMatmulDescAttributes_t attr, const T value) {
|
| 451 |
+
TORCH_HIPBLASLT_CHECK(::hipblasLtMatmulDescSetAttribute(descriptor(), attr, &value, sizeof(T)));
|
| 452 |
+
}
|
| 453 |
+
};
|
| 454 |
+
|
| 455 |
+
template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout, typename ParamsT>
|
| 456 |
+
class HipblasltGemmOp : public Callable<ParamsT> {
|
| 457 |
+
public:
|
| 458 |
+
HipblasltGemmOp(hipblasLtMatmulAlgo_t algo) : algo_{algo} {}
|
| 459 |
+
|
| 460 |
+
TuningStatus Call(const ParamsT* params) override {
|
| 461 |
+
hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout);
|
| 462 |
+
hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout);
|
| 463 |
+
auto a_datatype = HipDataTypeFor<AT>();
|
| 464 |
+
auto b_datatype = HipDataTypeFor<BT>();
|
| 465 |
+
auto in_out_datatype = HipDataTypeFor<CT>();
|
| 466 |
+
auto opa = _hipblasOpFromChar(params->transa);
|
| 467 |
+
auto opb = _hipblasOpFromChar(params->transb);
|
| 468 |
+
|
| 469 |
+
TORCH_CHECK(transa_outer == opa && transb_outer == opb, "trans mismatch, shouldn't happen");
|
| 470 |
+
|
| 471 |
+
float alpha = GetAlphaFromParams<CT>(params);
|
| 472 |
+
float beta = GetBetaFromParams<CT>(params);
|
| 473 |
+
|
| 474 |
+
hipblasLtMatrixLayout_t mat_a, mat_b, mat_c;
|
| 475 |
+
if (opa == HIPBLAS_OP_N) {
|
| 476 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, a_datatype, params->m, params->k, params->lda));
|
| 477 |
+
}
|
| 478 |
+
else {
|
| 479 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, a_datatype, params->k, params->m, params->lda));
|
| 480 |
+
}
|
| 481 |
+
if (opb == HIPBLAS_OP_N) {
|
| 482 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, b_datatype, params->k, params->n, params->ldb));
|
| 483 |
+
}
|
| 484 |
+
else {
|
| 485 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, b_datatype, params->n, params->k, params->ldb));
|
| 486 |
+
}
|
| 487 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_c, in_out_datatype, params->m, params->n, params->ldc));
|
| 488 |
+
|
| 489 |
+
// specific to batched gemmm
|
| 490 |
+
int batch = GetBatchFromParams<CT>(params);
|
| 491 |
+
if (batch > 1) {
|
| 492 |
+
int64_t stride_a = GetStrideAFromParams<CT>(params);
|
| 493 |
+
int64_t stride_b = GetStrideBFromParams<CT>(params);
|
| 494 |
+
int64_t stride_c = GetStrideCFromParams<CT>(params);
|
| 495 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
|
| 496 |
+
mat_a, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
|
| 497 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
|
| 498 |
+
mat_a, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_a, sizeof(stride_a)));
|
| 499 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
|
| 500 |
+
mat_b, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
|
| 501 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
|
| 502 |
+
mat_b, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_b, sizeof(stride_b)));
|
| 503 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
|
| 504 |
+
mat_c, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
|
| 505 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
|
| 506 |
+
mat_c, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_c, sizeof(stride_c)));
|
| 507 |
+
}
|
| 508 |
+
|
| 509 |
+
hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F;
|
| 510 |
+
if (at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) == at::Float32Precision::TF32) {
|
| 511 |
+
computeType = HIPBLAS_COMPUTE_32F_FAST_TF32;
|
| 512 |
+
}
|
| 513 |
+
HipBlasLtMatmulDescriptor matmul(computeType, HIP_R_32F);
|
| 514 |
+
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSA, opa);
|
| 515 |
+
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSB, opb);
|
| 516 |
+
|
| 517 |
+
// specific to scaled gemm
|
| 518 |
+
const void* mat1_scale_ptr = GetAScalePointerFromParams<CT>(params);
|
| 519 |
+
const void* mat2_scale_ptr = GetBScalePointerFromParams<CT>(params);
|
| 520 |
+
const void* result_scale_ptr = GetDScalePointerFromParams<CT>(params);
|
| 521 |
+
if (mat1_scale_ptr && mat2_scale_ptr) {
|
| 522 |
+
hipblasLtMatmulDescAttributes_t a_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER;
|
| 523 |
+
hipblasLtMatmulDescAttributes_t b_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER;
|
| 524 |
+
if (GetAScalingTypeFromParams<CT>(params) == ScalingType::RowWise) {
|
| 525 |
+
#if defined(HIPBLASLT_OUTER_VEC)
|
| 526 |
+
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F);
|
| 527 |
+
#elif defined(HIPBLASLT_VEC_EXT)
|
| 528 |
+
a_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT;
|
| 529 |
+
#endif
|
| 530 |
+
}
|
| 531 |
+
if (GetBScalingTypeFromParams<CT>(params) == ScalingType::RowWise) {
|
| 532 |
+
#if defined(HIPBLASLT_OUTER_VEC)
|
| 533 |
+
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F);
|
| 534 |
+
#elif defined(HIPBLASLT_VEC_EXT)
|
| 535 |
+
b_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT;
|
| 536 |
+
#endif
|
| 537 |
+
}
|
| 538 |
+
matmul.setAttribute(a_scale_ptr_desc, mat1_scale_ptr);
|
| 539 |
+
matmul.setAttribute(b_scale_ptr_desc, mat2_scale_ptr);
|
| 540 |
+
}
|
| 541 |
+
if (result_scale_ptr) {
|
| 542 |
+
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
|
| 543 |
+
}
|
| 544 |
+
|
| 545 |
+
const void* bias_ptr = GetBiasPointerFromParams<CT>(params);
|
| 546 |
+
auto bias_datatype = GetBiasTypeFromParams<CT>(params);
|
| 547 |
+
if (bias_ptr) {
|
| 548 |
+
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_POINTER, bias_ptr);
|
| 549 |
+
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, bias_datatype);
|
| 550 |
+
auto activation = GetActivationFromParams<CT>(params);
|
| 551 |
+
if (activation == at::cuda::blas::GEMMAndBiasActivationEpilogue::RELU) {
|
| 552 |
+
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_RELU_BIAS);
|
| 553 |
+
}
|
| 554 |
+
else if (activation == at::cuda::blas::GEMMAndBiasActivationEpilogue::GELU) {
|
| 555 |
+
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_GELU_BIAS);
|
| 556 |
+
}
|
| 557 |
+
else {
|
| 558 |
+
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_BIAS);
|
| 559 |
+
}
|
| 560 |
+
}
|
| 561 |
+
|
| 562 |
+
size_t workspace_size = at::cuda::getCUDABlasLtWorkspaceSize();
|
| 563 |
+
|
| 564 |
+
auto op_handle = at::cuda::getCurrentCUDABlasLtHandle();
|
| 565 |
+
|
| 566 |
+
size_t ret_workspace_size = 0;
|
| 567 |
+
auto status = hipblaslt_ext::matmulIsAlgoSupported(op_handle,
|
| 568 |
+
matmul.descriptor(),
|
| 569 |
+
&alpha,
|
| 570 |
+
mat_a,
|
| 571 |
+
mat_b,
|
| 572 |
+
&beta,
|
| 573 |
+
mat_c,
|
| 574 |
+
mat_c,
|
| 575 |
+
algo_,
|
| 576 |
+
ret_workspace_size);
|
| 577 |
+
|
| 578 |
+
if (status == HIPBLAS_STATUS_SUCCESS) {
|
| 579 |
+
if (ret_workspace_size >= workspace_size) {
|
| 580 |
+
return FAIL;
|
| 581 |
+
}
|
| 582 |
+
}
|
| 583 |
+
else {
|
| 584 |
+
return FAIL;
|
| 585 |
+
}
|
| 586 |
+
|
| 587 |
+
void* workspace_buffer = at::cuda::getCUDABlasLtWorkspace();
|
| 588 |
+
|
| 589 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatmul(op_handle,
|
| 590 |
+
matmul.descriptor(),
|
| 591 |
+
&alpha,
|
| 592 |
+
params->a,
|
| 593 |
+
mat_a,
|
| 594 |
+
params->b,
|
| 595 |
+
mat_b,
|
| 596 |
+
&beta,
|
| 597 |
+
params->c,
|
| 598 |
+
mat_c,
|
| 599 |
+
params->c,
|
| 600 |
+
mat_c,
|
| 601 |
+
&algo_,
|
| 602 |
+
workspace_buffer,
|
| 603 |
+
workspace_size,
|
| 604 |
+
at::cuda::getCurrentCUDAStream()));
|
| 605 |
+
|
| 606 |
+
//TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescDestroy(matmul));
|
| 607 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_a));
|
| 608 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_b));
|
| 609 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_c));
|
| 610 |
+
return OK;
|
| 611 |
+
}
|
| 612 |
+
|
| 613 |
+
private:
|
| 614 |
+
hipblasLtMatmulAlgo_t algo_;
|
| 615 |
+
};
|
| 616 |
+
|
| 617 |
+
template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout, typename ParamsT>
|
| 618 |
+
auto GetHipBlasLtTypeStringAndOps() {
|
| 619 |
+
hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout);
|
| 620 |
+
hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout);
|
| 621 |
+
auto a_datatype = HipDataTypeFor<AT>();
|
| 622 |
+
auto b_datatype = HipDataTypeFor<BT>();
|
| 623 |
+
auto in_out_datatype = HipDataTypeFor<CT>();
|
| 624 |
+
std::vector<hipblasLtMatmulHeuristicResult_t> heuristic_result;
|
| 625 |
+
#if ROCM_VERSION == 60400
|
| 626 |
+
// hipblaslt TT fp32 regression on ROCm 6.4, cannot use
|
| 627 |
+
if ((a_datatype == HIP_R_32F || b_datatype == HIP_R_32F || in_out_datatype == HIP_R_32F)
|
| 628 |
+
&& (transa_outer == HIPBLAS_OP_T && transb_outer == HIPBLAS_OP_T)) {
|
| 629 |
+
std::vector<std::pair<std::string, std::unique_ptr<Callable<ParamsT>>>> ignore;
|
| 630 |
+
return ignore;
|
| 631 |
+
}
|
| 632 |
+
#endif
|
| 633 |
+
|
| 634 |
+
hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F;
|
| 635 |
+
if (at::globalContext().allowTF32CuBLAS()) {
|
| 636 |
+
computeType = HIPBLAS_COMPUTE_32F_FAST_TF32;
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
+
hipblasLtHandle_t handle;
|
| 640 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtCreate(&handle));
|
| 641 |
+
TORCH_HIPBLASLT_CHECK(hipblaslt_ext::getAllAlgos(handle,
|
| 642 |
+
hipblaslt_ext::GemmType::HIPBLASLT_GEMM,
|
| 643 |
+
transa_outer,
|
| 644 |
+
transb_outer,
|
| 645 |
+
a_datatype,
|
| 646 |
+
b_datatype,
|
| 647 |
+
in_out_datatype,
|
| 648 |
+
in_out_datatype,
|
| 649 |
+
computeType,
|
| 650 |
+
heuristic_result));
|
| 651 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtDestroy(handle));
|
| 652 |
+
|
| 653 |
+
int returned_algo_count = heuristic_result.size();
|
| 654 |
+
std::vector<std::pair<std::string, std::unique_ptr<Callable<ParamsT>>>> ret;
|
| 655 |
+
for (int i = 0; i < returned_algo_count; i++) {
|
| 656 |
+
auto algo = heuristic_result[i].algo;
|
| 657 |
+
int algo_index = hipblaslt_ext::getIndexFromAlgo(algo);
|
| 658 |
+
auto callable = std::make_unique<HipblasltGemmOp<AT, BT, CT, ALayout, BLayout, ParamsT>>(algo);
|
| 659 |
+
std::string type_string = fmt::sprintf("Gemm_Hipblaslt_%d", algo_index);
|
| 660 |
+
ret.emplace_back(type_string, std::move(callable));
|
| 661 |
+
}
|
| 662 |
+
|
| 663 |
+
return ret;
|
| 664 |
+
}
|
| 665 |
+
|
| 666 |
+
template <typename T, BlasOp ALayout, BlasOp BLayout>
|
| 667 |
+
auto GetHipBlasLtGemmTypeStringAndOps() {
|
| 668 |
+
return GetHipBlasLtTypeStringAndOps<T, T, T, ALayout, BLayout, GemmParams<T>>();
|
| 669 |
+
}
|
| 670 |
+
|
| 671 |
+
template <typename T, BlasOp ALayout, BlasOp BLayout>
|
| 672 |
+
auto GetHipBlasLtGemmAndBiasTypeStringAndOps() {
|
| 673 |
+
return GetHipBlasLtTypeStringAndOps<T, T, T, ALayout, BLayout, GemmAndBiasParams<T>>();
|
| 674 |
+
}
|
| 675 |
+
|
| 676 |
+
template <typename T, BlasOp ALayout, BlasOp BLayout>
|
| 677 |
+
auto GetHipBlasLtGemmStridedBatchedTypeStringAndOps() {
|
| 678 |
+
return GetHipBlasLtTypeStringAndOps<T, T, T, ALayout, BLayout, GemmStridedBatchedParams<T>>();
|
| 679 |
+
}
|
| 680 |
+
|
| 681 |
+
template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout>
|
| 682 |
+
auto GetHipBlasLtScaledGemmTypeStringAndOps() {
|
| 683 |
+
return GetHipBlasLtTypeStringAndOps<AT, BT, CT, ALayout, BLayout, ScaledGemmParams<CT>>();
|
| 684 |
+
}
|
| 685 |
+
|
| 686 |
+
#undef TORCH_HIPBLASLT_CHECK
|
| 687 |
+
|
| 688 |
+
} // namespace at::cuda::tunable
|
| 689 |
+
|
| 690 |
+
#else
|
| 691 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 692 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/GemmRocblas.h
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 3 |
+
// Licensed under the MIT License.
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 8 |
+
#include <ATen/cuda/tunable/TunableOp.h>
|
| 9 |
+
#include <ATen/cuda/tunable/GemmCommon.h>
|
| 10 |
+
#include <c10/util/StringUtil.h>
|
| 11 |
+
#include <fmt/printf.h>
|
| 12 |
+
|
| 13 |
+
#define ROCBLAS_BETA_FEATURES_API
|
| 14 |
+
#include <rocblas/rocblas.h>
|
| 15 |
+
|
| 16 |
+
#define TORCH_ROCBLAS_CHECK(EXPR) \
|
| 17 |
+
do { \
|
| 18 |
+
rocblas_status __err = EXPR; \
|
| 19 |
+
TORCH_CHECK(__err == rocblas_status_success, \
|
| 20 |
+
"rocblas error: ", \
|
| 21 |
+
rocblas_status_to_string(__err), \
|
| 22 |
+
" when calling `" #EXPR "`"); \
|
| 23 |
+
} while (0)
|
| 24 |
+
|
| 25 |
+
namespace at::cuda::tunable {
|
| 26 |
+
|
| 27 |
+
template <typename T>
|
| 28 |
+
constexpr rocblas_datatype RocBlasDataTypeFor();
|
| 29 |
+
|
| 30 |
+
template <>
|
| 31 |
+
constexpr rocblas_datatype RocBlasDataTypeFor<float>() {
|
| 32 |
+
return rocblas_datatype_f32_r;
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
template <>
|
| 36 |
+
constexpr rocblas_datatype RocBlasDataTypeFor<double>() {
|
| 37 |
+
return rocblas_datatype_f64_r;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
template <>
|
| 41 |
+
constexpr rocblas_datatype RocBlasDataTypeFor<Half>() {
|
| 42 |
+
return rocblas_datatype_f16_r;
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
template <>
|
| 46 |
+
constexpr rocblas_datatype RocBlasDataTypeFor<BFloat16>() {
|
| 47 |
+
return rocblas_datatype_bf16_r;
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
template <>
|
| 51 |
+
constexpr rocblas_datatype RocBlasDataTypeFor<c10::complex<float>>() {
|
| 52 |
+
return rocblas_datatype_f32_c;
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
template <>
|
| 56 |
+
constexpr rocblas_datatype RocBlasDataTypeFor<c10::complex<double>>() {
|
| 57 |
+
return rocblas_datatype_f64_c;
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
template <typename T>
|
| 61 |
+
constexpr rocblas_datatype RocBlasComputeTypeFor();
|
| 62 |
+
|
| 63 |
+
template <>
|
| 64 |
+
constexpr rocblas_datatype RocBlasComputeTypeFor<float>() {
|
| 65 |
+
return rocblas_datatype_f32_r;
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
template <>
|
| 69 |
+
constexpr rocblas_datatype RocBlasComputeTypeFor<double>() {
|
| 70 |
+
return rocblas_datatype_f64_r;
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
template <>
|
| 74 |
+
constexpr rocblas_datatype RocBlasComputeTypeFor<Half>() {
|
| 75 |
+
// Note that we're returning the _compute_ type for a given datatype.
|
| 76 |
+
// As of 12/2022, using compute type FP16 for 16-bit floats was much
|
| 77 |
+
// slower than using compute type FP32. So we use FP32 compute even for
|
| 78 |
+
// FP16 datatypes. This is how GEMM is implemented even in the function
|
| 79 |
+
// rocblasGemmHelper (see fpgeneric.h)
|
| 80 |
+
return rocblas_datatype_f32_r;
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
template <>
|
| 84 |
+
constexpr rocblas_datatype RocBlasComputeTypeFor<BFloat16>() {
|
| 85 |
+
// Note that we're returning the _compute_ type for a given datatype.
|
| 86 |
+
// As of 12/2022, using compute type FP16 for 16-bit floats was much
|
| 87 |
+
// slower than using compute type FP32. So we use FP32 compute even for
|
| 88 |
+
// BF16 datatypes. This is how GEMM is implemented even in the function
|
| 89 |
+
// rocblasGemmHelper (see fpgeneric.h)
|
| 90 |
+
return rocblas_datatype_f32_r;
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
template <>
|
| 94 |
+
constexpr rocblas_datatype RocBlasComputeTypeFor<c10::complex<float>>() {
|
| 95 |
+
return rocblas_datatype_f32_c;
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
template <>
|
| 99 |
+
constexpr rocblas_datatype RocBlasComputeTypeFor<c10::complex<double>>() {
|
| 100 |
+
return rocblas_datatype_f64_c;
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
template <typename T>
|
| 104 |
+
auto DoCastForHalfOrBfloat16(const T fp) {
|
| 105 |
+
return fp;
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
template <>
|
| 109 |
+
inline auto DoCastForHalfOrBfloat16<Half>(const Half fp) {
|
| 110 |
+
// alpha and beta should be the same as compute_type, in Half case it is float.
|
| 111 |
+
float h = fp;
|
| 112 |
+
return h;
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
template <>
|
| 116 |
+
inline auto DoCastForHalfOrBfloat16<BFloat16>(const BFloat16 fp) {
|
| 117 |
+
// alpha and beta should be the same as compute_type, in bfloat16 case it is float.
|
| 118 |
+
float h = fp;
|
| 119 |
+
return h;
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
static rocblas_operation _rocblasOpFromChar(char op) {
|
| 123 |
+
switch (op) {
|
| 124 |
+
case 'n':
|
| 125 |
+
case 'N':
|
| 126 |
+
return rocblas_operation_none;
|
| 127 |
+
case 't':
|
| 128 |
+
case 'T':
|
| 129 |
+
return rocblas_operation_transpose;
|
| 130 |
+
case 'c':
|
| 131 |
+
case 'C':
|
| 132 |
+
return rocblas_operation_conjugate_transpose;
|
| 133 |
+
}
|
| 134 |
+
TORCH_CHECK(false,
|
| 135 |
+
"_rocblasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`");
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
template <typename T>
|
| 139 |
+
class RocblasGemmOp : public Callable<GemmParams<T>> {
|
| 140 |
+
public:
|
| 141 |
+
RocblasGemmOp(int solution) : solution_{solution} {}
|
| 142 |
+
|
| 143 |
+
TuningStatus Call(const GemmParams<T>* params) override {
|
| 144 |
+
auto input_output_type = RocBlasDataTypeFor<T>();
|
| 145 |
+
if (at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) == at::Float32Precision::TF32 && input_output_type == rocblas_datatype_f32_r)
|
| 146 |
+
return FAIL; // no support for TF32 in rocBLAS
|
| 147 |
+
auto compute_type = RocBlasComputeTypeFor<T>();
|
| 148 |
+
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
|
| 149 |
+
auto h_b = DoCastForHalfOrBfloat16(params->beta);
|
| 150 |
+
auto status = rocblas_gemm_ex(
|
| 151 |
+
(rocblas_handle)at::cuda::getCurrentCUDABlasHandle(),
|
| 152 |
+
_rocblasOpFromChar(params->transa),
|
| 153 |
+
_rocblasOpFromChar(params->transb),
|
| 154 |
+
params->m, params->n, params->k,
|
| 155 |
+
&h_a,
|
| 156 |
+
params->a, input_output_type, params->lda,
|
| 157 |
+
params->b, input_output_type, params->ldb,
|
| 158 |
+
&h_b,
|
| 159 |
+
params->c, input_output_type, params->ldc,
|
| 160 |
+
params->c, input_output_type, params->ldc,
|
| 161 |
+
compute_type,
|
| 162 |
+
rocblas_gemm_algo_solution_index,
|
| 163 |
+
solution_,
|
| 164 |
+
rocblas_gemm_flags_none);
|
| 165 |
+
if (status != rocblas_status_success) {
|
| 166 |
+
return FAIL;
|
| 167 |
+
}
|
| 168 |
+
return OK;
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
private:
|
| 172 |
+
int solution_;
|
| 173 |
+
};
|
| 174 |
+
|
| 175 |
+
template <typename T>
|
| 176 |
+
auto GetRocBlasGemmTypeStringAndOps() {
|
| 177 |
+
rocblas_handle handle = (rocblas_handle)at::cuda::getCurrentCUDABlasHandle();
|
| 178 |
+
int solution_size;
|
| 179 |
+
auto input_output_type = RocBlasDataTypeFor<T>();
|
| 180 |
+
auto compute_type = RocBlasComputeTypeFor<T>();
|
| 181 |
+
// Get the number of available solutions
|
| 182 |
+
TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
|
| 183 |
+
input_output_type,
|
| 184 |
+
input_output_type,
|
| 185 |
+
compute_type,
|
| 186 |
+
rocblas_gemm_flags_none,
|
| 187 |
+
nullptr,
|
| 188 |
+
&solution_size));
|
| 189 |
+
std::vector<int> solutions(solution_size);
|
| 190 |
+
// Get the list of available solutions
|
| 191 |
+
TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
|
| 192 |
+
input_output_type,
|
| 193 |
+
input_output_type,
|
| 194 |
+
compute_type,
|
| 195 |
+
rocblas_gemm_flags_none,
|
| 196 |
+
solutions.data(),
|
| 197 |
+
&solution_size));
|
| 198 |
+
std::vector<std::pair<std::string, std::unique_ptr<Callable<GemmParams<T>>>>> ret;
|
| 199 |
+
for (size_t i = 0; i < solutions.size(); ++i) {
|
| 200 |
+
auto callable = std::make_unique<RocblasGemmOp<T>>(solutions[i]);
|
| 201 |
+
ret.emplace_back(std::make_pair(fmt::sprintf("Gemm_Rocblas_%d", solutions[i]), std::move(callable)));
|
| 202 |
+
}
|
| 203 |
+
return ret;
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
template <typename T>
|
| 207 |
+
class RocblasGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>> {
|
| 208 |
+
public:
|
| 209 |
+
RocblasGemmStridedBatchedOp(int solution) : solution_{solution} {}
|
| 210 |
+
|
| 211 |
+
TuningStatus Call(const GemmStridedBatchedParams<T>* params) override {
|
| 212 |
+
auto input_output_type = RocBlasDataTypeFor<T>();
|
| 213 |
+
if (at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) == at::Float32Precision::TF32 && input_output_type == rocblas_datatype_f32_r)
|
| 214 |
+
return FAIL; // no support for TF32 in rocBLAS
|
| 215 |
+
auto compute_type = RocBlasComputeTypeFor<T>();
|
| 216 |
+
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
|
| 217 |
+
auto h_b = DoCastForHalfOrBfloat16(params->beta);
|
| 218 |
+
auto status = rocblas_gemm_strided_batched_ex(
|
| 219 |
+
(rocblas_handle)at::cuda::getCurrentCUDABlasHandle(),
|
| 220 |
+
_rocblasOpFromChar(params->transa),
|
| 221 |
+
_rocblasOpFromChar(params->transb),
|
| 222 |
+
params->m, params->n, params->k,
|
| 223 |
+
&h_a,
|
| 224 |
+
params->a, input_output_type, params->lda, params->stride_a,
|
| 225 |
+
params->b, input_output_type, params->ldb, params->stride_b,
|
| 226 |
+
&h_b,
|
| 227 |
+
params->c, input_output_type, params->ldc, params->stride_c,
|
| 228 |
+
params->c, input_output_type, params->ldc, params->stride_c,
|
| 229 |
+
params->batch,
|
| 230 |
+
compute_type,
|
| 231 |
+
rocblas_gemm_algo_solution_index,
|
| 232 |
+
solution_,
|
| 233 |
+
rocblas_gemm_flags_none);
|
| 234 |
+
if (status != rocblas_status_success) {
|
| 235 |
+
return FAIL;
|
| 236 |
+
}
|
| 237 |
+
return OK;
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
private:
|
| 241 |
+
int solution_;
|
| 242 |
+
};
|
| 243 |
+
|
| 244 |
+
template <typename T>
|
| 245 |
+
auto GetRocBlasGemmStridedBatchedTypeStringAndOps() {
|
| 246 |
+
rocblas_handle handle = (rocblas_handle)at::cuda::getCurrentCUDABlasHandle();
|
| 247 |
+
int solution_size;
|
| 248 |
+
auto input_output_type = RocBlasDataTypeFor<T>();
|
| 249 |
+
auto compute_type = RocBlasComputeTypeFor<T>();
|
| 250 |
+
// Get the number of available solutions
|
| 251 |
+
TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
|
| 252 |
+
input_output_type,
|
| 253 |
+
input_output_type,
|
| 254 |
+
compute_type,
|
| 255 |
+
rocblas_gemm_flags_none,
|
| 256 |
+
nullptr,
|
| 257 |
+
&solution_size));
|
| 258 |
+
std::vector<int> solutions(solution_size);
|
| 259 |
+
// Get the list of available solutions
|
| 260 |
+
TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
|
| 261 |
+
input_output_type,
|
| 262 |
+
input_output_type,
|
| 263 |
+
compute_type,
|
| 264 |
+
rocblas_gemm_flags_none,
|
| 265 |
+
solutions.data(),
|
| 266 |
+
&solution_size));
|
| 267 |
+
// Sort the solutions in ascending order to make the solution vector deterministic across runs
|
| 268 |
+
std::sort(solutions.begin(), solutions.end());
|
| 269 |
+
|
| 270 |
+
std::vector<std::pair<std::string, std::unique_ptr<Callable<GemmStridedBatchedParams<T>>>>> ret;
|
| 271 |
+
for (size_t i = 0; i < solutions.size(); ++i) {
|
| 272 |
+
auto callable = std::make_unique<RocblasGemmStridedBatchedOp<T>>(solutions[i]);
|
| 273 |
+
ret.emplace_back(std::make_pair(c10::str("Gemm_Rocblas_", solutions[i]), std::move(callable)));
|
| 274 |
+
}
|
| 275 |
+
return ret;
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
} // namespace at::cuda::tunable
|
| 279 |
+
|
| 280 |
+
#else
|
| 281 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 282 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/StreamTimer.h
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
// Original TunableOp is from onnxruntime.
|
| 3 |
+
// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
|
| 4 |
+
// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
|
| 5 |
+
// Copyright (c) Microsoft Corporation.
|
| 6 |
+
// Licensed under the MIT license.
|
| 7 |
+
//
|
| 8 |
+
// Adapting TunableOp into PyTorch
|
| 9 |
+
// Copyright (c) Advanced Micro Devices, Inc.
|
| 10 |
+
//
|
| 11 |
+
#pragma once
|
| 12 |
+
|
| 13 |
+
#include <cuda_runtime.h>
|
| 14 |
+
|
| 15 |
+
#include <ATen/cuda/tunable/Tunable.h>
|
| 16 |
+
|
| 17 |
+
namespace at::cuda::tunable {
|
| 18 |
+
|
| 19 |
+
class StreamTimer : public ITimer {
|
| 20 |
+
public:
|
| 21 |
+
StreamTimer();
|
| 22 |
+
~StreamTimer() override;
|
| 23 |
+
|
| 24 |
+
void Start() override;
|
| 25 |
+
|
| 26 |
+
void End() override;
|
| 27 |
+
|
| 28 |
+
float Duration() override;
|
| 29 |
+
|
| 30 |
+
private:
|
| 31 |
+
cudaEvent_t start_{};
|
| 32 |
+
cudaEvent_t end_{};
|
| 33 |
+
};
|
| 34 |
+
|
| 35 |
+
class StreamTimerNoSync : public ITimer {
|
| 36 |
+
public:
|
| 37 |
+
StreamTimerNoSync();
|
| 38 |
+
~StreamTimerNoSync() override;
|
| 39 |
+
|
| 40 |
+
void Start() override;
|
| 41 |
+
|
| 42 |
+
void End() override;
|
| 43 |
+
|
| 44 |
+
float Duration() override;
|
| 45 |
+
|
| 46 |
+
private:
|
| 47 |
+
cudaEvent_t start_{};
|
| 48 |
+
cudaEvent_t end_{};
|
| 49 |
+
};
|
| 50 |
+
|
| 51 |
+
} // namespace at::cuda::tunable
|
| 52 |
+
|
| 53 |
+
#else
|
| 54 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 55 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/Tunable.h
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
// Original TunableOp is from onnxruntime.
|
| 3 |
+
// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
|
| 4 |
+
// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
|
| 5 |
+
// Copyright (c) Microsoft Corporation.
|
| 6 |
+
// Licensed under the MIT license.
|
| 7 |
+
//
|
| 8 |
+
// Adapting TunableOp into PyTorch
|
| 9 |
+
// Copyright (c) Advanced Micro Devices, Inc.
|
| 10 |
+
//
|
| 11 |
+
#pragma once
|
| 12 |
+
|
| 13 |
+
#include <c10/util/CallOnce.h>
|
| 14 |
+
#include <c10/util/StringUtil.h>
|
| 15 |
+
#include <c10/util/env.h>
|
| 16 |
+
|
| 17 |
+
#include <fstream>
|
| 18 |
+
#include <functional>
|
| 19 |
+
#include <iostream>
|
| 20 |
+
#include <memory>
|
| 21 |
+
#include <mutex>
|
| 22 |
+
#include <string>
|
| 23 |
+
#include <unordered_map>
|
| 24 |
+
#include <unordered_set>
|
| 25 |
+
#include <utility>
|
| 26 |
+
|
| 27 |
+
#define TUNABLE_LOGV(LEVEL, ...) getTuningContext()->Log(LEVEL, __VA_ARGS__)
|
| 28 |
+
#define TUNABLE_LOG1(...) TUNABLE_LOGV(1, __VA_ARGS__)
|
| 29 |
+
#define TUNABLE_LOG2(...) TUNABLE_LOGV(2, __VA_ARGS__)
|
| 30 |
+
#define TUNABLE_LOG3(...) TUNABLE_LOGV(3, __VA_ARGS__)
|
| 31 |
+
|
| 32 |
+
namespace at::cuda::tunable {
|
| 33 |
+
|
| 34 |
+
enum TORCH_CUDA_CPP_API TuningStatus {
|
| 35 |
+
OK = 0,
|
| 36 |
+
FAIL = 1,
|
| 37 |
+
UNSUPPORTED = 2,
|
| 38 |
+
};
|
| 39 |
+
|
| 40 |
+
// Mapping from params signature to kernel id
|
| 41 |
+
class TORCH_CUDA_CPP_API ResultEntry {
|
| 42 |
+
public:
|
| 43 |
+
explicit ResultEntry(std::string key, double time) : key_(std::move(key)), time_(time) {}
|
| 44 |
+
explicit ResultEntry(std::string key, double time, std::string blas_sig ) : key_(std::move(key)), time_(time), blas_sig_(std::move(blas_sig)) {}
|
| 45 |
+
bool operator==(const ResultEntry& other) const { return key_ == other.key_; }
|
| 46 |
+
bool operator!=(const ResultEntry& other) const { return key_ != other.key_; }
|
| 47 |
+
operator std::string () { return key_; }
|
| 48 |
+
std::string GetKey() const { return key_; }
|
| 49 |
+
double GetTime() const { return time_; }
|
| 50 |
+
friend std::ostream& operator<<(std::ostream& stream, const ResultEntry& entry);
|
| 51 |
+
static ResultEntry Null() { return ResultEntry("Null", 0.0); }
|
| 52 |
+
static ResultEntry Default() { return ResultEntry("Default", 0.0); }
|
| 53 |
+
|
| 54 |
+
private:
|
| 55 |
+
std::string key_;
|
| 56 |
+
double time_;
|
| 57 |
+
std::string blas_sig_;
|
| 58 |
+
};
|
| 59 |
+
|
| 60 |
+
typedef std::unordered_map<std::string, ResultEntry> KernelMap;
|
| 61 |
+
typedef std::unordered_map<std::string, KernelMap> ResultsMap;
|
| 62 |
+
typedef std::unordered_map<std::string, std::unordered_set<std::string>> UntunedMap;
|
| 63 |
+
|
| 64 |
+
struct TORCH_CUDA_CPP_API TuningResults {
|
| 65 |
+
// Validates if these results are compatible with the libraries
|
| 66 |
+
std::unordered_map<std::string, std::string> validators;
|
| 67 |
+
|
| 68 |
+
// Mapping from Callable signature to Callable's tuning result
|
| 69 |
+
ResultsMap results;
|
| 70 |
+
};
|
| 71 |
+
|
| 72 |
+
class TORCH_CUDA_CPP_API TuningResultsManager {
|
| 73 |
+
public:
|
| 74 |
+
TuningResultsManager() = default;
|
| 75 |
+
~TuningResultsManager() = default;
|
| 76 |
+
|
| 77 |
+
KernelMap Lookup(const std::string& op_signature);
|
| 78 |
+
|
| 79 |
+
ResultEntry Lookup(const std::string& op_signature, const std::string& params_signature);
|
| 80 |
+
|
| 81 |
+
void AddImpl(const std::string& op_signature,
|
| 82 |
+
const std::string& params_signature,
|
| 83 |
+
ResultEntry best,
|
| 84 |
+
KernelMap& kernel_map);
|
| 85 |
+
|
| 86 |
+
void Add(const std::string& op_signature,
|
| 87 |
+
const std::string& params_signature,
|
| 88 |
+
ResultEntry best);
|
| 89 |
+
|
| 90 |
+
void Delete(const std::string& op_signature, const std::string& params_signature);
|
| 91 |
+
|
| 92 |
+
void DisjointMergeImpl(
|
| 93 |
+
const std::string& op_signature,
|
| 94 |
+
const KernelMap& kernel_map,
|
| 95 |
+
/*out*/ ResultsMap& results);
|
| 96 |
+
|
| 97 |
+
void Load(const ResultsMap& results_to_load);
|
| 98 |
+
|
| 99 |
+
ResultsMap Dump();
|
| 100 |
+
|
| 101 |
+
void DisjointMerge(const std::string& op_signature, const KernelMap& kernel_map);
|
| 102 |
+
|
| 103 |
+
size_t GetSize();
|
| 104 |
+
|
| 105 |
+
void RecordUntuned( std::ofstream& untuned_file, const std::string& op_signature,
|
| 106 |
+
const std::string& params_signature, const std::string& blas_signature);
|
| 107 |
+
|
| 108 |
+
void InitRealtimeAppend(
|
| 109 |
+
const std::string& filename,
|
| 110 |
+
const std::unordered_map<std::string, std::string>& validators);
|
| 111 |
+
|
| 112 |
+
void AppendResultLine(const std::string& op_sig,
|
| 113 |
+
const std::string& param_sig,
|
| 114 |
+
const ResultEntry& result);
|
| 115 |
+
|
| 116 |
+
void CloseRealtimeAppend(); // For clean shutdown
|
| 117 |
+
private:
|
| 118 |
+
std::mutex lock_;
|
| 119 |
+
std::mutex realtime_file_mutex_;
|
| 120 |
+
std::unique_ptr<std::ofstream> realtime_out_;
|
| 121 |
+
std::string realtime_filename_;
|
| 122 |
+
ResultsMap results_;
|
| 123 |
+
UntunedMap untuned_results_;
|
| 124 |
+
bool validators_written_ = false;
|
| 125 |
+
|
| 126 |
+
};
|
| 127 |
+
|
| 128 |
+
class TORCH_CUDA_CPP_API TuningResultsValidator {
|
| 129 |
+
public:
|
| 130 |
+
using GetFunc = std::function<std::string()>;
|
| 131 |
+
using ValidateFunc = std::function<TuningStatus(const std::string&)>;
|
| 132 |
+
using GetValidateFuncs = std::unordered_map<std::string, std::pair<GetFunc, ValidateFunc>>;
|
| 133 |
+
|
| 134 |
+
TuningResultsValidator();
|
| 135 |
+
~TuningResultsValidator() = default;
|
| 136 |
+
|
| 137 |
+
std::unordered_map<std::string, std::string> GetAllValidators() const;
|
| 138 |
+
TuningStatus ValidateAll(const std::unordered_map<std::string, std::string>& to_validate) const;
|
| 139 |
+
void RegisterValidator(const std::string& key, const GetFunc& gf, const ValidateFunc& vf);
|
| 140 |
+
|
| 141 |
+
protected:
|
| 142 |
+
static std::string GetPyTorchVersion() ;
|
| 143 |
+
TuningStatus ValidatePyTorchVersion(const std::string& value) const;
|
| 144 |
+
|
| 145 |
+
public:
|
| 146 |
+
static constexpr const std::array mandatory_keys{"PT_VERSION"};
|
| 147 |
+
|
| 148 |
+
private:
|
| 149 |
+
GetValidateFuncs validators_;
|
| 150 |
+
};
|
| 151 |
+
|
| 152 |
+
struct NumericalCheckConfig {
|
| 153 |
+
bool enabled{false};
|
| 154 |
+
double atol{1e-5};
|
| 155 |
+
double rtol{1e-5};
|
| 156 |
+
|
| 157 |
+
NumericalCheckConfig() = default;
|
| 158 |
+
NumericalCheckConfig(bool e, double a, double r) : enabled(e), atol(a), rtol(r) {}
|
| 159 |
+
};
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class TORCH_CUDA_CPP_API TuningContext {
|
| 163 |
+
public:
|
| 164 |
+
TuningContext();
|
| 165 |
+
~TuningContext();
|
| 166 |
+
TuningContext(TuningContext &) = delete;
|
| 167 |
+
TuningContext(TuningContext &&) = delete;
|
| 168 |
+
TuningContext &operator=(TuningContext &) = delete;
|
| 169 |
+
TuningContext &operator=(TuningContext &&) = delete;
|
| 170 |
+
|
| 171 |
+
void EnableTunableOp(bool value);
|
| 172 |
+
bool IsTunableOpEnabled() const;
|
| 173 |
+
|
| 174 |
+
void EnableTuning(bool value);
|
| 175 |
+
bool IsTuningEnabled() const;
|
| 176 |
+
|
| 177 |
+
void EnableRecordUntuned(bool value);
|
| 178 |
+
bool IsRecordUntunedEnabled() const;
|
| 179 |
+
std::ofstream& GetUntunedFile();
|
| 180 |
+
|
| 181 |
+
void EnableNumericsCheck(bool value);
|
| 182 |
+
bool IsNumericsCheckEnabled() const;
|
| 183 |
+
void SetNumericalCheckConfig(bool enabled, double atol, double rtol);
|
| 184 |
+
NumericalCheckConfig GetNumericalCheckConfig() const;
|
| 185 |
+
|
| 186 |
+
void SetMaxTuningDurationMs(int max_duration_ms);
|
| 187 |
+
int GetMaxTuningDurationMs() const;
|
| 188 |
+
|
| 189 |
+
void SetMaxTuningIterations(int max_iter);
|
| 190 |
+
int GetMaxTuningIterations() const;
|
| 191 |
+
|
| 192 |
+
void SetMaxWarmupDurationMs(int max_duration_ms);
|
| 193 |
+
int GetMaxWarmupDurationMs() const;
|
| 194 |
+
|
| 195 |
+
void SetMaxWarmupIterations(int max_iter);
|
| 196 |
+
int GetMaxWarmupIterations() const;
|
| 197 |
+
|
| 198 |
+
void EnableICacheFlush(bool value);
|
| 199 |
+
bool IsICacheFlushEnabled() const;
|
| 200 |
+
|
| 201 |
+
void SetRotatingBufferSize(int size);
|
| 202 |
+
int GetRotatingBufferSize() const;
|
| 203 |
+
|
| 204 |
+
TuningResultsManager& GetTuningResultsManager();
|
| 205 |
+
|
| 206 |
+
TuningResultsValidator& GetTuningResultsValidator();
|
| 207 |
+
|
| 208 |
+
TuningResults GetTuningResults();
|
| 209 |
+
|
| 210 |
+
TuningStatus LoadTuningResults(const TuningResults& tr);
|
| 211 |
+
|
| 212 |
+
void SetFilename(const std::string& filename, bool insert_device_ordinal=false);
|
| 213 |
+
std::string GetFilename() const;
|
| 214 |
+
|
| 215 |
+
bool ReadFile(const std::string& filename={});
|
| 216 |
+
|
| 217 |
+
template<class... Types>
|
| 218 |
+
void Log(int level, Types... args) {
|
| 219 |
+
if (GetLogOkay() && GetLogLevel() >= level) {
|
| 220 |
+
GetLog() << c10::str(args...) << std::endl;
|
| 221 |
+
}
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
private:
|
| 225 |
+
std::string GetLogFilename() const;
|
| 226 |
+
int GetLogLevel() const;
|
| 227 |
+
bool GetLogOkay() const;
|
| 228 |
+
std::ostream& GetLog() const;
|
| 229 |
+
|
| 230 |
+
bool enable_;
|
| 231 |
+
bool tuning_enable_;
|
| 232 |
+
bool record_untuned_enable_;
|
| 233 |
+
bool manager_initialized_;
|
| 234 |
+
bool numerics_check_enable_;
|
| 235 |
+
int max_tuning_duration_ms_;
|
| 236 |
+
int max_tuning_iterations_;
|
| 237 |
+
int max_warmup_duration_ms_;
|
| 238 |
+
int max_warmup_iterations_;
|
| 239 |
+
bool icache_flush_;
|
| 240 |
+
int rotating_buffer_size_;
|
| 241 |
+
mutable TuningResultsManager manager_;
|
| 242 |
+
mutable c10::once_flag manager_init_once_;
|
| 243 |
+
TuningResultsValidator validator_;
|
| 244 |
+
std::string filename_;
|
| 245 |
+
std::ofstream untuned_file_;
|
| 246 |
+
size_t results_count_from_input_file_;
|
| 247 |
+
bool is_shutting_down_;
|
| 248 |
+
|
| 249 |
+
NumericalCheckConfig numerics_cfg_{};
|
| 250 |
+
};
|
| 251 |
+
|
| 252 |
+
TORCH_CUDA_CPP_API TuningContext* getTuningContext();
|
| 253 |
+
|
| 254 |
+
class ITimer {
|
| 255 |
+
public:
|
| 256 |
+
ITimer() = default;
|
| 257 |
+
virtual ~ITimer() = default;
|
| 258 |
+
|
| 259 |
+
virtual void Start() = 0;
|
| 260 |
+
virtual void End() = 0;
|
| 261 |
+
|
| 262 |
+
/// Computes the elapsed time in milliseconds between Start() and End()
|
| 263 |
+
virtual float Duration() = 0;
|
| 264 |
+
};
|
| 265 |
+
|
| 266 |
+
} // namespace at::cuda::tunable
|
| 267 |
+
|
| 268 |
+
#else
|
| 269 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 270 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/TunableGemm.h
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
// Original TunableOp is from onnxruntime.
|
| 3 |
+
// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
|
| 4 |
+
// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
|
| 5 |
+
// Copyright (c) Microsoft Corporation.
|
| 6 |
+
// Licensed under the MIT license.
|
| 7 |
+
//
|
| 8 |
+
// Adapting TunableOp into PyTorch
|
| 9 |
+
// Copyright (c) Advanced Micro Devices, Inc.
|
| 10 |
+
//
|
| 11 |
+
#pragma once
|
| 12 |
+
|
| 13 |
+
#include <ATen/cuda/tunable/GemmCommon.h>
|
| 14 |
+
#ifdef USE_ROCM
|
| 15 |
+
#include <ATen/cuda/tunable/GemmHipblaslt.h>
|
| 16 |
+
#include <ATen/cuda/tunable/GemmRocblas.h>
|
| 17 |
+
#endif
|
| 18 |
+
#include <ATen/cuda/tunable/TunableOp.h>
|
| 19 |
+
#include <c10/cuda/CUDACachingAllocator.h>
|
| 20 |
+
#include <c10/util/Float8_e4m3fn.h>
|
| 21 |
+
#include <c10/util/Float8_e4m3fnuz.h>
|
| 22 |
+
#include <c10/util/Float8_e5m2.h>
|
| 23 |
+
#include <c10/util/Float8_e5m2fnuz.h>
|
| 24 |
+
#include <c10/util/Float8_e8m0fnu.h>
|
| 25 |
+
#include <c10/util/StringUtil.h>
|
| 26 |
+
#include <fmt/printf.h>
|
| 27 |
+
|
| 28 |
+
namespace at::cuda::tunable {
|
| 29 |
+
|
| 30 |
+
template <typename T>
|
| 31 |
+
class DefaultGemmOp : public Callable<GemmParams<T>> {
|
| 32 |
+
public:
|
| 33 |
+
TuningStatus Call(const GemmParams<T>* params) override {
|
| 34 |
+
at::cuda::blas::gemm_internal<T>(
|
| 35 |
+
params->transa, params->transb,
|
| 36 |
+
params->m, params->n, params->k,
|
| 37 |
+
params->alpha,
|
| 38 |
+
params->a, params->lda,
|
| 39 |
+
params->b, params->ldb,
|
| 40 |
+
params->beta,
|
| 41 |
+
params->c, params->ldc);
|
| 42 |
+
return OK;
|
| 43 |
+
}
|
| 44 |
+
};
|
| 45 |
+
|
| 46 |
+
static bool _transposeBoolFromChar(char op) {
|
| 47 |
+
return op == 't' || op == 'T';
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
template <typename T>
|
| 51 |
+
class DefaultGemmAndBiasOp : public Callable<GemmAndBiasParams<T>> {
|
| 52 |
+
public:
|
| 53 |
+
TuningStatus Call(const GemmAndBiasParams<T>* params) override {
|
| 54 |
+
at::cuda::blas::gemm_and_bias<T>(
|
| 55 |
+
_transposeBoolFromChar(params->transa),
|
| 56 |
+
_transposeBoolFromChar(params->transb),
|
| 57 |
+
params->m, params->n, params->k,
|
| 58 |
+
params->alpha,
|
| 59 |
+
params->a, params->lda,
|
| 60 |
+
params->b, params->ldb,
|
| 61 |
+
params->bias,
|
| 62 |
+
params->c, params->ldc,
|
| 63 |
+
params->activation);
|
| 64 |
+
return OK;
|
| 65 |
+
}
|
| 66 |
+
};
|
| 67 |
+
|
| 68 |
+
template <typename T>
|
| 69 |
+
class DefaultGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>> {
|
| 70 |
+
public:
|
| 71 |
+
TuningStatus Call(const GemmStridedBatchedParams<T>* params) override {
|
| 72 |
+
at::cuda::blas::bgemm_internal<T>(
|
| 73 |
+
params->transa, params->transb,
|
| 74 |
+
params->m, params->n, params->k,
|
| 75 |
+
params->alpha,
|
| 76 |
+
params->a, params->lda, params->stride_a,
|
| 77 |
+
params->b, params->ldb, params->stride_b,
|
| 78 |
+
params->beta,
|
| 79 |
+
params->c, params->ldc, params->stride_c,
|
| 80 |
+
params->batch);
|
| 81 |
+
return OK;
|
| 82 |
+
}
|
| 83 |
+
};
|
| 84 |
+
|
| 85 |
+
template <typename T>
|
| 86 |
+
class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
|
| 87 |
+
public:
|
| 88 |
+
TuningStatus Call(const ScaledGemmParams<T>* params) override {
|
| 89 |
+
at::cuda::blas::scaled_gemm(
|
| 90 |
+
params->transa,
|
| 91 |
+
params->transb,
|
| 92 |
+
params->m,
|
| 93 |
+
params->n,
|
| 94 |
+
params->k,
|
| 95 |
+
params->a,
|
| 96 |
+
params->a_scale_ptr,
|
| 97 |
+
params->lda,
|
| 98 |
+
params->a_dtype,
|
| 99 |
+
params->a_scale_dtype,
|
| 100 |
+
params->a_scaling_type,
|
| 101 |
+
params->b,
|
| 102 |
+
params->b_scale_ptr,
|
| 103 |
+
params->ldb,
|
| 104 |
+
params->b_dtype,
|
| 105 |
+
params->b_scale_dtype,
|
| 106 |
+
params->b_scaling_type,
|
| 107 |
+
params->bias_ptr,
|
| 108 |
+
params->bias_dtype,
|
| 109 |
+
params->c,
|
| 110 |
+
params->c_scale_ptr,
|
| 111 |
+
params->ldc,
|
| 112 |
+
params->c_dtype,
|
| 113 |
+
params->use_fast_accum,
|
| 114 |
+
std::nullopt /* alpha */);
|
| 115 |
+
return OK;
|
| 116 |
+
}
|
| 117 |
+
};
|
| 118 |
+
|
| 119 |
+
template <typename T>
|
| 120 |
+
inline bool IsZero(T v) {
|
| 121 |
+
return v == 0.0f;
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
template <>
|
| 125 |
+
inline bool IsZero(BFloat16 v) {
|
| 126 |
+
return v.x == 0;
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
template <>
|
| 130 |
+
inline bool IsZero(Half v) {
|
| 131 |
+
return float(v) == 0.0f;
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
template <>
|
| 135 |
+
inline bool IsZero(c10::complex<double> v) {
|
| 136 |
+
return v == 0.0;
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
template <>
|
| 140 |
+
inline bool IsZero(c10::complex<float> v) {
|
| 141 |
+
return v == 0.0f;
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
template <typename T>
|
| 145 |
+
inline const char* TypeName(T v) {
|
| 146 |
+
return "unknown";
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
template <>
|
| 150 |
+
inline const char* TypeName(float v) {
|
| 151 |
+
if (at::globalContext().allowTF32CuBLAS()) {
|
| 152 |
+
return "tf32";
|
| 153 |
+
} else {
|
| 154 |
+
return "float";
|
| 155 |
+
}
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
template <>
|
| 159 |
+
inline const char* TypeName(double v) {
|
| 160 |
+
return "double";
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
template <>
|
| 164 |
+
inline const char* TypeName(BFloat16 v) {
|
| 165 |
+
return "BFloat16";
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
template <>
|
| 169 |
+
inline const char* TypeName(Half v) {
|
| 170 |
+
return "Half";
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
template <>
|
| 174 |
+
inline const char* TypeName(Float8_e4m3fn v) {
|
| 175 |
+
return "Float8_e4m3fn";
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
template <>
|
| 179 |
+
inline const char* TypeName(Float8_e5m2 v) {
|
| 180 |
+
return "Float8_e5m2";
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
template <>
|
| 184 |
+
inline const char* TypeName(Float8_e4m3fnuz v) {
|
| 185 |
+
return "Float8_e4m3fnuz";
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
template <>
|
| 189 |
+
inline const char* TypeName(Float8_e5m2fnuz v) {
|
| 190 |
+
return "Float8_e5m2fnuz";
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
template <>
|
| 194 |
+
inline const char* TypeName(Float8_e8m0fnu v) {
|
| 195 |
+
return "Float8_e8m0fnu";
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
template <>
|
| 199 |
+
inline const char* TypeName(c10::complex<double> v) {
|
| 200 |
+
return "c10::complex<double>";
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
template <>
|
| 204 |
+
inline const char* TypeName(c10::complex<float> v) {
|
| 205 |
+
return "c10::complex<float>";
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
template <typename T, BlasOp ALayout, BlasOp BLayout>
|
| 209 |
+
class GemmTunableOp : public TunableOp<GemmParams<T>> {
|
| 210 |
+
public:
|
| 211 |
+
GemmTunableOp() {
|
| 212 |
+
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmOp<T>>());
|
| 213 |
+
|
| 214 |
+
#ifdef USE_ROCM
|
| 215 |
+
static const auto env_rocblas = c10::utils::check_env("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED");
|
| 216 |
+
if (!env_rocblas.has_value() || env_rocblas.value()) {
|
| 217 |
+
for (auto&& [name, op] : GetRocBlasGemmTypeStringAndOps<T>()) {
|
| 218 |
+
this->RegisterOp(std::move(name), std::move(op));
|
| 219 |
+
}
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
static const auto env_hipblaslt = c10::utils::check_env("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
|
| 223 |
+
if (!env_hipblaslt.has_value() || env_hipblaslt.value()) {
|
| 224 |
+
// disallow tuning of hipblaslt with c10::complex
|
| 225 |
+
if constexpr (
|
| 226 |
+
!std::is_same_v<T, c10::complex<float>> &&
|
| 227 |
+
!std::is_same_v<T, c10::complex<double>>) {
|
| 228 |
+
for (auto&& [name, op] : GetHipBlasLtGemmTypeStringAndOps<T, ALayout, BLayout>()) {
|
| 229 |
+
this->RegisterOp(std::move(name), std::move(op));
|
| 230 |
+
}
|
| 231 |
+
}
|
| 232 |
+
}
|
| 233 |
+
#endif
|
| 234 |
+
|
| 235 |
+
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmOp<T>>());
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
std::string Signature() override {
|
| 239 |
+
return fmt::sprintf("GemmTunableOp_%s_%c%c", TypeName<T>(T{}), BlasOpToString(ALayout), BlasOpToString(BLayout));
|
| 240 |
+
}
|
| 241 |
+
};
|
| 242 |
+
|
| 243 |
+
template <typename T, BlasOp ALayout, BlasOp BLayout>
|
| 244 |
+
class GemmAndBiasTunableOp : public TunableOp<GemmAndBiasParams<T>> {
|
| 245 |
+
public:
|
| 246 |
+
GemmAndBiasTunableOp() {
|
| 247 |
+
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmAndBiasOp<T>>());
|
| 248 |
+
|
| 249 |
+
#ifdef USE_ROCM
|
| 250 |
+
static const auto env_hipblaslt = c10::utils::check_env("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
|
| 251 |
+
if (!env_hipblaslt.has_value() || env_hipblaslt.value()) {
|
| 252 |
+
// disallow tuning of hipblaslt with c10::complex
|
| 253 |
+
if constexpr (
|
| 254 |
+
!std::is_same_v<T, c10::complex<float>> &&
|
| 255 |
+
!std::is_same_v<T, c10::complex<double>>) {
|
| 256 |
+
for (auto&& [name, op] : GetHipBlasLtGemmAndBiasTypeStringAndOps<T, ALayout, BLayout>()) {
|
| 257 |
+
this->RegisterOp(std::move(name), std::move(op));
|
| 258 |
+
}
|
| 259 |
+
}
|
| 260 |
+
}
|
| 261 |
+
#endif
|
| 262 |
+
|
| 263 |
+
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmAndBiasOp<T>>());
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
std::string Signature() override {
|
| 267 |
+
return fmt::sprintf("GemmAndBiasTunableOp_%s_%c%c", TypeName<T>(T{}), BlasOpToString(ALayout), BlasOpToString(BLayout));
|
| 268 |
+
}
|
| 269 |
+
};
|
| 270 |
+
|
| 271 |
+
template <typename T, BlasOp ALayout, BlasOp BLayout>
|
| 272 |
+
class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>> {
|
| 273 |
+
public:
|
| 274 |
+
GemmStridedBatchedTunableOp() {
|
| 275 |
+
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmStridedBatchedOp<T>>());
|
| 276 |
+
|
| 277 |
+
#ifdef USE_ROCM
|
| 278 |
+
static const auto env_rocblas = c10::utils::check_env("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED");
|
| 279 |
+
if (!env_rocblas.has_value() || env_rocblas.value()) {
|
| 280 |
+
for (auto&& [name, op] : GetRocBlasGemmStridedBatchedTypeStringAndOps<T>()) {
|
| 281 |
+
this->RegisterOp(std::move(name), std::move(op));
|
| 282 |
+
}
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
static const auto env_hipblaslt = c10::utils::check_env("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
|
| 286 |
+
if (!env_hipblaslt.has_value() || env_hipblaslt.value()) {
|
| 287 |
+
// disallow tuning of hipblaslt with c10::complex
|
| 288 |
+
if constexpr (
|
| 289 |
+
!std::is_same_v<T, c10::complex<float>> &&
|
| 290 |
+
!std::is_same_v<T, c10::complex<double>>) {
|
| 291 |
+
for (auto&& [name, op] : GetHipBlasLtGemmStridedBatchedTypeStringAndOps<T, ALayout, BLayout>()) {
|
| 292 |
+
this->RegisterOp(std::move(name), std::move(op));
|
| 293 |
+
}
|
| 294 |
+
}
|
| 295 |
+
}
|
| 296 |
+
#endif
|
| 297 |
+
|
| 298 |
+
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmStridedBatchedOp<T>>());
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
std::string Signature() override {
|
| 302 |
+
return fmt::sprintf("GemmStridedBatchedTunableOp_%s_%c%c", TypeName<T>(T{}), BlasOpToString(ALayout), BlasOpToString(BLayout));
|
| 303 |
+
}
|
| 304 |
+
};
|
| 305 |
+
|
| 306 |
+
template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout>
|
| 307 |
+
class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>> {
|
| 308 |
+
public:
|
| 309 |
+
ScaledGemmTunableOp() {
|
| 310 |
+
this->RegisterOp(std::string("Default"), std::make_unique<DefaultScaledGemmOp<CT>>());
|
| 311 |
+
|
| 312 |
+
#ifdef USE_ROCM
|
| 313 |
+
for (auto&& [name, op] : GetHipBlasLtScaledGemmTypeStringAndOps<AT, BT, CT, ALayout, BLayout>()) {
|
| 314 |
+
this->RegisterOp(std::move(name), std::move(op));
|
| 315 |
+
}
|
| 316 |
+
#endif
|
| 317 |
+
|
| 318 |
+
this->RegisterOp(std::string("Default"), std::make_unique<DefaultScaledGemmOp<CT>>());
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
std::string Signature() override {
|
| 322 |
+
return fmt::sprintf("ScaledGemmTunableOp_%s_%s_%s_%c%c",
|
| 323 |
+
TypeName<AT>(AT{}),
|
| 324 |
+
TypeName<BT>(BT{}),
|
| 325 |
+
TypeName<CT>(CT{}),
|
| 326 |
+
BlasOpToString(ALayout), BlasOpToString(BLayout));
|
| 327 |
+
}
|
| 328 |
+
};
|
| 329 |
+
|
| 330 |
+
} // namespace at::cuda::tunable
|
| 331 |
+
|
| 332 |
+
#else
|
| 333 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 334 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/TunableOp.h
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
// Original TunableOp is from onnxruntime.
|
| 3 |
+
// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
|
| 4 |
+
// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
|
| 5 |
+
// Copyright (c) Microsoft Corporation.
|
| 6 |
+
// Licensed under the MIT license.
|
| 7 |
+
//
|
| 8 |
+
// Adapting TunableOp into PyTorch
|
| 9 |
+
// Copyright (c) Advanced Micro Devices, Inc.
|
| 10 |
+
//
|
| 11 |
+
#pragma once
|
| 12 |
+
|
| 13 |
+
#include <ATen/cuda/tunable/Tunable.h>
|
| 14 |
+
#include <ATen/cuda/tunable/StreamTimer.h>
|
| 15 |
+
#include <ATen/cuda/Sleep.h>
|
| 16 |
+
#include <c10/cuda/CUDACachingAllocator.h>
|
| 17 |
+
|
| 18 |
+
#ifndef _WIN32
|
| 19 |
+
#include <cxxabi.h>
|
| 20 |
+
#endif
|
| 21 |
+
|
| 22 |
+
#include <string>
|
| 23 |
+
#include <unordered_map>
|
| 24 |
+
#include <vector>
|
| 25 |
+
#include <deque>
|
| 26 |
+
|
| 27 |
+
namespace at::cuda::tunable {
|
| 28 |
+
|
| 29 |
+
template <typename ParamsT>
|
| 30 |
+
class Callable {
|
| 31 |
+
public:
|
| 32 |
+
virtual ~Callable() = default;
|
| 33 |
+
virtual TuningStatus Call(const ParamsT* /*unused*/) {
|
| 34 |
+
return FAIL;
|
| 35 |
+
}
|
| 36 |
+
virtual TuningStatus IsSupported(const ParamsT* params) {
|
| 37 |
+
return Call(params);
|
| 38 |
+
}
|
| 39 |
+
};
|
| 40 |
+
|
| 41 |
+
namespace {
|
| 42 |
+
|
| 43 |
+
/** http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance */
|
| 44 |
+
|
| 45 |
+
class Stats {
|
| 46 |
+
public:
|
| 47 |
+
Stats() {
|
| 48 |
+
_n = 0UL;
|
| 49 |
+
_mean = 0.0;
|
| 50 |
+
_M2 = 0.0;
|
| 51 |
+
_sum = 0.0;
|
| 52 |
+
_min = 0.0;
|
| 53 |
+
_max = 0.0;
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
void sample_value(const double x) {
|
| 57 |
+
double delta = 0;
|
| 58 |
+
_sum = _sum + x;
|
| 59 |
+
if (0UL == _n) {
|
| 60 |
+
_min = x;
|
| 61 |
+
_max = x;
|
| 62 |
+
}
|
| 63 |
+
else {
|
| 64 |
+
_min = _min < x ? _min : x;
|
| 65 |
+
_max = _max > x ? _max : x;
|
| 66 |
+
}
|
| 67 |
+
_n = _n + 1UL;
|
| 68 |
+
delta = x - _mean;
|
| 69 |
+
_mean = _mean + delta/_n;
|
| 70 |
+
_M2 = _M2 + delta * (x - _mean);
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
double variance() const {
|
| 74 |
+
return _M2/(_n-1);
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
double stddev() const {
|
| 78 |
+
return std::sqrt(variance());
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
unsigned long _n;
|
| 82 |
+
double _mean;
|
| 83 |
+
double _M2;
|
| 84 |
+
double _sum;
|
| 85 |
+
double _min;
|
| 86 |
+
double _max;
|
| 87 |
+
};
|
| 88 |
+
|
| 89 |
+
class FixedSizeStack {
|
| 90 |
+
private:
|
| 91 |
+
std::deque<std::string> stack;
|
| 92 |
+
const size_t max_size;
|
| 93 |
+
|
| 94 |
+
public:
|
| 95 |
+
FixedSizeStack(size_t size) : max_size(size) {}
|
| 96 |
+
|
| 97 |
+
void push(const std::string& value) {
|
| 98 |
+
if (stack.size() >= max_size) {
|
| 99 |
+
stack.pop_front(); // Remove the oldest entry
|
| 100 |
+
}
|
| 101 |
+
stack.push_back(value); // Add new entry
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
auto rbegin() { return stack.rbegin(); }
|
| 105 |
+
auto rend() { return stack.rend(); }
|
| 106 |
+
};
|
| 107 |
+
|
| 108 |
+
} // anonymous namespace
|
| 109 |
+
|
| 110 |
+
template <typename ParamsT>
|
| 111 |
+
class TunableOp {
|
| 112 |
+
public:
|
| 113 |
+
virtual ~TunableOp() = default;
|
| 114 |
+
|
| 115 |
+
TuningStatus operator()(const ParamsT* params) {
|
| 116 |
+
ResultEntry result = ResultEntry::Null();
|
| 117 |
+
TuningContext* ctx = getTuningContext();
|
| 118 |
+
if (ctx->IsTunableOpEnabled()) {
|
| 119 |
+
auto& mgr = ctx->GetTuningResultsManager();
|
| 120 |
+
auto op_sig = Signature();
|
| 121 |
+
auto params_sig = params->Signature();
|
| 122 |
+
auto blas_sig = params->BLASSignature();
|
| 123 |
+
result = mgr.Lookup(op_sig, params_sig);
|
| 124 |
+
// If there is not previous tuning result been found, we do the tuning iff tuning is enabled
|
| 125 |
+
if (result == ResultEntry::Null()) {
|
| 126 |
+
if (ctx->IsTuningEnabled()) {
|
| 127 |
+
result = FindFastest(params);
|
| 128 |
+
mgr.Add(op_sig, params_sig, result);
|
| 129 |
+
}
|
| 130 |
+
else if (ctx->IsRecordUntunedEnabled()) {
|
| 131 |
+
// or record the gemm into file
|
| 132 |
+
mgr.RecordUntuned(ctx->GetUntunedFile(), op_sig, params_sig, blas_sig);
|
| 133 |
+
}
|
| 134 |
+
}
|
| 135 |
+
}
|
| 136 |
+
else {
|
| 137 |
+
result = ResultEntry::Default();
|
| 138 |
+
}
|
| 139 |
+
if (result == ResultEntry::Null()) {
|
| 140 |
+
TUNABLE_LOG2("no result, using default");
|
| 141 |
+
result = ResultEntry::Default();
|
| 142 |
+
}
|
| 143 |
+
auto iter = ops_.find(result);
|
| 144 |
+
TORCH_CHECK(iter != ops_.end());
|
| 145 |
+
return iter->second->Call(params);
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
virtual std::string Signature() {
|
| 149 |
+
// According to C++17 standard https://wg21.link/n4659 section 15.7.4
|
| 150 |
+
// > if the operand of typeid refers to the
|
| 151 |
+
// > object under construction or destruction, typeid yields the std::type_info object representing the constructor
|
| 152 |
+
// > or destructor’s class.
|
| 153 |
+
// So delay the op signature generation.
|
| 154 |
+
c10::call_once(signature_init_once_, [this]() { signature_ = CreateSignature(); });
|
| 155 |
+
return signature_;
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
protected:
|
| 159 |
+
void RegisterOp(const std::string& name, std::unique_ptr<Callable<ParamsT>> op) {
|
| 160 |
+
this->op_names_.emplace_back(name);
|
| 161 |
+
this->ops_.emplace(name, std::move(op));
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
private:
|
| 165 |
+
static void WarmUp(Callable<ParamsT> *op, const std::vector<ParamsT*> ¶m, size_t num_iter, size_t &offset) {
|
| 166 |
+
TuningContext* ctx = getTuningContext();
|
| 167 |
+
bool do_flush = ctx->IsICacheFlushEnabled();
|
| 168 |
+
for (size_t i = 0; i < num_iter; i++) {
|
| 169 |
+
if (do_flush) {
|
| 170 |
+
at::cuda::flush_icache();
|
| 171 |
+
}
|
| 172 |
+
TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK);
|
| 173 |
+
}
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
static double ProfileSimple(Callable<ParamsT> *op, const std::vector<ParamsT*> ¶m, size_t num_iter, size_t &offset) {
|
| 177 |
+
TuningContext* ctx = getTuningContext();
|
| 178 |
+
bool do_flush = ctx->IsICacheFlushEnabled();
|
| 179 |
+
StreamTimerNoSync timer{};
|
| 180 |
+
|
| 181 |
+
// Small Mandatory Warmup
|
| 182 |
+
// Reduces outliers
|
| 183 |
+
for (size_t i = 0; i < 2; i++) {
|
| 184 |
+
TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK);
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
timer.Start();
|
| 188 |
+
for (size_t i = 0; i < num_iter; i++) {
|
| 189 |
+
if (do_flush) {
|
| 190 |
+
at::cuda::flush_icache();
|
| 191 |
+
}
|
| 192 |
+
TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK);
|
| 193 |
+
}
|
| 194 |
+
timer.End();
|
| 195 |
+
return timer.Duration() / num_iter;
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
static Stats ProfileStats(Callable<ParamsT> *op, const std::vector<ParamsT*> ¶m, size_t num_iter, size_t &offset) {
|
| 199 |
+
TuningContext* ctx = getTuningContext();
|
| 200 |
+
bool do_flush = ctx->IsICacheFlushEnabled();
|
| 201 |
+
std::vector<StreamTimerNoSync> timer(num_iter);
|
| 202 |
+
|
| 203 |
+
// Small Mandatory Warmup
|
| 204 |
+
// Reduces outliers
|
| 205 |
+
for (size_t i = 0; i < 2; i++) {
|
| 206 |
+
TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK);
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
for (size_t i = 0; i < num_iter; i++) {
|
| 210 |
+
timer[i].Start();
|
| 211 |
+
TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK);
|
| 212 |
+
timer[i].End();
|
| 213 |
+
if (do_flush) {
|
| 214 |
+
at::cuda::flush_icache();
|
| 215 |
+
}
|
| 216 |
+
}
|
| 217 |
+
Stats s;
|
| 218 |
+
for (size_t i = 0; i < num_iter; i++) {
|
| 219 |
+
s.sample_value(timer[i].Duration());
|
| 220 |
+
}
|
| 221 |
+
return s;
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
protected:
|
| 225 |
+
virtual ResultEntry FindFastest(const ParamsT* params) {
|
| 226 |
+
TuningContext* ctx = getTuningContext();
|
| 227 |
+
auto op_sig = Signature();
|
| 228 |
+
auto params_sig = params->Signature();
|
| 229 |
+
auto blas_sig = params->BLASSignature();
|
| 230 |
+
TUNABLE_LOG2("finding fastest for ", op_sig, '(', params_sig, ')', " out of ", op_names_.size(), " candidates");
|
| 231 |
+
auto min_duration_ms = std::numeric_limits<double>::infinity();
|
| 232 |
+
std::string id_name = "Default";
|
| 233 |
+
ParamsT* reference_params = nullptr;
|
| 234 |
+
auto top_solns = FixedSizeStack(5);
|
| 235 |
+
|
| 236 |
+
// numeric check option is controlled by non-static env var, so check it once per tuned operator
|
| 237 |
+
bool do_numerics_check = ctx->IsNumericsCheckEnabled();
|
| 238 |
+
|
| 239 |
+
// calculate a reference answer for numerical check
|
| 240 |
+
if (do_numerics_check) {
|
| 241 |
+
reference_params = params->DeepCopy(false);
|
| 242 |
+
TORCH_CHECK(ops_[ResultEntry::Default()]->Call(reference_params) == OK);
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
// need copies of params to reuse
|
| 246 |
+
// make as many copies as will fill the requested rotating buffer size, if requested
|
| 247 |
+
// rotating_size guaranteed to be >= 0 even though GetRotatingBufferSize() returns int
|
| 248 |
+
size_t rotating_size = ctx->GetRotatingBufferSize();
|
| 249 |
+
bool use_buffer_rotation = (rotating_size > 0);
|
| 250 |
+
size_t param_size = params->GetSize(use_buffer_rotation);
|
| 251 |
+
size_t param_count = (rotating_size / param_size) + 1;
|
| 252 |
+
constexpr size_t MB = 1024ull*1024;
|
| 253 |
+
if (use_buffer_rotation) {
|
| 254 |
+
TUNABLE_LOG2("Rotating buffer ", rotating_size/MB, " MiB. ",
|
| 255 |
+
"Needed Size: ", param_size/MB, " MiB. ",
|
| 256 |
+
"Needed number of param copies: ", param_count);
|
| 257 |
+
}
|
| 258 |
+
TORCH_CHECK(param_count > 0);
|
| 259 |
+
|
| 260 |
+
std::vector<ParamsT*> reusable_params(param_count);
|
| 261 |
+
for (size_t i = 0; i < param_count; i++) {
|
| 262 |
+
reusable_params[i] = params->DeepCopy(use_buffer_rotation);
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
// for rotating buffer
|
| 266 |
+
size_t offset = 0;
|
| 267 |
+
|
| 268 |
+
for (size_t i = 0; i < op_names_.size(); i++) {
|
| 269 |
+
auto* candidate = ops_[op_names_[i]].get(); // borrow pointer
|
| 270 |
+
|
| 271 |
+
auto status = candidate->Call(reusable_params[0]);
|
| 272 |
+
if (status != OK) {
|
| 273 |
+
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
| 274 |
+
continue;
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
// collect a small profile
|
| 278 |
+
int approx_num_iter = 3;
|
| 279 |
+
auto s = ProfileStats(candidate, reusable_params, approx_num_iter, offset);
|
| 280 |
+
double approx_duration = s._mean;
|
| 281 |
+
// bail if too slow
|
| 282 |
+
if (approx_duration > 1.5 * min_duration_ms) {
|
| 283 |
+
TUNABLE_LOG3("├──skip slow instance id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
| 284 |
+
continue;
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
// 2nd phase skip, more aggressive
|
| 288 |
+
approx_num_iter = 10;
|
| 289 |
+
s = ProfileStats(candidate, reusable_params, approx_num_iter, offset);
|
| 290 |
+
approx_duration = s._mean;
|
| 291 |
+
// bail if too slow
|
| 292 |
+
if (approx_duration > 1.15 * min_duration_ms) {
|
| 293 |
+
TUNABLE_LOG3("├──2nd skip slow instance id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
| 294 |
+
continue;
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
if (do_numerics_check) {
|
| 298 |
+
ParamsT* numerical_params = params->DeepCopy(false);
|
| 299 |
+
auto status = candidate->Call(numerical_params);
|
| 300 |
+
if (status != OK) {
|
| 301 |
+
numerical_params->Delete();
|
| 302 |
+
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
| 303 |
+
continue;
|
| 304 |
+
}
|
| 305 |
+
status = reference_params->NumericalCheck(numerical_params);
|
| 306 |
+
numerical_params->Delete();
|
| 307 |
+
if (status != OK) {
|
| 308 |
+
TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
| 309 |
+
continue;
|
| 310 |
+
}
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
// for warmup does user set max duration, max iters, or both?
|
| 314 |
+
// warmup is skipped by default, i.e. warmup_iter = 0
|
| 315 |
+
// warmup will be set to the non-zero value of max_warmup_duration
|
| 316 |
+
// or max_warmup_iter
|
| 317 |
+
// if both are non-zero, we take the smaller of the two.
|
| 318 |
+
double max_warmup_duration = ctx->GetMaxWarmupDurationMs();
|
| 319 |
+
int max_warmup_iter = ctx->GetMaxWarmupIterations();
|
| 320 |
+
int warmup_iter = 0; // default
|
| 321 |
+
if (max_warmup_duration > 0) {
|
| 322 |
+
int duration_iters = max_warmup_duration / approx_duration;
|
| 323 |
+
if (max_warmup_iter > 0) {
|
| 324 |
+
warmup_iter = std::min(max_warmup_iter, duration_iters);
|
| 325 |
+
}
|
| 326 |
+
else {
|
| 327 |
+
warmup_iter = duration_iters;
|
| 328 |
+
}
|
| 329 |
+
}
|
| 330 |
+
else if (max_warmup_iter > 0) {
|
| 331 |
+
warmup_iter = max_warmup_iter;
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
// for tuning does user set max duration, max iters, or both?
|
| 335 |
+
double max_tuning_duration = ctx->GetMaxTuningDurationMs();
|
| 336 |
+
int max_tuning_iter = ctx->GetMaxTuningIterations();
|
| 337 |
+
int tuning_iter = 100; // default
|
| 338 |
+
if (max_tuning_duration > 0) {
|
| 339 |
+
int duration_iters = max_tuning_duration / approx_duration;
|
| 340 |
+
if (max_tuning_iter > 0) {
|
| 341 |
+
tuning_iter = std::min(max_tuning_iter, duration_iters);
|
| 342 |
+
}
|
| 343 |
+
else {
|
| 344 |
+
tuning_iter = duration_iters;
|
| 345 |
+
}
|
| 346 |
+
}
|
| 347 |
+
else if (max_tuning_iter > 0) {
|
| 348 |
+
tuning_iter = max_tuning_iter;
|
| 349 |
+
}
|
| 350 |
+
// tuning must run at least 1 iteration
|
| 351 |
+
tuning_iter = std::max(1, tuning_iter);
|
| 352 |
+
|
| 353 |
+
// do the full warmup followed by tuning
|
| 354 |
+
double warmup_ms = warmup_iter * approx_duration;
|
| 355 |
+
double tuning_ms = tuning_iter * approx_duration;
|
| 356 |
+
TUNABLE_LOG3("├──tuning using "
|
| 357 |
+
"warmup iters ", warmup_iter, " [", warmup_ms, " ms] "
|
| 358 |
+
"and tuning iters ", tuning_iter, " [", tuning_ms, " ms] ",
|
| 359 |
+
"instance id=", i, ", ", op_sig, "(", params_sig, ") ", op_names_[i]);
|
| 360 |
+
TUNABLE_LOG3("├──offset at ", offset);
|
| 361 |
+
WarmUp(candidate, reusable_params, warmup_iter, offset);
|
| 362 |
+
s = ProfileStats(candidate, reusable_params, tuning_iter, offset);
|
| 363 |
+
auto s_stddev = s.stddev();
|
| 364 |
+
// Assume normal distribution.
|
| 365 |
+
// Solution with smallest mean + 2*sigma will be a better solution?
|
| 366 |
+
// if ((s._mean + 2*s_stddev) < (min_duration_ms + 2*min_stddev_ms)) {
|
| 367 |
+
if (s._mean < min_duration_ms) {
|
| 368 |
+
TUNABLE_LOG3("├──found better instance id=", i, ". " , s._mean, "ms. ", op_names_[i],
|
| 369 |
+
" min ", s._min,
|
| 370 |
+
" max ", s._max,
|
| 371 |
+
" mean ", s._mean,
|
| 372 |
+
" std ", s_stddev);
|
| 373 |
+
min_duration_ms = s._mean;
|
| 374 |
+
id_name = op_names_[i];
|
| 375 |
+
std::string current_soln = std::to_string(s._mean) + " " + op_names_[i];
|
| 376 |
+
top_solns.push(current_soln);
|
| 377 |
+
}
|
| 378 |
+
else {
|
| 379 |
+
TUNABLE_LOG3("├──found slower instance id=", i, ". " , s._mean, "ms. ", op_names_[i],
|
| 380 |
+
" min ", s._min,
|
| 381 |
+
" max ", s._max,
|
| 382 |
+
" mean ", s._mean,
|
| 383 |
+
" std ", s_stddev);
|
| 384 |
+
}
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
for (size_t i = 0; i < reusable_params.size(); i++) {
|
| 388 |
+
reusable_params[i]->Delete();
|
| 389 |
+
}
|
| 390 |
+
if (reference_params) {
|
| 391 |
+
reference_params->Delete();
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
TUNABLE_LOG2("└──found fastest for ", op_sig, '(', params_sig, ") ", id_name);
|
| 395 |
+
TUNABLE_LOG2("└──top five solutions for ", op_sig, '(', params_sig, ") ");
|
| 396 |
+
for (auto it = top_solns.rbegin(); it != top_solns.rend(); ++it) {
|
| 397 |
+
TUNABLE_LOG2(" ", *it);
|
| 398 |
+
}
|
| 399 |
+
return ResultEntry(id_name, min_duration_ms, blas_sig);
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
private:
|
| 403 |
+
std::string CreateSignature() {
|
| 404 |
+
#ifndef _WIN32
|
| 405 |
+
const auto* name = typeid(*this).name();
|
| 406 |
+
// NOLINTNEXTLINE(*array*)
|
| 407 |
+
char buf[256];
|
| 408 |
+
size_t buf_len = 256;
|
| 409 |
+
abi::__cxa_demangle(name, buf, &buf_len, nullptr);
|
| 410 |
+
buf[255] = '\0';
|
| 411 |
+
return buf;
|
| 412 |
+
#else
|
| 413 |
+
return typeid(*this).name();
|
| 414 |
+
#endif
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
mutable c10::once_flag signature_init_once_;
|
| 418 |
+
std::string signature_;
|
| 419 |
+
|
| 420 |
+
std::unordered_map<std::string, std::unique_ptr<Callable<ParamsT>>> ops_;
|
| 421 |
+
std::vector<std::string> op_names_;
|
| 422 |
+
};
|
| 423 |
+
|
| 424 |
+
struct OpParams {
|
| 425 |
+
virtual ~OpParams() = default;
|
| 426 |
+
virtual std::string Signature() const = 0;
|
| 427 |
+
virtual std::string BLASSignature() const = 0;
|
| 428 |
+
};
|
| 429 |
+
|
| 430 |
+
} // namespace at::cuda::tunable
|
| 431 |
+
|
| 432 |
+
#else
|
| 433 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 434 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/functorch/ADInterpreters.h
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
#include <ATen/functorch/Interpreter.h>
|
| 4 |
+
|
| 5 |
+
namespace at::functorch {
|
| 6 |
+
|
| 7 |
+
// These are the interpreters for our AD transforms
|
| 8 |
+
// (grad, vjp and jvp).
|
| 9 |
+
// See NOTE: [functorch interpreter stack] for more details.
|
| 10 |
+
|
| 11 |
+
struct TORCH_API GradInterpreterPtr {
|
| 12 |
+
explicit GradInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Grad); }
|
| 13 |
+
TransformType key() const { return base_->key(); }
|
| 14 |
+
int64_t level() const { return base_->level(); }
|
| 15 |
+
void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
| 16 |
+
void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
|
| 17 |
+
bool prevGradMode() const {
|
| 18 |
+
return std::get<GradInterpreterMeta>(base_->meta()).prevGradMode_;
|
| 19 |
+
}
|
| 20 |
+
Tensor lift(const Tensor& tensor) const;
|
| 21 |
+
private:
|
| 22 |
+
const Interpreter* base_;
|
| 23 |
+
};
|
| 24 |
+
|
| 25 |
+
struct TORCH_API JvpInterpreterPtr {
|
| 26 |
+
explicit JvpInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Jvp); }
|
| 27 |
+
TransformType key() const { return base_->key(); }
|
| 28 |
+
int64_t level() const { return base_->level(); }
|
| 29 |
+
void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
| 30 |
+
void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
|
| 31 |
+
bool prevFwdGradMode() const {
|
| 32 |
+
return std::get<JvpInterpreterMeta>(base_->meta()).prevFwdGradMode_;
|
| 33 |
+
}
|
| 34 |
+
Tensor lift(const Tensor& tensor) const;
|
| 35 |
+
private:
|
| 36 |
+
const Interpreter* base_;
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
} // namespace at::functorch
|
| 40 |
+
|
| 41 |
+
#else
|
| 42 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 43 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchRulesHelper.h
ADDED
|
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
// Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
// All rights reserved.
|
| 4 |
+
//
|
| 5 |
+
// This source code is licensed under the BSD-style license found in the
|
| 6 |
+
// LICENSE file in the root directory of this source tree.
|
| 7 |
+
#pragma once
|
| 8 |
+
|
| 9 |
+
#include <c10/util/TypeList.h>
|
| 10 |
+
|
| 11 |
+
#include <ATen/ATen.h>
|
| 12 |
+
#include <ATen/Operators.h>
|
| 13 |
+
|
| 14 |
+
#include <ATen/functorch/DynamicLayer.h>
|
| 15 |
+
#include <ATen/functorch/TensorWrapper.h>
|
| 16 |
+
#include <ATen/functorch/BatchingMetaprogramming.h>
|
| 17 |
+
#include <ATen/functorch/LegacyVmapTransforms.h>
|
| 18 |
+
#include <ATen/functorch/BatchedFallback.h>
|
| 19 |
+
#include <ATen/functorch/PlumbingHelper.h>
|
| 20 |
+
#include <ATen/core/dispatch/Dispatcher.h>
|
| 21 |
+
#include <ATen/VmapGeneratedPlumbing.h>
|
| 22 |
+
|
| 23 |
+
#include <utility>
|
| 24 |
+
|
| 25 |
+
// This file contains helper functions for batching rules.
|
| 26 |
+
|
| 27 |
+
namespace at::functorch {
|
| 28 |
+
|
| 29 |
+
TORCH_API Tensor reshape_dim_into(int64_t src, int64_t dst, const Tensor& x);
|
| 30 |
+
TORCH_API Tensor reshape_dim_outof(int64_t src, int64_t size1, const Tensor& x);
|
| 31 |
+
|
| 32 |
+
TORCH_API Tensor reshape_dim_outof_symint(int64_t src, const c10::SymInt& size1, const Tensor& x);
|
| 33 |
+
|
| 34 |
+
Tensor moveBatchDimToFront(Tensor tensor, std::optional<int64_t> maybe_batch_dim);
|
| 35 |
+
int64_t rankWithoutBatchDim(const Tensor& tensor, std::optional<int64_t> maybe_batch_dim);
|
| 36 |
+
int64_t numelWithoutBatchDim(const Tensor& tensor, std::optional<int64_t> maybe_batch_dim);
|
| 37 |
+
std::optional<int64_t> valIfNonempty(std::optional<int64_t> maybe_empty, int64_t new_val);
|
| 38 |
+
int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical_dim);
|
| 39 |
+
VmapDimVector getPhysicalDims(const Tensor& tensor, bool has_batch_dim, IntArrayRef logical_dims);
|
| 40 |
+
|
| 41 |
+
void vmapIncompatibleInplaceError(const char* schema_name);
|
| 42 |
+
|
| 43 |
+
Tensor maybePadToLogicalRank(const Tensor& tensor, std::optional<int64_t> has_bdim, int64_t logical_rank);
|
| 44 |
+
|
| 45 |
+
void check_randomness(RandomnessType randomness);
|
| 46 |
+
void check_randomness(RandomnessType randomness, bool any_tensor_bdim);
|
| 47 |
+
|
| 48 |
+
inline Tensor ensure_has_bdim(const Tensor& tensor, bool has_bdim, c10::SymInt batch_size) {
|
| 49 |
+
if (has_bdim) {
|
| 50 |
+
return tensor;
|
| 51 |
+
}
|
| 52 |
+
const auto sizes = tensor.sym_sizes();
|
| 53 |
+
SymDimVector expanded_shape;
|
| 54 |
+
expanded_shape.reserve(sizes.size());
|
| 55 |
+
expanded_shape.emplace_back(std::move(batch_size));
|
| 56 |
+
expanded_shape.insert(expanded_shape.end(), sizes.begin(), sizes.end());
|
| 57 |
+
return tensor.expand_symint(expanded_shape);
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
#define VMAP_SUPPORT(op, batch_rule) \
|
| 61 |
+
m.impl(#op, op ## _generated_plumbing<decltype(&batch_rule), &batch_rule>);
|
| 62 |
+
|
| 63 |
+
#define VMAP_SUPPORT2(op, overload, batch_rule) \
|
| 64 |
+
m.impl(#op "." #overload, op ## _ ## overload ## _generated_plumbing<decltype(&batch_rule), &batch_rule>);
|
| 65 |
+
|
| 66 |
+
#define OP_DECOMPOSE(op) m.impl(#op, static_cast<decltype(&ATEN_FN(op))>(native::op));
|
| 67 |
+
#define OP_DECOMPOSE2(op, overload) m.impl(#op"."#overload, static_cast<decltype(&ATEN_FN2(op, overload))>(native::op));
|
| 68 |
+
|
| 69 |
+
// DO NOT USE ME DIRECTLY! Use BASIC_UNARY_BATCH_RULE to save yourself some pain
|
| 70 |
+
template <typename A, A a, typename C>
|
| 71 |
+
struct BasicUnaryBatchRuleHelper;
|
| 72 |
+
|
| 73 |
+
template <typename F, F Func, typename A, typename... T>
|
| 74 |
+
struct BasicUnaryBatchRuleHelper<F, Func, c10::guts::typelist::typelist<A, T...>> {
|
| 75 |
+
static std::tuple<Tensor, std::optional<int64_t>> apply(
|
| 76 |
+
const Tensor& tensor,
|
| 77 |
+
std::optional<int64_t> batch_dim,
|
| 78 |
+
T... extra_args) {
|
| 79 |
+
return std::make_tuple(Func(tensor, std::forward<T>(extra_args)...), batch_dim);
|
| 80 |
+
}
|
| 81 |
+
};
|
| 82 |
+
|
| 83 |
+
// USAGE: BASIC_UNARY_BATCH_RULE(at::sin)
|
| 84 |
+
// INCORRECT USAGE: BASIC_UNARY_BATCH_RULE(&at::sin)
|
| 85 |
+
// It is important that this macro is not passed a function pointer!!
|
| 86 |
+
#define BASIC_UNARY_BATCH_RULE(fn) SINGLE_ARG(\
|
| 87 |
+
BasicUnaryBatchRuleHelper<\
|
| 88 |
+
decltype(&fn),\
|
| 89 |
+
&fn,\
|
| 90 |
+
c10::guts::function_traits<decltype(fn)>::parameter_types>::apply)
|
| 91 |
+
|
| 92 |
+
#define UNARY_POINTWISE(op) \
|
| 93 |
+
VMAP_SUPPORT(op, BASIC_UNARY_BATCH_RULE(ATEN_FN(op)));
|
| 94 |
+
|
| 95 |
+
template <typename A, A a, typename C>
|
| 96 |
+
struct VariadicBdimsBatchRuleHelper;
|
| 97 |
+
|
| 98 |
+
template <typename F, F Func, typename A, typename... T>
|
| 99 |
+
struct VariadicBdimsBatchRuleHelper<F, Func, c10::guts::typelist::typelist<A, T...>> {
|
| 100 |
+
static std::tuple<Tensor, std::optional<int64_t>> apply(
|
| 101 |
+
const Tensor& tensor,
|
| 102 |
+
std::optional<int64_t> batch_dim,
|
| 103 |
+
T... extra_args) {
|
| 104 |
+
auto tensor_ = moveBatchDimToFront(tensor, batch_dim);
|
| 105 |
+
return std::make_tuple(Func(tensor_, std::forward<T>(extra_args)...), 0);
|
| 106 |
+
}
|
| 107 |
+
};
|
| 108 |
+
|
| 109 |
+
// USAGE: VARIADIC_BDIMS_BATCH_RULE(at::cholesky_inverse)
|
| 110 |
+
// INCORRECT USAGE: VARIADIC_BDIMS_BATCH_RULE(&at::cholesky_inverse)
|
| 111 |
+
// It is important that this macro is not passed a function pointer!!
|
| 112 |
+
#define VARIADIC_BDIMS_BATCH_RULE(fn) SINGLE_ARG(\
|
| 113 |
+
VariadicBdimsBatchRuleHelper<\
|
| 114 |
+
decltype(&fn),\
|
| 115 |
+
&fn,\
|
| 116 |
+
c10::guts::function_traits<decltype(fn)>::parameter_types>::apply)
|
| 117 |
+
|
| 118 |
+
#define VARIADIC_BDIMS(op) \
|
| 119 |
+
VMAP_SUPPORT(op, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(op)));
|
| 120 |
+
|
| 121 |
+
#define VARIADIC_BDIMS2(op, overload) \
|
| 122 |
+
VMAP_SUPPORT2(op, overload, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN2(op, overload)));
|
| 123 |
+
|
| 124 |
+
template<class F, F Func>
|
| 125 |
+
void boxed_tensor_inputs_batch_rule(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
| 126 |
+
const auto& schema = op.schema();
|
| 127 |
+
const auto num_returns = schema.returns().size();
|
| 128 |
+
const auto num_arguments = schema.arguments().size();
|
| 129 |
+
|
| 130 |
+
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
|
| 131 |
+
auto maybe_layer = maybeCurrentDynamicLayer();
|
| 132 |
+
vmap_check_escaped(maybe_layer, "boxed_tensor_inputs_batch_rule");
|
| 133 |
+
|
| 134 |
+
int64_t cur_level = maybe_layer->layerId();
|
| 135 |
+
|
| 136 |
+
auto orig_arguments = torch::jit::last(*stack, num_arguments);
|
| 137 |
+
if (std::none_of(orig_arguments.begin(), orig_arguments.end(), ivalueParticipatesInCurrentLevel)) {
|
| 138 |
+
op.callBoxed(stack);
|
| 139 |
+
return;
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
auto arguments = torch::jit::pop(*stack, num_arguments);
|
| 143 |
+
std::vector<std::pair<Tensor, std::optional<int64_t>>> tensor_inputs;
|
| 144 |
+
std::vector<int64_t> tensor_pos;
|
| 145 |
+
for (const auto idx : c10::irange(0, num_arguments)) {
|
| 146 |
+
const auto& ivalue = arguments[idx];
|
| 147 |
+
if (ivalue.isTensor()) {
|
| 148 |
+
auto [tensor_value, tensor_bdim] = unwrapTensorAtLevel(ivalue.toTensor(), cur_level);
|
| 149 |
+
tensor_inputs.emplace_back(std::move(tensor_value), tensor_bdim);
|
| 150 |
+
tensor_pos.push_back(static_cast<int64_t>(idx));
|
| 151 |
+
}
|
| 152 |
+
}
|
| 153 |
+
Func(tensor_inputs);
|
| 154 |
+
|
| 155 |
+
size_t tensor_idx = 0;
|
| 156 |
+
TORCH_INTERNAL_ASSERT(!tensor_pos.empty());
|
| 157 |
+
for (const auto arg_idx : c10::irange(0, num_arguments)) {
|
| 158 |
+
if (tensor_idx >= tensor_pos.size() || (int64_t)arg_idx != tensor_pos[tensor_idx]) {
|
| 159 |
+
torch::jit::push(stack, arguments[arg_idx]);
|
| 160 |
+
} else {
|
| 161 |
+
TORCH_INTERNAL_ASSERT(tensor_idx < tensor_inputs.size());
|
| 162 |
+
torch::jit::push(stack, tensor_inputs[tensor_idx].first);
|
| 163 |
+
tensor_idx++;
|
| 164 |
+
}
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
op.callBoxed(stack);
|
| 168 |
+
const auto returns = torch::jit::pop(*stack, num_returns);
|
| 169 |
+
for (const auto& ret : returns) {
|
| 170 |
+
if (ret.isTensor()) {
|
| 171 |
+
torch::jit::push(stack, makeBatched(ret.toTensor(), 0, cur_level));
|
| 172 |
+
} else {
|
| 173 |
+
TORCH_INTERNAL_ASSERT(false, "This boxed batching rule does not currently support ops that return non-tensor values");
|
| 174 |
+
}
|
| 175 |
+
}
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
inline void handle_pointwise_ops(std::vector<std::pair<Tensor, std::optional<int64_t>>> &tensor_inputs) {
|
| 179 |
+
int64_t out_logical_rank = 0;
|
| 180 |
+
for (auto& tensor_input : tensor_inputs) {
|
| 181 |
+
int64_t cur_logical_rank = rankWithoutBatchDim(tensor_input.first, tensor_input.second);
|
| 182 |
+
out_logical_rank = std::max(out_logical_rank, cur_logical_rank);
|
| 183 |
+
}
|
| 184 |
+
for (auto& tensor_input: tensor_inputs) {
|
| 185 |
+
tensor_input.first = moveBatchDimToFront(tensor_input.first, tensor_input.second);
|
| 186 |
+
tensor_input.first = maybePadToLogicalRank(tensor_input.first, tensor_input.second, out_logical_rank);
|
| 187 |
+
}
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
#define POINTWISE_BOXED(op) \
|
| 191 |
+
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_tensor_inputs_batch_rule<decltype(&handle_pointwise_ops), &handle_pointwise_ops>>());
|
| 192 |
+
|
| 193 |
+
#define POINTWISE_BOXED2(op, overload) \
|
| 194 |
+
m.impl(#op "." #overload, torch::CppFunction::makeFromBoxedFunction<boxed_tensor_inputs_batch_rule<decltype(&handle_pointwise_ops), &handle_pointwise_ops>>());
|
| 195 |
+
|
| 196 |
+
inline void handle_variadic_bdims(std::vector<std::pair<Tensor, std::optional<int64_t>>> &tensor_inputs) {
|
| 197 |
+
for (auto & tensor_input : tensor_inputs) {
|
| 198 |
+
tensor_input.first = moveBatchDimToFront(tensor_input.first, tensor_input.second);
|
| 199 |
+
}
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
#define VARIADIC_BDIMS_BOXED(op) \
|
| 203 |
+
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_tensor_inputs_batch_rule<decltype(&handle_variadic_bdims), &handle_variadic_bdims>>());
|
| 204 |
+
|
| 205 |
+
using UnpackedBatchedTensor = std::tuple<Tensor, std::optional<int64_t>>;
|
| 206 |
+
|
| 207 |
+
inline void find_and_unpack_tensors(
|
| 208 |
+
const torch::jit::Stack* stack,
|
| 209 |
+
int64_t num_args,
|
| 210 |
+
int64_t cur_level,
|
| 211 |
+
SmallVector<UnpackedBatchedTensor, 5>* tensors,
|
| 212 |
+
SmallVector<int64_t, 5>* tensors_pos,
|
| 213 |
+
int64_t* batch_size) {
|
| 214 |
+
|
| 215 |
+
int64_t computed_batch_size = -1;
|
| 216 |
+
int64_t args_begin = static_cast<int64_t>(stack->size()) - num_args;
|
| 217 |
+
|
| 218 |
+
for (const auto idx : c10::irange(0, num_args)) {
|
| 219 |
+
const auto& ivalue = (*stack)[args_begin + idx];
|
| 220 |
+
if (!ivalue.isTensor()) {
|
| 221 |
+
continue;
|
| 222 |
+
}
|
| 223 |
+
auto unpacked = unwrapTensorAtLevel(ivalue.toTensor(), cur_level);
|
| 224 |
+
const auto& [tensor_value, tensor_bdim] = unpacked;
|
| 225 |
+
if (tensor_bdim.has_value()) {
|
| 226 |
+
auto candidate_batch_size = tensor_value.size(*tensor_bdim);
|
| 227 |
+
if (computed_batch_size == -1) {
|
| 228 |
+
computed_batch_size = candidate_batch_size;
|
| 229 |
+
}
|
| 230 |
+
TORCH_INTERNAL_ASSERT(candidate_batch_size == computed_batch_size);
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
tensors->push_back(std::move(unpacked));
|
| 234 |
+
tensors_pos->push_back(idx);
|
| 235 |
+
}
|
| 236 |
+
TORCH_INTERNAL_ASSERT(computed_batch_size > -1);
|
| 237 |
+
*batch_size = computed_batch_size;
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
inline void boxed_existing_bdim_all_batch_rule(
|
| 241 |
+
const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
| 242 |
+
const auto& schema = op.schema();
|
| 243 |
+
const auto num_returns = schema.returns().size();
|
| 244 |
+
const auto num_arguments = static_cast<int64_t>(schema.arguments().size());
|
| 245 |
+
|
| 246 |
+
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
|
| 247 |
+
const auto maybe_layer = maybeCurrentDynamicLayer();
|
| 248 |
+
vmap_check_escaped(maybe_layer, "boxed_existing_bdim_all_batch_rule");
|
| 249 |
+
|
| 250 |
+
const auto arguments = torch::jit::last(stack, num_arguments);
|
| 251 |
+
if (std::none_of(arguments.begin(), arguments.end(), ivalueParticipatesInCurrentLevel)) {
|
| 252 |
+
op.callBoxed(stack);
|
| 253 |
+
return;
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
int64_t args_begin = static_cast<int64_t>(stack->size()) - num_arguments;
|
| 257 |
+
SmallVector<UnpackedBatchedTensor, 5> tensor_inputs;
|
| 258 |
+
SmallVector<int64_t, 5> tensor_pos;
|
| 259 |
+
int64_t batch_size = 0;
|
| 260 |
+
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
| 261 |
+
int64_t cur_level = maybe_layer->layerId();
|
| 262 |
+
|
| 263 |
+
find_and_unpack_tensors(
|
| 264 |
+
stack, num_arguments, cur_level,
|
| 265 |
+
&tensor_inputs, &tensor_pos, &batch_size);
|
| 266 |
+
|
| 267 |
+
// for each tensor, ensure it has a bdim and reshape it.
|
| 268 |
+
for (const auto tensor_idx : c10::irange(0, tensor_inputs.size())) {
|
| 269 |
+
const auto& [value, bdim] = tensor_inputs[tensor_idx];
|
| 270 |
+
auto value_ = ensure_has_bdim(value, bdim.has_value(), batch_size);
|
| 271 |
+
(*stack)[args_begin + tensor_pos[tensor_idx]] = reshape_dim_into(bdim.value_or(0), 0, value_);
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
op.callBoxed(stack);
|
| 275 |
+
|
| 276 |
+
for (const auto idx : c10::irange(args_begin, args_begin + num_returns)) {
|
| 277 |
+
const auto& ret = (*stack)[idx];
|
| 278 |
+
TORCH_INTERNAL_ASSERT(ret.isTensor(),
|
| 279 |
+
"This boxed batching rule does not currently support ops that return non-tensor values");
|
| 280 |
+
(*stack)[idx] = makeBatched(reshape_dim_outof(0, batch_size, ret.toTensor()), 0, cur_level);
|
| 281 |
+
}
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
// Use when all tensors arguments accept one (normal) batch dim.
|
| 285 |
+
// This batching rule expands the batch dim on all Tensors, reshapes it into
|
| 286 |
+
// dim 0, calls the op, and then reshapes the batch dim out of dim 0.
|
| 287 |
+
// This is not the most efficient thing; if there are alternatives, please try
|
| 288 |
+
// to use them. Use this only as a last resort.
|
| 289 |
+
#define EXISTING_BDIM_ALL_BOXED(op) \
|
| 290 |
+
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_existing_bdim_all_batch_rule>());
|
| 291 |
+
|
| 292 |
+
template <int64_t feature_rank, int64_t contig_tensor_index=-1>
|
| 293 |
+
inline void boxed_all_tensors_have_optional_bdim(
|
| 294 |
+
const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
| 295 |
+
const auto& schema = op.schema();
|
| 296 |
+
const auto num_returns = schema.returns().size();
|
| 297 |
+
const auto num_arguments = schema.arguments().size();
|
| 298 |
+
|
| 299 |
+
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
|
| 300 |
+
auto maybe_layer = maybeCurrentDynamicLayer();
|
| 301 |
+
vmap_check_escaped(maybe_layer, "boxed_all_tensors_have_optional_bdim");
|
| 302 |
+
int64_t cur_level = maybe_layer->layerId();
|
| 303 |
+
|
| 304 |
+
const auto arguments = torch::jit::last(stack, num_arguments);
|
| 305 |
+
if (std::none_of(arguments.begin(), arguments.end(), ivalueParticipatesInCurrentLevel)) {
|
| 306 |
+
op.callBoxed(stack);
|
| 307 |
+
return;
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
int64_t args_begin = static_cast<int64_t>(stack->size() - num_arguments);
|
| 311 |
+
SmallVector<UnpackedBatchedTensor, 5> tensor_inputs;
|
| 312 |
+
SmallVector<int64_t, 5> tensor_pos;
|
| 313 |
+
int64_t batch_size = 0;
|
| 314 |
+
|
| 315 |
+
find_and_unpack_tensors(
|
| 316 |
+
stack, static_cast<int64_t>(num_arguments), cur_level,
|
| 317 |
+
&tensor_inputs, &tensor_pos, &batch_size);
|
| 318 |
+
|
| 319 |
+
std::optional<bool> is_no_batch_dim_case;
|
| 320 |
+
|
| 321 |
+
for (const auto tensor_idx : c10::irange(0, tensor_inputs.size())) {
|
| 322 |
+
const auto& value = std::get<0>(tensor_inputs[tensor_idx]);
|
| 323 |
+
auto bdim = std::get<1>(tensor_inputs[tensor_idx]);
|
| 324 |
+
const auto logical_rank = rankWithoutBatchDim(value, bdim);
|
| 325 |
+
|
| 326 |
+
if (!is_no_batch_dim_case.has_value()) {
|
| 327 |
+
is_no_batch_dim_case = (logical_rank == feature_rank);
|
| 328 |
+
}
|
| 329 |
+
auto value_ = ensure_has_bdim(value, bdim.has_value(), batch_size);
|
| 330 |
+
if (!bdim.has_value()) {
|
| 331 |
+
bdim = 0;
|
| 332 |
+
}
|
| 333 |
+
if (*is_no_batch_dim_case) {
|
| 334 |
+
TORCH_INTERNAL_ASSERT(logical_rank == feature_rank);
|
| 335 |
+
value_ = moveBatchDimToFront(value_, bdim);
|
| 336 |
+
if (tensor_idx == contig_tensor_index) {
|
| 337 |
+
value_ = value_.contiguous();
|
| 338 |
+
}
|
| 339 |
+
(*stack)[args_begin + tensor_pos[tensor_idx]] = std::move(value_);
|
| 340 |
+
continue;
|
| 341 |
+
}
|
| 342 |
+
TORCH_INTERNAL_ASSERT(logical_rank == feature_rank + 1);
|
| 343 |
+
value_ = reshape_dim_into(*bdim, 0, value_);
|
| 344 |
+
if (tensor_idx == contig_tensor_index) {
|
| 345 |
+
value_ = value_.contiguous();
|
| 346 |
+
}
|
| 347 |
+
(*stack)[args_begin + tensor_pos[tensor_idx]] = std::move(value_);
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
op.callBoxed(stack);
|
| 351 |
+
|
| 352 |
+
for (const auto idx : c10::irange(args_begin, args_begin + num_returns)) {
|
| 353 |
+
const auto& ret = (*stack)[idx];
|
| 354 |
+
TORCH_INTERNAL_ASSERT(ret.isTensor(),
|
| 355 |
+
"This boxed batching rule does not currently support ops that return non-tensor values");
|
| 356 |
+
if (*is_no_batch_dim_case) {
|
| 357 |
+
(*stack)[idx] = makeBatched(ret.toTensor(), 0, cur_level);
|
| 358 |
+
} else {
|
| 359 |
+
(*stack)[idx] = makeBatched(reshape_dim_outof(0, batch_size, ret.toTensor()), 0, cur_level);
|
| 360 |
+
}
|
| 361 |
+
}
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
// Useful for many NN operators.
|
| 365 |
+
// The operator must satisfy the following:
|
| 366 |
+
// - All arguments must accept an optional batch dim.
|
| 367 |
+
// - All arguments must be the same rank
|
| 368 |
+
#define ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED(feature_rank, op) \
|
| 369 |
+
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_all_tensors_have_optional_bdim<feature_rank>>());
|
| 370 |
+
|
| 371 |
+
#define ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(feature_rank, op, contig_tensor_index) \
|
| 372 |
+
m.impl(#op, \
|
| 373 |
+
torch::CppFunction::makeFromBoxedFunction<\
|
| 374 |
+
boxed_all_tensors_have_optional_bdim<\
|
| 375 |
+
feature_rank, \
|
| 376 |
+
contig_tensor_index>\
|
| 377 |
+
>());
|
| 378 |
+
|
| 379 |
+
template <typename A, A a, typename C>
|
| 380 |
+
struct ExistingBdimBatchRuleHelper;
|
| 381 |
+
|
| 382 |
+
template <typename F, F Func, typename A, typename... T>
|
| 383 |
+
struct ExistingBdimBatchRuleHelper<F, Func, c10::guts::typelist::typelist<A, T...>> {
|
| 384 |
+
static std::tuple<Tensor, std::optional<int64_t>> apply(
|
| 385 |
+
const Tensor& self,
|
| 386 |
+
std::optional<int64_t> self_bdim,
|
| 387 |
+
T... extra_args) {
|
| 388 |
+
auto self_ = reshape_dim_into(*self_bdim, 0, self);
|
| 389 |
+
auto out = Func(self_, std::forward<T>(extra_args)...);
|
| 390 |
+
return std::make_tuple(reshape_dim_outof_symint(0, self.sym_sizes()[*self_bdim], out), 0);
|
| 391 |
+
}
|
| 392 |
+
};
|
| 393 |
+
|
| 394 |
+
// USAGE: EXISTING_BDIM_BATCH_RULE(at::cholesky_inverse)
|
| 395 |
+
// INCORRECT USAGE: EXISTING_BDIM_BATCH_RULE(&at::cholesky_inverse)
|
| 396 |
+
// It is important that this macro is not passed a function pointer!!
|
| 397 |
+
#define EXISTING_BDIM_BATCH_RULE(fn) SINGLE_ARG(\
|
| 398 |
+
ExistingBdimBatchRuleHelper<\
|
| 399 |
+
decltype(&fn),\
|
| 400 |
+
&fn,\
|
| 401 |
+
c10::guts::function_traits<decltype(fn)>::parameter_types>::apply)
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
#define EXISTING_BDIM(op) \
|
| 405 |
+
VMAP_SUPPORT(op, EXISTING_BDIM_BATCH_RULE(ATEN_FN(op)));
|
| 406 |
+
|
| 407 |
+
#define EXISTING_BDIM2(op, overload) \
|
| 408 |
+
VMAP_SUPPORT2(op, overload, EXISTING_BDIM_BATCH_RULE(ATEN_FN2(op, overload)));
|
| 409 |
+
|
| 410 |
+
#define INVOKE(object,ptrToMember) ((object).*(ptrToMember))
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
template <typename F, F Method, typename... ExtraArgs>
|
| 414 |
+
Tensor& unary_inplace_batch_rule(Tensor& self, std::optional<int64_t> /*unused*/, ExtraArgs... extra_args) {
|
| 415 |
+
INVOKE(self, Method)(std::forward<ExtraArgs>(extra_args)...);
|
| 416 |
+
return self;
|
| 417 |
+
}
|
| 418 |
+
|
| 419 |
+
inline int64_t get_bdim_size4(
|
| 420 |
+
const Tensor& a_value, std::optional<int64_t> a_bdim,
|
| 421 |
+
const Tensor& b_value, std::optional<int64_t> b_bdim,
|
| 422 |
+
const Tensor& c_value, std::optional<int64_t> c_bdim,
|
| 423 |
+
const Tensor& d_value, std::optional<int64_t> d_bdim) {
|
| 424 |
+
if (a_bdim)
|
| 425 |
+
return a_value.size(*a_bdim);
|
| 426 |
+
if (b_bdim)
|
| 427 |
+
return b_value.size(*b_bdim);
|
| 428 |
+
if (c_bdim)
|
| 429 |
+
return c_value.size(*c_bdim);
|
| 430 |
+
if (d_bdim)
|
| 431 |
+
return d_value.size(*d_bdim);
|
| 432 |
+
TORCH_INTERNAL_ASSERT(false);
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
inline int64_t get_bdim_size3(
|
| 436 |
+
const Tensor& a_value, std::optional<int64_t> a_bdim,
|
| 437 |
+
const Tensor& b_value, std::optional<int64_t> b_bdim,
|
| 438 |
+
const Tensor& c_value, std::optional<int64_t> c_bdim) {
|
| 439 |
+
if (a_bdim)
|
| 440 |
+
return a_value.size(*a_bdim);
|
| 441 |
+
if (b_bdim)
|
| 442 |
+
return b_value.size(*b_bdim);
|
| 443 |
+
if (c_bdim)
|
| 444 |
+
return c_value.size(*c_bdim);
|
| 445 |
+
TORCH_INTERNAL_ASSERT(false);
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
inline int64_t get_bdim_size2(
|
| 449 |
+
const Tensor& a_value, std::optional<int64_t> a_bdim,
|
| 450 |
+
const Tensor& b_value, std::optional<int64_t> b_bdim) {
|
| 451 |
+
if (a_bdim)
|
| 452 |
+
return a_value.size(*a_bdim);
|
| 453 |
+
if (b_bdim)
|
| 454 |
+
return b_value.size(*b_bdim);
|
| 455 |
+
TORCH_INTERNAL_ASSERT(false);
|
| 456 |
+
}
|
| 457 |
+
|
| 458 |
+
inline c10::SymInt get_bdim_size2_symint(
|
| 459 |
+
const Tensor& a_value, std::optional<int64_t> a_bdim,
|
| 460 |
+
const Tensor& b_value, std::optional<int64_t> b_bdim) {
|
| 461 |
+
if (a_bdim)
|
| 462 |
+
return a_value.sym_size(*a_bdim);
|
| 463 |
+
if (b_bdim)
|
| 464 |
+
return b_value.sym_size(*b_bdim);
|
| 465 |
+
TORCH_INTERNAL_ASSERT(false);
|
| 466 |
+
}
|
| 467 |
+
|
| 468 |
+
// [start, start + 1, ..., stop - 1]
|
| 469 |
+
inline VmapDimVector range(int64_t start, int64_t stop) {
|
| 470 |
+
TORCH_INTERNAL_ASSERT(stop >= start);
|
| 471 |
+
VmapDimVector dims;
|
| 472 |
+
dims.reserve(stop - start);
|
| 473 |
+
for (int64_t i = start; i < stop; i++) {
|
| 474 |
+
dims.emplace_back(i);
|
| 475 |
+
}
|
| 476 |
+
return dims;
|
| 477 |
+
}
|
| 478 |
+
std::tuple<Tensor, Tensor> _binary_pointwise_helper(
|
| 479 |
+
const Tensor& tensor, std::optional<int64_t> tensor_batch_dim, const Tensor& other, std::optional<int64_t> other_batch_dim,
|
| 480 |
+
bool do_type_promotion=true);
|
| 481 |
+
|
| 482 |
+
} // namespace at::functorch
|
| 483 |
+
|
| 484 |
+
#else
|
| 485 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 486 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchedFallback.h
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
// Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
// All rights reserved.
|
| 4 |
+
//
|
| 5 |
+
// This source code is licensed under the BSD-style license found in the
|
| 6 |
+
// LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
#pragma once
|
| 9 |
+
#include <ATen/ATen.h>
|
| 10 |
+
#include <ATen/core/op_registration/op_registration.h>
|
| 11 |
+
#include <torch/library.h>
|
| 12 |
+
|
| 13 |
+
namespace at::functorch {
|
| 14 |
+
|
| 15 |
+
// This file contains code for the vmap fallback (also known as the
|
| 16 |
+
// BatchedTensor fallback or the Batched fallback). This code runs
|
| 17 |
+
// when an operation doesn't have a batching rule implemented.
|
| 18 |
+
|
| 19 |
+
// If an operator doesn't have a batching rule implemented then we fallback
|
| 20 |
+
// to this implementation. The fallback doesn't work on out= variants or
|
| 21 |
+
// view operations; that is, it works for out-of-place operations and
|
| 22 |
+
// in-place non-view operations.
|
| 23 |
+
//
|
| 24 |
+
// For out-of-place operations, the fallback effectively takes all of the
|
| 25 |
+
// BatchedTensors in `stack`, slices them, and runs `op` on all of the
|
| 26 |
+
// corresponding slices to produce slices of the outputs. The output slices
|
| 27 |
+
// then get `torch.stack`ed to create the
|
| 28 |
+
// final returns.
|
| 29 |
+
//
|
| 30 |
+
// The performance of the fallback is not very good because it introduces an
|
| 31 |
+
// extra copy from stacking the sliced outputs. Because of this, we prefer to
|
| 32 |
+
// write batching rules for operators whenever possible.
|
| 33 |
+
void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
| 34 |
+
void batchedNestedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
| 35 |
+
|
| 36 |
+
void vmapErrorFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
| 37 |
+
|
| 38 |
+
// The vmap fallback emits a warning by default, but it may be disabled if
|
| 39 |
+
// the user finds it to be too annoying.
|
| 40 |
+
TORCH_API bool isVmapFallbackWarningEnabled();
|
| 41 |
+
TORCH_API void setVmapFallbackWarningEnabled(bool enabled);
|
| 42 |
+
|
| 43 |
+
// Used for testing. The vmap fallback is enabled by default. When it is disabled,
|
| 44 |
+
// it raises an error.
|
| 45 |
+
TORCH_API bool isVmapFallbackEnabled();
|
| 46 |
+
TORCH_API void setVmapFallbackEnabled(bool enabled);
|
| 47 |
+
|
| 48 |
+
template <typename A> A vector_to_result(const std::vector<IValue>& buffer) {
|
| 49 |
+
return buffer[0].to<A>();
|
| 50 |
+
}
|
| 51 |
+
template <typename A, typename B> std::tuple<A, B> vector_to_result(const std::vector<IValue>& buffer) {
|
| 52 |
+
return std::make_tuple(buffer[0].to<A>(), buffer[1].to<B>());
|
| 53 |
+
}
|
| 54 |
+
template <typename A, typename B, typename C> std::tuple<A, B, C> vector_to_result(const std::vector<IValue>& buffer) {
|
| 55 |
+
return std::make_tuple(buffer[0].to<A>(), buffer[1].to<B>(), buffer[2].to<B>());
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
// slow_fallback is a way to call the vmap fallback inside some boxed kernel.
|
| 59 |
+
// There is probably some better way to metaprogram this.
|
| 60 |
+
template <typename Ret>
|
| 61 |
+
Ret slow_fallback(const c10::OperatorHandle& op, ArrayRef<IValue> args) {
|
| 62 |
+
std::vector<IValue> stack(args.begin(), args.end());
|
| 63 |
+
batchedTensorForLoopFallback(op, &stack);
|
| 64 |
+
return vector_to_result<Ret>(stack);
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
template <typename A, typename B>
|
| 68 |
+
std::tuple<A, B> slow_fallback(const c10::OperatorHandle& op, ArrayRef<IValue> args) {
|
| 69 |
+
std::vector<IValue> stack(args.begin(), args.end());
|
| 70 |
+
batchedTensorForLoopFallback(op, &stack);
|
| 71 |
+
return vector_to_result<A, B>(stack);
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
template <typename A, typename B, typename C>
|
| 75 |
+
std::tuple<A, B, C> slow_fallback(const c10::OperatorHandle& op, ArrayRef<IValue> args) {
|
| 76 |
+
std::vector<IValue> stack(args.begin(), args.end());
|
| 77 |
+
batchedTensorForLoopFallback(op, &stack);
|
| 78 |
+
return vector_to_result<A, B, C>(stack);
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
} // namespace at::functorch
|
| 83 |
+
|
| 84 |
+
#else
|
| 85 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 86 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchedTensorImpl.h
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
// Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
// All rights reserved.
|
| 4 |
+
//
|
| 5 |
+
// This source code is licensed under the BSD-style license found in the
|
| 6 |
+
// LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
#pragma once
|
| 9 |
+
|
| 10 |
+
#include <bitset>
|
| 11 |
+
|
| 12 |
+
#include <ATen/ArrayRef.h>
|
| 13 |
+
#include <ATen/SmallVector.h>
|
| 14 |
+
#include <ATen/Tensor.h>
|
| 15 |
+
|
| 16 |
+
namespace at::functorch {
|
| 17 |
+
|
| 18 |
+
using Tensor = at::Tensor;
|
| 19 |
+
|
| 20 |
+
// We assume this in a few other places in the codebase,
|
| 21 |
+
// but there isn't a centralized definition.
|
| 22 |
+
constexpr int64_t kVmapMaxTensorDims = 64;
|
| 23 |
+
|
| 24 |
+
// The valid vmap levels range from [0, 64). This effectively means that we
|
| 25 |
+
// support a maximum of 64 nested vmaps.
|
| 26 |
+
constexpr int64_t kVmapNumLevels = 64;
|
| 27 |
+
|
| 28 |
+
// Store this number of elements of BatchDims on the stack. Most people will
|
| 29 |
+
// probably use <= 5 nested vmaps, but adjust this number as necessary.
|
| 30 |
+
constexpr int64_t kBatchDimsStackSize = 5;
|
| 31 |
+
|
| 32 |
+
// A BatchedTensorImpl holds an underlying Tensor and a single batch dim
|
| 33 |
+
// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
|
| 34 |
+
// BatchedTensorImpl.
|
| 35 |
+
//
|
| 36 |
+
// The batch dimensions are treated as being "private"; they are not user-visible.
|
| 37 |
+
// For example, in the following Tensor,
|
| 38 |
+
// bt = BatchedTensorImpl(ones(2, 3, 5, 7), lvl=1, dim=0)
|
| 39 |
+
// dimension 0 is batch dimension.
|
| 40 |
+
//
|
| 41 |
+
// bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public)
|
| 42 |
+
// dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7) tensor.
|
| 43 |
+
struct TORCH_API BatchedTensorImpl : public c10::TensorImpl {
|
| 44 |
+
explicit BatchedTensorImpl(at::DispatchKeySet key_set, Tensor value, int64_t dim, int64_t level);
|
| 45 |
+
|
| 46 |
+
// Returns batch dimension of this tensor
|
| 47 |
+
int64_t bdim() const { return bdim_; }
|
| 48 |
+
|
| 49 |
+
// Returns batch dimension of this tensor
|
| 50 |
+
int64_t level() const { return level_; }
|
| 51 |
+
|
| 52 |
+
// BatchedTensorImpl wraps a Tensor
|
| 53 |
+
const Tensor& value() const { return value_; }
|
| 54 |
+
|
| 55 |
+
// Given a public dimension index, return the dimension index in the underlying
|
| 56 |
+
// value() tensor.
|
| 57 |
+
// For example, if we have
|
| 58 |
+
// bt = BatchedTensorImpl(ones(2, 3, 5, 7), lvl=1, dim=0)
|
| 59 |
+
// bt.actualDim(0) -> 1
|
| 60 |
+
// bt.actualDim(1) -> 2
|
| 61 |
+
// bt.actualDim(2) -> 3
|
| 62 |
+
// bt.actualDim(3) -> Error
|
| 63 |
+
int64_t actualDim(int64_t dim, bool wrap_dim = true) const;
|
| 64 |
+
|
| 65 |
+
IntArrayRef sizes_custom() const override;
|
| 66 |
+
SymIntArrayRef sym_sizes_custom() const override;
|
| 67 |
+
int64_t size_custom(int64_t d) const override;
|
| 68 |
+
c10::SymInt sym_size_custom(int64_t d) const override;
|
| 69 |
+
// We have to override this because we opted into CustomStrides
|
| 70 |
+
IntArrayRef strides_custom() const override;
|
| 71 |
+
SymIntArrayRef sym_strides_custom() const override;
|
| 72 |
+
// Override a bunch of methods inherited from TensorImpl to return error messages.
|
| 73 |
+
c10::SymBool sym_is_contiguous_custom(at::MemoryFormat memory_format) const override;
|
| 74 |
+
void set_size(int64_t dim, int64_t new_size) override;
|
| 75 |
+
void set_stride(int64_t dim, int64_t new_stride) override;
|
| 76 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
| 77 |
+
const c10::VariableVersion& version_counter,
|
| 78 |
+
bool allow_tensor_metadata_change) const override;
|
| 79 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
| 80 |
+
c10::VariableVersion&& version_counter,
|
| 81 |
+
bool allow_tensor_metadata_change) const override;
|
| 82 |
+
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;
|
| 83 |
+
#ifdef DEBUG
|
| 84 |
+
bool has_storage() const override;
|
| 85 |
+
#endif
|
| 86 |
+
|
| 87 |
+
void refreshTensorMetadata();
|
| 88 |
+
|
| 89 |
+
// Used in torchdim. torchdim uses non-lexical BatchedTensor; the way it
|
| 90 |
+
// accomplishes this is a hack where it is able to modify the levels of
|
| 91 |
+
// BatchedTensor to match the level of the current vmap transform.
|
| 92 |
+
void _unsafe_set_level(int64_t level) {
|
| 93 |
+
level_ = level;
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
// Used in batching rule for in-place view operations that can change
|
| 97 |
+
// the index of the bdim (think squeeze_, unsqueeze_)
|
| 98 |
+
void unsafe_set_bdim(int64_t bdim) {
|
| 99 |
+
// NB: you MUST call refreshTensorMetadata after doing this.
|
| 100 |
+
bdim_ = bdim;
|
| 101 |
+
}
|
| 102 |
+
private:
|
| 103 |
+
// see NOTE: [BatchedTensorImpl levels invariant]
|
| 104 |
+
void checkInvariants() const;
|
| 105 |
+
const char* tensorimpl_type_name() const override;
|
| 106 |
+
|
| 107 |
+
Tensor value_;
|
| 108 |
+
|
| 109 |
+
int64_t level_;
|
| 110 |
+
int64_t bdim_;
|
| 111 |
+
};
|
| 112 |
+
|
| 113 |
+
// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
|
| 114 |
+
// BatchedTensorImpl.
|
| 115 |
+
inline bool isBatchedTensor(const Tensor& tensor) {
|
| 116 |
+
return tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::FuncTorchBatched) ||
|
| 117 |
+
tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::BatchedNestedTensor);
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
// It is unsafe to call this on a Tensor that is not backed by a
|
| 121 |
+
// BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible.
|
| 122 |
+
inline BatchedTensorImpl* unsafeGetBatchedImpl(const Tensor& tensor) {
|
| 123 |
+
return static_cast<BatchedTensorImpl*>(tensor.unsafeGetTensorImpl());
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
inline BatchedTensorImpl* maybeGetBatchedImpl(const Tensor& tensor) {
|
| 127 |
+
if (!isBatchedTensor(tensor)) {
|
| 128 |
+
return nullptr;
|
| 129 |
+
}
|
| 130 |
+
return unsafeGetBatchedImpl(tensor);
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
// Returns a bitset. If bit i is set, then that means dim i is a batchdim.
|
| 134 |
+
inline std::bitset<kVmapMaxTensorDims> createBatchDimBitset(int64_t dim) {
|
| 135 |
+
std::bitset<kVmapMaxTensorDims> is_bdim;
|
| 136 |
+
is_bdim.set(dim);
|
| 137 |
+
return is_bdim;
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
// Creates a bitset for the given level
|
| 141 |
+
inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(int64_t level) {
|
| 142 |
+
std::bitset<kVmapNumLevels> result;
|
| 143 |
+
result.set(level);
|
| 144 |
+
return result;
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
// Use this to construct a BatchedTensor from a regular Tensor
|
| 148 |
+
TORCH_API Tensor makeBatched(Tensor tensor, int64_t dim, int64_t level);
|
| 149 |
+
|
| 150 |
+
// Adds a batch dim to `tensor`, returning a BatchedTensor
|
| 151 |
+
TORCH_API Tensor addBatchDim(Tensor tensor, int64_t dim, int64_t level);
|
| 152 |
+
|
| 153 |
+
// Certain dispatch keys must be propagated to the BatchedTensor (or, in general,
|
| 154 |
+
// any wrapper Tensor subclasses). This is because there are methods on Tensor
|
| 155 |
+
// that skip dispatch and check for the presence of a dispatch key (e.g. is_cpu()).
|
| 156 |
+
// TODO: should probably contain more (or all?) backend keys
|
| 157 |
+
constexpr DispatchKeySet kKeysToPropagateToWrapper({
|
| 158 |
+
DispatchKey::Negative,
|
| 159 |
+
DispatchKey::Conjugate,
|
| 160 |
+
DispatchKey::XLA,
|
| 161 |
+
DispatchKey::XPU,
|
| 162 |
+
DispatchKey::HPU,
|
| 163 |
+
DispatchKey::CUDA,
|
| 164 |
+
DispatchKey::CPU,
|
| 165 |
+
DispatchKey::PrivateUse1,
|
| 166 |
+
DispatchKey::SparseCPU,
|
| 167 |
+
DispatchKey::SparseCUDA,
|
| 168 |
+
DispatchKey::SparseCsrCPU,
|
| 169 |
+
DispatchKey::SparseCsrCUDA,
|
| 170 |
+
});
|
| 171 |
+
|
| 172 |
+
inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) {
|
| 173 |
+
auto key_set = tensor.unsafeGetTensorImpl()->key_set();
|
| 174 |
+
return key_set & kKeysToPropagateToWrapper;
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
} // namespace at::functorch
|
| 178 |
+
|
| 179 |
+
#else
|
| 180 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 181 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchingMetaprogramming.h
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
// Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
// All rights reserved.
|
| 4 |
+
//
|
| 5 |
+
// This source code is licensed under the BSD-style license found in the
|
| 6 |
+
// LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
#pragma once
|
| 9 |
+
#include <ATen/Tensor.h>
|
| 10 |
+
#include <ATen/VmapGeneratedPlumbing.h>
|
| 11 |
+
|
| 12 |
+
// This file contains template metaprogramming things that are used for our
|
| 13 |
+
// batching rules.
|
| 14 |
+
//
|
| 15 |
+
// See NOTE: [vmap plumbing] for more details on why this is necessary.
|
| 16 |
+
// The plumbing has a bunch of metaprogramming hacks for determining the signature
|
| 17 |
+
// of a batching rule from the signature of the operator, many of which use the
|
| 18 |
+
// helper functions in this file.
|
| 19 |
+
|
| 20 |
+
namespace at::functorch {
|
| 21 |
+
|
| 22 |
+
// Metaprogramming things
|
| 23 |
+
template <class... Items> using typelist = c10::guts::typelist::typelist<Items...>;
|
| 24 |
+
template <class TypeList> using head_t = c10::guts::typelist::head_t<TypeList>;
|
| 25 |
+
template <class TL1, class TL2> using concat_t = c10::guts::typelist::concat_t<TL1, TL2>;
|
| 26 |
+
template <typename T> class debug_t;
|
| 27 |
+
|
| 28 |
+
// tail operation
|
| 29 |
+
template<class TypeList>
|
| 30 |
+
struct tail final {
|
| 31 |
+
static_assert(c10::guts::false_t<TypeList>::value,
|
| 32 |
+
"In typelist::tail<T>, the T argument must be typelist<...>.");
|
| 33 |
+
};
|
| 34 |
+
template<class Head, class... Tail>
|
| 35 |
+
struct tail<typelist<Head, Tail...>> final {
|
| 36 |
+
using type = typelist<Tail...>;
|
| 37 |
+
};
|
| 38 |
+
template<class TypeList> using tail_t = typename tail<TypeList>::type;
|
| 39 |
+
|
| 40 |
+
template <class First, class Second, class Next, class Tail>
|
| 41 |
+
struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext {
|
| 42 |
+
using type = Next;
|
| 43 |
+
};
|
| 44 |
+
template <class Next, class Tail>
|
| 45 |
+
struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<Tensor, std::optional<int64_t>, Next, Tail> {
|
| 46 |
+
using type = Tail;
|
| 47 |
+
};
|
| 48 |
+
template <class Next, class Tail>
|
| 49 |
+
struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<const Tensor&, std::optional<int64_t>, Next, Tail> {
|
| 50 |
+
using type = Tail;
|
| 51 |
+
};
|
| 52 |
+
template <class Next, class Tail>
|
| 53 |
+
struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<Tensor&, std::optional<int64_t>, Next, Tail> {
|
| 54 |
+
using type = Tail;
|
| 55 |
+
};
|
| 56 |
+
template <class Next, class Tail>
|
| 57 |
+
struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<std::optional<Tensor>, std::optional<int64_t>, Next, Tail> {
|
| 58 |
+
using type = Tail;
|
| 59 |
+
};
|
| 60 |
+
template <class Next, class Tail>
|
| 61 |
+
struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<const std::optional<Tensor>&, std::optional<int64_t>, Next, Tail> {
|
| 62 |
+
using type = Tail;
|
| 63 |
+
};
|
| 64 |
+
template <class Next, class Tail>
|
| 65 |
+
struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<std::optional<Tensor>&, std::optional<int64_t>, Next, Tail> {
|
| 66 |
+
using type = Tail;
|
| 67 |
+
};
|
| 68 |
+
template <class Next, class Tail>
|
| 69 |
+
struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<std::vector<Tensor>, std::optional<int64_t>, Next, Tail> {
|
| 70 |
+
using type = Tail;
|
| 71 |
+
};
|
| 72 |
+
template <class TypeList> struct RemoveBatchDimAfterTensor {
|
| 73 |
+
using first = head_t<TypeList>;
|
| 74 |
+
using next = tail_t<TypeList>;
|
| 75 |
+
using second = head_t<next>;
|
| 76 |
+
using tail = tail_t<next>;
|
| 77 |
+
|
| 78 |
+
using type = concat_t<
|
| 79 |
+
typelist<first>,
|
| 80 |
+
typename RemoveBatchDimAfterTensor<
|
| 81 |
+
typename IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<first, second, next, tail>::type
|
| 82 |
+
>::type
|
| 83 |
+
>;
|
| 84 |
+
};
|
| 85 |
+
template <class Type> struct RemoveBatchDimAfterTensor<typelist<Type>> {
|
| 86 |
+
using type = typelist<Type>;
|
| 87 |
+
};
|
| 88 |
+
template <> struct RemoveBatchDimAfterTensor<typelist<>> {
|
| 89 |
+
using type = typelist<>;
|
| 90 |
+
};
|
| 91 |
+
template<class TypeList> using remove_batch_dim_after_tensor_t = typename RemoveBatchDimAfterTensor<TypeList>::type;
|
| 92 |
+
|
| 93 |
+
template <typename T> struct UnpackSingleItemTuple {
|
| 94 |
+
using type = T;
|
| 95 |
+
};
|
| 96 |
+
template <typename T> struct UnpackSingleItemTuple<std::tuple<T>> {
|
| 97 |
+
using type = T;
|
| 98 |
+
};
|
| 99 |
+
template <typename T> using unpack_single_item_tuple_t = typename UnpackSingleItemTuple<T>::type;
|
| 100 |
+
|
| 101 |
+
template <typename Return, typename TupleArgs> struct BuildFunctionHelper;
|
| 102 |
+
template <typename Return, typename... Args> struct BuildFunctionHelper<Return, std::tuple<Args...>> {
|
| 103 |
+
using type = Return(Args...);
|
| 104 |
+
};
|
| 105 |
+
template <typename Return, typename TL>
|
| 106 |
+
struct BuildFunction {
|
| 107 |
+
using type = typename BuildFunctionHelper<Return, c10::guts::typelist::to_tuple_t<TL>>::type;
|
| 108 |
+
};
|
| 109 |
+
template <typename Return, typename TL> using build_function_t = typename BuildFunction<Return, TL>::type;
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
template <typename batch_rule_t> struct ToOperatorType {
|
| 113 |
+
using batch_rule_return_type = typename c10::guts::function_traits<batch_rule_t>::return_type;
|
| 114 |
+
using batch_rule_parameter_types = typename c10::guts::function_traits<batch_rule_t>::parameter_types;
|
| 115 |
+
|
| 116 |
+
using operator_parameter_types = remove_batch_dim_after_tensor_t<batch_rule_parameter_types>;
|
| 117 |
+
using operator_return_type =
|
| 118 |
+
unpack_single_item_tuple_t<
|
| 119 |
+
c10::guts::typelist::to_tuple_t<
|
| 120 |
+
remove_batch_dim_after_tensor_t<
|
| 121 |
+
c10::guts::typelist::from_tuple_t<batch_rule_return_type>>>>;
|
| 122 |
+
|
| 123 |
+
using type = build_function_t<operator_return_type, operator_parameter_types>;
|
| 124 |
+
};
|
| 125 |
+
template <typename batch_rule_t> using to_operator_t = typename ToOperatorType<batch_rule_t>::type;
|
| 126 |
+
|
| 127 |
+
} // namespace at::functorch
|
| 128 |
+
|
| 129 |
+
#else
|
| 130 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 131 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/functorch/DynamicLayer.h
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
// Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
// All rights reserved.
|
| 4 |
+
//
|
| 5 |
+
// This source code is licensed under the BSD-style license found in the
|
| 6 |
+
// LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
#pragma once
|
| 9 |
+
#include <ATen/functorch/Macros.h>
|
| 10 |
+
#include <c10/core/DispatchKey.h>
|
| 11 |
+
#include <ATen/core/function_schema.h>
|
| 12 |
+
#include <optional>
|
| 13 |
+
#include <c10/core/impl/LocalDispatchKeySet.h>
|
| 14 |
+
#include <ATen/functorch/Interpreter.h>
|
| 15 |
+
#include <ATen/functorch/VmapInterpreter.h>
|
| 16 |
+
#include <ATen/functorch/ADInterpreters.h>
|
| 17 |
+
#include <ATen/functorch/FunctionalizeInterpreter.h>
|
| 18 |
+
|
| 19 |
+
// Forward declared
|
| 20 |
+
namespace c10 { struct AutogradMetaInterface; }
|
| 21 |
+
|
| 22 |
+
namespace at::functorch {
|
| 23 |
+
|
| 24 |
+
// This file contains the implementation of functorch's interpreter stack.
|
| 25 |
+
// See NOTE: [functorch interpreter stack] first before reading on.
|
| 26 |
+
//
|
| 27 |
+
// NB: the functorch interpreter stack is also referred to as:
|
| 28 |
+
// - the "dynamic layer stack" -- an older name for "interpreter" was
|
| 29 |
+
// "dynamic layer".
|
| 30 |
+
// - the "functorch mode stack". You can think of each functorch transform as a
|
| 31 |
+
// "mode" (in the same sense as torch_dispatch mode or torch_function mode),
|
| 32 |
+
// and functorch being an implementation of a "mode stack" where the modes
|
| 33 |
+
// may be arbitrary composed.
|
| 34 |
+
|
| 35 |
+
// DynamicLayer is basically the same thing as an Interpreter.
|
| 36 |
+
// It represents a functorch transform and it holds an Interpreter,
|
| 37 |
+
// which contains metadata related to the transform and instructions on
|
| 38 |
+
// how to perform the transform.
|
| 39 |
+
//
|
| 40 |
+
// TODO: we can excise DynamicLayer in favor of Interpreter,
|
| 41 |
+
// But I am going to leave it for now as a compatibility shim to avoid
|
| 42 |
+
// needing to refactor a lot of callsites...
|
| 43 |
+
struct TORCH_API DynamicLayer {
|
| 44 |
+
explicit DynamicLayer(
|
| 45 |
+
TransformType transform_type,
|
| 46 |
+
int64_t layerId,
|
| 47 |
+
std::optional<c10::SymInt> batchSize = std::nullopt,
|
| 48 |
+
std::optional<RandomnessType> randomness = std::nullopt,
|
| 49 |
+
std::optional<bool> prev_grad_mode = std::nullopt,
|
| 50 |
+
std::optional<bool> pre_fwd_grad_mode = std::nullopt,
|
| 51 |
+
std::optional<bool> functionalize_add_back_views = std::nullopt);
|
| 52 |
+
|
| 53 |
+
TransformType key() const;
|
| 54 |
+
int64_t layerId() const;
|
| 55 |
+
|
| 56 |
+
const Interpreter& interpreter() const { return interpreter_; }
|
| 57 |
+
Interpreter& interpreter() { return interpreter_; }
|
| 58 |
+
|
| 59 |
+
// Only valid for vmap
|
| 60 |
+
c10::SymInt batchSize() const;
|
| 61 |
+
RandomnessType randomness() const;
|
| 62 |
+
|
| 63 |
+
private:
|
| 64 |
+
Interpreter interpreter_;
|
| 65 |
+
};
|
| 66 |
+
|
| 67 |
+
TORCH_API int64_t initAndPushDynamicLayer(
|
| 68 |
+
TransformType transform_type,
|
| 69 |
+
std::optional<c10::SymInt> batch_size = std::nullopt,
|
| 70 |
+
std::optional<RandomnessType> randomness = std::nullopt,
|
| 71 |
+
std::optional<bool> prev_grad_mode = std::nullopt,
|
| 72 |
+
std::optional<bool> prev_fwd_grad_mode = std::nullopt,
|
| 73 |
+
std::optional<bool> functionalize_add_back_views = std::nullopt);
|
| 74 |
+
TORCH_API DynamicLayer popDynamicLayerAndDeleteMetadata();
|
| 75 |
+
TORCH_API std::optional<DynamicLayer> maybeCurrentDynamicLayer();
|
| 76 |
+
TORCH_API const std::vector<DynamicLayer>& getDynamicLayerStack();
|
| 77 |
+
TORCH_API void setDynamicLayerStack(const std::vector<DynamicLayer>& stack);
|
| 78 |
+
TORCH_API void setDynamicLayerFrontBackKeysIncluded(bool included);
|
| 79 |
+
|
| 80 |
+
// NOTE: [Life handles and lexically scoped transforms]
|
| 81 |
+
// functorch transforms are lexically scoped.
|
| 82 |
+
// Given a level, we store a "life handle" that is a boolean that tells us if the
|
| 83 |
+
// transform with that level is active or not.
|
| 84 |
+
//
|
| 85 |
+
// functorch's TensorWrapper (for grad transforms) stores a life handle.
|
| 86 |
+
// If a TensorWrapper escapes from the scope of the transform, then somehow
|
| 87 |
+
// it must know it escaped; it can tell by querying the life handle.
|
| 88 |
+
TORCH_API const std::shared_ptr<bool>& getLifeHandleForLevel(int64_t level);
|
| 89 |
+
|
| 90 |
+
// Returns if an operator is in-place. An operator is inplace if:
|
| 91 |
+
// 1. The first argument is a Tensor and it is being written to
|
| 92 |
+
// 2. The first argument is being returned
|
| 93 |
+
// 3. No other arguments are aliased
|
| 94 |
+
// Here is an example of an in-place operator:
|
| 95 |
+
// add_(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
|
| 96 |
+
TORCH_API bool isInplaceOp(const c10::FunctionSchema& schema);
|
| 97 |
+
|
| 98 |
+
// Given the indices of unwrapped inputs and the schema, this returns the indices of any outputs that should remain unwrapped
|
| 99 |
+
TORCH_API std::optional<size_t> findAliasedOutput(const FunctionSchema& schema, const int64_t immutable_input);
|
| 100 |
+
|
| 101 |
+
TORCH_API Tensor unwrapIfDead(const Tensor& tensor);
|
| 102 |
+
TORCH_API bool isDeadTensorWrapper(const Tensor& tensor);
|
| 103 |
+
|
| 104 |
+
// Pretty printers
|
| 105 |
+
TORCH_API std::ostream& operator<<(std::ostream& os, const DynamicLayer& layer);
|
| 106 |
+
TORCH_API std::ostream& operator<<(std::ostream& os, const std::vector<DynamicLayer>& dynamicLayerStack);
|
| 107 |
+
|
| 108 |
+
// While a functorch transform is active, torch.autograd.function._SingleLevelFunction
|
| 109 |
+
// is disabled by default. The following two APIs are APIs for enabling
|
| 110 |
+
// it. These are not user-facing APIs. We can delete this in the future, but
|
| 111 |
+
// it is useful for debugging when something goes wrong with the
|
| 112 |
+
// autograd.Function <> functorch interaction, which uses _SingleLevelFunction,
|
| 113 |
+
// because it leads to loud errors if something is incorrect.
|
| 114 |
+
TORCH_API void setSingleLevelAutogradFunctionAllowed(bool allowed);
|
| 115 |
+
TORCH_API bool getSingleLevelAutogradFunctionAllowed();
|
| 116 |
+
|
| 117 |
+
// While a functorch grad transform is active, Tensor.requires_grad_() gets
|
| 118 |
+
// disabled. These two functions are the mechanism to controlling that.
|
| 119 |
+
TORCH_API void setInplaceRequiresGradAllowed(bool allowed);
|
| 120 |
+
TORCH_API bool getInplaceRequiresGradAllowed();
|
| 121 |
+
|
| 122 |
+
TORCH_API DynamicLayer popDynamicLayer();
|
| 123 |
+
TORCH_API int64_t pushDynamicLayer(DynamicLayer&& layer);
|
| 124 |
+
|
| 125 |
+
} // namespace at::functorch
|
| 126 |
+
|
| 127 |
+
#else
|
| 128 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 129 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/functorch/FunctionalizeInterpreter.h
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
#include <ATen/functorch/Interpreter.h>
|
| 4 |
+
|
| 5 |
+
namespace at::functorch {
|
| 6 |
+
|
| 7 |
+
// This is the interpreter that handles the functionalize() transform.
|
| 8 |
+
// See NOTE: [functorch interpreter stack] for more details.
|
| 9 |
+
|
| 10 |
+
struct FunctionalizeInterpreterPtr {
|
| 11 |
+
explicit FunctionalizeInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Functionalize); }
|
| 12 |
+
TransformType key() const { return base_->key(); }
|
| 13 |
+
int64_t level() const { return base_->level(); }
|
| 14 |
+
void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
| 15 |
+
void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
|
| 16 |
+
bool functionalizeAddBackViews() const {
|
| 17 |
+
return std::get<FunctionalizeInterpreterMeta>(base_->meta()).functionalizeAddBackViews_;
|
| 18 |
+
}
|
| 19 |
+
private:
|
| 20 |
+
const Interpreter* base_;
|
| 21 |
+
};
|
| 22 |
+
|
| 23 |
+
} // namespace at::functorch
|
| 24 |
+
|
| 25 |
+
#else
|
| 26 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 27 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/functorch/Interpreter.h
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/functorch/Macros.h>
|
| 5 |
+
#include <ATen/core/dispatch/Dispatcher.h>
|
| 6 |
+
#include <c10/core/impl/LocalDispatchKeySet.h>
|
| 7 |
+
#include <c10/util/Exception.h>
|
| 8 |
+
#include <optional>
|
| 9 |
+
#include <bitset>
|
| 10 |
+
#include <utility>
|
| 11 |
+
#include <variant>
|
| 12 |
+
|
| 13 |
+
#include <nlohmann/json.hpp>
|
| 14 |
+
|
| 15 |
+
namespace at::functorch {
|
| 16 |
+
|
| 17 |
+
// NOTE: [functorch interpreter stack]
|
| 18 |
+
//
|
| 19 |
+
// functorch's dispatching system uses a stack of interpreters.
|
| 20 |
+
// Historically we've referred to this as the "DynamicLayerStack".
|
| 21 |
+
//
|
| 22 |
+
// An interpreter is something that reads in the code it is passed
|
| 23 |
+
// and then executes it. We have a different interpreter per-transform:
|
| 24 |
+
// the "VmapInterpreter" is responsible for reading in operators (like aten::mv)
|
| 25 |
+
// and executing the batched version of it (the batching rule for aten::mv).
|
| 26 |
+
//
|
| 27 |
+
// Concretely, each interpreter is responsible for two things:
|
| 28 |
+
//
|
| 29 |
+
// 1) process(ophandle, stack)
|
| 30 |
+
// Given an operator handle and a stack of arguments, the interpreter is
|
| 31 |
+
// responsible for figuring out how to execute the operation under the semantics
|
| 32 |
+
// of the interpreter. For e.g. VmapInterpreter, this is figuring out how to call
|
| 33 |
+
// the batching rule.
|
| 34 |
+
//
|
| 35 |
+
// The batching rules are stored as kernels on the FuncTorchBatched key, so the way
|
| 36 |
+
// VmapInterpreter calls the batching rule is roughly: (A) exclude all
|
| 37 |
+
// dispatch keys aside from the Batched key, (B) redispatch so we get to the
|
| 38 |
+
// Batched key.
|
| 39 |
+
//
|
| 40 |
+
// 2) sendToNextInterpreter(ophandle, stack)
|
| 41 |
+
// The VmapInterpreter, when it sees aten::mv, will process it into a call to
|
| 42 |
+
// aten::mm. It then needs to send the call to aten::mm to the next interpreter
|
| 43 |
+
// in the interpreter stack.
|
| 44 |
+
//
|
| 45 |
+
// The VmapInterpreter just does this via a call to ophandle.callBoxed(stack)
|
| 46 |
+
// and most Interpreters will implement it this way.
|
| 47 |
+
|
| 48 |
+
enum class RandomnessType {
|
| 49 |
+
Error, // always errors when calling a random function
|
| 50 |
+
Same, // randomness appears the same across batches
|
| 51 |
+
Different, // randomness appears different across batches
|
| 52 |
+
END
|
| 53 |
+
};
|
| 54 |
+
|
| 55 |
+
enum class TransformType {
|
| 56 |
+
Torch, // Unused
|
| 57 |
+
Vmap,
|
| 58 |
+
Grad, // reverse-mode AD, aka vjp
|
| 59 |
+
Jvp, // forward-mode AD
|
| 60 |
+
Functionalize,
|
| 61 |
+
};
|
| 62 |
+
|
| 63 |
+
std::ostream& operator<<(std::ostream& os, const TransformType& t);
|
| 64 |
+
|
| 65 |
+
// NOTE: [Interpreter "subclassing" design]
|
| 66 |
+
//
|
| 67 |
+
// How are various Interpreters for different transforms (vmap, grad, ...)
|
| 68 |
+
// implemented?
|
| 69 |
+
//
|
| 70 |
+
// Accessing interpreters is in the hot-path of functorch so we have a constraint
|
| 71 |
+
// that this code must be as fast as possible.
|
| 72 |
+
//
|
| 73 |
+
// As a result, we stay away from virtual methods and this causes our code
|
| 74 |
+
// to look a little funny.
|
| 75 |
+
//
|
| 76 |
+
// `Interpreter` is the struct for Interpreters. It holds ALL of the
|
| 77 |
+
// relevant information (what type of interpreter it is and the metadata).
|
| 78 |
+
// Metadata for each interpreter is represented as a Union (std::variant)
|
| 79 |
+
// of all possible metadata (VmapInterpreterMeta, GradInterpreterMeta, ...).
|
| 80 |
+
//
|
| 81 |
+
// Given an Interpreter, how do I get a "VmapInterpreter"? You may wish to do this
|
| 82 |
+
// if you want to access the metadata fields (like batchSize and randomness).
|
| 83 |
+
//
|
| 84 |
+
// Each type of interpreter (e.g. Vmap) has a convenience struct
|
| 85 |
+
// (e.g. VmapInterpreterPtr) associated with it.
|
| 86 |
+
//
|
| 87 |
+
// Construct the convenience struct with VmapInterpreterPtr(Interpreter*),
|
| 88 |
+
// and then one can access methods on VmapInterpreterPtr like so:
|
| 89 |
+
// >>> VmapInterpreterPtr(&interpreter).batchSize()
|
| 90 |
+
//
|
| 91 |
+
// Finally, Interpreter::process switches on the type of the interpreter
|
| 92 |
+
// and calls one of {Transform}Interpreter::processImpl under the hood.
|
| 93 |
+
// Same for Interpreter::sendToNextInterpreter :)
|
| 94 |
+
|
| 95 |
+
struct VmapInterpreterMeta {
|
| 96 |
+
explicit VmapInterpreterMeta(c10::SymInt batchSize, RandomnessType randomness) :
|
| 97 |
+
batchSize_(std::move(batchSize)), randomness_(randomness) {}
|
| 98 |
+
|
| 99 |
+
c10::SymInt batchSize_;
|
| 100 |
+
RandomnessType randomness_;
|
| 101 |
+
|
| 102 |
+
VmapInterpreterMeta() = default;
|
| 103 |
+
VmapInterpreterMeta(const VmapInterpreterMeta&) = default;
|
| 104 |
+
VmapInterpreterMeta(VmapInterpreterMeta&&) = default;
|
| 105 |
+
VmapInterpreterMeta& operator=(const VmapInterpreterMeta&) = default;
|
| 106 |
+
VmapInterpreterMeta& operator=(VmapInterpreterMeta&&) = default;
|
| 107 |
+
~VmapInterpreterMeta() = default;
|
| 108 |
+
|
| 109 |
+
template <typename T>
|
| 110 |
+
friend void to_json(T& json_j, const VmapInterpreterMeta& json_t) {
|
| 111 |
+
TORCH_CHECK(
|
| 112 |
+
!json_t.batchSize_.is_heap_allocated(),
|
| 113 |
+
"Serialization for heap-allocated SymInt is not implemented yet"
|
| 114 |
+
);
|
| 115 |
+
json_j["batchSize"] = json_t.batchSize_.as_int_unchecked();
|
| 116 |
+
json_j["randomness"] = static_cast<int64_t>(json_t.randomness_);
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
template <typename T>
|
| 120 |
+
friend void from_json(const T& json_j, VmapInterpreterMeta& json_t) {
|
| 121 |
+
json_t.batchSize_ = c10::SymInt(SymInt::Unchecked::UNCHECKED, json_j["batchSize"]);
|
| 122 |
+
json_t.randomness_ = static_cast<RandomnessType>(json_j["randomness"]);
|
| 123 |
+
}
|
| 124 |
+
};
|
| 125 |
+
|
| 126 |
+
struct GradInterpreterMeta {
|
| 127 |
+
explicit GradInterpreterMeta(bool prevGradMode): prevGradMode_(prevGradMode) {}
|
| 128 |
+
GradInterpreterMeta() = default;
|
| 129 |
+
GradInterpreterMeta(const GradInterpreterMeta&) = default;
|
| 130 |
+
GradInterpreterMeta(GradInterpreterMeta&&) = default;
|
| 131 |
+
GradInterpreterMeta& operator=(const GradInterpreterMeta&) = default;
|
| 132 |
+
GradInterpreterMeta& operator=(GradInterpreterMeta&&) = default;
|
| 133 |
+
~GradInterpreterMeta() = default;
|
| 134 |
+
|
| 135 |
+
bool prevGradMode_;
|
| 136 |
+
template <typename T>
|
| 137 |
+
friend void to_json(T& json_j, const GradInterpreterMeta& json_t) {
|
| 138 |
+
json_j["prevGradMode"] = json_t.prevGradMode_;
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
template <typename T>
|
| 142 |
+
friend void from_json(const T& json_j, GradInterpreterMeta& json_t) {
|
| 143 |
+
json_t.prevGradMode_ = json_j["prevGradMode"];
|
| 144 |
+
}
|
| 145 |
+
};
|
| 146 |
+
|
| 147 |
+
struct JvpInterpreterMeta {
|
| 148 |
+
explicit JvpInterpreterMeta(bool prevFwdGradMode) : prevFwdGradMode_(prevFwdGradMode) {}
|
| 149 |
+
JvpInterpreterMeta() = default;
|
| 150 |
+
JvpInterpreterMeta(const JvpInterpreterMeta&) = default;
|
| 151 |
+
JvpInterpreterMeta(JvpInterpreterMeta&&) = default;
|
| 152 |
+
JvpInterpreterMeta& operator=(const JvpInterpreterMeta&) = default;
|
| 153 |
+
JvpInterpreterMeta& operator=(JvpInterpreterMeta&&) = default;
|
| 154 |
+
~JvpInterpreterMeta() = default;
|
| 155 |
+
|
| 156 |
+
bool prevFwdGradMode_;
|
| 157 |
+
template <typename T>
|
| 158 |
+
friend void to_json(T& json_j, const JvpInterpreterMeta& json_t) {
|
| 159 |
+
json_j["prevFwdGradMode"] = json_t.prevFwdGradMode_;
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
template <typename T>
|
| 163 |
+
friend void from_json(const T& json_j, JvpInterpreterMeta& json_t) {
|
| 164 |
+
json_t.prevFwdGradMode_ = json_j["prevFwdGradMode"];
|
| 165 |
+
}
|
| 166 |
+
};
|
| 167 |
+
|
| 168 |
+
struct FunctionalizeInterpreterMeta {
|
| 169 |
+
explicit FunctionalizeInterpreterMeta(bool functionalizeAddBackViews) :
|
| 170 |
+
functionalizeAddBackViews_(functionalizeAddBackViews) {}
|
| 171 |
+
FunctionalizeInterpreterMeta() = default;
|
| 172 |
+
FunctionalizeInterpreterMeta(const FunctionalizeInterpreterMeta&) = default;
|
| 173 |
+
FunctionalizeInterpreterMeta(FunctionalizeInterpreterMeta&&) = default;
|
| 174 |
+
FunctionalizeInterpreterMeta& operator=(const FunctionalizeInterpreterMeta&) = default;
|
| 175 |
+
FunctionalizeInterpreterMeta& operator=(FunctionalizeInterpreterMeta&&) = default;
|
| 176 |
+
~FunctionalizeInterpreterMeta() = default;
|
| 177 |
+
|
| 178 |
+
bool functionalizeAddBackViews_;
|
| 179 |
+
template <typename T>
|
| 180 |
+
friend void to_json(T& json_j, const FunctionalizeInterpreterMeta& json_t) {
|
| 181 |
+
json_j["functionalizeAddBackViews"] = json_t.functionalizeAddBackViews_;
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
template <typename T>
|
| 185 |
+
friend void from_json(const T& json_j, FunctionalizeInterpreterMeta& json_t) {
|
| 186 |
+
json_t.functionalizeAddBackViews_ = json_j["functionalizeAddBackViews"];
|
| 187 |
+
}
|
| 188 |
+
};
|
| 189 |
+
|
| 190 |
+
typedef std::variant<
|
| 191 |
+
int64_t,
|
| 192 |
+
GradInterpreterMeta,
|
| 193 |
+
JvpInterpreterMeta,
|
| 194 |
+
VmapInterpreterMeta,
|
| 195 |
+
FunctionalizeInterpreterMeta
|
| 196 |
+
> InterpreterMeta;
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
struct Interpreter {
|
| 200 |
+
// factory functions
|
| 201 |
+
static Interpreter Vmap(int64_t level, c10::SymInt batchSize, RandomnessType randomness) {
|
| 202 |
+
return Interpreter(TransformType::Vmap, level, VmapInterpreterMeta(std::move(batchSize), randomness));
|
| 203 |
+
}
|
| 204 |
+
static Interpreter Grad(int64_t level, bool prevGradMode) {
|
| 205 |
+
return Interpreter(TransformType::Grad, level, GradInterpreterMeta(prevGradMode));
|
| 206 |
+
}
|
| 207 |
+
static Interpreter Jvp(int64_t level, bool prevFwdGradMode) {
|
| 208 |
+
return Interpreter(TransformType::Jvp, level, JvpInterpreterMeta(prevFwdGradMode));
|
| 209 |
+
}
|
| 210 |
+
static Interpreter Functionalize(int64_t level, bool functionalizeAddBackViews) {
|
| 211 |
+
return Interpreter(TransformType::Functionalize, level, FunctionalizeInterpreterMeta(functionalizeAddBackViews));
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
// methods
|
| 215 |
+
TransformType key() const { return type_; }
|
| 216 |
+
int64_t level() const { return level_; }
|
| 217 |
+
const InterpreterMeta& meta() const { return meta_; }
|
| 218 |
+
|
| 219 |
+
void process(const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
| 220 |
+
void sendToNextInterpreter(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
|
| 221 |
+
|
| 222 |
+
void saveLocalDispatchKeySet(c10::impl::LocalDispatchKeySet keyset) {
|
| 223 |
+
TORCH_INTERNAL_ASSERT(!savedLocalDispatchKeySet_.has_value());
|
| 224 |
+
savedLocalDispatchKeySet_ = keyset;
|
| 225 |
+
}
|
| 226 |
+
void clearSavedLocalDispatchKeySet() {
|
| 227 |
+
TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value());
|
| 228 |
+
savedLocalDispatchKeySet_ = std::nullopt;
|
| 229 |
+
}
|
| 230 |
+
c10::impl::LocalDispatchKeySet getSavedLocalDispatchKeySet() const {
|
| 231 |
+
TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value());
|
| 232 |
+
return *savedLocalDispatchKeySet_;
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
// An Interpreter is alive if we are currently inside the ongoing transform
|
| 236 |
+
// for the interpreter. For example, vmap(f)(x); inside of f, the vmap's
|
| 237 |
+
// corresponding Interpreter is alive, even when it is not on the DynamicLayerStack.
|
| 238 |
+
bool is_alive() const {
|
| 239 |
+
return *is_alive_;
|
| 240 |
+
}
|
| 241 |
+
const std::shared_ptr<bool>& is_alive_ptr() const {
|
| 242 |
+
return is_alive_;
|
| 243 |
+
}
|
| 244 |
+
void set_is_alive(bool alive) {
|
| 245 |
+
*is_alive_ = alive;
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
// Please don't use this
|
| 249 |
+
explicit Interpreter() = default;
|
| 250 |
+
|
| 251 |
+
template <typename T>
|
| 252 |
+
friend void to_json(T& json_j, const Interpreter& json_t) {
|
| 253 |
+
json_j["type"] = static_cast<int64_t>(json_t.type_);
|
| 254 |
+
json_j["level"] = json_t.level_;
|
| 255 |
+
if (json_t.savedLocalDispatchKeySet_) {
|
| 256 |
+
json_j["savedLocalDispatchKeySet"] = {
|
| 257 |
+
{"included", json_t.savedLocalDispatchKeySet_->included_.raw_repr()},
|
| 258 |
+
{"excluded", json_t.savedLocalDispatchKeySet_->excluded_.raw_repr()}
|
| 259 |
+
};
|
| 260 |
+
} else {
|
| 261 |
+
json_j["savedLocalDispatchKeySet"] = nlohmann::json();
|
| 262 |
+
}
|
| 263 |
+
json_j["is_alive"] = *json_t.is_alive_;
|
| 264 |
+
std::visit([&](auto&& arg) {
|
| 265 |
+
using V = std::decay_t<decltype(arg)>;
|
| 266 |
+
if constexpr (std::is_same_v<V, int64_t>) {
|
| 267 |
+
json_j["meta"] = {{"Torch", arg}};
|
| 268 |
+
} else if constexpr (std::is_same_v<V, GradInterpreterMeta>) {
|
| 269 |
+
json_j["meta"] = {{"Grad", arg}};
|
| 270 |
+
} else if constexpr (std::is_same_v<V, JvpInterpreterMeta>) {
|
| 271 |
+
json_j["meta"] = {{"Jvp", arg}};
|
| 272 |
+
} else if constexpr (std::is_same_v<V, VmapInterpreterMeta>) {
|
| 273 |
+
json_j["meta"] = {{"Vmap", arg}};
|
| 274 |
+
} else if constexpr (std::is_same_v<V, FunctionalizeInterpreterMeta>) {
|
| 275 |
+
json_j["meta"] = {{"Functionalize", arg}};
|
| 276 |
+
} else {
|
| 277 |
+
static_assert(false && sizeof(V), "unknown variant case");
|
| 278 |
+
}
|
| 279 |
+
}, json_t.meta_);
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
template <typename T>
|
| 283 |
+
friend void from_json(const T& json_j, Interpreter& json_t) {
|
| 284 |
+
json_t.type_ = static_cast<TransformType>(json_j["type"]);
|
| 285 |
+
json_t.level_ = json_j["level"];
|
| 286 |
+
auto savedLocalDispatchKeySet = json_j["savedLocalDispatchKeySet"];
|
| 287 |
+
if (savedLocalDispatchKeySet.is_null()) {
|
| 288 |
+
json_t.savedLocalDispatchKeySet_ = std::nullopt;
|
| 289 |
+
} else {
|
| 290 |
+
c10::impl::PODLocalDispatchKeySet pod;
|
| 291 |
+
pod.set_included(DispatchKeySet::from_raw_repr(savedLocalDispatchKeySet["included"].template get<uint64_t>()));
|
| 292 |
+
pod.set_excluded(DispatchKeySet::from_raw_repr(savedLocalDispatchKeySet["excluded"].template get<uint64_t>()));
|
| 293 |
+
json_t.savedLocalDispatchKeySet_ = c10::impl::LocalDispatchKeySet(pod);
|
| 294 |
+
}
|
| 295 |
+
json_t.is_alive_ = std::make_shared<bool>(json_j["is_alive"]);
|
| 296 |
+
auto meta = json_j["meta"];
|
| 297 |
+
if (meta.contains("Torch")) {
|
| 298 |
+
json_t.meta_.emplace<int64_t>(meta["Torch"].template get<int64_t>());
|
| 299 |
+
} else if (meta.contains("Grad")) {
|
| 300 |
+
json_t.meta_.emplace<GradInterpreterMeta>(meta["Grad"].template get<GradInterpreterMeta>());
|
| 301 |
+
} else if (meta.contains("Jvp")) {
|
| 302 |
+
json_t.meta_.emplace<JvpInterpreterMeta>(meta["Jvp"].template get<JvpInterpreterMeta>());
|
| 303 |
+
} else if (meta.contains("Vmap")) {
|
| 304 |
+
json_t.meta_.emplace<VmapInterpreterMeta>(meta["Vmap"].template get<VmapInterpreterMeta>());
|
| 305 |
+
} else if (meta.contains("Functionalize")) {
|
| 306 |
+
json_t.meta_.emplace<FunctionalizeInterpreterMeta>(meta["Functionalize"].template get<FunctionalizeInterpreterMeta>());
|
| 307 |
+
} else {
|
| 308 |
+
TORCH_CHECK(false, "unknown interpreter metadata type");
|
| 309 |
+
}
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
std::string serialize() const {
|
| 313 |
+
return nlohmann::json(*this).dump();
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
static Interpreter deserialize(const std::string& serialized) {
|
| 317 |
+
return nlohmann::json::parse(serialized).get<Interpreter>();
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
private:
|
| 321 |
+
explicit Interpreter(TransformType type, int64_t level, InterpreterMeta meta):
|
| 322 |
+
type_(type), level_(level), is_alive_(std::make_shared<bool>(false)), meta_(std::move(meta)) {}
|
| 323 |
+
|
| 324 |
+
// fields
|
| 325 |
+
TransformType type_{};
|
| 326 |
+
int64_t level_{};
|
| 327 |
+
std::optional<c10::impl::LocalDispatchKeySet> savedLocalDispatchKeySet_;
|
| 328 |
+
std::shared_ptr<bool> is_alive_;
|
| 329 |
+
InterpreterMeta meta_;
|
| 330 |
+
};
|
| 331 |
+
|
| 332 |
+
// Applies the following for-loop:
|
| 333 |
+
// for i in range(begin, end):
|
| 334 |
+
// args[i] = func(args[i])
|
| 335 |
+
void foreachTensorInplace(std::vector<IValue>& args, int64_t begin, int64_t end,
|
| 336 |
+
std::function<Tensor(const Tensor&)> func);
|
| 337 |
+
|
| 338 |
+
// Applies the following for-loop:
|
| 339 |
+
// for i in range(begin, end):
|
| 340 |
+
// if use_flag_relative[i] == 1: <-- treats use_flag_relative as a bitset
|
| 341 |
+
// args[i] = func(args[i], i - begin, true)
|
| 342 |
+
// args[i] = func(args[i], i - begin)
|
| 343 |
+
void foreachTensorInplaceWithFlag(std::vector<IValue>& args, int64_t begin, int64_t end,
|
| 344 |
+
const std::bitset<64> use_flag_relative, const std::function<Tensor(const Tensor&, bool)>& func);
|
| 345 |
+
|
| 346 |
+
std::vector<int64_t> findUnwrappedInputs(std::vector<IValue>& args, int64_t begin, int64_t end);
|
| 347 |
+
|
| 348 |
+
DispatchKeySet keysToExcludeWhenEnteringDynamicLayer(TransformType key);
|
| 349 |
+
|
| 350 |
+
void setup_dispatch_key_tls(TransformType key, DispatchKeySet include);
|
| 351 |
+
|
| 352 |
+
void sanityCheckStack(const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
| 353 |
+
|
| 354 |
+
} // namespace at::functorch
|
| 355 |
+
|
| 356 |
+
#else
|
| 357 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 358 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|