BryanW commited on
Commit
3f10421
·
verified ·
1 Parent(s): fd876a9

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/BoxedKernel.h +218 -0
  2. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/BoxedKernel_impl.h +111 -0
  3. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/KernelFunction.h +346 -0
  4. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/KernelFunction_impl.h +395 -0
  5. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/OperatorKernel.h +32 -0
  6. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h +43 -0
  7. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h +46 -0
  8. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/impl/boxing.h +415 -0
  9. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h +790 -0
  10. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/impl/test_helpers.h +145 -0
  11. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/CppSignature.h +72 -0
  12. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/DispatchKeyExtractor.h +285 -0
  13. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/Dispatcher.h +955 -0
  14. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/ObservedOperators.h +22 -0
  15. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/OperatorEntry.h +342 -0
  16. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/OperatorOptions.h +35 -0
  17. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/RegistrationHandleRAII.h +41 -0
  18. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/op_registration/adaption.h +86 -0
  19. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/op_registration/infer_schema.h +162 -0
  20. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/op_registration/op_allowlist.h +186 -0
  21. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/op_registration/op_registration.h +599 -0
  22. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/FlushDenormal.h +19 -0
  23. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/Utils.h +38 -0
  24. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vml.h +175 -0
  25. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/BLASConstants.h +16 -0
  26. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/CUDAHooks.h +76 -0
  27. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/DeviceThreadHandles.h +156 -0
  28. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/IndexUtils.cuh +41 -0
  29. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/IntegerDivider.cuh +129 -0
  30. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/KernelUtils.h +42 -0
  31. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/LazyNVRTC.h +16 -0
  32. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/OffsetCalculator.cuh +141 -0
  33. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/PhiloxCudaStateRaw.cuh +48 -0
  34. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/TensorInfo.cuh +121 -0
  35. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/UnpackRaw.cuh +39 -0
  36. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/GemmCommon.h +705 -0
  37. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/GemmHipblaslt.h +692 -0
  38. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/GemmRocblas.h +282 -0
  39. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/StreamTimer.h +55 -0
  40. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/Tunable.h +270 -0
  41. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/TunableGemm.h +334 -0
  42. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/TunableOp.h +434 -0
  43. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/functorch/ADInterpreters.h +43 -0
  44. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchRulesHelper.h +486 -0
  45. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchedFallback.h +86 -0
  46. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchedTensorImpl.h +181 -0
  47. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchingMetaprogramming.h +131 -0
  48. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/functorch/DynamicLayer.h +129 -0
  49. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/functorch/FunctionalizeInterpreter.h +27 -0
  50. URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/functorch/Interpreter.h +358 -0
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/BoxedKernel.h ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/core/boxing/OperatorKernel.h>
5
+ #include <c10/core/DispatchKeySet.h>
6
+ #include <c10/util/intrusive_ptr.h>
7
+
8
+ namespace c10 {
9
+
10
+ struct IValue;
11
+ using Stack = std::vector<IValue>;
12
+
13
+ class OperatorHandle;
14
+ class KernelFunction;
15
+
16
+ // This kernel implements the behavior of falling through to the next available
17
+ // registered dispatch key. The implementation of this function is FAST; it is
18
+ // no overhead to fallthrough to the next key. See cpp file for some more
19
+ // implementation notes; notably, this does NOT actually go through the
20
+ // boxing/unboxing codepath.
21
+ TORCH_API void fallthrough_kernel(
22
+ OperatorKernel* /*unused*/,
23
+ const OperatorHandle& /*unused*/,
24
+ DispatchKeySet /*unused*/,
25
+ Stack* /*unused*/);
26
+
27
+ // Note [Ambiguity in AutogradOther kernel]
28
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
29
+ // This error-reporting kernel is registered to the AutogradOther entry in the
30
+ // dispatch table when there is both a CompositeImplicitAutograd kernel and a
31
+ // backend kernel for ANY backend that maps to AutogradOther. To see why
32
+ // this is necessary in the AutogradOther case, it's helpful to first see
33
+ // why everything works out fine for a backend that has a reserved Autograd
34
+ // entry (see rule 2.2 in [Note] DispatchTable computation):
35
+ //
36
+ // CPU AutogradCPU
37
+ // reg? registers with...
38
+ // -------------------------------------------------
39
+ // y Autograd registration takes precedence
40
+ // over CompositeImplicitAutograd.
41
+ // This is good, because the CPU specific backend
42
+ // implementation is more specialized and typically better;
43
+ // if we used the composite, we would bypass it.
44
+ // (NB: the Autograd key is guaranteed to exist because
45
+ // the autograd codegen requires it!)
46
+ //
47
+ // n CompositeImplicitAutograd takes precedence.
48
+ // This is also good, because the Autograd
49
+ // registration (if it exists) would try to redispatch
50
+ // to the (non-existent) CPU implementation; by
51
+ // using the composite, we ensure the operator
52
+ // actually works.
53
+ //
54
+ // As you can see, when we have a specific Autograd key (AutogradCPU), we can
55
+ // decide whether or not to use the CompositeImplicitAutograd kernel or the
56
+ // Autograd kernel based on whether or not the backend kernel exists.
57
+ //
58
+ // However, for AutogradOther (which is the catchall autograd kernel for
59
+ // everything that doesn't have a specific Autograd key), we can't do this
60
+ // trick because there isn't any unique backend to peek at to disambiguate;
61
+ // if there are some backends that have implementations they prefer Autograd,
62
+ // but unimplemented backends would prefer CompositeImplicitAutograd. Rather
63
+ // than arbitrarily pick one or the other, we just register a kernel that raises
64
+ // an error and let the user decide how to proceed.
65
+ TORCH_API void ambiguous_autogradother_kernel(
66
+ OperatorKernel* /*unused*/,
67
+ const OperatorHandle& /*op*/,
68
+ DispatchKeySet /*unused*/,
69
+ Stack* /*unused*/);
70
+
71
+ // Note [named_not_supported_kernel]
72
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
73
+ // This kernel implements reporting an error message saying that named tensor is
74
+ // not supported. This kernel doesn't rely on the Stack, and so it is special
75
+ // cased in the dispatcher to be triggered before we attempt boxing (so we can
76
+ // give a good error message in cases when boxing is not supported). When
77
+ // boxing is universally supported this can be removed.
78
+ [[noreturn]] TORCH_API void named_not_supported_kernel(
79
+ OperatorKernel* /*unused*/,
80
+ const OperatorHandle& /*op*/,
81
+ DispatchKeySet /*unused*/,
82
+ Stack* /*unused*/);
83
+
84
+ /**
85
+ * BoxedKernel is similar to a std::function storing a boxed kernel.
86
+ */
87
+ class TORCH_API BoxedKernel final {
88
+ public:
89
+ // This is how boxed kernels are actually stored
90
+ //
91
+ // Note [Plumbing Keys Through The Dispatcher]
92
+ // Benchmarks have shown that it is expensive for the dispatcher to read from
93
+ // thread-local storage (TLS) upon every dispatch call into order to compute
94
+ // which kernel to dispatch to.
95
+ //
96
+ // To mitigate this, we've updated the calling convention inside the
97
+ // dispatcher to expect every kernel that it stores to have a first argument
98
+ // of type DispatchKeySet.
99
+ //
100
+ // What are the invariants of the DispatchKeySet when it gets passed to a
101
+ // kernel?
102
+ // - All keys to the left of the current dispatch key have been masked out.
103
+ // (e.g. a Tracing kernel that takes in the DispatchKeySet will expect the
104
+ // highest bit to be DispatchKey::Tracer)
105
+ // - All other keys that dispatcher normally would have computed through TLS +
106
+ // global state + op arguments
107
+ // are still in the set.
108
+ //
109
+ // Kernels can then opt into using this keyset to save the dispatcher from
110
+ // doing repeated work during redispatches: recalculating the highest-priority
111
+ // dispatch key, which involves reading from TLS. Instead, the kernels that
112
+ // opt in will calculate an updated DispatchKeySet directly from the old one,
113
+ // and pass the updated set directly into the dispatcher upon redispatching.
114
+ //
115
+ // This is an opt-in mechanism: Kernels can automatically opt in by setting
116
+ // the first argument in their signature to be of type DispatchKeySet. See the
117
+ // kernels in VariableTypeEverything.cpp and TraceTypeEverything.cpp for
118
+ // examples.
119
+ //
120
+ // The mechanism for optionally passing that DispatchKeySet into the kernel
121
+ // lives in make_boxed_from_unboxed_functor.h. See Note [Plumbing Keys Through
122
+ // The Dispatcher 2] for details.
123
+ using InternalBoxedKernelFunction =
124
+ void(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*);
125
+ // This is the public API for how boxed kernels are defined
126
+ using BoxedKernelFunction = void(const OperatorHandle&, Stack*);
127
+ using BoxedKernelFunction_withDispatchKeys =
128
+ void(const OperatorHandle&, DispatchKeySet, Stack*);
129
+
130
+ BoxedKernel();
131
+
132
+ // Fast path for dispatch to allow not touching the boxed kernel in
133
+ // the common case where unboxed is available.
134
+ bool isValid() const;
135
+ bool isFallthrough() const;
136
+
137
+ /**
138
+ * Call the function with boxed arguments.
139
+ */
140
+ void callBoxed(
141
+ const OperatorHandle& opHandle,
142
+ DispatchKeySet dispatchKeySet,
143
+ Stack* stack) const;
144
+
145
+ /**
146
+ * Create a KernelFunction from a boxed function.
147
+ *
148
+ * Example:
149
+ *
150
+ * > void boxed_func(OperatorKernel*, Stack* stack) {...}
151
+ * > BoxedFunction func = BoxedKernel::makeFromFunction<&boxed_func>();
152
+ */
153
+ template <BoxedKernelFunction* func>
154
+ static BoxedKernel makeFromFunction();
155
+
156
+ /**
157
+ * TODO: This will only be useful if we write a backend fallback that plumbs
158
+ * dispatch keys (currently there are none) See Note [Plumbing Keys Through
159
+ * The Dispatcher] for details.
160
+ */
161
+ template <BoxedKernelFunction_withDispatchKeys* func>
162
+ static BoxedKernel makeFromFunction();
163
+
164
+ /**
165
+ * Create a KernelFunction from a boxed functor.
166
+ *
167
+ * Example:
168
+ *
169
+ * > class MyFunctor final : public c10::OperatorKernel {
170
+ * > public:
171
+ * > void operator()(const OperatorHandle&, DispatchKeySet, Stack*) {...}
172
+ * > };
173
+ * > BoxedKernel func =
174
+ * BoxedKernel::makeFromFunctor(std::make_unique<MyFunctor>());
175
+ */
176
+ template <class KernelFunctor>
177
+ static BoxedKernel makeFromFunctor(
178
+ std::unique_ptr<KernelFunctor> kernelFunctor);
179
+
180
+ static BoxedKernel makeFallthrough();
181
+ static BoxedKernel makeAmbiguousAutogradOther();
182
+ static BoxedKernel makeNamedNotSupported();
183
+
184
+ private:
185
+ friend class KernelFunction;
186
+
187
+ template <BoxedKernelFunction* func>
188
+ static void make_boxed_function(
189
+ OperatorKernel* /*unused*/,
190
+ const OperatorHandle& opHandle,
191
+ DispatchKeySet /*unused*/,
192
+ Stack* stack);
193
+
194
+ template <BoxedKernelFunction_withDispatchKeys* func>
195
+ static void make_boxed_function(
196
+ OperatorKernel* /*unused*/,
197
+ const OperatorHandle& opHandle,
198
+ DispatchKeySet /*ks*/,
199
+ Stack* stack);
200
+
201
+ explicit BoxedKernel(
202
+ std::unique_ptr<OperatorKernel> functor,
203
+ InternalBoxedKernelFunction* boxed_kernel_func);
204
+
205
+ OperatorKernel* getFunctor() const;
206
+ InternalBoxedKernelFunction* getFnPtr() const;
207
+
208
+ c10::intrusive_ptr<OperatorKernel> functor_;
209
+ InternalBoxedKernelFunction* boxed_kernel_func_;
210
+ };
211
+
212
+ } // namespace c10
213
+
214
+ #include <ATen/core/boxing/BoxedKernel_impl.h>
215
+
216
+ #else
217
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
218
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/BoxedKernel_impl.h ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ namespace c10 {
5
+
6
+ inline BoxedKernel::BoxedKernel() : boxed_kernel_func_(nullptr) {}
7
+
8
+ inline BoxedKernel::BoxedKernel(
9
+ std::unique_ptr<OperatorKernel> functor,
10
+ InternalBoxedKernelFunction* boxed_kernel_func)
11
+ : functor_(std::move(functor)), boxed_kernel_func_(boxed_kernel_func) {}
12
+
13
+ template <BoxedKernel::BoxedKernelFunction* func>
14
+ inline void BoxedKernel::make_boxed_function(
15
+ OperatorKernel* /*unused*/,
16
+ const OperatorHandle& opHandle,
17
+ DispatchKeySet /*unused*/,
18
+ Stack* stack) {
19
+ // Note that we're dropping the DispatchKeySet argument.
20
+ // See Note [Plumbing Keys Through The Dispatcher 2] for details.
21
+ func(opHandle, stack);
22
+ }
23
+
24
+ template <BoxedKernel::BoxedKernelFunction_withDispatchKeys* func>
25
+ inline void BoxedKernel::make_boxed_function(
26
+ OperatorKernel* /*unused*/,
27
+ const OperatorHandle& opHandle,
28
+ DispatchKeySet ks,
29
+ Stack* stack) {
30
+ // See Note [Plumbing Keys Through The Dispatcher 2] for details.
31
+ func(opHandle, ks, stack);
32
+ }
33
+
34
+ inline bool BoxedKernel::isValid() const {
35
+ return boxed_kernel_func_ != nullptr;
36
+ }
37
+
38
+ inline bool BoxedKernel::isFallthrough() const {
39
+ return boxed_kernel_func_ == &fallthrough_kernel;
40
+ }
41
+
42
+ inline void BoxedKernel::callBoxed(
43
+ const OperatorHandle& opHandle,
44
+ DispatchKeySet dispatchKeySet,
45
+ Stack* stack) const {
46
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
47
+ boxed_kernel_func_ != nullptr,
48
+ "Tried to call BoxedKernel::callBoxed() on an uninitialized BoxedKernel.");
49
+ (*boxed_kernel_func_)(functor_.get(), opHandle, dispatchKeySet, stack);
50
+ }
51
+
52
+ template <BoxedKernel::BoxedKernelFunction* func>
53
+ inline BoxedKernel BoxedKernel::makeFromFunction() {
54
+ return BoxedKernel(
55
+ nullptr, // no functor_ object
56
+ &make_boxed_function<func>);
57
+ }
58
+
59
+ template <BoxedKernel::BoxedKernelFunction_withDispatchKeys* func>
60
+ inline BoxedKernel BoxedKernel::makeFromFunction() {
61
+ return BoxedKernel(
62
+ nullptr, // no functor_ object
63
+ &make_boxed_function<func>);
64
+ }
65
+
66
+ inline BoxedKernel BoxedKernel::makeFallthrough() {
67
+ return BoxedKernel(
68
+ nullptr, // no functor_ object
69
+ &fallthrough_kernel);
70
+ }
71
+
72
+ inline BoxedKernel BoxedKernel::makeAmbiguousAutogradOther() {
73
+ return BoxedKernel(
74
+ nullptr, // no functor_ object
75
+ &ambiguous_autogradother_kernel);
76
+ }
77
+
78
+ inline BoxedKernel BoxedKernel::makeNamedNotSupported() {
79
+ return BoxedKernel(
80
+ nullptr, // no functor_ object
81
+ &named_not_supported_kernel);
82
+ }
83
+
84
+ template <class KernelFunctor>
85
+ inline BoxedKernel BoxedKernel::makeFromFunctor(
86
+ std::unique_ptr<KernelFunctor> kernelFunctor) {
87
+ static_assert(
88
+ std::is_base_of_v<OperatorKernel, KernelFunctor>,
89
+ "Tried to call BoxedKernel::makeFromFunctor<KernelFunctor>, but the functor doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
90
+ return BoxedKernel(
91
+ std::move(kernelFunctor),
92
+ [](OperatorKernel* kernel,
93
+ const OperatorHandle& op,
94
+ DispatchKeySet ks,
95
+ Stack* stack) {
96
+ (*static_cast<KernelFunctor*>(kernel))(op, ks, stack);
97
+ });
98
+ }
99
+
100
+ inline OperatorKernel* BoxedKernel::getFunctor() const {
101
+ return functor_.get();
102
+ }
103
+ inline BoxedKernel::InternalBoxedKernelFunction* BoxedKernel::getFnPtr() const {
104
+ return boxed_kernel_func_;
105
+ }
106
+
107
+ } // namespace c10
108
+
109
+ #else
110
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
111
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/KernelFunction.h ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/core/ATen_fwd.h>
5
+ #include <ATen/core/boxing/BoxedKernel.h>
6
+ #include <ATen/core/stack.h>
7
+ #include <c10/core/DispatchKeySet.h>
8
+ #include <c10/util/TypeList.h>
9
+ #include <c10/util/intrusive_ptr.h>
10
+ #include <atomic>
11
+ #include <memory>
12
+ #include <type_traits>
13
+
14
+ namespace c10 {
15
+
16
+ using Stack = torch::jit::Stack; // TODO Instead of this, move torch::jit::Stack
17
+ // to the c10 namespace.
18
+
19
+ class OperatorHandle;
20
+ struct OperatorKernel;
21
+ class KernelFunction;
22
+
23
+ class KernelToken;
24
+ class SafeKernelFunction;
25
+
26
+ template <typename T>
27
+ using has_symint = std::disjunction<
28
+ std::is_same<c10::SymInt, T>,
29
+ std::is_same<c10::SymIntArrayRef, T>,
30
+ std::is_same<at::OptionalSymIntArrayRef, T>,
31
+ std::is_same<std::optional<c10::SymInt>, T>>;
32
+
33
+ template <typename T>
34
+ struct remove_symint {
35
+ using type = T;
36
+ };
37
+
38
+ template <>
39
+ struct remove_symint<c10::SymInt> {
40
+ using type = int64_t;
41
+ };
42
+
43
+ template <>
44
+ struct remove_symint<at::OptionalSymIntArrayRef> {
45
+ using type = OptionalIntArrayRef;
46
+ };
47
+
48
+ template <>
49
+ struct remove_symint<c10::SymIntArrayRef> {
50
+ using type = c10::IntArrayRef;
51
+ };
52
+
53
+ template <>
54
+ struct remove_symint<std::optional<c10::SymInt>> {
55
+ using type = std::optional<int64_t>;
56
+ };
57
+
58
+ template <bool symint, typename T>
59
+ struct maybe_keep_symint final {};
60
+
61
+ template <typename T>
62
+ struct maybe_keep_symint<true, T> {
63
+ using type = T;
64
+ };
65
+
66
+ template <typename T>
67
+ struct maybe_keep_symint<false, T> {
68
+ using type = typename remove_symint<T>::type;
69
+ };
70
+
71
+ template <typename T>
72
+ using fn_has_symint = typename guts::typelist::true_for_any_type<
73
+ has_symint,
74
+ typename guts::infer_function_traits<T>::type::parameter_types>;
75
+
76
+ template <typename T>
77
+ struct fn_remove_symint;
78
+
79
+ template <typename Ret, typename... Args>
80
+ struct fn_remove_symint<Ret(Args...)> {
81
+ using type = Ret(typename remove_symint<Args>::type...);
82
+ };
83
+
84
+ /**
85
+ * KernelFunction is similar to std::function but stores a kernel function.
86
+ * You can create a KernelFunction from a boxed or unboxed
87
+ * function/functor/lambda and call it in a boxed or unboxed way. If the way it
88
+ * was created doesn't match the way it was called, it will do boxing or
89
+ * unboxing as necessary.
90
+ */
91
+ class TORCH_API KernelFunction final {
92
+ public:
93
+ using InternalBoxedKernelFunction = BoxedKernel::InternalBoxedKernelFunction;
94
+ using BoxedKernelFunction = BoxedKernel::BoxedKernelFunction;
95
+ using BoxedKernelFunction_withDispatchKeys =
96
+ BoxedKernel::BoxedKernelFunction_withDispatchKeys;
97
+
98
+ KernelFunction();
99
+ ~KernelFunction();
100
+
101
+ KernelFunction(const KernelFunction& other);
102
+ KernelFunction& operator=(const KernelFunction& other);
103
+
104
+ KernelFunction(KernelFunction&&) noexcept = default;
105
+
106
+ // Fast path for dispatch to allow not touching the boxed kernel in
107
+ // the common case where unboxed is available.
108
+ bool isValidUnboxed() const;
109
+ bool isValidSymUnboxed() const;
110
+ bool isValid() const;
111
+ bool isFallthrough() const;
112
+
113
+ /**
114
+ * Call the function in a boxed way.
115
+ * If the kernel function was created with an unboxed function,
116
+ * this will call an unboxing wrapper which then calls into that
117
+ * unboxed function.
118
+ *
119
+ * Example:
120
+ *
121
+ * > void boxed_func(OperatorKernel*, Stack* stack) {...}
122
+ * > KernelFunction func = KernelFunction::makeFromBoxedFunction(&boxed_func);
123
+ * > Tensor result = func.callBoxed(stack);
124
+ *
125
+ * Or, with an unboxed implementation:
126
+ *
127
+ * > KernelFunction func = KernelFunction::makeFromUnboxedLambda(
128
+ * > [] (Tensor a, bool b) -> Tensor {...});
129
+ * > Tensor result = func.callBoxed(stack);
130
+ */
131
+ void callBoxed(
132
+ const OperatorHandle& opHandle,
133
+ DispatchKeySet dispatchKeySet,
134
+ Stack* stack) const;
135
+
136
+ /**
137
+ * Call the function in an unboxed way.
138
+ * If the kernel function was created with a boxed function,
139
+ * this will box all inputs and then call into that boxed function.
140
+ *
141
+ * Note that this doesn't work for all types yet.
142
+ *
143
+ * Example:
144
+ *
145
+ * > KernelFunction func = KernelFunction::makeFromUnboxedLambda(
146
+ * > [] (Tensor a, bool b) -> Tensor {...});
147
+ * > Tensor result = func.call<Tensor, Tensor, bool>(tensor1, true);
148
+ *
149
+ * Or, with a boxed implementation:
150
+ *
151
+ * > void boxed_func(OperatorKernel*, Stack* stack) {...}
152
+ * > KernelFunction func = KernelFunction::makeFromBoxedFunction(&boxed_func);
153
+ * > Tensor result = func.call<Tensor, Tensor, bool>(tensor1, true);
154
+ */
155
+ template <class Return, class... Args>
156
+ Return call(
157
+ const OperatorHandle& opHandle,
158
+ DispatchKeySet dispatchKeySet,
159
+ Args... args) const;
160
+
161
+ /**
162
+ * Create a KernelFunction from a BoxedKernel.
163
+ */
164
+ static KernelFunction makeFromBoxedKernel(BoxedKernel boxed_fn);
165
+
166
+ /**
167
+ * Create a KernelFunction from a boxed function.
168
+ *
169
+ * Example:
170
+ *
171
+ * > void boxed_func(OperatorKernel*, Stack* stack) {...}
172
+ * > KernelFunction func =
173
+ * KernelFunction::makeFromBoxedFunction<&boxed_func>();
174
+ */
175
+ template <BoxedKernelFunction* func>
176
+ static KernelFunction makeFromBoxedFunction();
177
+
178
+ /**
179
+ * TODO: This will only be useful if we write a backend fallback that plumbs
180
+ * dispatch keys (currently there are none) See Note [Plumbing Keys Through
181
+ * The Dispatcher] for details.
182
+ */
183
+ template <BoxedKernelFunction_withDispatchKeys* func>
184
+ static KernelFunction makeFromBoxedFunction();
185
+
186
+ /**
187
+ * Create a KernelFunction from an unboxed functor.
188
+ *
189
+ * Example:
190
+ *
191
+ * > class MyFunctor final : public c10::OperatorKernel {
192
+ * > public:
193
+ * > Tensor operator()(Tensor a, Tensor b) {...}
194
+ * > };
195
+ * > KernelFunction func =
196
+ * KernelFunction::makeFromUnboxedFunctor<MyFunctor>(std::make_unique<MyFunctor>());
197
+ */
198
+ template <bool AllowLegacyTypes = false, class KernelFunctor>
199
+ static KernelFunction makeFromUnboxedFunctor(
200
+ std::unique_ptr<OperatorKernel> kernelFunctor);
201
+
202
+ /**
203
+ * Create a KernelFunction from a boxed functor.
204
+ *
205
+ * Example:
206
+ *
207
+ * > class MyFunctor final : public c10::OperatorKernel {
208
+ * > public:
209
+ * > void operator()(const OperatorHandle&, DispatchKeySet, Stack*) {...}
210
+ * > };
211
+ * > KernelFunction func =
212
+ * KernelFunction::makeFromBoxedFunctor(std::make_unique<MyFunctor>());
213
+ */
214
+ template <class KernelFunctor>
215
+ static KernelFunction makeFromBoxedFunctor(
216
+ std::unique_ptr<KernelFunctor> kernelFunctor);
217
+
218
+ /**
219
+ * Create a KernelFunction from an unboxed function.
220
+ * This is usually better than KernelFunction::makeFromUnboxedRuntimeFunction
221
+ * because knowing the function pointer as a template argument (i.e. at
222
+ * compile time) allows the compiler to inline the function into its
223
+ * unboxing wrapper and yields better performance when calling the function.
224
+ *
225
+ * Example:
226
+ *
227
+ * > Tensor unboxed_func(Tensor a, Tensor b) {...}
228
+ * > KernelFunction func =
229
+ * KernelFunction::makeFromUnboxedFunction<decltype(unboxed_func),
230
+ * &unboxed_func>();
231
+ */
232
+ template <class FuncPtr, bool AllowLegacyTypes = false>
233
+ static KernelFunction makeFromUnboxedFunction(FuncPtr /*func_ptr*/);
234
+
235
+ /**
236
+ * Create a KernelFunction from an unboxed function.
237
+ * KernelFunction::makeFromUnboxedFunction is usually a better choice than
238
+ * this if you know the function pointer at compile time, see doc comment
239
+ * there for an explanation.
240
+ *
241
+ * Example:
242
+ *
243
+ * > Tensor unboxed_func(Tensor a, Tensor b) {...}
244
+ * > KernelFunction func =
245
+ * KernelFunction::makeFromUnboxedRuntimeFunction(&unboxed_func);
246
+ */
247
+ template <bool AllowLegacyTypes = false, class FuncType>
248
+ static KernelFunction makeFromUnboxedRuntimeFunction(FuncType* func);
249
+
250
+ static KernelFunction makeFallthrough();
251
+ static KernelFunction makeAmbiguousAutogradOther();
252
+ static KernelFunction makeNamedNotSupported();
253
+
254
+ /**
255
+ * Create a KernelFunction from an unboxed lambda.
256
+ *
257
+ * Example:
258
+ *
259
+ * > KernelFunction func = KernelFunction::makeFromUnboxedLambda(
260
+ * > [] (Tensor a, bool b) -> Tensor {...});
261
+ */
262
+ template <bool AllowLegacyTypes = false, class Lambda>
263
+ static std::enable_if_t<
264
+ guts::is_stateless_lambda<std::decay_t<Lambda>>::value,
265
+ KernelFunction>
266
+ makeFromUnboxedLambda(Lambda&& lambda);
267
+ template <bool AllowLegacyTypes = false, class Lambda>
268
+ static std::enable_if_t<
269
+ !guts::is_stateless_lambda<std::decay_t<Lambda>>::value,
270
+ KernelFunction>
271
+ makeFromUnboxedLambda(Lambda&& lambda);
272
+
273
+ std::string dumpState() const;
274
+ // For testing internal invariants only
275
+ bool _equalsBoxedAndUnboxed(const KernelFunction& /*other*/) const;
276
+
277
+ // Register a token to be invalidated when this KernelFunction is destroyed
278
+ void registerToken(std::weak_ptr<KernelToken> token) const;
279
+
280
+ private:
281
+ explicit KernelFunction(
282
+ std::unique_ptr<OperatorKernel> functor,
283
+ InternalBoxedKernelFunction* boxed_kernel_func,
284
+ void* unboxed_kernel_func,
285
+ void* sym_unboxed_kernel_func);
286
+ explicit KernelFunction(
287
+ BoxedKernel boxed_fn,
288
+ void* unboxed_kernel_func,
289
+ void* sym_unboxed_kernel_func);
290
+
291
+ BoxedKernel boxed_kernel_func_;
292
+ void* unboxed_kernel_func_;
293
+ void* sym_unboxed_kernel_func_;
294
+ // List of tokens that need to be invalidated when this KernelFunction is
295
+ // destroyed (lazy allocation to save memory when empty)
296
+ mutable std::unique_ptr<std::vector<std::weak_ptr<KernelToken>>> tokens_;
297
+ };
298
+
299
+ // Token held by SafeKernelFunction that gets invalidated when KernelFunction is
300
+ // destroyed
301
+ class KernelToken {
302
+ public:
303
+ bool isValid() const;
304
+ void invalidate();
305
+
306
+ private:
307
+ std::atomic<bool> invalid_{false};
308
+ };
309
+
310
+ class SafeKernelFunction {
311
+ public:
312
+ SafeKernelFunction(
313
+ const KernelFunction* kernel,
314
+ std::string debug,
315
+ std::shared_ptr<OperatorHandle> opHandle);
316
+
317
+ // Safe callBoxed - checks token validity first
318
+ void callBoxed(
319
+ const OperatorHandle& opHandle,
320
+ DispatchKeySet dispatchKeySet,
321
+ Stack* stack) const;
322
+
323
+ // Get debug information
324
+ const std::string& debug() const {
325
+ return debug_;
326
+ }
327
+
328
+ // Get the OpHandle that lives on this SafeKernelFunction
329
+ const OperatorHandle& opHandle() const {
330
+ return *opHandle_;
331
+ }
332
+
333
+ private:
334
+ KernelFunction kernel_;
335
+ std::shared_ptr<KernelToken> token_;
336
+ std::string debug_;
337
+ std::shared_ptr<OperatorHandle> opHandle_;
338
+ };
339
+
340
+ } // namespace c10
341
+
342
+ #include <ATen/core/boxing/KernelFunction_impl.h>
343
+
344
+ #else
345
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
346
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/KernelFunction_impl.h ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #include <ATen/core/boxing/impl/WrapFunctionIntoFunctor.h>
3
+ #include <ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h>
4
+ #include <ATen/core/boxing/impl/boxing.h>
5
+ #include <ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h>
6
+
7
+ #include <c10/util/C++17.h>
8
+ #include <type_traits>
9
+
10
+ namespace c10 {
11
+
12
+ namespace detail {
13
+ template <typename Base, typename Child, typename... Args>
14
+ std::enable_if_t<
15
+ !std::is_array_v<Base> && !std::is_array_v<Child> &&
16
+ std::is_base_of_v<Base, Child>,
17
+ std::unique_ptr<Base>>
18
+ make_unique_base(Args&&... args) {
19
+ return std::make_unique<Child>(std::forward<Args>(args)...);
20
+ }
21
+ } // namespace detail
22
+
23
+ inline KernelFunction::KernelFunction()
24
+ : unboxed_kernel_func_(nullptr), sym_unboxed_kernel_func_(nullptr) {}
25
+
26
+ inline KernelFunction::~KernelFunction() {
27
+ if (tokens_) {
28
+ for (auto& weak_token : *tokens_) {
29
+ if (auto token = weak_token.lock()) {
30
+ token->invalidate();
31
+ }
32
+ }
33
+ }
34
+ }
35
+
36
+ inline KernelFunction::KernelFunction(const KernelFunction& other)
37
+ : boxed_kernel_func_(other.boxed_kernel_func_),
38
+ unboxed_kernel_func_(other.unboxed_kernel_func_),
39
+ sym_unboxed_kernel_func_(other.sym_unboxed_kernel_func_) {
40
+ // tokens_ is intentionally not copied as we only care about invalidating
41
+ // tokens if the original KernelFunction is destroyed
42
+ }
43
+
44
+ inline KernelFunction& KernelFunction::operator=(const KernelFunction& other) {
45
+ if (this != &other) {
46
+ boxed_kernel_func_ = other.boxed_kernel_func_;
47
+ unboxed_kernel_func_ = other.unboxed_kernel_func_;
48
+ sym_unboxed_kernel_func_ = other.sym_unboxed_kernel_func_;
49
+
50
+ // tokens_ is intentionally not copied as we only care about invalidating
51
+ // tokens if the original KernelFunction is destroyed
52
+ }
53
+ return *this;
54
+ }
55
+
56
+ inline KernelFunction::KernelFunction(
57
+ std::unique_ptr<OperatorKernel> functor,
58
+ InternalBoxedKernelFunction* boxed_kernel_func,
59
+ void* unboxed_kernel_func,
60
+ void* sym_unboxed_kernel_func = nullptr)
61
+ : boxed_kernel_func_(std::move(functor), boxed_kernel_func),
62
+ unboxed_kernel_func_(unboxed_kernel_func),
63
+ sym_unboxed_kernel_func_(sym_unboxed_kernel_func) {}
64
+
65
+ inline KernelFunction::KernelFunction(
66
+ BoxedKernel boxed_fn,
67
+ void* unboxed_kernel_func,
68
+ void* sym_unboxed_kernel_func = nullptr)
69
+ : boxed_kernel_func_(std::move(boxed_fn)),
70
+ unboxed_kernel_func_(unboxed_kernel_func),
71
+ sym_unboxed_kernel_func_(sym_unboxed_kernel_func) {}
72
+
73
+ inline bool KernelFunction::isValidUnboxed() const {
74
+ return unboxed_kernel_func_ != nullptr;
75
+ }
76
+
77
+ inline bool KernelFunction::isValidSymUnboxed() const {
78
+ return sym_unboxed_kernel_func_ != nullptr;
79
+ }
80
+
81
+ inline bool KernelFunction::isValid() const {
82
+ return boxed_kernel_func_.isValid();
83
+ }
84
+
85
+ inline bool KernelFunction::isFallthrough() const {
86
+ return boxed_kernel_func_.isFallthrough();
87
+ }
88
+
89
+ inline void KernelFunction::callBoxed(
90
+ const OperatorHandle& opHandle,
91
+ DispatchKeySet dispatchKeySet,
92
+ Stack* stack) const {
93
+ boxed_kernel_func_.callBoxed(opHandle, dispatchKeySet, stack);
94
+ }
95
+
96
+ template <class Return, class... Args>
97
+ inline Return callUnboxedKernelFunction(
98
+ void* unboxed_kernel_func,
99
+ OperatorKernel* functor,
100
+ DispatchKeySet dispatchKeySet,
101
+ Args&&... args) {
102
+ using ActualSignature = Return(OperatorKernel*, DispatchKeySet, Args...);
103
+ ActualSignature* func =
104
+ reinterpret_cast<ActualSignature*>(unboxed_kernel_func);
105
+ return (*func)(functor, dispatchKeySet, std::forward<Args>(args)...);
106
+ }
107
+
108
+ // This template requires you to explicitly specify the argument you want to
109
+ // forward; it doesn't work if you try to deduce it
110
+ // NB: keep this in sync with cloneWithRealTypes in function_schema.cpp
111
+
112
+ template <typename T>
113
+ inline typename remove_symint<T>::type unpackSymInt(T x) {
114
+ return x;
115
+ }
116
+
117
+ template <>
118
+ inline remove_symint<c10::SymInt>::type unpackSymInt(c10::SymInt x) {
119
+ return x.guard_int(__FILE__, __LINE__);
120
+ }
121
+
122
+ template <>
123
+ inline remove_symint<c10::SymIntArrayRef>::type unpackSymInt(
124
+ c10::SymIntArrayRef x) {
125
+ return C10_AS_INTARRAYREF_SLOW(x);
126
+ }
127
+
128
+ template <>
129
+ inline remove_symint<std::optional<c10::SymInt>>::type unpackSymInt(
130
+ std::optional<c10::SymInt> x) {
131
+ return x.has_value() ? std::make_optional(x->guard_int(__FILE__, __LINE__))
132
+ : std::nullopt;
133
+ }
134
+
135
+ template <>
136
+ inline remove_symint<at::OptionalSymIntArrayRef>::type unpackSymInt(
137
+ at::OptionalSymIntArrayRef x) {
138
+ return x.has_value() ? std::make_optional(C10_AS_INTARRAYREF_SLOW(*x))
139
+ : std::nullopt;
140
+ }
141
+
142
+ template <class Return, class... Args>
143
+ C10_ALWAYS_INLINE Return KernelFunction::call(
144
+ const OperatorHandle& opHandle,
145
+ DispatchKeySet dispatchKeySet,
146
+ Args... args) const {
147
+ // note: Args above is intentionally not Args&&. We don't want perfect
148
+ // forwarding, which would require Args to be deduced, but instead we
149
+ // want callers to explicitly specify the Args.
150
+
151
+ if constexpr (std::disjunction_v<has_symint<Args>...>) {
152
+ if (sym_unboxed_kernel_func_ != nullptr) {
153
+ auto* functor = boxed_kernel_func_.getFunctor();
154
+ return callUnboxedKernelFunction<Return, Args...>(
155
+ sym_unboxed_kernel_func_,
156
+ functor,
157
+ dispatchKeySet,
158
+ std::forward<Args>(args)...);
159
+ }
160
+
161
+ if (unboxed_kernel_func_ != nullptr) {
162
+ auto* functor = boxed_kernel_func_.getFunctor();
163
+ return callUnboxedKernelFunction<
164
+ Return,
165
+ typename remove_symint<Args>::type...>(
166
+ unboxed_kernel_func_,
167
+ functor,
168
+ dispatchKeySet,
169
+ unpackSymInt<Args>(args)...);
170
+ }
171
+ } else {
172
+ if (C10_LIKELY(unboxed_kernel_func_ != nullptr)) {
173
+ auto* functor = boxed_kernel_func_.getFunctor();
174
+ return callUnboxedKernelFunction<Return, Args...>(
175
+ unboxed_kernel_func_,
176
+ functor,
177
+ dispatchKeySet,
178
+ std::forward<Args>(args)...);
179
+ }
180
+ }
181
+
182
+ return impl::BoxedKernelWrapper<Return(Args...)>::call(
183
+ boxed_kernel_func_,
184
+ opHandle,
185
+ dispatchKeySet,
186
+ std::forward<Args>(args)...);
187
+ }
188
+
189
+ inline void KernelFunction::registerToken(
190
+ std::weak_ptr<KernelToken> token) const {
191
+ if (!tokens_) {
192
+ tokens_ = std::make_unique<std::vector<std::weak_ptr<KernelToken>>>();
193
+ }
194
+ tokens_->push_back(std::move(token));
195
+ }
196
+
197
+ inline KernelFunction KernelFunction::makeFromBoxedKernel(
198
+ BoxedKernel boxed_fn) {
199
+ return KernelFunction(
200
+ std::move(boxed_fn), nullptr); // no unboxed function pointer
201
+ }
202
+
203
+ template <KernelFunction::BoxedKernelFunction* func>
204
+ inline KernelFunction KernelFunction::makeFromBoxedFunction() {
205
+ return KernelFunction::makeFromBoxedKernel(
206
+ BoxedKernel::makeFromFunction<func>());
207
+ }
208
+
209
+ template <KernelFunction::BoxedKernelFunction_withDispatchKeys* func>
210
+ inline KernelFunction KernelFunction::makeFromBoxedFunction() {
211
+ return KernelFunction::makeFromBoxedKernel(
212
+ BoxedKernel::makeFromFunction<func>());
213
+ }
214
+
215
+ inline KernelFunction KernelFunction::makeFallthrough() {
216
+ return KernelFunction::makeFromBoxedKernel(BoxedKernel::makeFallthrough());
217
+ }
218
+
219
+ inline KernelFunction KernelFunction::makeAmbiguousAutogradOther() {
220
+ return KernelFunction::makeFromBoxedKernel(
221
+ BoxedKernel::makeAmbiguousAutogradOther());
222
+ }
223
+
224
+ inline KernelFunction KernelFunction::makeNamedNotSupported() {
225
+ return KernelFunction::makeFromBoxedKernel(
226
+ BoxedKernel::makeNamedNotSupported());
227
+ }
228
+
229
+ template <bool AllowLegacyTypes, class KernelFunctor>
230
+ inline KernelFunction KernelFunction::makeFromUnboxedFunctor(
231
+ std::unique_ptr<OperatorKernel> kernelFunctor) {
232
+ #ifndef NDEBUG
233
+ // This assertion is costly for build time so it's debug-gated.
234
+ static_assert(
235
+ guts::is_functor<KernelFunctor>::value,
236
+ "Tried to call KernelFunction::makeFromUnboxedFunctor<KernelFunctor> but the argument is not a functor.");
237
+ #endif
238
+ static_assert(
239
+ std::is_base_of_v<OperatorKernel, KernelFunctor>,
240
+ "Tried to call KernelFunction::makeFromUnboxedFunctor<KernelFunctor>, but the functor doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
241
+
242
+ auto* unboxed_fn = &impl::wrap_kernel_functor_unboxed<KernelFunctor>::call;
243
+ void* void_unboxed_fn = reinterpret_cast<void*>(unboxed_fn);
244
+ bool is_symint = fn_has_symint<decltype(unboxed_fn)>::value;
245
+ return KernelFunction(
246
+ std::move(kernelFunctor),
247
+ &impl::make_boxed_from_unboxed_functor<KernelFunctor, AllowLegacyTypes>::
248
+ call,
249
+ is_symint ? nullptr : void_unboxed_fn,
250
+ is_symint ? void_unboxed_fn : nullptr);
251
+ }
252
+
253
+ template <class KernelFunctor>
254
+ inline KernelFunction KernelFunction::makeFromBoxedFunctor(
255
+ std::unique_ptr<KernelFunctor> kernelFunctor) {
256
+ return KernelFunction::makeFromBoxedKernel(
257
+ BoxedKernel::makeFromFunctor(std::move(kernelFunctor)));
258
+ }
259
+
260
+ template <class FuncPtr, bool AllowLegacyTypes>
261
+ inline KernelFunction KernelFunction::makeFromUnboxedFunction(
262
+ FuncPtr func_ptr) {
263
+ static_assert(
264
+ is_compile_time_function_pointer<FuncPtr>::value,
265
+ "Tried to call KernelFunction::makeFromUnboxedFunction with an invalid parameter. It must be a function pointer created with TORCH_FN.");
266
+ static_assert(
267
+ !std::is_same_v<typename FuncPtr::FuncType, BoxedKernelFunction>,
268
+ "Tried to call KernelFunction::makeFromUnboxedFunction with a boxed function pointer. Please use KernelFunction::makeFromBoxedFunction instead.");
269
+ #if defined(__GNUC__) && defined(__SANITIZE_ADDRESS__) && !defined(__CUDACC__)
270
+ TORCH_INTERNAL_ASSERT(
271
+ FuncPtr::func_ptr() != nullptr, "Kernel function cannot be nullptr");
272
+ #else
273
+ static_assert(
274
+ FuncPtr::func_ptr() != nullptr, "Kernel function cannot be nullptr");
275
+ #endif
276
+
277
+ #if !defined(C10_MOBILE)
278
+ (void)func_ptr; // Suppress unused variable warning
279
+ return makeFromUnboxedFunctor<
280
+ AllowLegacyTypes,
281
+ typename impl::WrapFunctionIntoFunctor<FuncPtr>::type>(
282
+ detail::make_unique_base<
283
+ OperatorKernel,
284
+ typename impl::WrapFunctionIntoFunctor<FuncPtr>::type>());
285
+ #else
286
+ // On mobile, we rather want to optimize for binary size than for performance,
287
+ // so let's not inline the kernel into the wrapper but use
288
+ // makeFromUnboxedRuntimeFunction instead.
289
+ return makeFromUnboxedRuntimeFunction(func_ptr.func_ptr());
290
+ #endif
291
+ }
292
+
293
+ template <bool AllowLegacyTypes, class FuncType>
294
+ inline KernelFunction KernelFunction::makeFromUnboxedRuntimeFunction(
295
+ FuncType* func) {
296
+ static_assert(
297
+ guts::is_function_type<FuncType>::value,
298
+ "Tried to call KernelFunction::makeFromUnboxedRuntimeFunction with a non-function type.");
299
+ static_assert(
300
+ !std::is_same_v<FuncType, BoxedKernelFunction>,
301
+ "Tried to call KernelFunction::makeFromUnboxedRuntimeFunction with a boxed function pointer. Please use KernelFunction::makeFromBoxedFunction instead.");
302
+ TORCH_INTERNAL_ASSERT(func != nullptr, "Kernel function cannot be nullptr");
303
+
304
+ return makeFromUnboxedFunctor<
305
+ AllowLegacyTypes,
306
+ impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<FuncType>>>(
307
+ detail::make_unique_base<
308
+ OperatorKernel,
309
+ impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<FuncType>>>(func));
310
+ }
311
+
312
+ template <bool AllowLegacyTypes, class Lambda>
313
+ inline std::enable_if_t<
314
+ guts::is_stateless_lambda<std::decay_t<Lambda>>::value,
315
+ KernelFunction>
316
+ KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) {
317
+ static_assert(
318
+ guts::is_functor<std::decay_t<Lambda>>::value,
319
+ "Tried to call KernelFunction::makeFromUnboxedLambda with a non-lambda type.");
320
+
321
+ #if !defined(C10_MOBILE)
322
+ return makeFromUnboxedFunctor<
323
+ AllowLegacyTypes,
324
+ impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>(
325
+ detail::make_unique_base<
326
+ OperatorKernel,
327
+ impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>(
328
+ std::forward<Lambda>(lambda)));
329
+ #else
330
+ // On mobile, we rather want to optimize for binary size than for performance,
331
+ // so let's not inline the kernel into the wrapper but use
332
+ // makeFromUnboxedRuntimeFunction instead.
333
+ using FuncType =
334
+ typename guts::infer_function_traits_t<std::decay_t<Lambda>>::func_type;
335
+ return makeFromUnboxedRuntimeFunction<AllowLegacyTypes, FuncType>(lambda);
336
+ #endif
337
+ }
338
+
339
+ template <bool AllowLegacyTypes, class Lambda>
340
+ inline std::enable_if_t<
341
+ !guts::is_stateless_lambda<std::decay_t<Lambda>>::value,
342
+ KernelFunction>
343
+ KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) {
344
+ static_assert(
345
+ guts::is_functor<std::decay_t<Lambda>>::value,
346
+ "Tried to call KernelFunction::makeFromUnboxedLambda with a non-lambda type.");
347
+
348
+ return makeFromUnboxedFunctor<
349
+ AllowLegacyTypes,
350
+ impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>(
351
+ detail::make_unique_base<
352
+ OperatorKernel,
353
+ impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>(
354
+ std::forward<Lambda>(lambda)));
355
+ }
356
+
357
+ inline bool KernelToken::isValid() const {
358
+ return !invalid_.load(std::memory_order_acquire);
359
+ }
360
+
361
+ inline void KernelToken::invalidate() {
362
+ invalid_.store(true, std::memory_order_release);
363
+ }
364
+
365
+ inline SafeKernelFunction::SafeKernelFunction(
366
+ const KernelFunction* kernel,
367
+ std::string debug,
368
+ std::shared_ptr<OperatorHandle> opHandle)
369
+ : kernel_(kernel ? *kernel : KernelFunction()),
370
+ token_(std::make_shared<KernelToken>()),
371
+ debug_(std::move(debug)),
372
+ opHandle_(std::move(opHandle)) {
373
+ // Register the token with the original kernel so it gets invalidated when the
374
+ // kernel is destroyed
375
+ if (kernel) {
376
+ kernel->registerToken(token_);
377
+ }
378
+ }
379
+
380
+ inline void SafeKernelFunction::callBoxed(
381
+ const OperatorHandle& opHandle,
382
+ DispatchKeySet dispatchKeySet,
383
+ Stack* stack) const {
384
+ TORCH_CHECK(
385
+ token_ && token_->isValid(),
386
+ "SafeKernelFunction has been invalidated ",
387
+ debug_);
388
+ kernel_.callBoxed(opHandle, dispatchKeySet, stack);
389
+ }
390
+
391
+ } // namespace c10
392
+
393
+ #else
394
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
395
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/OperatorKernel.h ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+ #include <c10/util/intrusive_ptr.h>
4
+
5
+ namespace c10 {
6
+
7
+ /**
8
+ * Inherit from OperatorKernel to implement a c10 kernel.
9
+ *
10
+ * Example:
11
+ * > namespace {
12
+ * > class my_kernel_cpu final : public c10::OperatorKernel {
13
+ * > public:
14
+ * > Tensor operator()(Tensor a, Tensor b) {...}
15
+ * > };
16
+ * > }
17
+ *
18
+ * The kernel class is allowed to have members but these are equivalent
19
+ * to global variables. The kernel implementation is responsible for
20
+ * preventing race conditions on them.
21
+ *
22
+ * See below for how to register this kernel with PyTorch.
23
+ */
24
+ struct TORCH_API OperatorKernel : public c10::intrusive_ptr_target {
25
+ ~OperatorKernel() override = default;
26
+ };
27
+
28
+ } // namespace c10
29
+
30
+ #else
31
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
32
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <c10/core/CompileTimeFunctionPointer.h>
5
+
6
+ namespace c10::impl {
7
+ namespace detail {
8
+ template <class FuncPtr, class ReturnType, class ParameterList>
9
+ class WrapFunctionIntoFunctor_ {};
10
+ template <class FuncPtr, class ReturnType, class... Parameters>
11
+ class WrapFunctionIntoFunctor_<
12
+ FuncPtr,
13
+ ReturnType,
14
+ guts::typelist::typelist<Parameters...>>
15
+ final : public c10::OperatorKernel {
16
+ public:
17
+ C10_ALWAYS_INLINE decltype(auto) operator()(Parameters... args) {
18
+ return (*FuncPtr::func_ptr())(std::forward<Parameters>(args)...);
19
+ }
20
+ };
21
+ } // namespace detail
22
+
23
+ // WrapFunctionIntoFunctor: Wraps a compile time function pointer into a kernel
24
+ // functor. Since it is a compile time function pointer, many compilers can
25
+ // inline it into the wrapper and you don't get any performance overhead for
26
+ // wrapping.
27
+ template <class FuncPtr>
28
+ struct WrapFunctionIntoFunctor final {
29
+ static_assert(
30
+ c10::is_compile_time_function_pointer<FuncPtr>::value,
31
+ "WrapFunctionIntoFunctor can only wrap functions created with TORCH_FN.");
32
+ using type = detail::WrapFunctionIntoFunctor_<
33
+ FuncPtr,
34
+ typename guts::function_traits<typename FuncPtr::FuncType>::return_type,
35
+ typename guts::function_traits<
36
+ typename FuncPtr::FuncType>::parameter_types>;
37
+ };
38
+
39
+ } // namespace c10::impl
40
+
41
+ #else
42
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
43
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <c10/util/TypeTraits.h>
5
+
6
+ namespace c10::impl {
7
+
8
+ namespace detail {
9
+ template <class FuncType, class ReturnType, class ParameterList>
10
+ class WrapFunctionIntoRuntimeFunctor_ {};
11
+ template <class FuncType, class ReturnType, class... Parameters>
12
+ class WrapFunctionIntoRuntimeFunctor_<
13
+ FuncType,
14
+ ReturnType,
15
+ guts::typelist::typelist<Parameters...>>
16
+ final : public c10::OperatorKernel {
17
+ public:
18
+ template <class FuncType_>
19
+ explicit WrapFunctionIntoRuntimeFunctor_(FuncType_&& kernel_func)
20
+ : kernel_func_(std::forward<FuncType_>(kernel_func)) {}
21
+
22
+ decltype(auto) operator()(Parameters... args) {
23
+ return kernel_func_(std::forward<Parameters>(args)...);
24
+ }
25
+
26
+ private:
27
+ FuncType kernel_func_;
28
+ };
29
+ } // namespace detail
30
+
31
+ // WrapFunctionIntoRuntimeFunctor: Wraps any runtime functor into a functor that
32
+ // inherits from c10::OperatorKernel, so it can be used as a c10 kernel.
33
+ // This can, for example, be used for lambdas, functors or even function
34
+ // pointers. In the case of function pointers, since it is a runtime function
35
+ // pointer, there is an overhead for calling it whenever the kernel is invoked.
36
+ template <class FuncType>
37
+ using WrapFunctionIntoRuntimeFunctor = detail::WrapFunctionIntoRuntimeFunctor_<
38
+ FuncType,
39
+ typename guts::infer_function_traits_t<FuncType>::return_type,
40
+ typename guts::infer_function_traits_t<FuncType>::parameter_types>;
41
+
42
+ } // namespace c10::impl
43
+
44
+ #else
45
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
46
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/impl/boxing.h ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ // This file contains boxing (not unboxing) logic,
5
+ // i.e. how to make a vector<IValue> from a set of concrete arguments.
6
+
7
+ #include <ATen/core/ivalue.h>
8
+ #include <ATen/core/stack.h>
9
+ #include <c10/core/TensorOptions.h>
10
+
11
+ #include <ATen/core/boxing/BoxedKernel.h>
12
+
13
+ #include <c10/util/Metaprogramming.h>
14
+ #include <type_traits>
15
+
16
+ namespace c10::impl {
17
+
18
+ //
19
+ // utils
20
+ //
21
+
22
+ // is_mutable_tensor_ref
23
+ template <class T>
24
+ struct is_mutable_tensor_ref : std::false_type {};
25
+ template <>
26
+ struct is_mutable_tensor_ref<at::Tensor&> : std::true_type {};
27
+
28
+ // is_tuple_of_mutable_tensor_refs
29
+ //
30
+ template <class T, class Enable = void>
31
+ struct is_tuple_of_mutable_tensor_refs : std::false_type {};
32
+
33
+ template <class T>
34
+ struct is_tuple_of_mutable_tensor_refs<
35
+ T,
36
+ std::enable_if_t<guts::is_instantiation_of<std::tuple, T>::value, void>>
37
+ : guts::typelist::
38
+ all<is_mutable_tensor_ref, guts::typelist::from_tuple_t<T>> {};
39
+
40
+ // has_ivalue_to<T> tests the presence/absence of instance method
41
+ // IValue::to<T>()
42
+ //
43
+ template <class T, class Enable = void>
44
+ struct has_ivalue_to : std::false_type {};
45
+
46
+ template <class T>
47
+ struct ivalue_to_helper {
48
+ using type = decltype(std::declval<IValue>().template to<T>());
49
+ };
50
+ template <class T>
51
+ using ivalue_to_helper_t = typename ivalue_to_helper<T>::type;
52
+
53
+ template <class T>
54
+ struct has_ivalue_to<T, std::void_t<ivalue_to_helper_t<T>>> : std::true_type {};
55
+
56
+ //
57
+ // boxing predicates
58
+ //
59
+
60
+ // A boxable arg type is one that IValue has a constructor for.
61
+ template <typename T>
62
+ using can_box = std::disjunction<
63
+ std::is_constructible<IValue, std::decay_t<T>>,
64
+ // TensorOptions are not directly constructible into IValue,
65
+ // but torch::jit::push knows how to handle them
66
+ std::is_same<TensorOptions, std::decay_t<T>>>;
67
+
68
+ template <typename... Ts>
69
+ using can_box_all = std::conjunction<can_box<Ts>...>;
70
+
71
+ // an unboxable result is one that can be extracted from an IValue
72
+ template <typename T>
73
+ using can_unbox = std::conjunction<
74
+ std::disjunction<
75
+ has_ivalue_to<T>,
76
+ // void returns are ok
77
+ std::is_same<void, T>>,
78
+ std::negation<std::is_lvalue_reference<T>>>;
79
+
80
+ //
81
+ // boxArgs - utility for pushing unboxed args onto IValue stack
82
+ //
83
+ template <class... Args>
84
+ torch::jit::Stack boxArgs(Args... args) {
85
+ // TODO Reuse stack vector instead of allocating?
86
+ torch::jit::Stack stack;
87
+ stack.reserve(sizeof...(Args));
88
+ torch::jit::push(stack, std::forward<Args>(args)...);
89
+ return stack;
90
+ }
91
+
92
+ template <class T>
93
+ inline constexpr size_t boxed_size_one() {
94
+ static_assert(
95
+ !std::is_same_v<std::decay_t<T>, c10::TensorOptions>,
96
+ "need to patch this path to support TensorOptions passed by reference");
97
+ return 1;
98
+ }
99
+
100
+ // torch::jit::push pushes 4 values for a TensorOptions; this needs to
101
+ // be kept in sync.
102
+ template <>
103
+ inline constexpr size_t boxed_size_one<c10::TensorOptions>() {
104
+ return 4;
105
+ }
106
+
107
+ // NOTE: this could probably be simplified with C++17 fold expressions.
108
+ template <typename...>
109
+ struct BoxedSize : std::integral_constant<size_t, 0> {};
110
+ template <class T, class... Args>
111
+ struct BoxedSize<T, Args...>
112
+ : std::integral_constant<
113
+ size_t,
114
+ boxed_size_one<T>() + BoxedSize<Args...>::value> {};
115
+
116
+ template <class... Args>
117
+ static inline constexpr size_t boxed_size() {
118
+ return BoxedSize<Args...>::value;
119
+ }
120
+
121
+ template <typename T>
122
+ C10_ALWAYS_INLINE_UNLESS_MOBILE void boxToStack(IValue*& dest, T& arg) {
123
+ new (dest++) IValue(arg);
124
+ }
125
+
126
+ C10_ALWAYS_INLINE_UNLESS_MOBILE void boxToStack(
127
+ IValue*& dest,
128
+ c10::TensorOptions options) {
129
+ new (dest++) IValue(c10::typeMetaToScalarType(options.dtype()));
130
+ new (dest++) IValue(options.layout());
131
+ new (dest++) IValue(options.device());
132
+ new (dest++) IValue(options.pinned_memory());
133
+ }
134
+
135
+ inline void boxArgsToStack(IValue*& /*unused*/) {}
136
+
137
+ template <typename T, typename... Args>
138
+ C10_ALWAYS_INLINE_UNLESS_MOBILE void boxArgsToStack(
139
+ IValue*& dest,
140
+ T& arg,
141
+ Args&... args) {
142
+ boxToStack(dest, arg);
143
+ boxArgsToStack(dest, args...);
144
+ }
145
+
146
+ //
147
+ // PopResult is a helper class whose specializations handle popping single and
148
+ // multiple return values, respectively.
149
+ //
150
+ template <class Result>
151
+ struct PopResult final {
152
+ static Result call(Stack& stack) {
153
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
154
+ stack.size() == 1,
155
+ "Boxed kernel was expected to return one value on the stack, ",
156
+ "but instead pushed ",
157
+ stack.size(),
158
+ " values.");
159
+ return std::move(stack[0]).to<Result>();
160
+ }
161
+ };
162
+
163
+ template <class... Types>
164
+ struct PopResult<std::tuple<Types...>> final {
165
+ using Result = std::tuple<Types...>;
166
+
167
+ static Result call(Stack& stack) {
168
+ // for tuple return types, boxed kernel has pushed multiple values onto the
169
+ // stack
170
+ constexpr int RetCount = sizeof...(Types);
171
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
172
+ stack.size() == RetCount,
173
+ "Boxed kernel was expected to return ",
174
+ RetCount,
175
+ " values on the stack, ",
176
+ "but instead pushed ",
177
+ stack.size(),
178
+ " values.");
179
+ return pop_to_tuple_impl(stack, std::make_index_sequence<RetCount>());
180
+ }
181
+
182
+ private:
183
+ // note: this has been moved into its own helper only to avoid a parse error
184
+ // on `indices` otherwise. I'm sure there's an incantation that slips it past
185
+ // the parser but eh
186
+ template <size_t... indices>
187
+ static Result pop_to_tuple_impl(
188
+ Stack& stack,
189
+ std::index_sequence<indices...> /*unused*/) {
190
+ return std::make_tuple((std::move(stack[indices]).template to<Types>())...);
191
+ }
192
+ };
193
+
194
+ //
195
+ // BoxedKernelWrapper
196
+ //
197
+ // For a given function type FT, BoxedKernelWrapper<FT> implements
198
+ // a `call` method that
199
+ // - takes a boxed kernel and unboxed arguments as specified by FT,
200
+ // - calls `boxArgs` to box the arguments
201
+ // - calls the boxed kernel
202
+ // - unboxes and returns the result
203
+ //
204
+ // The partial specializations below handle various cases: in
205
+ // particular, not all types appearing in op signatures are supported,
206
+ // and ops returning references have nonstandard wrapper implementations.
207
+ //
208
+
209
+ // 1. The base specialization of BoxedKernelWrapper should never be
210
+ // instantiated. A "no call method defined on BoxedKernelWrapper" compile error
211
+ // means that an op signature has failed to trigger any of the partial
212
+ // specializations that follow this one.
213
+ //
214
+ template <class FuncType, class Enable = void>
215
+ struct BoxedKernelWrapper {
216
+ // The reason we're not just doing straight up static_assert(false, ...) here:
217
+ // Basically, the way to make sure a static_assert only fires if a template
218
+ // is actually instantiated (rather than every time the file is parsed) is to
219
+ // use template parameters in the expression, e.g. FuncType here. However,
220
+ // since `sizeof(FuncType) != sizeof(FuncType)` is always false, this has the
221
+ // same effect.
222
+ static_assert(
223
+ sizeof(FuncType) != sizeof(FuncType),
224
+ "Function signature contains one or more unsupported parameter and/or return types. "
225
+ "Look for a nearby error like "
226
+ "\"'call' is not a member of 'c10::impl::BoxedKernelWrapper<(your function type), void>'\" "
227
+ "- (your function type) is the unsupported signature.");
228
+ };
229
+
230
+ //
231
+ // 2. Supported signatures, other than those involving non-const Tensor refs -
232
+ // i.e., "functional" ops.
233
+ //
234
+
235
+ template <class Result, class... Args>
236
+ struct BoxedKernelWrapper<
237
+ Result(Args...),
238
+ std::enable_if_t<
239
+ can_box_all<Args...>::value && can_unbox<Result>::value &&
240
+ !is_tuple_of_mutable_tensor_refs<Result>::value,
241
+ void>> {
242
+ static Result call(
243
+ const BoxedKernel& boxed_kernel_func,
244
+ const OperatorHandle& opHandle,
245
+ DispatchKeySet dispatchKeySet,
246
+ Args... args) {
247
+ torch::jit::Stack stack = boxArgs<Args...>(std::forward<Args>(args)...);
248
+ boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack);
249
+
250
+ if constexpr (!std::is_same_v<void, Result>) {
251
+ // op has pushed one or more values onto the stack.
252
+ return PopResult<Result>::call(stack);
253
+ } else {
254
+ // op returns void, boxed kernel has pushed nothing onto stack.
255
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
256
+ stack.empty(),
257
+ "Boxed kernel was expected to return no values on the stack, ",
258
+ "but instead returned ",
259
+ stack.size(),
260
+ " values.");
261
+ }
262
+ }
263
+ };
264
+
265
+ //
266
+ // 3. in-place ops take a single non-const Tensor reference
267
+ // as their first argument, and return it.
268
+ //
269
+ // Note: all signatures matching this pattern are assumed to be for such ops.
270
+ // Because of this, the generated BoxedKernelWrapper specializations simply
271
+ // return the in-place argument.
272
+ //
273
+
274
+ template <class... OtherArgs>
275
+ struct BoxedKernelWrapper<
276
+ at::Tensor&(at::Tensor&, OtherArgs...),
277
+ std::enable_if_t<can_box_all<OtherArgs...>::value, void>> {
278
+ static at::Tensor& call(
279
+ const BoxedKernel& boxed_kernel_func,
280
+ const OperatorHandle& opHandle,
281
+ DispatchKeySet dispatchKeySet,
282
+ at::Tensor& outArg,
283
+ OtherArgs... otherArgs) {
284
+ torch::jit::Stack stack = boxArgs<at::Tensor&, OtherArgs...>(
285
+ outArg, std::forward<OtherArgs>(otherArgs)...);
286
+ boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack);
287
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
288
+ stack.size() == 1,
289
+ "Boxed kernel was expected to return a single value on the stack, ",
290
+ "but instead returned ",
291
+ stack.size(),
292
+ " values.");
293
+
294
+ return outArg;
295
+ }
296
+ };
297
+
298
+ //
299
+ // 3.5. In-process migration to make in-place ops take and return
300
+ // const references instead.
301
+ template <class... OtherArgs>
302
+ struct BoxedKernelWrapper<
303
+ const at::Tensor&(const at::Tensor&, OtherArgs...),
304
+ std::enable_if_t<can_box_all<OtherArgs...>::value, void>> {
305
+ static const at::Tensor& call(
306
+ const BoxedKernel& boxed_kernel_func,
307
+ const OperatorHandle& opHandle,
308
+ DispatchKeySet dispatchKeySet,
309
+ const at::Tensor& outArg,
310
+ OtherArgs... otherArgs) {
311
+ torch::jit::Stack stack = boxArgs(outArg, otherArgs...);
312
+ boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack);
313
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
314
+ stack.size() == 1,
315
+ "Boxed kernel was expected to return a single value on the stack, ",
316
+ "but instead returned ",
317
+ stack.size(),
318
+ " values.");
319
+
320
+ return outArg;
321
+ }
322
+ };
323
+
324
+ //
325
+ // 4. out of place ops that take a single non-const Tensor reference as their
326
+ // final argument, and also return it.
327
+ //
328
+ // Note: all signatures matching this pattern are assumed to be for such ops.
329
+ // This assumption permits the generated BoxedKernelWrapper specializations to
330
+ // simply return out arguments.
331
+ //
332
+ template <class FirstArg, class... RestArgs>
333
+ struct BoxedKernelWrapper<
334
+ at::Tensor&(FirstArg, RestArgs...),
335
+ std::enable_if_t<
336
+ can_box_all<FirstArg, RestArgs...>::value
337
+ // this skips over in-place kernels with a non-const Tensor
338
+ // arg at the front, so those can unambiguously trigger the
339
+ // preceding specialization.
340
+ && !is_mutable_tensor_ref<FirstArg>::value,
341
+ void>> {
342
+ static at::Tensor& call(
343
+ const BoxedKernel& boxed_kernel_func,
344
+ const OperatorHandle& opHandle,
345
+ DispatchKeySet dispatchKeySet,
346
+ FirstArg firstArg,
347
+ RestArgs... restArgs) {
348
+ torch::jit::Stack stack = boxArgs<FirstArg, RestArgs...>(
349
+ std::forward<FirstArg>(firstArg), std::forward<RestArgs>(restArgs)...);
350
+ boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack);
351
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
352
+ stack.size() == 1,
353
+ "Boxed kernel was expected to return a single value on the stack, ",
354
+ "but instead returned ",
355
+ stack.size(),
356
+ " values.");
357
+
358
+ // reusing restArgs after it has been forwarded here is ok because we know
359
+ // that the last element is of type `Tensor&`.
360
+ return std::get<sizeof...(RestArgs) - 1>(
361
+ std::tuple<RestArgs...>{restArgs...});
362
+ }
363
+ };
364
+
365
+ //
366
+ // 5. out of place ops that take multiple non-const Tensor references as their
367
+ // final arguments, and return them in a std::tuple.
368
+ //
369
+ // Note: all signatures matching this pattern are assumed to be for such ops.
370
+ // This assumption permits the generated BoxedKernelWrapper specializations to
371
+ // simply return the out arguments.
372
+ //
373
+ template <class Result, class... Args>
374
+ struct BoxedKernelWrapper<
375
+ Result(Args...),
376
+ std::enable_if_t<
377
+ can_box_all<Args...>::value &&
378
+ is_tuple_of_mutable_tensor_refs<Result>::value,
379
+ void>> {
380
+ static Result call(
381
+ const BoxedKernel& boxed_kernel_func,
382
+ const OperatorHandle& opHandle,
383
+ DispatchKeySet dispatchKeySet,
384
+ Args... args) {
385
+ using ArgTuple = std::tuple<Args...>;
386
+ constexpr int RetCount = std::tuple_size<Result>();
387
+
388
+ torch::jit::Stack stack = boxArgs<Args...>(std::forward<Args>(args)...);
389
+ boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack);
390
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
391
+ stack.size() == RetCount,
392
+ "Boxed kernel was expected to return ",
393
+ RetCount,
394
+ " values on the stack, ",
395
+ "but instead returned ",
396
+ stack.size(),
397
+ " values.");
398
+
399
+ // reusing args after it has been forwarded here is ok because we know
400
+ // that the last RetCount elements are of type `Tensor&`.
401
+ auto result = guts::tuple_take<ArgTuple, -RetCount>(
402
+ ArgTuple{std::forward<Args>(args)...});
403
+ static_assert(
404
+ std::is_same_v<Result, decltype(result)>,
405
+ "The parameter list of an op returning a tuple of Tensor references "
406
+ "must end with an equal number of Tensor reference parameters.");
407
+ return result;
408
+ }
409
+ };
410
+
411
+ } // namespace c10::impl
412
+
413
+ #else
414
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
415
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h ADDED
@@ -0,0 +1,790 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/core/IListRef.h>
5
+ #include <ATen/core/boxing/OperatorKernel.h>
6
+ #include <ATen/core/ivalue.h>
7
+ #include <ATen/core/stack.h>
8
+ #include <c10/util/Metaprogramming.h>
9
+ #include <c10/util/TypeList.h>
10
+ #include <c10/util/intrusive_ptr.h>
11
+
12
+ #include <utility>
13
+
14
+ namespace c10 {
15
+
16
+ using Stack = torch::jit::Stack; // TODO Instead of this, move torch::jit::Stack
17
+ // to the c10 namespace.
18
+ class OperatorHandle;
19
+
20
+ /*
21
+ * [Note: Argument forwarding in the dispatcher]
22
+ *
23
+ * The dispatcher uses a somewhat unusual way to forward arguments through
24
+ * several layers of wrapper functions. This can be confusing because an
25
+ * experienced C++ programmer would look at this and think "oh this is supposed
26
+ * to be forwarding a universal reference but the && is missing. This is a
27
+ * bug.". It is not a bug. The common way in C++ to forward arguments is to use
28
+ * universal references:
29
+ *
30
+ * > template<class T> void func(T&& arg) { func2(std::forward<T>(arg)); }
31
+ *
32
+ * but that relies on inferring the correct reference type (i.e. value vs & vs
33
+ * &&) from the argument. In our case, we cannot rely on the argument as
34
+ * supplied by the caller, because that could infer a different reference type
35
+ * than was used in the kernel function. The correct reference type is dictated
36
+ * by the kernel signature and must be identical since we cast function pointers
37
+ * through void* pointers and mismatches would be UB. So we need a forwarding
38
+ * pattern that determines the reference type to use by looking at the
39
+ * explicitly supplied operator signature, not by looking at the argument we're
40
+ * calling it with.
41
+ *
42
+ * What does std::forward do, exactly?
43
+ * ------------------------------------
44
+ * std::forward<T>(t) is a way to cast t to the reference type supplied in T.
45
+ * Let's assume decay_t<T> == U and T is either U or some reference of U.
46
+ * - std::forward<T&>(t) will return U&, no matter what kind of reference t is.
47
+ * - std::forward<T&&>(t) will return U&&, no matter what kind of reference t
48
+ * is.
49
+ * - std::forward<T>(t) will return U&& (not U!), no matter what kind of
50
+ * reference t is.
51
+ *
52
+ * For universal references, that means that in the following function
53
+ * > template<class T> void func(T&& arg) { func2(std::forward<T>(arg)); }
54
+ *
55
+ * - when called with arg being a rvalue reference or non-reference value, T
56
+ * gets inferred to be a non-reference U, and std::forward<T>(t) will return
57
+ * U&&, correctly moving the argument.
58
+ * - when called with arg behind a lvalue reference, T gets inferred to be U&
59
+ * because that's the only way to match the signature (in C++, a type that is
60
+ * (T&)&& will collapse to T&). That means std::forward<T>(t) will return U& and
61
+ * the value will not be moved but passed on as a lvalue reference.
62
+ *
63
+ * How do we use that?
64
+ * ------------------------------------
65
+ * But std::forward can also be used outside of the common "universal
66
+ * forwarding" pattern to change reference types. So instead of following the
67
+ * common C++ pattern, we notice what std::forward<T>() actually does, and that
68
+ * is it takes a value and changes its reference to the type of reference passed
69
+ * in as T. If we don't infer T but explicitly specify it, we can use this to
70
+ * forward based on an explicitly specified reference type instead of the
71
+ * inferred argument type.
72
+ *
73
+ * This is why many of the dispatcher functions look like
74
+ * > template<class T> func(T t) { func2<T>(std::forward<T>(t)); }
75
+ * instead of the common
76
+ * > template<class T> func(T&& t) { func2(std::forward<T>(t)); }
77
+ *
78
+ * and are expected to be called by explicitly specifying the template
79
+ * parameters in a way that matches the expected operator signature at each call
80
+ * site.
81
+ */
82
+
83
+ namespace impl {
84
+ // supported_primitive_arg_types defines which primitive types we allow in
85
+ // kernel functions as arguments or returns.
86
+ // Additionally, we support lists, dicts and optionals containing these types.
87
+ using supported_primitive_arg_types = guts::typelist::typelist<
88
+ int64_t,
89
+ double,
90
+ bool,
91
+ std::string_view,
92
+ at::Tensor,
93
+ at::Scalar,
94
+ c10::QScheme,
95
+ c10::ScalarType,
96
+ c10::Device,
97
+ c10::DeviceIndex,
98
+ c10::Layout,
99
+ c10::MemoryFormat,
100
+ at::Dimname>;
101
+
102
+ // We have an unboxed functor in hand that takes C++ arguments, and
103
+ // we're building a boxed functor wrapper for it that takes IValues.
104
+ // So "outside" is boxed and "inside" is unboxed.
105
+ //
106
+ // So a valid input type is one that our boxed functor wrapper can
107
+ // unbox from an IValue into a C++ value.
108
+ //
109
+ // Whereas a valid output type is one that our wrapper can receive
110
+ // as a C++ value from the unboxed functor, and box into an IValue.
111
+
112
+ //
113
+ // assert_is_valid_input_type
114
+ // checks that T can be unboxed from an IValue into a C++ value.
115
+ //
116
+
117
+ template <class T, bool AllowDeprecatedTypes, class Enable = void>
118
+ struct assert_is_valid_input_type {
119
+ assert_is_valid_input_type() {
120
+ if constexpr (guts::typelist::contains<supported_primitive_arg_types, T>::
121
+ value) {
122
+ /* everything is ok, this is a primitive type */
123
+ } else {
124
+ /* otherwise this must be an instance of a valid custom class, since it
125
+ can only have been created via IValue(x), which ensures this. */
126
+ }
127
+ }
128
+ };
129
+
130
+ template <class T, bool AllowDeprecatedTypes>
131
+ struct assert_is_valid_input_type<std::optional<T>, AllowDeprecatedTypes>
132
+ : assert_is_valid_input_type<T, AllowDeprecatedTypes> {};
133
+
134
+ template <bool AllowDeprecatedTypes, class... Args>
135
+ struct TypeCheckHelper;
136
+
137
+ template <bool AllowDeprecatedTypes>
138
+ struct TypeCheckHelper<AllowDeprecatedTypes> {};
139
+
140
+ template <bool AllowDeprecatedTypes, class Head, class... Rest>
141
+ struct TypeCheckHelper<AllowDeprecatedTypes, Head, Rest...>
142
+ : TypeCheckHelper<AllowDeprecatedTypes, Rest...> {
143
+ assert_is_valid_input_type<Head, AllowDeprecatedTypes> check;
144
+ };
145
+
146
+ template <class... Contained, bool AllowDeprecatedTypes>
147
+ struct assert_is_valid_input_type<
148
+ std::tuple<Contained...>,
149
+ AllowDeprecatedTypes>
150
+ : TypeCheckHelper<AllowDeprecatedTypes, Contained...> {};
151
+
152
+ template <class Key, class Value, bool AllowDeprecatedTypes>
153
+ struct assert_is_valid_input_type<Dict<Key, Value>, AllowDeprecatedTypes>
154
+ : assert_is_valid_input_type<Value, AllowDeprecatedTypes> {
155
+ static_assert(
156
+ guts::typelist::contains<impl::valid_dict_key_types, Key>::value,
157
+ "You tried to register a kernel with an unsupported input type: Dict<Key, Value> where Key is invalid. We only support int64_t, double, bool, and string.");
158
+ };
159
+
160
+ template <class Key, class Value, bool AllowDeprecatedTypes>
161
+ struct assert_is_valid_input_type<
162
+ std::unordered_map<Key, Value>,
163
+ AllowDeprecatedTypes>
164
+ : assert_is_valid_input_type<Value, AllowDeprecatedTypes> {
165
+ static_assert(
166
+ AllowDeprecatedTypes,
167
+ "You tried to register a kernel with an unsupported input type: std::unordered_map<Key, Value>. Please use Dict<Key, Value> instead.");
168
+ static_assert(
169
+ guts::typelist::contains<impl::valid_dict_key_types, Key>::value,
170
+ "You tried to register a kernel with an unsupported input type: std::unordered_map<Key, Value> where Key is invalid. We only support int64_t, double, bool, and string.");
171
+ };
172
+
173
+ template <class T, bool AllowDeprecatedTypes>
174
+ struct assert_is_valid_input_type<List<T>, AllowDeprecatedTypes>
175
+ : assert_is_valid_input_type<T, AllowDeprecatedTypes> {
176
+ static_assert(
177
+ !std::is_same_v<T, at::Scalar>,
178
+ "You tried to register a kernel with an unsupported input type: List<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
179
+ };
180
+
181
+ template <class T, bool AllowDeprecatedTypes>
182
+ struct assert_is_valid_input_type<c10::ArrayRef<T>, AllowDeprecatedTypes>
183
+ : assert_is_valid_input_type<T, AllowDeprecatedTypes> {
184
+ static_assert(
185
+ !std::is_same_v<T, at::Scalar>,
186
+ "You tried to register a kernel with an unsupported input type: ArrayRef<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
187
+ };
188
+
189
+ template <class T, bool AllowDeprecatedTypes>
190
+ struct assert_is_valid_input_type<
191
+ c10::OptionalArrayRef<T>,
192
+ AllowDeprecatedTypes>
193
+ : assert_is_valid_input_type<T, AllowDeprecatedTypes> {
194
+ static_assert(
195
+ !std::is_same_v<T, at::Scalar>,
196
+ "You tried to register a kernel with an unsupported input type: OptionalArrayRef<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
197
+ };
198
+
199
+ template <class T, size_t N, bool AllowDeprecatedTypes>
200
+ struct assert_is_valid_input_type<std::array<T, N>, AllowDeprecatedTypes>
201
+ : assert_is_valid_input_type<T, AllowDeprecatedTypes> {
202
+ static_assert(
203
+ !std::is_same_v<T, at::Scalar>,
204
+ "You tried to register a kernel with an unsupported input type: std::array<Scalar, N>. Please use std::array<int64_t, N> instead.");
205
+ };
206
+
207
+ template <class T, bool AllowDeprecatedTypes>
208
+ struct assert_is_valid_input_type<
209
+ T,
210
+ AllowDeprecatedTypes,
211
+ std::enable_if_t<std::is_same_v<float, T>>> {
212
+ // There is no reason to support float when we have double. Keep the API lean.
213
+ static_assert(
214
+ guts::false_t<T>::value,
215
+ "You tried to register a kernel with an unsupported input type: float. Please use double instead; you should use `double` in the C++ function signature and `float` in the schema string.");
216
+ };
217
+ template <class T, bool AllowDeprecatedTypes>
218
+ struct assert_is_valid_input_type<
219
+ T,
220
+ AllowDeprecatedTypes,
221
+ std::enable_if_t<std::is_same_v<const char*, T>>> {
222
+ static_assert(
223
+ guts::false_t<T>::value,
224
+ "You tried to register a kernel with an unsupported input type: const char*. Please use std::string_view instead.");
225
+ };
226
+ template <class T, bool AllowDeprecatedTypes>
227
+ struct assert_is_valid_input_type<
228
+ T,
229
+ AllowDeprecatedTypes,
230
+ std::enable_if_t<std::is_same_v<std::vector<bool>, T>>> {
231
+ static_assert(
232
+ guts::false_t<T>::value,
233
+ "You tried to register a kernel with an unsupported input type: vector<bool>. Please use List<bool> instead.");
234
+ };
235
+ template <class T, bool AllowDeprecatedTypes>
236
+ struct assert_is_valid_input_type<
237
+ T,
238
+ AllowDeprecatedTypes,
239
+ std::enable_if_t<
240
+ std::is_integral_v<T> &&
241
+ !guts::typelist::contains<supported_primitive_arg_types, T>::value>> {
242
+ static_assert(
243
+ guts::false_t<T>::value,
244
+ "You tried to register a kernel with an unsupported integral input type. Please use int64_t instead; you should use `int64_t` in the C++ function signature and `int` in the schema string.");
245
+ };
246
+ template <class T, bool AllowDeprecatedTypes>
247
+ struct assert_is_valid_input_type<
248
+ T,
249
+ AllowDeprecatedTypes,
250
+ std::enable_if_t<std::is_same_v<const c10::SymInt&, T>>> {
251
+ static_assert(
252
+ guts::false_t<T>::value,
253
+ "You tried to register a kernel taking c10::SymInt by reference. Please accept it by value instead.");
254
+ };
255
+
256
+ // TODO: it probably would be good to tighten this up quite a bit more with
257
+ // an explicit list for everything
258
+
259
+ //
260
+ // assert_is_valid_output_type
261
+ //
262
+
263
+ template <class T, bool AllowDeprecatedTypes, class Enable = void>
264
+ struct assert_is_valid_output_type {
265
+ assert_is_valid_output_type() {
266
+ if constexpr (guts::typelist::contains<supported_primitive_arg_types, T>::
267
+ value) {
268
+ /* everything is ok, this is a primitive type */
269
+ } else {
270
+ /* otherwise T is verified to be a registered custom class in the IValue
271
+ constructor, so no benefit in double-checking here */
272
+ }
273
+ }
274
+ };
275
+
276
+ template <class T, bool AllowDeprecatedTypes>
277
+ struct assert_is_valid_output_type<std::optional<T>, AllowDeprecatedTypes>
278
+ : assert_is_valid_output_type<T, AllowDeprecatedTypes> {};
279
+
280
+ template <class T, bool AllowDeprecatedTypes>
281
+ struct assert_is_valid_output_type<
282
+ c10::OptionalArrayRef<T>,
283
+ AllowDeprecatedTypes>
284
+ : assert_is_valid_output_type<T, AllowDeprecatedTypes> {};
285
+
286
+ template <class Key, class Value, bool AllowDeprecatedTypes>
287
+ struct assert_is_valid_output_type<Dict<Key, Value>, AllowDeprecatedTypes>
288
+ : assert_is_valid_output_type<Value, AllowDeprecatedTypes> {
289
+ static_assert(
290
+ guts::typelist::contains<impl::valid_dict_key_types, Key>::value,
291
+ "You tried to register a kernel with an unsupported output type: Dict<Key, Value> where Key is invalid. We only support int64_t, double, bool, and string.");
292
+ static_assert(
293
+ !std::is_same_v<Value, at::Scalar>,
294
+ "You tried to register a kernel with an unsupported output type: Dict<Key, Scalar>. Please use Dict<Key, int64_t> or Dict<Key, double>.");
295
+ };
296
+
297
+ template <class Key, class Value, bool AllowDeprecatedTypes>
298
+ struct assert_is_valid_output_type<
299
+ std::unordered_map<Key, Value>,
300
+ AllowDeprecatedTypes>
301
+ : assert_is_valid_output_type<Value, AllowDeprecatedTypes> {
302
+ static_assert(
303
+ AllowDeprecatedTypes,
304
+ "You tried to register a kernel with an unsupported output type: std::unordered_map<Key, Value>. Please use Dict<Key, Value> instead.");
305
+ static_assert(
306
+ guts::typelist::contains<impl::valid_dict_key_types, Key>::value,
307
+ "You tried to register a kernel with an unsupported output type: std::unordered_map<Key, Value> where Key is invalid. We only support int64_t, double, bool, and string.");
308
+ static_assert(
309
+ !std::is_same_v<Value, at::Scalar>,
310
+ "You tried to register a kernel with an unsupported output type: std::unordered_map<Key, Scalar>. Please use Dict<Key, int64_t> or Dict<Key, double>.");
311
+ };
312
+
313
+ template <class T, bool AllowDeprecatedTypes>
314
+ struct assert_is_valid_output_type<List<T>, AllowDeprecatedTypes>
315
+ : assert_is_valid_output_type<T, AllowDeprecatedTypes> {
316
+ static_assert(
317
+ !std::is_same_v<T, at::Scalar>,
318
+ "You tried to register a kernel with an unsupported output type: List<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
319
+ };
320
+
321
+ template <class T, bool AllowDeprecatedTypes>
322
+ struct assert_is_valid_output_type<std::vector<T>, AllowDeprecatedTypes>
323
+ : assert_is_valid_output_type<T, AllowDeprecatedTypes> {
324
+ static_assert(
325
+ !std::is_same_v<T, at::Scalar>,
326
+ "You tried to register a kernel with an unsupported output type: std::vector<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
327
+ // TODO static_assert(AllowDeprecatedTypes, "You tried to register a kernel
328
+ // with an unsupported output type: std::vector<T>. Please use List<T>
329
+ // instead.");
330
+ };
331
+
332
+ template <class T, size_t N, bool AllowDeprecatedTypes>
333
+ struct assert_is_valid_output_type<std::array<T, N>, AllowDeprecatedTypes>
334
+ : assert_is_valid_output_type<T, AllowDeprecatedTypes> {
335
+ static_assert(
336
+ !std::is_same_v<T, at::Scalar>,
337
+ "You tried to register a kernel with an unsupported output type: std::array<Scalar, N>. Please use std::array<int64_t, N> instead.");
338
+ };
339
+
340
+ // The following specialisations of assert_is_valid_output_type are technically
341
+ // not necessary since we would hit the base case and show an error message
342
+ // there if they didn't exist, but we can show a better error message
343
+ // in some common error scenarios.
344
+ template <class T, bool AllowDeprecatedTypes>
345
+ struct assert_is_valid_output_type<
346
+ T,
347
+ AllowDeprecatedTypes,
348
+ std::enable_if_t<std::is_same_v<float, T>>> {
349
+ // There is no reason to support float when we have double. Keep the API lean.
350
+ static_assert(
351
+ guts::false_t<T>::value,
352
+ "You tried to register a kernel with an unsupported output type: float. Please use double instead; you should use `double` in the C++ function signature and `float` in the schema string.");
353
+ };
354
+ template <class T, bool AllowDeprecatedTypes>
355
+ struct assert_is_valid_output_type<
356
+ T,
357
+ AllowDeprecatedTypes,
358
+ std::enable_if_t<std::is_same_v<const char*, T>>> {
359
+ static_assert(
360
+ guts::false_t<T>::value,
361
+ "You tried to register a kernel with an unsupported output type: const char*. Please use std::string_view instead.");
362
+ };
363
+ template <class T, bool AllowDeprecatedTypes>
364
+ struct assert_is_valid_output_type<
365
+ T,
366
+ AllowDeprecatedTypes,
367
+ std::enable_if_t<std::is_same_v<std::vector<bool>, T>>> {
368
+ static_assert(
369
+ guts::false_t<T>::value,
370
+ "You tried to register a kernel with an unsupported output type: vector<bool>. Please use List<bool> instead.");
371
+ };
372
+ template <class T, bool AllowDeprecatedTypes>
373
+ struct assert_is_valid_output_type<
374
+ T,
375
+ AllowDeprecatedTypes,
376
+ std::enable_if_t<
377
+ std::is_integral_v<T> &&
378
+ !guts::typelist::contains<supported_primitive_arg_types, T>::value>> {
379
+ static_assert(
380
+ guts::false_t<T>::value,
381
+ "You tried to register a kernel with an unsupported integral output type. Please use int64_t instead; you should use `int64_t` in the C++ function signature and `int` in the schema string.");
382
+ };
383
+
384
+ // ivalue_to_arg
385
+
386
+ template <class T>
387
+ struct decay_if_not_tensor final {
388
+ using type = std::decay_t<T>;
389
+ };
390
+
391
+ template <>
392
+ struct decay_if_not_tensor<at::Tensor&> final {
393
+ using type = at::Tensor&;
394
+ };
395
+
396
+ template <>
397
+ struct decay_if_not_tensor<const at::Tensor&> final {
398
+ using type = const at::Tensor&;
399
+ };
400
+
401
+ template <class T, bool AllowDeprecatedTypes>
402
+ struct ivalue_to_arg final {
403
+ static decltype(auto) call(IValue& v) {
404
+ assert_is_valid_input_type<T, AllowDeprecatedTypes>();
405
+ return std::move(v).to<T>();
406
+ }
407
+ };
408
+
409
+ // The following two specializations take advantage of specialized
410
+ // `toTensor()` overloads on IValue to avoid copying.
411
+ template <bool AllowDeprecatedTypes>
412
+ struct ivalue_to_arg<at::Tensor&, AllowDeprecatedTypes> final {
413
+ // We cannot use the default implementation if they asked for a
414
+ // `at::Tensor&` because it moves from the IValue, so it can't get
415
+ // an lvalue reference.
416
+ static at::Tensor& call(IValue& v) {
417
+ // Tensor& is valid, don't bother asserting
418
+ return v.toTensor();
419
+ }
420
+ };
421
+
422
+ template <bool AllowDeprecatedTypes>
423
+ struct ivalue_to_arg<const at::Tensor&, AllowDeprecatedTypes> final {
424
+ // We should not use the default implementation if they asked for
425
+ // a `const at::Tensor&` because it moves from the IValue and they
426
+ // didn't ask for that.
427
+ static const at::Tensor& call(IValue& v) {
428
+ // const Tensor& is valid, don't bother asserting
429
+ return v.toTensor();
430
+ }
431
+ };
432
+
433
+ template <bool AllowDeprecatedTypes>
434
+ struct ivalue_to_arg<at::ITensorListRef, AllowDeprecatedTypes> final {
435
+ static List<at::Tensor> call(IValue& v) {
436
+ return v.toTensorList();
437
+ }
438
+ };
439
+
440
+ template <class T, bool AllowDeprecatedTypes>
441
+ struct ivalue_to_arg<ArrayRef<T>, AllowDeprecatedTypes> final {
442
+ // If an argument is ArrayRef<T>, convert the IValue to a std::vector<T> and
443
+ // pass that to the operator. std::vector<T> is implicitly convertible to
444
+ // ArrayRef<T>.
445
+ static std::vector<T> call(IValue& v) {
446
+ return ivalue_to_arg<std::vector<T>, AllowDeprecatedTypes>::call(v);
447
+ }
448
+ };
449
+ template <bool AllowDeprecatedTypes>
450
+ struct ivalue_to_arg<c10::SymIntArrayRef, AllowDeprecatedTypes> final {
451
+ static std::vector<c10::SymInt> call(IValue& v) {
452
+ if (v.isIntList()) {
453
+ std::vector<c10::SymInt> r;
454
+ auto src = v.toIntList();
455
+ std::transform(
456
+ src.begin(), src.end(), std::back_inserter(r), [](int64_t i) {
457
+ return c10::SymInt(i);
458
+ });
459
+ return r;
460
+ } else {
461
+ return ivalue_to_arg<std::vector<c10::SymInt>, AllowDeprecatedTypes>::
462
+ call(v);
463
+ }
464
+ }
465
+ };
466
+ template <bool AllowDeprecatedTypes>
467
+ struct ivalue_to_arg<c10::OptionalArray<c10::SymInt>, AllowDeprecatedTypes>
468
+ final {
469
+ static OptionalArray<c10::SymInt> call(IValue& v) {
470
+ if (v.isIntList()) {
471
+ std::vector<c10::SymInt> r;
472
+ auto src = v.toIntList();
473
+ std::transform(
474
+ src.begin(), src.end(), std::back_inserter(r), [](int64_t i) {
475
+ return c10::SymInt(i);
476
+ });
477
+ return OptionalArray<c10::SymInt>(std::move(r));
478
+ } else {
479
+ return std::move(v).to<OptionalArray<c10::SymInt>>();
480
+ }
481
+ }
482
+ };
483
+ template <class T, bool AllowDeprecatedTypes>
484
+ struct ivalue_to_arg<std::optional<ArrayRef<T>>, AllowDeprecatedTypes> final {
485
+ // If an argument is std::optional<ArrayRef<T>>, convert the IValue to an
486
+ // std::optional<std::vector<T>> and pass that to the operator.
487
+ // OptionalArray<T> is basically a std::optional<std::vector<T>> but
488
+ // implicitly convertible to std::optional<ArrayRef<T>>.
489
+ static OptionalArray<T> call(IValue& v) {
490
+ return ivalue_to_arg<OptionalArray<T>, AllowDeprecatedTypes>::call(v);
491
+ }
492
+ };
493
+
494
+ template <class T, bool AllowDeprecatedTypes>
495
+ struct ivalue_to_arg<OptionalArrayRef<T>, AllowDeprecatedTypes> final {
496
+ // If an argument is OptionalArrayRef<T>, convert the IValue to an
497
+ // std::optional<std::vector<T>> and pass that to the operator.
498
+ // OptionalArray<T> is basically a std::optional<std::vector<T>> but
499
+ // implicitly convertible to OptionalArrayRef<T>
500
+ static OptionalArray<T> call(IValue& v) {
501
+ return ivalue_to_arg<OptionalArray<T>, AllowDeprecatedTypes>::call(v);
502
+ }
503
+ };
504
+
505
+ // return_to_ivalue
506
+ template <class T, bool AllowDeprecatedTypes, class Enable = void>
507
+ struct return_to_ivalue final {};
508
+
509
+ template <class T, bool AllowDeprecatedTypes>
510
+ struct return_to_ivalue<
511
+ T,
512
+ AllowDeprecatedTypes,
513
+ std::enable_if_t<!std::is_same_v<at::Tensor&, T>>>
514
+ final {
515
+ static IValue call(T&& v) {
516
+ assert_is_valid_output_type<T, AllowDeprecatedTypes>();
517
+ return c10::ivalue::from(std::move(v));
518
+ }
519
+ static IValue copy(const T& v) {
520
+ assert_is_valid_output_type<T, AllowDeprecatedTypes>();
521
+ return IValue(v);
522
+ }
523
+ };
524
+
525
+ // Special case to allow kernels to return `Tensor&`.
526
+ // TODO Delete this once kernels don't do that anymore
527
+ template <bool AllowDeprecatedTypes>
528
+ struct return_to_ivalue<at::Tensor&, AllowDeprecatedTypes, void> final {
529
+ static IValue call(at::Tensor& v) {
530
+ return c10::ivalue::from(v);
531
+ }
532
+ static IValue copy(at::Tensor& v) {
533
+ return IValue(v);
534
+ }
535
+ };
536
+
537
+ // wrap_kernel_functor_unboxed_
538
+
539
+ template <class KernelFunctor, class OpSignature>
540
+ struct wrap_kernel_functor_unboxed_ final {};
541
+
542
+ // This specialization is for kernels with a first argument that is NOT of type
543
+ // DispatchKeySet This includes kernels with 0 arguments.
544
+ template <class KernelFunctor, class ReturnType, class... ParameterTypes>
545
+ struct wrap_kernel_functor_unboxed_<
546
+ KernelFunctor,
547
+ ReturnType(ParameterTypes...)>
548
+ final {
549
+ static_assert(
550
+ std::is_same_v<
551
+ ReturnType,
552
+ typename guts::infer_function_traits_t<KernelFunctor>::return_type>,
553
+ "Return type mismatch");
554
+ static_assert(
555
+ std::is_same_v<
556
+ guts::typelist::typelist<ParameterTypes...>,
557
+ typename guts::infer_function_traits_t<
558
+ KernelFunctor>::parameter_types>,
559
+ "Parameter types mismatch");
560
+
561
+ // See [Note: Argument forwarding in the dispatcher] for why ParameterTypes
562
+ // doesn't use &&
563
+ static ReturnType call(
564
+ OperatorKernel* functor,
565
+ DispatchKeySet /*unused*/,
566
+ ParameterTypes... args) {
567
+ KernelFunctor* functor_ = static_cast<KernelFunctor*>(functor);
568
+ // Note [Plumbing Keys Through The Dispatcher 2]
569
+ // See Note [Plumbing Keys Through The Dispatcher] for the background.
570
+ // This functor explicitly takes in a dispatchKeySet and drops it on the
571
+ // floor- it does not forward it to the registered kernel.
572
+ //
573
+ // This is due to the calling convention within the dispatcher, which
574
+ // expects all registered kernels to have a first argument of type
575
+ // DispatchKeySet.
576
+ // This is not the case for pretty much all manually written kernels,
577
+ // however- this functor serves to separate the calling convention of the
578
+ // dispatcher from the calling convention of manually written kernels.
579
+ return (*functor_)(std::forward<ParameterTypes>(args)...);
580
+ }
581
+ };
582
+
583
+ // This specialization is for kernels with a first argument of type
584
+ // DispatchKeySet
585
+ template <class KernelFunctor, class ReturnType, class... ParameterTypes>
586
+ struct wrap_kernel_functor_unboxed_<
587
+ KernelFunctor,
588
+ ReturnType(DispatchKeySet, ParameterTypes...)>
589
+ final {
590
+ static_assert(
591
+ std::is_same_v<
592
+ ReturnType,
593
+ typename guts::infer_function_traits_t<KernelFunctor>::return_type>,
594
+ "Return type mismatch");
595
+ static_assert(
596
+ std::is_same_v<
597
+ guts::typelist::typelist<DispatchKeySet, ParameterTypes...>,
598
+ typename guts::infer_function_traits_t<
599
+ KernelFunctor>::parameter_types>,
600
+ "Parameter types mismatch");
601
+
602
+ // See [Note: Argument forwarding in the dispatcher] for why ParameterTypes
603
+ // doesn't use &&
604
+ static ReturnType call(
605
+ OperatorKernel* functor,
606
+ DispatchKeySet dispatchKeySet,
607
+ ParameterTypes... args) {
608
+ KernelFunctor* functor_ = static_cast<KernelFunctor*>(functor);
609
+ // We're explicitly taking in a dispatchKeySet and forwarding it to the
610
+ // registered kernel. See Note [Plumbing Keys Through The Dispatcher 2] for
611
+ // details.
612
+ return (*functor_)(dispatchKeySet, std::forward<ParameterTypes>(args)...);
613
+ }
614
+ };
615
+
616
+ template <class KernelFunctor>
617
+ using wrap_kernel_functor_unboxed = wrap_kernel_functor_unboxed_<
618
+ KernelFunctor,
619
+ typename guts::infer_function_traits_t<KernelFunctor>::func_type>;
620
+
621
+ // call_functor_with_args_from_stack
622
+
623
+ template <
624
+ class Functor,
625
+ bool AllowDeprecatedTypes,
626
+ size_t... ivalue_arg_indices,
627
+ typename... ArgTypes>
628
+ std::decay_t<typename guts::infer_function_traits_t<Functor>::return_type>
629
+ call_functor_with_args_from_stack_(
630
+ OperatorKernel* functor,
631
+ DispatchKeySet dispatchKeySet,
632
+ Stack* stack,
633
+ std::index_sequence<ivalue_arg_indices...> /*unused*/,
634
+ guts::typelist::typelist<ArgTypes...>* /*unused*/) {
635
+ (void)stack; // when sizeof...(ivalue_arg_indices) == 0, this argument would
636
+ // be unused and we have to silence the compiler warning.
637
+
638
+ // We're explicitly filtering out DispatchKeySet from the argument list.
639
+ // Some kernels take a DispatchKeySet as their first argument in order to
640
+ // plumb keys through the dispatcher. We don't want to expose the
641
+ // DispatchKeySet type to jit, so we don't include this argument on the stack.
642
+ // See Note [Plumbing Keys Through The Dispatcher] for the background.
643
+ return wrap_kernel_functor_unboxed<Functor>::call(
644
+ functor,
645
+ dispatchKeySet,
646
+ ivalue_to_arg<
647
+ typename decay_if_not_tensor<ArgTypes>::type,
648
+ AllowDeprecatedTypes>::
649
+ call(torch::jit::peek(
650
+ *stack, ivalue_arg_indices, sizeof...(ivalue_arg_indices)))...);
651
+ }
652
+
653
+ template <class Functor, bool AllowDeprecatedTypes>
654
+ std::decay_t<typename guts::infer_function_traits_t<Functor>::return_type>
655
+ call_functor_with_args_from_stack(
656
+ OperatorKernel* functor,
657
+ DispatchKeySet dispatchKeySet,
658
+ Stack* stack) {
659
+ // We're explicitly filtering out DispatchKeySet from the argument list.
660
+ // Some kernels take a DispatchKeySet as their first argument in order to
661
+ // plumb keys through the dispatcher. We don't want to expose the
662
+ // DispatchKeySet type to jit, so we don't include this argument on the stack.
663
+ // See Note [Plumbing Keys Through The Dispatcher] for the background.
664
+ using ArgTypes = typename c10::remove_DispatchKeySet_arg_from_func<
665
+ Functor>::parameter_types;
666
+ constexpr size_t num_ivalue_args = guts::typelist::size<ArgTypes>::value;
667
+ return call_functor_with_args_from_stack_<Functor, AllowDeprecatedTypes>(
668
+ functor,
669
+ dispatchKeySet,
670
+ stack,
671
+ std::make_index_sequence<num_ivalue_args>(),
672
+ static_cast<ArgTypes*>(nullptr));
673
+ }
674
+
675
+ // push_outputs
676
+
677
+ template <class OutputType, bool AllowDeprecatedTypes>
678
+ struct push_outputs final {
679
+ // Contrary to [Note: Argument forwarding in the dispatcher], we use
680
+ // OutputType&& here to avoid one extra call to the move constructor in this
681
+ // case. This is still not a universal reference though because OutputType is
682
+ // an explicitly specified class template parameter.
683
+ static void call(OutputType&& output, Stack* stack) {
684
+ torch::jit::push(
685
+ *stack,
686
+ return_to_ivalue<OutputType, AllowDeprecatedTypes>::call(
687
+ std::forward<OutputType>(output)));
688
+ }
689
+ static void copy(const OutputType& output, Stack* stack) {
690
+ torch::jit::push(
691
+ *stack,
692
+ return_to_ivalue<OutputType, AllowDeprecatedTypes>::copy(output));
693
+ }
694
+ };
695
+ template <class... OutputTypes, bool AllowDeprecatedTypes>
696
+ struct push_outputs<std::tuple<OutputTypes...>, AllowDeprecatedTypes> final {
697
+ static void call(std::tuple<OutputTypes...>&& output, Stack* stack) {
698
+ call_(
699
+ std::move(output),
700
+ stack,
701
+ std::make_index_sequence<sizeof...(OutputTypes)>());
702
+ }
703
+ static void copy(const std::tuple<OutputTypes...>& output, Stack* stack) {
704
+ copy_(output, stack, std::make_index_sequence<sizeof...(OutputTypes)>());
705
+ }
706
+
707
+ private:
708
+ template <size_t... indices>
709
+ static void call_(
710
+ std::tuple<OutputTypes...>&& output,
711
+ Stack* stack,
712
+ std::index_sequence<indices...> /*unused*/) {
713
+ torch::jit::push(
714
+ *stack,
715
+ return_to_ivalue<OutputTypes, AllowDeprecatedTypes>::call(
716
+ std::forward<OutputTypes>(std::get<indices>(output)))...);
717
+ }
718
+ template <size_t... indices>
719
+ static void copy_(
720
+ const std::tuple<OutputTypes...>& output,
721
+ Stack* stack,
722
+ std::index_sequence<indices...> /*unused*/) {
723
+ torch::jit::push(
724
+ *stack,
725
+ return_to_ivalue<OutputTypes, AllowDeprecatedTypes>::copy(
726
+ std::get<indices>(output))...);
727
+ }
728
+ };
729
+ template <bool AllowDeprecatedTypes>
730
+ struct push_outputs<void, AllowDeprecatedTypes> final {
731
+ static void call(int /*dummy*/, Stack* /*stack*/) {}
732
+ static void copy(int /*dummy*/, Stack* /*stack*/) {}
733
+ };
734
+
735
+ // make_boxed_from_unboxed_functor
736
+
737
+ template <class KernelFunctor, bool AllowDeprecatedTypes>
738
+ struct make_boxed_from_unboxed_functor final {
739
+ static_assert(
740
+ std::is_base_of_v<OperatorKernel, KernelFunctor>,
741
+ "Tried to register a kernel functor using the kernel<Functor>() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
742
+
743
+ static void call(
744
+ OperatorKernel* functor,
745
+ const OperatorHandle& /*unused*/,
746
+ DispatchKeySet dispatchKeySet,
747
+ Stack* stack) {
748
+ using ReturnType =
749
+ typename guts::infer_function_traits_t<KernelFunctor>::return_type;
750
+ // We're explicitly filtering out DispatchKeySet from the argument list.
751
+ // Some kernels take a DispatchKeySet as their first argument in order to
752
+ // plumb keys through the dispatcher. We don't want to expose the
753
+ // DispatchKeySet type to jit, so we don't include this argument on the
754
+ // stack. See Note [Plumbing Keys Through The Dispatcher] for the
755
+ // background.
756
+ using ArgTypes = typename c10::remove_DispatchKeySet_arg_from_func<
757
+ KernelFunctor>::parameter_types;
758
+ constexpr bool has_outputs = !std::is_same_v<void, ReturnType>;
759
+ constexpr size_t num_inputs = guts::typelist::size<ArgTypes>::value;
760
+ if constexpr (has_outputs) {
761
+ // Decay ReturnType to ReturnType_ so that if a reference gets returned,
762
+ // we actually store it by value and don't get a dangling reference. This
763
+ // is only required because some kernels still return `Tensor&`. [Note:
764
+ // VC++ and 'std': ambiguous symbol]
765
+ using ReturnType_ = ::std::decay_t<ReturnType>;
766
+ ReturnType_ output = call_functor_with_args_from_stack<
767
+ KernelFunctor,
768
+ AllowDeprecatedTypes>(functor, dispatchKeySet, stack);
769
+ torch::jit::drop(*stack, num_inputs);
770
+ // See note [ VC++ and 'std': ambiguous symbol]
771
+ push_outputs<ReturnType_, AllowDeprecatedTypes>::call(
772
+ ::std::move(output), stack);
773
+ } else {
774
+ call_functor_with_args_from_stack<KernelFunctor, AllowDeprecatedTypes>(
775
+ functor, dispatchKeySet, stack);
776
+ torch::jit::drop(*stack, num_inputs);
777
+ }
778
+ }
779
+ };
780
+ } // namespace impl
781
+
782
+ } // namespace c10
783
+
784
+ namespace torch {
785
+ using OperatorKernel = c10::OperatorKernel;
786
+ }
787
+
788
+ #else
789
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
790
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/boxing/impl/test_helpers.h ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <gmock/gmock.h>
5
+ #include <gtest/gtest.h>
6
+
7
+ #include <ATen/core/Tensor.h>
8
+ #include <ATen/core/dispatch/Dispatcher.h>
9
+ #include <ATen/core/ivalue.h>
10
+ #include <c10/core/CPUAllocator.h>
11
+ #include <c10/util/irange.h>
12
+
13
+ template <class... Inputs>
14
+ inline std::vector<c10::IValue> makeStack(Inputs&&... inputs) {
15
+ return {std::forward<Inputs>(inputs)...};
16
+ }
17
+
18
+ inline at::Tensor dummyTensor(
19
+ c10::DispatchKeySet ks,
20
+ bool requires_grad = false) {
21
+ auto* allocator = c10::GetCPUAllocator();
22
+ int64_t nelements = 1;
23
+ auto dtype = caffe2::TypeMeta::Make<float>();
24
+ int64_t size_bytes = nelements * dtype.itemsize();
25
+ auto storage_impl = c10::make_intrusive<c10::StorageImpl>(
26
+ c10::StorageImpl::use_byte_size_t(),
27
+ size_bytes,
28
+ allocator->allocate(size_bytes),
29
+ allocator,
30
+ /*resizable=*/true);
31
+ at::Tensor t =
32
+ at::detail::make_tensor<c10::TensorImpl>(storage_impl, ks, dtype);
33
+ // TODO: We add this to simulate the ideal case where we only have Autograd
34
+ // backend keys
35
+ // on Tensor when it requires grad. But currently Autograd keys are
36
+ // added in TensorImpl constructor by default.
37
+ if (!requires_grad) {
38
+ t.unsafeGetTensorImpl()->remove_autograd_key();
39
+ }
40
+ return t;
41
+ }
42
+
43
+ inline at::Tensor dummyTensor(
44
+ c10::DispatchKey dispatch_key,
45
+ bool requires_grad = false) {
46
+ return dummyTensor(c10::DispatchKeySet(dispatch_key), requires_grad);
47
+ }
48
+
49
+ template <class... Args>
50
+ inline std::vector<c10::IValue> callOp(
51
+ const c10::OperatorHandle& op,
52
+ Args... args) {
53
+ auto stack = makeStack(std::forward<Args>(args)...);
54
+ op.callBoxed(&stack);
55
+ return stack;
56
+ }
57
+
58
+ template <class Result, class... Args>
59
+ inline Result callOpUnboxed(const c10::OperatorHandle& op, Args... args) {
60
+ return op.typed<Result(Args...)>().call(std::forward<Args>(args)...);
61
+ }
62
+
63
+ template <class Result, class... Args>
64
+ inline Result callOpUnboxedWithDispatchKey(
65
+ const c10::OperatorHandle& op,
66
+ c10::DispatchKey dispatchKey,
67
+ Args... args) {
68
+ return op.typed<Result(Args...)>().callWithDispatchKey(
69
+ dispatchKey, std::forward<Args>(args)...);
70
+ }
71
+
72
+ template <class Result, class... Args>
73
+ inline Result callOpUnboxedWithPrecomputedDispatchKeySet(
74
+ const c10::OperatorHandle& op,
75
+ c10::DispatchKeySet ks,
76
+ Args... args) {
77
+ return op.typed<Result(Args...)>().redispatch(
78
+ ks, std::forward<Args>(args)...);
79
+ }
80
+
81
+ inline void expectDoesntFindKernel(
82
+ const char* op_name,
83
+ c10::DispatchKey dispatch_key) {
84
+ auto op = c10::Dispatcher::singleton().findSchema({op_name, ""});
85
+ EXPECT_ANY_THROW(callOp(*op, dummyTensor(dispatch_key), 5););
86
+ }
87
+
88
+ inline void expectDoesntFindOperator(const char* op_name) {
89
+ auto op = c10::Dispatcher::singleton().findSchema({op_name, ""});
90
+ EXPECT_FALSE(op.has_value());
91
+ }
92
+
93
+ template <class Exception, class Functor>
94
+ inline void expectThrows(Functor&& functor, const char* expectMessageContains) {
95
+ try {
96
+ std::forward<Functor>(functor)();
97
+ } catch (const Exception& e) {
98
+ EXPECT_THAT(e.what(), testing::HasSubstr(expectMessageContains));
99
+ return;
100
+ }
101
+ ADD_FAILURE() << "Expected to throw exception containing \""
102
+ << expectMessageContains << "\" but didn't throw";
103
+ }
104
+
105
+ template <class T, size_t N>
106
+ void expectListEquals(c10::ArrayRef<T> expected, std::array<T, N> actual) {
107
+ EXPECT_EQ(expected.size(), actual.size());
108
+ for (const auto i : c10::irange(expected.size())) {
109
+ EXPECT_EQ(expected[i], actual[i]);
110
+ }
111
+ }
112
+
113
+ template <class T>
114
+ void expectListEquals(c10::ArrayRef<T> expected, c10::ArrayRef<T> actual) {
115
+ EXPECT_EQ(expected.size(), actual.size());
116
+ for (const auto i : c10::irange(expected.size())) {
117
+ EXPECT_EQ(expected[i], actual[i]);
118
+ }
119
+ }
120
+
121
+ template <class T>
122
+ void expectListEquals(c10::ArrayRef<T> expected, c10::List<T> actual) {
123
+ EXPECT_EQ(expected.size(), actual.size());
124
+ for (const auto i : c10::irange(expected.size())) {
125
+ EXPECT_EQ(expected[i], actual.get(i));
126
+ }
127
+ }
128
+
129
+ template <class T>
130
+ void expectListEquals(c10::ArrayRef<T> expected, std::vector<T> actual) {
131
+ EXPECT_EQ(expected.size(), actual.size());
132
+ for (const auto i : c10::irange(expected.size())) {
133
+ EXPECT_EQ(expected[i], actual[i]);
134
+ }
135
+ }
136
+
137
+ // NB: This is not really sound, but all of the type sets constructed here
138
+ // are singletons so it's fine
139
+ static inline c10::DispatchKey extractDispatchKey(const at::Tensor& t) {
140
+ return legacyExtractDispatchKey(t.key_set());
141
+ }
142
+
143
+ #else
144
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
145
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/CppSignature.h ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <c10/core/DispatchKeySet.h>
5
+ #include <c10/macros/Macros.h>
6
+ #include <c10/util/Metaprogramming.h>
7
+ #include <c10/util/Type.h>
8
+ #include <typeindex>
9
+
10
+ namespace c10::impl {
11
+
12
+ // A CppSignature object holds RTTI information about a C++ function signature
13
+ // at runtime and can compare them or get a debug-printable name.
14
+ class TORCH_API CppSignature final {
15
+ public:
16
+ CppSignature(const CppSignature&) = default;
17
+ CppSignature(CppSignature&&) noexcept = default;
18
+ CppSignature& operator=(const CppSignature&) = default;
19
+ CppSignature& operator=(CppSignature&&) noexcept = default;
20
+
21
+ template <class FuncType>
22
+ static CppSignature make() {
23
+ // Normalize functors, lambdas, function pointers, etc. into the plain
24
+ // function type The first argument of the schema might be of type
25
+ // DispatchKeySet, in which case we remove it. We do this to guarantee that
26
+ // all CppSignature's for an operator will match, even if they're registered
27
+ // with different calling conventions.
28
+ // See Note [Plumbing Keys Through The Dispatcher]
29
+ using decayed_function_type =
30
+ typename c10::remove_DispatchKeySet_arg_from_func<
31
+ std::decay_t<FuncType>>::func_type;
32
+
33
+ return CppSignature(std::type_index(typeid(decayed_function_type)));
34
+ }
35
+
36
+ std::string name() const {
37
+ return c10::demangle(signature_.name());
38
+ }
39
+
40
+ friend bool operator==(const CppSignature& lhs, const CppSignature& rhs) {
41
+ if (lhs.signature_ == rhs.signature_) {
42
+ return true;
43
+ }
44
+ // Without RTLD_GLOBAL, the type_index comparison could yield false because
45
+ // they point to different instances of the RTTI data, but the types would
46
+ // still be the same. Let's check for that case too.
47
+ // Note that there still is a case where this might not work, i.e. when
48
+ // linking libraries of different compilers together, they might have
49
+ // different ways to serialize a type name. That, together with a missing
50
+ // RTLD_GLOBAL, would still fail this.
51
+ if (0 == strcmp(lhs.signature_.name(), rhs.signature_.name())) {
52
+ return true;
53
+ }
54
+
55
+ return false;
56
+ }
57
+
58
+ private:
59
+ explicit CppSignature(std::type_index signature)
60
+ : signature_(std::move(signature)) {}
61
+ std::type_index signature_;
62
+ };
63
+
64
+ inline bool operator!=(const CppSignature& lhs, const CppSignature& rhs) {
65
+ return !(lhs == rhs);
66
+ }
67
+
68
+ } // namespace c10::impl
69
+
70
+ #else
71
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
72
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/DispatchKeyExtractor.h ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/core/Variadic.h>
5
+ #include <ATen/core/function_schema.h>
6
+ #include <ATen/core/jit_type.h>
7
+ #include <ATen/core/stack.h>
8
+ #include <c10/core/DispatchKeySet.h>
9
+ #include <c10/util/Bitset.h>
10
+ #include <c10/util/irange.h>
11
+ #include <cstdint>
12
+
13
+ namespace c10 {
14
+
15
+ namespace impl {
16
+
17
+ // Take a DispatchKeySet for a Tensor and determine what the actual dispatch
18
+ // DispatchKey should be, taking into account TLS, and skipping backends which
19
+ // fall through.
20
+ //
21
+ // Unlike Tensor::key_set(), the value of this on a tensor can change depending
22
+ // on TLS.
23
+ //
24
+ // NB: If there is no valid dispatch key, this will return Undefined
25
+ inline DispatchKeySet computeDispatchKeySet(
26
+ DispatchKeySet ks,
27
+ // The key mask lets us eliminate (by zero entries) keys which should not
28
+ // be considered for dispatch. There are two cases when we use this:
29
+ //
30
+ // - If an operator's dispatch table contains a fallthrough entry, we
31
+ // should bypass it entirely when finding the key
32
+ // - If a user invokes with redispatch, the mask lets us
33
+ // zero out the key the user asked us to stop.
34
+ //
35
+ // These excluded backends are NOT tracked in the TLS, but must be applied
36
+ // AFTER TLS (since the backend may have been introduced for consideration
37
+ // by the included TLS), which is why you have to pass them in to this
38
+ // function (as opposed to just applying it to the input 'ks').
39
+ DispatchKeySet key_mask) {
40
+ c10::impl::LocalDispatchKeySet local =
41
+ c10::impl::tls_local_dispatch_key_set();
42
+ // TODO: It's a bit irritating that we have to do logical ORs here, it would
43
+ // be nice to only do one. Can always_included be folded into the TLS? Well,
44
+ // it's a bit troublesome, because fastpath TLS access requires the type of
45
+ // the TLS in question to be zero-initialized, so you don't actually win
46
+ // anything in that case.
47
+ return (((ks | local.included_) - local.excluded_) & key_mask);
48
+ }
49
+
50
+ } // namespace impl
51
+
52
+ namespace detail {
53
+ // A small gadget to extract the DispatchKeySet from types which are known
54
+ // to have it. Used to extract dispatch keys from unboxed calls.
55
+ struct MultiDispatchKeySet : at::IterArgs<MultiDispatchKeySet> {
56
+ DispatchKeySet ts;
57
+ void operator()(const at::Tensor& x) {
58
+ ts = ts | x.key_set();
59
+ }
60
+ void operator()(const std::optional<at::Tensor>& x) {
61
+ if (x.has_value()) {
62
+ ts = ts | x->key_set();
63
+ }
64
+ }
65
+ void operator()(at::ArrayRef<at::Tensor> xs) {
66
+ for (const auto& x : xs) {
67
+ ts = ts | x.key_set();
68
+ }
69
+ }
70
+ // Tensor?[] translates to this case.
71
+ void operator()(const c10::List<std::optional<at::Tensor>>& xs) {
72
+ for (std::optional<at::Tensor> x : xs) {
73
+ if (x.has_value()) {
74
+ ts = ts | x.value().key_set();
75
+ }
76
+ }
77
+ }
78
+ // Structured Tensor[] translates to this case
79
+ void operator()(const at::ITensorListRef& xs) {
80
+ for (const auto& x : xs) {
81
+ ts = ts | x.key_set();
82
+ }
83
+ }
84
+ [[noreturn]] void operator()(
85
+ at::ArrayRef<std::optional<at::Tensor>> /*unused*/) {
86
+ // Just checking that the handling of Tensor?[] didn't change.
87
+ TORCH_INTERNAL_ASSERT(false);
88
+ }
89
+ void operator()(const at::Generator& gen) {
90
+ if (gen.defined()) {
91
+ ts = ts | gen.key_set();
92
+ }
93
+ }
94
+ void operator()(const std::optional<at::Generator>& gen) {
95
+ if (gen.has_value() && gen->defined()) {
96
+ ts = ts | gen->key_set();
97
+ }
98
+ }
99
+ template <typename T>
100
+ void operator()(const T& /*unused*/) {
101
+ // do nothing
102
+ }
103
+ };
104
+
105
+ // NB: take by const reference (Don't do universal forwarding here! You
106
+ // don't want to move into this function!)
107
+ template <typename... Args>
108
+ DispatchKeySet multi_dispatch_key_set(const Args&... args) {
109
+ return MultiDispatchKeySet().apply(args...).ts;
110
+ }
111
+ } // namespace detail
112
+
113
+ /**
114
+ * An instance of DispatchKeyExtractor knows how to get a dispatch key given
115
+ * a list of arguments for an operator call.
116
+ *
117
+ * The instance is specific for a certain operator as:
118
+ * - In boxed dispatch, different operators have different ways to extract
119
+ * the dispatch key (e.g. different numbers of arguments), and we precompute
120
+ * the stack locations we should look at; and
121
+ * - In all dispatch, some backends should be excluded from dispatch because
122
+ * they have been registered as fallthrough. The set of excluded backends
123
+ * varies from operator, as some operators may have overridden the
124
+ * fallthrough with custom behavior.
125
+ *
126
+ * Note - this should maintain identical impl to the py dispatcher key
127
+ * extraction logic at pytorch/torch/dispatcher.py
128
+ */
129
+ struct TORCH_API DispatchKeyExtractor final {
130
+ public:
131
+ static DispatchKeyExtractor make(const FunctionSchema& schema) {
132
+ return DispatchKeyExtractor(makeBitsetForDispatchArgs(schema));
133
+ }
134
+
135
+ static DispatchKeyExtractor makeUninitialized() {
136
+ return DispatchKeyExtractor(c10::utils::bitset());
137
+ }
138
+
139
+ void registerSchema(const FunctionSchema& schema) {
140
+ TORCH_INTERNAL_ASSERT(dispatch_arg_indices_reverse_.is_entirely_unset());
141
+ dispatch_arg_indices_reverse_ = makeBitsetForDispatchArgs(schema);
142
+ }
143
+ void deregisterSchema() {
144
+ dispatch_arg_indices_reverse_ = c10::utils::bitset();
145
+ }
146
+
147
+ DispatchKeySet getDispatchKeySetBoxed(const torch::jit::Stack* stack) const {
148
+ DispatchKeySet ks;
149
+ dispatch_arg_indices_reverse_.for_each_set_bit([&](size_t
150
+ reverse_arg_index) {
151
+ const auto& ivalue = torch::jit::peek(*stack, 0, reverse_arg_index + 1);
152
+ if (C10_LIKELY(ivalue.isTensor())) {
153
+ // NB: Take care not to introduce a refcount bump (there's
154
+ // no safe toTensorRef method, alas)
155
+ ks = ks | ivalue.unsafeToTensorImpl()->key_set();
156
+ } else if (C10_UNLIKELY(ivalue.isTensorList())) {
157
+ // NB: use toListRef as it doesn't induce refcount bumps
158
+ // (toTensorListRef is not a thing)
159
+ for (const auto& nv : ivalue.toListRef()) {
160
+ auto* tensor = nv.unsafeToTensorImpl();
161
+ ks = ks | tensor->key_set();
162
+ }
163
+ }
164
+ // Tensor?[] translates to a c10::List<IValue> so we need to peek inside
165
+ else if (C10_UNLIKELY(ivalue.isList())) {
166
+ for (const auto& elt : ivalue.toListRef()) {
167
+ if (elt.isTensor()) {
168
+ ks = ks | elt.toTensor().key_set();
169
+ }
170
+ }
171
+ }
172
+ });
173
+ // Keys that are fallthrough should be skipped
174
+ if (requiresBitsetPerBackend_) {
175
+ c10::impl::LocalDispatchKeySet tls =
176
+ c10::impl::tls_local_dispatch_key_set();
177
+ auto backend_idx =
178
+ ((ks | tls.included_) - tls.excluded_).getBackendIndex();
179
+ return impl::computeDispatchKeySet(
180
+ ks, nonFallthroughKeysPerBackend_[backend_idx]);
181
+ } else {
182
+ return impl::computeDispatchKeySet(ks, nonFallthroughKeys_);
183
+ }
184
+ }
185
+
186
+ template <class... Args>
187
+ DispatchKeySet getDispatchKeySetUnboxed(const Args&... args) const {
188
+ auto ks = detail::multi_dispatch_key_set(args...);
189
+ // Keys that are fallthrough should be skipped
190
+ if (requiresBitsetPerBackend_) {
191
+ c10::impl::LocalDispatchKeySet tls =
192
+ c10::impl::tls_local_dispatch_key_set();
193
+ auto backend_idx =
194
+ ((ks | tls.included_) - tls.excluded_).getBackendIndex();
195
+ return impl::computeDispatchKeySet(
196
+ ks, nonFallthroughKeysPerBackend_[backend_idx]);
197
+ } else {
198
+ return impl::computeDispatchKeySet(ks, nonFallthroughKeys_);
199
+ }
200
+ }
201
+
202
+ void setOperatorHasFallthroughForKey(DispatchKey k, bool has_fallthrough);
203
+
204
+ std::string dumpState() const;
205
+ void checkInvariants(const FunctionSchema& schema) const;
206
+
207
+ private:
208
+ static bool isDispatchType(const Type& type) {
209
+ // Checking isSubtypeOf on a DynamicType heap-allocates a
210
+ // DynamicType version of the argument if it's not a DynamicType
211
+ // already, and this has measurable overhead during startup.
212
+ #ifdef C10_MOBILE
213
+ struct CachedTypes {
214
+ DynamicTypePtr listOfTensors;
215
+ DynamicTypePtr listOfOptionalTensors;
216
+ DynamicTypePtr optionalOfTensor;
217
+ };
218
+ static const CachedTypes ct = {
219
+ DynamicType::create(*ListType::ofTensors()),
220
+ DynamicType::create(*ListType::ofOptionalTensors()),
221
+ DynamicType::create(*OptionalType::ofTensor())};
222
+ return type.isSubtypeOf(c10::TypeFactory::get<TensorType>()) ||
223
+ type.isSubtypeOf(ct.listOfTensors) ||
224
+ type.isSubtypeOf(ct.listOfOptionalTensors) ||
225
+ type.isSubtypeOf(ct.optionalOfTensor);
226
+ #else // C10_MOBILE
227
+ return type.isSubtypeOf(*TensorType::get()) ||
228
+ type.isSubtypeOf(*ListType::ofTensors()) ||
229
+ type.isSubtypeOf(*ListType::ofOptionalTensors()) ||
230
+ type.isSubtypeOf(*OptionalType::ofTensor());
231
+ #endif // C10_MOBILE
232
+ }
233
+ static c10::utils::bitset makeBitsetForDispatchArgs(
234
+ const FunctionSchema& schema) {
235
+ TORCH_CHECK(
236
+ schema.arguments().size() <= c10::utils::bitset::NUM_BITS(),
237
+ "The function schema has ",
238
+ schema.arguments().size(),
239
+ " arguments but this PyTorch build only supports ",
240
+ c10::utils::bitset::NUM_BITS());
241
+ c10::utils::bitset dispatch_arg_indices_reverse;
242
+ for (const auto index : c10::irange(schema.arguments().size())) {
243
+ if (isDispatchType(*schema.arguments()[index].type())) {
244
+ dispatch_arg_indices_reverse.set(schema.arguments().size() - 1 - index);
245
+ }
246
+ }
247
+ return dispatch_arg_indices_reverse;
248
+ }
249
+
250
+ explicit DispatchKeyExtractor(c10::utils::bitset dispatch_arg_indices_reverse)
251
+ : dispatch_arg_indices_reverse_(dispatch_arg_indices_reverse),
252
+ nonFallthroughKeys_(DispatchKeySet::FULL) {
253
+ for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size())) {
254
+ nonFallthroughKeysPerBackend_[i] = DispatchKeySet::FULL;
255
+ }
256
+ }
257
+
258
+ // this is a bitset that has ones for each argument index which has to be
259
+ // considered for dispatch. This avoids having to iterate over the stack
260
+ // to find all the tensors. The bits are stored in reverse order, i.e.
261
+ // dispatch_arg_indices_reverse_[i] == true, then the i-th argument from
262
+ // the top of the stack (i.e. the i-th last argument of the function)
263
+ // is relevant for dispatch.
264
+ // dispatch_arg_indices_reverse_ is allowed to have zero bits set; that just
265
+ // means you must do the fallthrough
266
+ c10::utils::bitset dispatch_arg_indices_reverse_;
267
+
268
+ // Set of functionality keys for which the operator does NOT have fallthrough
269
+ // kernel.
270
+ DispatchKeySet nonFallthroughKeys_;
271
+ // Set of functionality keys for which the operator does NOT have fallthrough
272
+ // kernel, defined PER BACKEND. This is only needed if we know that the
273
+ // operator has a different set of fallthroughs defined for some backends.
274
+ std::array<DispatchKeySet, num_backends> nonFallthroughKeysPerBackend_;
275
+ // Flag to tell us if we can use the single set of nonFallthroughKeys_ (fast
276
+ // path), or if we need to fall back to the slower path and check
277
+ // nonFallthroughKeysPerBackend_
278
+ bool requiresBitsetPerBackend_{false};
279
+ };
280
+
281
+ } // namespace c10
282
+
283
+ #else
284
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
285
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/Dispatcher.h ADDED
@@ -0,0 +1,955 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/SequenceNumber.h>
5
+ #include <ATen/core/boxing/KernelFunction.h>
6
+ #include <ATen/core/boxing/impl/boxing.h>
7
+ #include <ATen/core/dispatch/CppSignature.h>
8
+ #include <ATen/core/dispatch/OperatorEntry.h>
9
+ #include <ATen/core/dispatch/RegistrationHandleRAII.h>
10
+ #include <ATen/record_function.h>
11
+ #include <c10/core/SafePyObject.h>
12
+ #include <c10/util/Exception.h>
13
+ #include <c10/util/LeftRight.h>
14
+ #include <condition_variable>
15
+ #include <list>
16
+ #include <mutex>
17
+ #include <type_traits>
18
+
19
+ #include <ATen/core/enum_tag.h>
20
+ #include <ATen/core/grad_mode.h>
21
+
22
+ #ifndef NDEBUG
23
+ #include <iostream>
24
+ #endif
25
+
26
+ namespace c10 {
27
+
28
+ TORCH_API bool show_dispatch_trace();
29
+ TORCH_API void dispatch_trace_nesting_incr();
30
+ TORCH_API void dispatch_trace_nesting_decr();
31
+ TORCH_API int64_t dispatch_trace_nesting_value();
32
+
33
+ struct DispatchTraceNestingGuard {
34
+ DispatchTraceNestingGuard() {
35
+ dispatch_trace_nesting_incr();
36
+ }
37
+ ~DispatchTraceNestingGuard() {
38
+ dispatch_trace_nesting_decr();
39
+ }
40
+ };
41
+
42
+ class TORCH_API OperatorHandle;
43
+ template <class FuncType>
44
+ class TypedOperatorHandle;
45
+
46
+ /**
47
+ * Implement this interface and register your instance with the dispatcher
48
+ * to get notified when operators are registered or deregistered with
49
+ * the dispatcher.
50
+ *
51
+ * NB: registration events only occur when a 'def' occurs; we don't trigger
52
+ * on 'impl' or 'fallback' calls.
53
+ */
54
+ class TORCH_API OpRegistrationListener {
55
+ public:
56
+ virtual ~OpRegistrationListener();
57
+
58
+ virtual void onOperatorRegistered(const OperatorHandle& op) = 0;
59
+ virtual void onOperatorDeregistered(const OperatorHandle& op) = 0;
60
+ };
61
+
62
+ namespace detail {
63
+ class RegistrationListenerList;
64
+ }
65
+ class SchemaRegistrationHandleRAII;
66
+
67
+ /**
68
+ * Top-level dispatch interface for dispatching via the dynamic dispatcher.
69
+ * Most end users shouldn't use this directly; if you're trying to register
70
+ * ops look in op_registration
71
+ */
72
+ class TORCH_API Dispatcher final {
73
+ private:
74
+ // For direct access to backend fallback information
75
+ friend class impl::OperatorEntry;
76
+
77
+ struct OperatorDef final {
78
+ explicit OperatorDef(OperatorName&& op_name) : op(std::move(op_name)) {}
79
+
80
+ impl::OperatorEntry op;
81
+
82
+ // These refer to the number of outstanding RegistrationHandleRAII
83
+ // for this operator. def_count reflects only def() registrations
84
+ // (in the new world, this should only ever be 1, but old style
85
+ // registrations may register the schema multiple times, which
86
+ // will increase this count). def_and_impl_count reflects the number
87
+ // of combined def() and impl() registrations. When the last def() gets
88
+ // unregistered, we must immediately call the Deregistered listeners, but we
89
+ // must not actually delete the handle as there are other outstanding RAII
90
+ // destructors which will try to destruct and they had better still have a
91
+ // working operator handle in this case
92
+ size_t def_count = 0;
93
+ size_t def_and_impl_count = 0;
94
+ };
95
+ friend class OperatorHandle;
96
+ template <class>
97
+ friend class TypedOperatorHandle;
98
+
99
+ struct Guard final {
100
+ Guard() : alive(true) {}
101
+ std::atomic<bool> alive;
102
+ std::mutex mutex;
103
+ };
104
+
105
+ public:
106
+ ~Dispatcher();
107
+
108
+ // Implementation note: this class abstracts over the fact that we have
109
+ // per-operator dispatch tables. This could be easily adjusted to have a
110
+ // single global hash table.
111
+ static Dispatcher& realSingleton();
112
+
113
+ C10_ALWAYS_INLINE static Dispatcher& singleton() {
114
+ #if !defined C10_MOBILE
115
+ // Implemented inline so that steady-state code needn't incur
116
+ // function-call overhead. We can't just inline `realSingleton`
117
+ // because the function-local static would get duplicated across
118
+ // all DSOs that include & use this header, leading to multiple
119
+ // singleton instances.
120
+ static Dispatcher& s = realSingleton();
121
+ return s;
122
+ #else
123
+ // For C10_MOBILE, we should never inline a static function that
124
+ // has a static member, since the generated code calls
125
+ // __cxa_guard_acquire and __cxa_guard_release which help
126
+ // implement exactly once semantics for the initialization of the
127
+ // static Dispatcher& s above (for the non-mobile case). That
128
+ // additional code when duplicated across all operator stubs
129
+ // for every backend results in a lot of additional code
130
+ // being generated by the compiler.
131
+ return realSingleton();
132
+ #endif
133
+ }
134
+
135
+ // ------------------------------------------------------------------------
136
+ //
137
+ // Accessing operators by schema
138
+ //
139
+ // ------------------------------------------------------------------------
140
+
141
+ /**
142
+ * Looks for an operator schema with the given name and overload name
143
+ * and returns it if it is registered WITH A SCHEMA.
144
+ * Returns nullopt otherwise.
145
+ */
146
+ std::optional<OperatorHandle> findSchema(const OperatorName& operator_name);
147
+
148
+ /**
149
+ * Variant of findSchema that results in less code generated at the call site.
150
+ * It (1) takes const char* pointer rather than OperatorName (so we skip
151
+ * generating std::string constructor calls at the call site), and (2)
152
+ * it raises an exception if the operator is not found (so we skip
153
+ * generating exception raising code at the call site)
154
+ *
155
+ * Irritatingly, we still have to generate the handful of instructions
156
+ * for dealing with an exception being thrown during static initialization
157
+ * (e.g. __cxa_guard_abort). If we could annotate this method noexcept we
158
+ * could avoid this code too, but as the name of the function suggests,
159
+ * it does throw exceptions.
160
+ */
161
+ OperatorHandle findSchemaOrThrow(const char* name, const char* overload_name);
162
+
163
+ // Like findSchema, but also returns OperatorHandle even if there is no schema
164
+ std::optional<OperatorHandle> findOp(const OperatorName& operator_name);
165
+
166
+ // Returns a list of all operator names present in the operatorLookupTable_
167
+ const std::vector<OperatorName> getAllOpNames();
168
+
169
+ // Returns a list of all operator names present in the operatorLookupTable_
170
+ // for a given dispatch key
171
+ const std::vector<OperatorName> getAllOpNamesForDispatchKey(DispatchKey k);
172
+
173
+ // ------------------------------------------------------------------------
174
+ //
175
+ // Invoking operators
176
+ //
177
+ // ------------------------------------------------------------------------
178
+
179
+ template <class Return, class... Args>
180
+ Return call(const TypedOperatorHandle<Return(Args...)>& op, Args... args)
181
+ const;
182
+
183
+ template <class Return, class... Args>
184
+ static Return callWithDispatchKeySlowPath(
185
+ const TypedOperatorHandle<Return(Args...)>& op,
186
+ at::StepCallbacks& stepCallbacks,
187
+ DispatchKeySet dispatchKeySet,
188
+ const KernelFunction& kernel,
189
+ Args... args);
190
+
191
+ // Like call, but intended for use in a redispatch in kernels that have
192
+ // explicitly performed the DispatchKey update calculatulation. This will take
193
+ // the DispatchKeySet completely as is and dispatch to the kernel of the
194
+ // corresponding highest priority key in the set. Note that this version of
195
+ // redispatch treats the inputted DispatchKeySet *as is*, and does NOT mask
196
+ // out the highest priority key. See Note [Plumbing Keys Through The
197
+ // Dispatcher]
198
+ template <class Return, class... Args>
199
+ Return redispatch(
200
+ const TypedOperatorHandle<Return(Args...)>& op,
201
+ DispatchKeySet currentDispatchKeySet,
202
+ Args... args) const;
203
+
204
+ // Invoke an operator via the boxed calling convention using an IValue stack
205
+ void callBoxed(const OperatorHandle& op, Stack* stack) const;
206
+ void callBoxedForDispatchKey(
207
+ const OperatorHandle& op,
208
+ DispatchKey dk,
209
+ Stack* stack) const;
210
+
211
+ // TODO: This will only be useful if we write a backend fallback that plumbs
212
+ // dispatch keys (currently there are none) See Note [Plumbing Keys Through
213
+ // The Dispatcher]
214
+ void redispatchBoxed(
215
+ const OperatorHandle& op,
216
+ DispatchKeySet dispatchKeySet,
217
+ Stack* stack) const;
218
+
219
+ bool hasBackendFallbackForDispatchKey(DispatchKey dk) {
220
+ auto dispatch_ix = getDispatchTableIndexForDispatchKey(dk);
221
+ if (dispatch_ix < 0)
222
+ return false;
223
+ return backendFallbackKernels_[dispatch_ix].kernel.isValid();
224
+ }
225
+
226
+ // Used by torchdeploy/multipy for multiple // codespell:ignore: multipy
227
+ // interpreters racing.
228
+ void waitForDef(const FunctionSchema& schema);
229
+ void waitForImpl(
230
+ const OperatorName& op_name,
231
+ std::optional<DispatchKey> dispatch_key);
232
+
233
+ // ------------------------------------------------------------------------
234
+ //
235
+ // Performing registrations (NON user public; use op_registration)
236
+ //
237
+ // ------------------------------------------------------------------------
238
+
239
+ /**
240
+ * Register a new operator schema.
241
+ *
242
+ * If a schema with the same operator name and overload name already exists,
243
+ * this function will check that both schemas are exactly identical.
244
+ */
245
+ RegistrationHandleRAII registerDef(
246
+ FunctionSchema schema,
247
+ std::string debug,
248
+ std::vector<at::Tag> tags = {});
249
+
250
+ /**
251
+ * Register a kernel to the dispatch table for an operator.
252
+ * If dispatch_key is nullopt, then this registers a fallback kernel.
253
+ *
254
+ * @return A RAII object that manages the lifetime of the registration.
255
+ * Once that object is destructed, the kernel will be deregistered.
256
+ */
257
+ // NB: steals the inferred function schema, as we may need to hold on to
258
+ // it for a bit until the real schema turns up
259
+ RegistrationHandleRAII registerImpl(
260
+ OperatorName op_name,
261
+ std::optional<DispatchKey> dispatch_key,
262
+ KernelFunction kernel,
263
+ std::optional<impl::CppSignature> cpp_signature,
264
+ std::unique_ptr<FunctionSchema> inferred_function_schema,
265
+ std::string debug);
266
+
267
+ /**
268
+ * Given an operator, tells the Dispatcher that we have implemented a fake
269
+ * impl for this op in the given Python module. Call this a "pystub".
270
+ */
271
+ RegistrationHandleRAII registerPythonModule(
272
+ const OperatorName& op_name,
273
+ const char* pymodule,
274
+ const char* context);
275
+
276
+ /**
277
+ * Given an operator, throws if we have a pystub.
278
+ */
279
+ void throwIfHasPythonModule(OperatorName op_name);
280
+
281
+ std::optional<std::pair<const char*, const char*>> getPyStub(
282
+ OperatorName op_name);
283
+
284
+ /**
285
+ * Register a new operator by name.
286
+ */
287
+ RegistrationHandleRAII registerName(OperatorName op_name);
288
+
289
+ /**
290
+ * Register a fallback kernel for a backend.
291
+ * If an operator is called but there is no concrete kernel for the dispatch
292
+ * key of the given operator arguments, it will check if there is such a
293
+ * fallback kernel for the given dispatch key and, if yes, call that one.
294
+ */
295
+ RegistrationHandleRAII registerFallback(
296
+ DispatchKey dispatch_key,
297
+ KernelFunction kernel,
298
+ std::string debug);
299
+
300
+ /**
301
+ * Use to register whenever we had a TORCH_LIBRARY declaration in the frontend
302
+ * API. These invocations are only permitted once per program, so we raise
303
+ * an error if this is called again for the same namespace.
304
+ */
305
+ RegistrationHandleRAII registerLibrary(std::string ns, std::string debug);
306
+
307
+ // ------------------------------------------------------------------------
308
+ //
309
+ // Listeners on registrations
310
+ //
311
+ // ------------------------------------------------------------------------
312
+
313
+ /**
314
+ * Add a listener that gets called whenever a new op is registered or an
315
+ * existing op is deregistered. Immediately after registering, this listener
316
+ * gets called for all previously registered ops, so it can be used to keep
317
+ * track of ops registered with this dispatcher.
318
+ */
319
+ RegistrationHandleRAII addRegistrationListener(
320
+ std::unique_ptr<OpRegistrationListener> listener);
321
+
322
+ void checkInvariants() const;
323
+
324
+ //
325
+ // ------------------------------------------------------------------------
326
+ //
327
+ // Assertions
328
+ //
329
+ // ------------------------------------------------------------------------
330
+
331
+ /**
332
+ * For testing purposes.
333
+ * Returns a list of all operators that were created through calls to
334
+ * registerImpl(), without any corresponding calls to registerDef(). After
335
+ * static initialization is done this is almost certainly a bug, as the
336
+ * created OperatorHandle won't have any schema associated with it and users
337
+ * calling the op through the dispatcher won't be able to access it
338
+ *
339
+ * Note that we cannot enforce this invariant "as we go" during static
340
+ * initialization, due to undefined static initialization order- we have no
341
+ * guarantees over the order in which .def() and .impl() calls are registered
342
+ * in the dispatcher at static initialization time. So this function should
343
+ * only be called after static initialization.
344
+ */
345
+ std::vector<OperatorHandle> findDanglingImpls() const;
346
+
347
+ /**
348
+ * Useful for inspecting global Dispatcher registration state.
349
+ * Returns the names of all operators with a kernel registered for the
350
+ * specified DispatchKey. If no DispatchKey is specified, it returns all
351
+ * registered operators.
352
+ */
353
+ std::vector<OperatorName> getRegistrationsForDispatchKey(
354
+ std::optional<DispatchKey> k) const;
355
+
356
+ private:
357
+ Dispatcher();
358
+
359
+ static int64_t sequenceNumberForRunningRecordFunction(
360
+ DispatchKey dispatchKey,
361
+ DispatchKeySet dispatchKeySet);
362
+ static void runRecordFunction(
363
+ at::RecordFunction& guard,
364
+ at::RecordFunction::schema_ref_t schema_ref,
365
+ DispatchKey dispatchKey,
366
+ DispatchKeySet dispatchKeySet);
367
+ static void runRecordFunction(
368
+ at::RecordFunction& guard,
369
+ at::RecordFunction::schema_ref_t schema_ref,
370
+ DispatchKey dispatchKey,
371
+ DispatchKeySet dispatchKeySet,
372
+ c10::ArrayRef<const c10::IValue> args);
373
+
374
+ #ifdef FBCODE_CAFFE2
375
+ static bool profilingOperatorEvents();
376
+ static void fireOpStartUSDT(
377
+ at::RecordFunction::schema_ref_t schema_ref,
378
+ std::vector<void*>& argsAddresses,
379
+ std::vector<const char*>& argsTypes);
380
+ static void fireOpEndUSDT(at::RecordFunction::schema_ref_t schema_ref);
381
+ #endif // FBCODE_CAFFE2
382
+
383
+ OperatorHandle findOrRegisterSchema_(FunctionSchema&& schema);
384
+ OperatorHandle findOrRegisterName_(const OperatorName& op_name);
385
+
386
+ void deregisterDef_(const OperatorHandle& op, const OperatorName& op_name);
387
+ void deregisterImpl_(
388
+ const OperatorHandle& op,
389
+ const OperatorName& op_name,
390
+ std::optional<DispatchKey> dispatch_key,
391
+ impl::OperatorEntry::AnnotatedKernelContainerIterator kernel_handle);
392
+ void deregisterName_(const OperatorHandle& op, const OperatorName& op_name);
393
+ void deregisterFallback_(DispatchKey dispatchKey);
394
+ void deregisterLibrary_(const std::string& ns);
395
+ void cleanup(const OperatorHandle& op, const OperatorName& op_name);
396
+ void checkSchemaCompatibility(
397
+ const OperatorHandle& op,
398
+ const FunctionSchema& schema,
399
+ const std::string& debug);
400
+
401
+ std::list<OperatorDef> operators_;
402
+ #if !defined(C10_MOBILE)
403
+ LeftRight<ska::flat_hash_map<OperatorName, OperatorHandle>>
404
+ operatorLookupTable_;
405
+ #else
406
+ RWSafeLeftRightWrapper<ska::flat_hash_map<OperatorName, OperatorHandle>>
407
+ operatorLookupTable_;
408
+ #endif
409
+ // Map from namespace to debug string (saying, e.g., where the library was
410
+ // defined)
411
+ ska::flat_hash_map<std::string, std::string> libraries_;
412
+
413
+ std::array<impl::AnnotatedKernel, num_runtime_entries>
414
+ backendFallbackKernels_;
415
+
416
+ std::unique_ptr<detail::RegistrationListenerList> listeners_;
417
+
418
+ // This condition variable gets notified whenever we add a new def/impl to the
419
+ // dispatch table. This is primarily used by multiply/torchdeploy, when
420
+ // we have multiple interpreters trying to register to the dispatch table.
421
+ // In this situation, whenever the non-primary interpreter would have tried
422
+ // to register to the dispatch table, instead it will check to see if the
423
+ // expected registration has already been made, and if it hasn't, wait on
424
+ // this condition variable to see if it was just racing with the primary
425
+ // interpreter.
426
+ //
427
+ // We expect it to be rare for there to be any waiters on this condition
428
+ // variable. This is mostly just to help give better diagnostics if
429
+ // something goes horribly wrong
430
+ std::condition_variable cond_var_;
431
+
432
+ // Protect concurrent access to the dispatcher. We store this in a
433
+ // `shared_ptr` as we return callbacks that call back into dispatcher methods,
434
+ // and we need to be able to handle and guard against the event when the
435
+ // `Dispatcher` has been destroyed before the callbacks fire.
436
+ std::shared_ptr<Guard> guard_;
437
+ };
438
+
439
+ /**
440
+ * This is a handle to an operator schema registered with the dispatcher.
441
+ * This handle can be used to register kernels with the dispatcher or
442
+ * to lookup a kernel for a certain set of arguments.
443
+ */
444
+ class TORCH_API OperatorHandle {
445
+ template <typename T>
446
+ friend struct std::hash;
447
+
448
+ public:
449
+ OperatorHandle(OperatorHandle&&) noexcept = default;
450
+ OperatorHandle& operator=(OperatorHandle&&) noexcept = default;
451
+ OperatorHandle(const OperatorHandle&) = default;
452
+ OperatorHandle& operator=(const OperatorHandle&) = default;
453
+ // NOLINTNEXTLINE(performance-trivially-destructible)
454
+ ~OperatorHandle();
455
+
456
+ const OperatorName& operator_name() const {
457
+ return operatorDef_->op.operator_name();
458
+ }
459
+
460
+ bool hasSchema() const {
461
+ return operatorDef_->op.hasSchema();
462
+ }
463
+
464
+ const FunctionSchema& schema() const {
465
+ return operatorDef_->op.schema();
466
+ }
467
+
468
+ const std::string& debug() const {
469
+ return operatorDef_->op.debug();
470
+ }
471
+
472
+ std::string dumpState() const {
473
+ return operatorDef_->op.dumpState();
474
+ }
475
+
476
+ bool hasKernelForDispatchKey(DispatchKey k) const {
477
+ return operatorDef_->op.hasKernelForDispatchKey(k);
478
+ }
479
+
480
+ bool isKernelFallthroughKernel(DispatchKey k) const {
481
+ return operatorDef_->op.kernelForDispatchKey(k).isFallthrough();
482
+ }
483
+
484
+ bool hasKernelForAnyDispatchKey(DispatchKeySet k) const {
485
+ return operatorDef_->op.hasKernelForAnyDispatchKey(k);
486
+ }
487
+
488
+ bool hasComputedKernelForDispatchKey(DispatchKey k) const {
489
+ return operatorDef_->op.hasComputedKernelForDispatchKey(k);
490
+ }
491
+
492
+ SafeKernelFunction getComputedKernelForDispatchKey(DispatchKey k) const {
493
+ return operatorDef_->op.getComputedKernelForDispatchKey(k);
494
+ }
495
+
496
+ std::string dumpComputedTable() const {
497
+ return operatorDef_->op.dumpComputedTable();
498
+ }
499
+
500
+ void checkInvariants() const {
501
+ operatorDef_->op.checkInvariants();
502
+ }
503
+
504
+ c10::ArrayRef<at::Tag> getTags() const {
505
+ return operatorDef_->op.getTags();
506
+ }
507
+
508
+ void setReportErrorCallback_(std::unique_ptr<c10::SafePyObject> callback) {
509
+ operatorDef_->op.setReportErrorCallback_(std::move(callback));
510
+ }
511
+
512
+ bool hasTag(const at::Tag& tag) const {
513
+ for (const auto& tag_ : getTags()) {
514
+ if (tag == tag_) {
515
+ return true;
516
+ }
517
+ }
518
+ return false;
519
+ }
520
+
521
+ template <class FuncType>
522
+ TypedOperatorHandle<FuncType> typed() const {
523
+ // NB: This assert is not 100% sound: you can retrieve a typed() operator
524
+ // handle prior to ANY C++ signature being registered on the operator
525
+ // and the check will say everything is OK (at which point you can then
526
+ // smuggle in a kernel that is typed incorrectly). For everything
527
+ // in core library this won't happen, because all the static registrations
528
+ // will be done by the time a typed() handle is acquired.
529
+ #if !defined C10_MOBILE
530
+ operatorDef_->op.assertSignatureIsCorrect<FuncType>();
531
+ if (fn_has_symint<FuncType>::value) {
532
+ operatorDef_->op.assertSignatureIsCorrect<
533
+ typename fn_remove_symint<FuncType>::type>();
534
+ }
535
+ #endif
536
+ return TypedOperatorHandle<FuncType>(operatorIterator_);
537
+ }
538
+
539
+ void callBoxed(Stack* stack) const {
540
+ c10::Dispatcher::singleton().callBoxed(*this, stack);
541
+ }
542
+
543
+ void callBoxed(Stack& stack) const {
544
+ callBoxed(&stack);
545
+ }
546
+
547
+ void callBoxedForDispatchKey(DispatchKey dk, Stack& stack) const {
548
+ c10::Dispatcher::singleton().callBoxedForDispatchKey(*this, dk, &stack);
549
+ }
550
+
551
+ void redispatchBoxed(DispatchKeySet ks, Stack* stack) const {
552
+ c10::Dispatcher::singleton().redispatchBoxed(*this, ks, stack);
553
+ }
554
+
555
+ template <typename F>
556
+ PyObject* getPythonOp(
557
+ c10::impl::PyInterpreter* self_interpreter,
558
+ F slow_accessor) const {
559
+ return operatorDef_->op.getPythonOp(self_interpreter, slow_accessor);
560
+ }
561
+
562
+ bool operator==(const OperatorHandle& other) const {
563
+ return operatorDef_ == other.operatorDef_;
564
+ }
565
+
566
+ bool operator!=(const OperatorHandle& other) const {
567
+ return operatorDef_ != other.operatorDef_;
568
+ }
569
+
570
+ private:
571
+ explicit OperatorHandle(
572
+ std::list<Dispatcher::OperatorDef>::iterator operatorIterator)
573
+ : operatorDef_(&*operatorIterator), operatorIterator_(operatorIterator) {}
574
+ friend class Dispatcher;
575
+ template <class>
576
+ friend class TypedOperatorHandle;
577
+
578
+ // Storing a direct pointer to the OperatorDef even though we
579
+ // already have the iterator saves an instruction in the critical
580
+ // dispatch path. The iterator is effectively a
581
+ // pointer-to-std::list-node, and (at least in libstdc++'s
582
+ // implementation) the element is at an offset 16 bytes from that,
583
+ // because the prev/next pointers come first in the list node
584
+ // struct. So, an add instruction would be necessary to convert from the
585
+ // iterator to an OperatorDef*.
586
+ Dispatcher::OperatorDef* operatorDef_;
587
+
588
+ // We need to store this iterator in order to make
589
+ // Dispatcher::cleanup() fast -- it runs a lot on program
590
+ // termination (and presumably library unloading).
591
+ std::list<Dispatcher::OperatorDef>::iterator operatorIterator_;
592
+ };
593
+
594
+ /**
595
+ * This is a handle to an operator schema registered with the dispatcher.
596
+ * It holds the same information as an OperatorHandle, but it is templated
597
+ * on the operator arguments and allows calling the operator in an
598
+ * unboxed way.
599
+ */
600
+ template <class FuncType>
601
+ class TypedOperatorHandle final {
602
+ static_assert(
603
+ guts::false_t<FuncType>(),
604
+ "FuncType in OperatorHandle::typed<FuncType> was not a valid function type");
605
+ };
606
+ template <class Return, class... Args>
607
+ class TypedOperatorHandle<Return(Args...)> final : public OperatorHandle {
608
+ public:
609
+ TypedOperatorHandle(TypedOperatorHandle&&) noexcept = default;
610
+ TypedOperatorHandle& operator=(TypedOperatorHandle&&) noexcept = default;
611
+ TypedOperatorHandle(const TypedOperatorHandle&) = default;
612
+ TypedOperatorHandle& operator=(const TypedOperatorHandle&) = default;
613
+
614
+ // See [Note: Argument forwarding in the dispatcher] for why Args doesn't use
615
+ // &&
616
+ C10_ALWAYS_INLINE Return call(Args... args) const {
617
+ return c10::Dispatcher::singleton().call<Return, Args...>(
618
+ *this, std::forward<Args>(args)...);
619
+ }
620
+
621
+ // See [Note: Argument forwarding in the dispatcher] for why Args doesn't use
622
+ // &&
623
+ C10_ALWAYS_INLINE Return
624
+ redispatch(DispatchKeySet currentDispatchKeySet, Args... args) const {
625
+ return c10::Dispatcher::singleton().redispatch<Return, Args...>(
626
+ *this, currentDispatchKeySet, std::forward<Args>(args)...);
627
+ }
628
+
629
+ private:
630
+ explicit TypedOperatorHandle(
631
+ std::list<Dispatcher::OperatorDef>::iterator operatorIterator)
632
+ : OperatorHandle(operatorIterator) {}
633
+ friend class OperatorHandle;
634
+ };
635
+
636
+ namespace detail {
637
+ template <class... Args>
638
+ inline void unused_arg_(const Args&... /*unused*/) {}
639
+
640
+ // CaptureKernelCall is intended to capture return values from Dispatcher
641
+ // unboxed kernel calls. A record function may request to get outputs from the
642
+ // kernel calls. For boxed kernels, it's straightforward, the returned values
643
+ // are in the stack object. The stack can be passed to record functions. For
644
+ // unboxed kernels, we need to handle different kinds of return values, cache
645
+ // them temporarily, then release the values for the actual function call
646
+ // return.
647
+ template <typename ReturnType>
648
+ struct CaptureKernelCall {
649
+ template <typename F, typename... Args>
650
+ CaptureKernelCall(
651
+ const F& kernel,
652
+ const TypedOperatorHandle<ReturnType(Args...)>& op,
653
+ const DispatchKeySet& dispatchKeySet,
654
+ Args&&... args)
655
+ // Calls the kernel and capture the result in output_.
656
+ : output_{kernel.template call<ReturnType, Args...>(
657
+ op,
658
+ dispatchKeySet,
659
+ std::forward<Args>(args)...)} {}
660
+ // Wraps the return values in a Stack.
661
+ Stack getOutputs() {
662
+ Stack stack;
663
+ impl::push_outputs<ReturnType, false>::copy(output_, &stack);
664
+ return stack;
665
+ }
666
+ // Since we are returning the output_, we don't expect the output_ to be used
667
+ // afterward. Copy elision and RVO do not apply to class data members. Using
668
+ // move semantic to avoid copies when possible.
669
+ ReturnType release() && {
670
+ return std::move(output_);
671
+ }
672
+
673
+ private:
674
+ ReturnType output_;
675
+ };
676
+
677
+ // Handle the lvalue reference differently since it should not be moved.
678
+ template <>
679
+ inline at::Tensor& CaptureKernelCall<at::Tensor&>::release() && {
680
+ return output_;
681
+ }
682
+
683
+ // Handle case where the kernel returns void.
684
+ template <>
685
+ struct CaptureKernelCall<void> {
686
+ template <typename F, typename... Args>
687
+ CaptureKernelCall(
688
+ const F& kernel,
689
+ const TypedOperatorHandle<void(Args...)>& op,
690
+ const DispatchKeySet& dispatchKeySet,
691
+ Args&&... args) {
692
+ // Calling the kernel and no need to capture void.
693
+ kernel.template call<void, Args...>(
694
+ op, dispatchKeySet, std::forward<Args>(args)...);
695
+ }
696
+ Stack getOutputs() {
697
+ return Stack();
698
+ }
699
+ void release() && {}
700
+ };
701
+
702
+ TORCH_API void _print_dispatch_trace(
703
+ const std::string& label,
704
+ const std::string& op_name,
705
+ const DispatchKeySet& dispatchKeySet);
706
+
707
+ } // namespace detail
708
+
709
+ // See [Note: Argument forwarding in the dispatcher] for why Args doesn't use &&
710
+ template <class Return, class... Args>
711
+ inline Return Dispatcher::callWithDispatchKeySlowPath(
712
+ const TypedOperatorHandle<Return(Args...)>& op,
713
+ at::StepCallbacks& stepCallbacks,
714
+ DispatchKeySet dispatchKeySet,
715
+ const KernelFunction& kernel,
716
+ Args... args) {
717
+ // If callbacks need inputs, we box the arguments and pass them to the guard.
718
+ // Note: For perf reasons we wouldn't want to prematurely box the arguments.
719
+ at::RecordFunction guard(std::move(stepCallbacks));
720
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(op.operatorDef_->op.isObserved());
721
+ auto dispatchKey = dispatchKeySet.highestPriorityTypeId();
722
+ auto& schema = op.schema();
723
+ auto schema_ref = std::reference_wrapper<const FunctionSchema>(schema);
724
+ constexpr auto num_boxed_args = impl::boxed_size<Args...>();
725
+ if constexpr (num_boxed_args != 0) {
726
+ if (guard.needsInputs()) {
727
+ // If we used std::array<IValue, num_boxed_args> here, we would
728
+ // have to spend time default constructing the IValues in
729
+ // boxedArgs. aligned_storage has no such requirement.
730
+ // NOLINTNEXTLINE(*array*)
731
+ alignas(IValue) std::byte boxedArgs[num_boxed_args * sizeof(IValue)];
732
+ // For debugging only; could be removed (but the compiler will do
733
+ // that for us and it's nice to have the extra assurance of
734
+ // correctness from our debug builds).
735
+ IValue* boxedArgsPtr = reinterpret_cast<IValue*>(boxedArgs);
736
+ impl::boxArgsToStack(boxedArgsPtr, args...);
737
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
738
+ reinterpret_cast<std::byte*>(boxedArgsPtr) ==
739
+ boxedArgs + num_boxed_args * sizeof(IValue));
740
+ // I don't *think* we need std::launder here, because IValue has
741
+ // no subclasses and no const or reference fields.
742
+ runRecordFunction(
743
+ guard,
744
+ schema_ref,
745
+ dispatchKey,
746
+ dispatchKeySet,
747
+ c10::ArrayRef<const c10::IValue>(
748
+ reinterpret_cast<IValue*>(boxedArgs), num_boxed_args));
749
+ boxedArgsPtr = reinterpret_cast<IValue*>(boxedArgs);
750
+ for (size_t ii = 0; ii < num_boxed_args; ++ii) {
751
+ (boxedArgsPtr + ii)->~IValue();
752
+ }
753
+ } else {
754
+ runRecordFunction(guard, schema_ref, dispatchKey, dispatchKeySet);
755
+ }
756
+ } else {
757
+ runRecordFunction(guard, schema_ref, dispatchKey, dispatchKeySet);
758
+ }
759
+
760
+ if (C10_UNLIKELY(guard.needsOutputs())) {
761
+ // Calls the kernel and capture the output temporarily to pass to
762
+ // RecordFunction.
763
+ detail::CaptureKernelCall<Return> captureKernelCall(
764
+ kernel, op, dispatchKeySet, std::forward<Args>(args)...);
765
+ guard.setOutputs(captureKernelCall.getOutputs());
766
+ // Releases the captured output to return to caller.
767
+ return std::move(captureKernelCall).release();
768
+ }
769
+
770
+ // keeping the guard alive while executing the kernel
771
+ return kernel.template call<Return, Args...>(
772
+ op, dispatchKeySet, std::forward<Args>(args)...);
773
+ }
774
+
775
+ // See [Note: Argument forwarding in the dispatcher] for why Args doesn't use &&
776
+ template <class Return, class... Args>
777
+ C10_ALWAYS_INLINE_UNLESS_MOBILE Return Dispatcher::call(
778
+ const TypedOperatorHandle<Return(Args...)>& op,
779
+ Args... args) const {
780
+ auto dispatchKeySet =
781
+ op.operatorDef_->op.dispatchKeyExtractor()
782
+ .template getDispatchKeySetUnboxed<Args...>(args...);
783
+ #if defined(HAS_TORCH_SHOW_DISPATCH_TRACE) || !defined(NDEBUG)
784
+ DispatchTraceNestingGuard debug_guard;
785
+ if (show_dispatch_trace()) {
786
+ detail::_print_dispatch_trace(
787
+ "[call]", toString(op.operator_name()), dispatchKeySet);
788
+ }
789
+ #endif
790
+ const KernelFunction& kernel = op.operatorDef_->op.lookup(dispatchKeySet);
791
+ #ifndef PYTORCH_DISABLE_PER_OP_PROFILING
792
+ auto step_callbacks =
793
+ at::getStepCallbacksUnlessEmpty(at::RecordScope::FUNCTION);
794
+ if (C10_UNLIKELY(
795
+ step_callbacks.has_value() && op.operatorDef_->op.isObserved())) {
796
+ return callWithDispatchKeySlowPath<Return, Args...>(
797
+ op,
798
+ *step_callbacks,
799
+ dispatchKeySet,
800
+ kernel,
801
+ std::forward<Args>(args)...);
802
+ }
803
+ #endif // PYTORCH_DISABLE_PER_OP_PROFILING
804
+
805
+ #ifdef FBCODE_CAFFE2
806
+ if (profilingOperatorEvents()) {
807
+ std::vector<void*> argsAddresses = {(void*)(&args)...};
808
+ std::vector<const char*> argsTypes = {(typeid(args).name())...};
809
+ struct FireOpRAII {
810
+ FireOpRAII(
811
+ at::RecordFunction::schema_ref_t schema_ref,
812
+ std::vector<void*>& argsAddresses,
813
+ std::vector<const char*>& argsTypes)
814
+ : schema_ref_(schema_ref) {
815
+ fireOpStartUSDT(schema_ref, argsAddresses, argsTypes);
816
+ }
817
+ ~FireOpRAII() {
818
+ fireOpEndUSDT(schema_ref_);
819
+ }
820
+ at::RecordFunction::schema_ref_t schema_ref_;
821
+ } event(op.schema(), argsAddresses, argsTypes);
822
+ return kernel.template call<Return, Args...>(
823
+ op, dispatchKeySet, std::forward<Args>(args)...);
824
+ } else {
825
+ return kernel.template call<Return, Args...>(
826
+ op, dispatchKeySet, std::forward<Args>(args)...);
827
+ }
828
+ #else
829
+ return kernel.template call<Return, Args...>(
830
+ op, dispatchKeySet, std::forward<Args>(args)...);
831
+ #endif // FBCODE_CAFFE2
832
+ }
833
+
834
+ // See [Note: Argument forwarding in the dispatcher] for why Args doesn't use &&
835
+ template <class Return, class... Args>
836
+ inline Return Dispatcher::redispatch(
837
+ const TypedOperatorHandle<Return(Args...)>& op,
838
+ DispatchKeySet currentDispatchKeySet,
839
+ Args... args) const {
840
+ // do not use RecordFunction on redispatch
841
+ #if defined(HAS_TORCH_SHOW_DISPATCH_TRACE) || !defined(NDEBUG)
842
+ DispatchTraceNestingGuard debug_guard;
843
+ if (show_dispatch_trace()) {
844
+ detail::_print_dispatch_trace(
845
+ "[redispatch]", toString(op.operator_name()), currentDispatchKeySet);
846
+ }
847
+ #endif
848
+ const KernelFunction& kernel =
849
+ op.operatorDef_->op.lookup(currentDispatchKeySet);
850
+ return kernel.template call<Return, Args...>(
851
+ op, currentDispatchKeySet, std::forward<Args>(args)...);
852
+ }
853
+
854
+ inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack)
855
+ const {
856
+ // note: this doesn't need the mutex because write operations on the list keep
857
+ // iterators intact.
858
+ const auto& entry = op.operatorDef_->op;
859
+ auto dispatchKeySet =
860
+ entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack);
861
+ #if defined(HAS_TORCH_SHOW_DISPATCH_TRACE) || !defined(NDEBUG)
862
+ DispatchTraceNestingGuard debug_guard;
863
+ if (show_dispatch_trace()) {
864
+ detail::_print_dispatch_trace(
865
+ "[callBoxed]", toString(op.operator_name()), dispatchKeySet);
866
+ }
867
+ #endif
868
+ const auto& kernel = entry.lookup(dispatchKeySet);
869
+ #ifndef PYTORCH_DISABLE_PER_OP_PROFILING
870
+ auto step_callbacks =
871
+ at::getStepCallbacksUnlessEmpty(at::RecordScope::FUNCTION);
872
+ if (C10_UNLIKELY(step_callbacks.has_value() && entry.isObserved())) {
873
+ at::RecordFunction guard(std::move(*step_callbacks));
874
+ auto dispatchKey = dispatchKeySet.highestPriorityTypeId();
875
+ auto& schema = op.schema();
876
+ auto schema_ref = std::reference_wrapper<const FunctionSchema>(schema);
877
+ guard.needsInputs()
878
+ ? runRecordFunction(
879
+ guard,
880
+ schema_ref,
881
+ dispatchKey,
882
+ dispatchKeySet,
883
+ c10::ArrayRef<const c10::IValue>(stack->data(), stack->size()))
884
+ : runRecordFunction(guard, schema_ref, dispatchKey, dispatchKeySet);
885
+
886
+ // keeping the guard alive while executing the kernel
887
+ kernel.callBoxed(op, dispatchKeySet, stack);
888
+
889
+ if (C10_UNLIKELY(guard.needsOutputs())) {
890
+ guard.setOutputs(*stack);
891
+ }
892
+ return;
893
+ }
894
+ #endif // PYTORCH_DISABLE_PER_OP_PROFILING
895
+ kernel.callBoxed(op, dispatchKeySet, stack);
896
+ }
897
+
898
+ // NB: this doesn't count as a "true" dispatcher jump, so no instrumentation
899
+ inline void Dispatcher::callBoxedForDispatchKey(
900
+ const OperatorHandle& op,
901
+ DispatchKey dk,
902
+ Stack* stack) const {
903
+ // note: this doesn't need the mutex because write operations on the list keep
904
+ // iterators intact.
905
+ const auto& entry = op.operatorDef_->op;
906
+ // We still compute this as we're obligated to pass it on to the internal
907
+ // kernel, if it is a boxed fallback
908
+ auto dispatchKeySet =
909
+ entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack);
910
+ const auto& kernel = ([&]() {
911
+ if (op.hasKernelForDispatchKey(dk)) {
912
+ return entry.kernelForDispatchKey(dk);
913
+ } else {
914
+ auto idx = getDispatchTableIndexForDispatchKey(dk);
915
+ TORCH_INTERNAL_ASSERT(idx >= 0);
916
+ return backendFallbackKernels_[idx].kernel;
917
+ }
918
+ })();
919
+ kernel.callBoxed(op, dispatchKeySet, stack);
920
+ }
921
+
922
+ inline void Dispatcher::redispatchBoxed(
923
+ const OperatorHandle& op,
924
+ DispatchKeySet dispatchKeySet,
925
+ Stack* stack) const {
926
+ // note: this doesn't need the mutex because write operations on the list keep
927
+ // iterators intact.
928
+ const auto& entry = op.operatorDef_->op;
929
+ #if defined(HAS_TORCH_SHOW_DISPATCH_TRACE) || !defined(NDEBUG)
930
+ DispatchTraceNestingGuard debug_guard;
931
+ if (show_dispatch_trace()) {
932
+ detail::_print_dispatch_trace(
933
+ "[redispatchBoxed]", toString(op.operator_name()), dispatchKeySet);
934
+ }
935
+ #endif
936
+ const auto& kernel = entry.lookup(dispatchKeySet);
937
+ kernel.callBoxed(op, dispatchKeySet, stack);
938
+ }
939
+
940
+ } // namespace c10
941
+
942
+ namespace std {
943
+
944
+ template <>
945
+ struct hash<c10::OperatorHandle> {
946
+ size_t operator()(const c10::OperatorHandle& op) const noexcept {
947
+ return std::hash<void*>{}(static_cast<void*>(op.operatorDef_));
948
+ }
949
+ };
950
+
951
+ } // namespace std
952
+
953
+ #else
954
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
955
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/ObservedOperators.h ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/core/operator_name.h>
5
+ #include <string>
6
+ #include <unordered_set>
7
+
8
+ namespace c10 {
9
+
10
+ struct TORCH_API ObservedOperators {
11
+ ObservedOperators() = delete;
12
+
13
+ static bool isObserved(const OperatorName& name);
14
+
15
+ static std::unordered_set<std::string>& getUnobservedOperatorList();
16
+ };
17
+
18
+ } // namespace c10
19
+
20
+ #else
21
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
22
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/OperatorEntry.h ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/core/boxing/KernelFunction.h>
5
+ #include <ATen/core/dispatch/DispatchKeyExtractor.h>
6
+ #include <ATen/core/function_schema.h>
7
+ #include <ATen/core/ivalue.h>
8
+ #include <c10/core/DispatchKey.h>
9
+ #include <c10/core/PyHandleCache.h>
10
+ #include <c10/core/SafePyObject.h>
11
+ #include <c10/util/Metaprogramming.h>
12
+ #include <c10/util/flat_hash_map.h>
13
+
14
+ #include <ATen/core/dispatch/CppSignature.h>
15
+ #include <ATen/core/dispatch/OperatorOptions.h>
16
+ #include <ATen/core/dispatch/RegistrationHandleRAII.h>
17
+ #include <ATen/core/enum_tag.h>
18
+
19
+ #include <array>
20
+ #include <list>
21
+ #include <optional>
22
+
23
+ #ifdef C10_MOBILE
24
+ #define C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
25
+ #endif
26
+
27
+ namespace c10 {
28
+
29
+ class Dispatcher;
30
+
31
+ namespace impl {
32
+
33
+ // This data structure represents a kernel that was registered to us from a
34
+ // user. Unlike KernelFunction, AnnotatedKernel contains some extra metadata
35
+ // about the kernel that isn't necessary for actual dispatching (this is why
36
+ // we don't put AnnotatedKernel in the actual DispatchTable), but is useful for
37
+ // giving good error messages.
38
+ struct AnnotatedKernel final {
39
+ AnnotatedKernel(
40
+ KernelFunction k,
41
+ std::unique_ptr<FunctionSchema> s,
42
+ std::string d)
43
+ : kernel(std::move(k)),
44
+ inferred_function_schema(std::move(s)),
45
+ debug(std::move(d)) {}
46
+ AnnotatedKernel() = default;
47
+ KernelFunction kernel;
48
+ std::unique_ptr<FunctionSchema> inferred_function_schema;
49
+ // A little debug string to help us identify the kernel in question.
50
+ // Most importantly it records the TORCH_LIBRARY block that did the
51
+ // registration.
52
+ std::string debug;
53
+ };
54
+
55
+ // This data structure represents operator schema, with metadata specifying
56
+ // where the registration of this schema occurred
57
+ struct AnnotatedSchema final {
58
+ AnnotatedSchema(FunctionSchema s, std::string d)
59
+ : schema(std::move(s)), debug(std::move(d)) {}
60
+ FunctionSchema schema;
61
+ std::string debug;
62
+ };
63
+
64
+ // Internal data structure that records information about a specific operator.
65
+ // It's not part of the public API; typically, users will interact with
66
+ // OperatorHandle instead.
67
+ //
68
+ // Concurrent writes to OperatorEntry are protected by the GLOBAL Dispatcher
69
+ // lock (this is important because some methods in OperatorEntry access
70
+ // dispatcher state)
71
+ class TORCH_API OperatorEntry final {
72
+ public:
73
+ explicit OperatorEntry(OperatorName&& operator_name);
74
+
75
+ OperatorEntry(const OperatorEntry&) = delete;
76
+ OperatorEntry(OperatorEntry&&) noexcept = delete;
77
+ OperatorEntry& operator=(const OperatorEntry&) = delete;
78
+ OperatorEntry& operator=(OperatorEntry&&) noexcept = delete;
79
+
80
+ const FunctionSchema& schema() const {
81
+ TORCH_INTERNAL_ASSERT(
82
+ schema_.has_value(),
83
+ "Tried to access the schema for ",
84
+ name_,
85
+ " which doesn't have a schema registered yet");
86
+ return schema_->schema;
87
+ }
88
+ const std::string& debug() const {
89
+ TORCH_INTERNAL_ASSERT(schema_.has_value());
90
+ return schema_->debug;
91
+ }
92
+ bool hasSchema() const {
93
+ return schema_.has_value();
94
+ }
95
+
96
+ bool isObserved() const {
97
+ return is_observed_;
98
+ }
99
+
100
+ // We may allocate an OperatorEntry for an operator even when we don't
101
+ // have a schema. When we receive the schema registration, we post
102
+ // facto register a schema.
103
+ //
104
+ // NB: registerSchema/deregisterSchema are not idempotent; if you
105
+ // attempt to register a schema when one is already present or vice
106
+ // versa that is an error. (Refcounting for the registrations is
107
+ // handled in the OperatorHandle in Dispatcher)
108
+ void registerSchema(
109
+ FunctionSchema&& /*schema*/,
110
+ std::string&& debug,
111
+ std::vector<at::Tag> tags = {});
112
+ void deregisterSchema();
113
+
114
+ const OperatorName& operator_name() const {
115
+ return name_;
116
+ }
117
+
118
+ #ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
119
+ using AnnotatedKernelContainer = std::array<AnnotatedKernel, 1>;
120
+ #else
121
+ using AnnotatedKernelContainer = std::list<AnnotatedKernel>;
122
+ #endif
123
+ using AnnotatedKernelContainerIterator = AnnotatedKernelContainer::iterator;
124
+
125
+ // Why are kernels and fallback asymmetric? It has to do with ownership.
126
+ // Kernels and the computed dispatch tables for them are canonically
127
+ // owned by OperatorEntry, but backend fallbacks are specified once
128
+ // and apply for all operators, so they should be owned by Dispatcher.
129
+ // However, the registration of a backend fallback affects the
130
+ // state of the computed dispatch table, so when a backend fallback
131
+ // is updated, we need to update the operator tables too. Thus,
132
+ // registerKernel is the mechanism by which we give kernels to
133
+ // operator entry to own (and update dispatch table), but we only
134
+ // need a non-owning mechanism to update fallback.
135
+
136
+ // Precondition: Dispatcher::mutex_ is held
137
+ // Postcondition: caller is responsible for disposing of the kernel
138
+ AnnotatedKernelContainerIterator registerKernel(
139
+ const Dispatcher& dispatcher,
140
+ std::optional<DispatchKey> dispatch_key,
141
+ KernelFunction kernel,
142
+ std::optional<CppSignature> cpp_signature,
143
+ std::unique_ptr<FunctionSchema> inferred_function_schema,
144
+ std::string debug);
145
+
146
+ // Precondition: Dispatcher::mutex_ is held
147
+ void deregisterKernel_(
148
+ const Dispatcher& dispatcher,
149
+ std::optional<DispatchKey> dispatch_key,
150
+ AnnotatedKernelContainerIterator kernel);
151
+
152
+ // Precondition: Dispatcher::mutex_ is held
153
+ void updateFallback(const Dispatcher& dispatcher, DispatchKey dispatch_key);
154
+
155
+ // Precondition: Dispatcher::mutex_ is held
156
+ void updateSchemaAliasAnalysis(AliasAnalysisKind a) {
157
+ TORCH_INTERNAL_ASSERT(schema_.has_value());
158
+ schema_->schema.setAliasAnalysis(a);
159
+ }
160
+
161
+ std::string dumpComputedTable() const;
162
+ std::string dumpState() const;
163
+ void checkInvariants() const;
164
+
165
+ const DispatchKeyExtractor& dispatchKeyExtractor() const {
166
+ return dispatchKeyExtractor_;
167
+ }
168
+
169
+ // Asserts that the given FuncType is correct for calling this operator in an
170
+ // unboxed way.
171
+ template <class FuncType>
172
+ inline void assertSignatureIsCorrect() {
173
+ assertSignatureIsCorrect(
174
+ CppSignature::make<FuncType>(), fn_has_symint<FuncType>::value);
175
+ }
176
+
177
+ void assertSignatureIsCorrect(
178
+ const CppSignature& call_signature,
179
+ bool has_symint) const;
180
+
181
+ [[noreturn]] void reportError(DispatchKey dispatchKey) const;
182
+
183
+ const KernelFunction& lookup(DispatchKeySet ks) const {
184
+ const auto idx = ks.getDispatchTableIndexForDispatchKeySet();
185
+ if (C10_UNLIKELY(idx == -1)) {
186
+ reportError(ks.highestPriorityTypeId());
187
+ }
188
+ const auto& kernel = dispatchTable_[idx];
189
+ // A valid kernel *always* has a boxed kernel and *may* have an
190
+ // unboxed kernel. However, we typically do unboxed calls in at::
191
+ // APIs, where the kernel 1) will very likely be valid and 2)
192
+ // should have an unboxed kernel. Checking the unboxed kernel
193
+ // first will allow us to avoid touching the boxed kernel at all
194
+ // in the common case.
195
+ if (C10_UNLIKELY(!kernel.isValidUnboxed())) {
196
+ if (!kernel.isValid()) {
197
+ reportError(ks.highestPriorityTypeId());
198
+ }
199
+ }
200
+ return kernel;
201
+ }
202
+
203
+ std::string listAllDispatchKeys() const;
204
+
205
+ // Returns true if kernel_ has entry for any key in ks.
206
+ //
207
+ // Invariant: There are no alias keys in the passed-in dispatch key set.
208
+ // Note [No Alias Keys in DispatchKeySet]
209
+ // Alias keys should be checked using `hasKernelForDispatchKey`
210
+ // Alias keys shouldn't go inside of a DispatchKeySet, since they can
211
+ // technically have a value > 63 (causing overflow).
212
+ bool hasKernelForAnyDispatchKey(DispatchKeySet ks) const;
213
+ // Returns true if kernel_ has entry for a particular key.
214
+ bool hasKernelForDispatchKey(DispatchKey k) const;
215
+ // Retrieves the kernel entry at a particular key. Symmetric with
216
+ // hasKernelForDispatchKey. To get the AnnotatedKernel, see
217
+ // getKernelForDispatchKey (private)
218
+ const KernelFunction& kernelForDispatchKey(DispatchKey k) const;
219
+ // Returns true if the "computed table" has an entry for a particular key.
220
+ bool hasComputedKernelForDispatchKey(DispatchKey k) const;
221
+ // Returns a KernelFunction corresponding to the kernel in dispatchTable
222
+ SafeKernelFunction getComputedKernelForDispatchKey(DispatchKey k) const;
223
+ // Returns all the operator tags added at the time of registration
224
+ const std::vector<at::Tag>& getTags() const;
225
+ void setReportErrorCallback_(std::unique_ptr<c10::SafePyObject> callback);
226
+
227
+ template <typename F>
228
+ PyObject* getPythonOp(PyInterpreter* self_interpreter, F slow_accessor)
229
+ const {
230
+ return py_cache_.ptr_or(self_interpreter, slow_accessor);
231
+ }
232
+
233
+ private:
234
+ OperatorName name_;
235
+ std::optional<AnnotatedSchema> schema_;
236
+ #ifndef C10_MOBILE
237
+ std::vector<at::Tag> tags_;
238
+ #endif
239
+ std::array<KernelFunction, c10::num_runtime_entries> dispatchTable_;
240
+ DispatchKeyExtractor dispatchKeyExtractor_;
241
+ // Pointer to the torch.ops.ns.op.overload object for speed
242
+ c10::PyHandleCache py_cache_;
243
+
244
+ // kernels_ stores all registered kernels for the corresponding dispatch key
245
+ // and catchAllKernels_ stores the catch-all kernels.
246
+ // If an operator library gets loaded that overwrites an already existing
247
+ // kernel, both kernels will be in that list but only the newer one will be in
248
+ // dispatchTable. If any of the kernels go away (say the library gets
249
+ // unloaded), we remove the kernel from this list and update the
250
+ // dispatchTable if necessary.
251
+ // Kernels in the list are ordered by registration time descendingly,
252
+ // newer registrations are before older registrations.
253
+ // We do not combine dispatchTable and kernels into one hash map because
254
+ // kernels is a larger data structure and accessed quite infrequently
255
+ // while dispatchTable is accessed often and should be kept small to fit
256
+ // into CPU caches.
257
+ // Invariants:
258
+ // - dispatchTable[dispatch_key] == kernels_[dispatch_key].front()
259
+ // - dispatchTable[dispatch_key] does not exist if and only if
260
+ // kernels_[dispatch_key] does not exist
261
+ // - If kernels_[dispatch_key] exists, then it has elements.
262
+ // It is never an empty list.
263
+ //
264
+ // Why do we do that?
265
+ // -----
266
+ // We mostly do this to enable Jupyter notebooks where a cell registering
267
+ // a kernel could be executed multiple times and the later execution
268
+ // should overwrite the earlier one. Note that this still fails when the
269
+ // function schema changed between the executions, but it works as long
270
+ // as the function schema didn't change. A better solution would be to
271
+ // unload the old extension library from the Jupyter cell when the cell is
272
+ // re-executed and then only allow one kernel here, i.e. error if a kernel
273
+ // is already registered, but that's a lot of effort to implement and
274
+ // currently not high-pri.
275
+ ska::flat_hash_map<
276
+ DispatchKey,
277
+ #ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
278
+ // On mobile, we needn't worry about Jupyter notebooks.
279
+ std::array<AnnotatedKernel, 1>
280
+ #else
281
+ std::list<AnnotatedKernel>
282
+ #endif
283
+ >
284
+ kernels_;
285
+
286
+ const AnnotatedKernel& missingKernel() const;
287
+ const AnnotatedKernel& ambiguousAutogradOtherKernel() const;
288
+
289
+ // cpp_signature_ stores function signature if any of
290
+ // the kernels was created in a way that allowed us to know the function
291
+ // signature (i.e. by supplying an unboxed C++ kernel function).
292
+ // If this is set, it will be used to check that future kernel
293
+ // registrations match and it will be used in unboxed function calls
294
+ // to verify their arguments against the known function signature.
295
+ struct CppSignatureWithDebug {
296
+ CppSignature signature;
297
+ std::string debug;
298
+ std::optional<DispatchKey> dispatch_key;
299
+ };
300
+ std::optional<CppSignatureWithDebug> cpp_signature_;
301
+ std::optional<CppSignatureWithDebug> sym_cpp_signature_;
302
+
303
+ // A Python custom error handler for OperatorEntry::reportError
304
+ std::unique_ptr<c10::SafePyObject> report_error_callback_;
305
+
306
+ // Whether this operator needs to be observed with RecordFunction
307
+ const bool is_observed_;
308
+
309
+ [[noreturn]] void reportSignatureError(
310
+ const CppSignature& call_signature,
311
+ const CppSignatureWithDebug& saved_signature) const;
312
+ const KernelFunction& computeDispatchTableEntry(
313
+ const c10::Dispatcher& dispatcher,
314
+ DispatchKey dispatch_key) const;
315
+ std::pair<const AnnotatedKernel&, const char*>
316
+ computeDispatchTableEntryWithDebug(
317
+ const c10::Dispatcher& dispatcher,
318
+ DispatchKey dispatch_key) const;
319
+ // This function re-establishes the invariant that dispatchTable
320
+ // contains the front element from the kernels list for a given runtime
321
+ // dispatch key.
322
+ void updateDispatchTableEntry_(
323
+ const c10::Dispatcher& dispatcher,
324
+ DispatchKey dispatch_key);
325
+ // Like above, but also handles alias dispatch keys.
326
+ void updateDispatchTable_(
327
+ const c10::Dispatcher& dispatcher,
328
+ DispatchKey dispatch_key);
329
+ // Like above, but for ALL entries in the dispatch table.
330
+ void updateDispatchTableFull_(const c10::Dispatcher& dispatcher);
331
+ // Retrieves a pointer to AnnotatedKernel at
332
+ // kernels_.at(dispatch_key).front().
333
+ const AnnotatedKernel* getKernelForDispatchKey(
334
+ DispatchKey dispatch_key) const;
335
+ };
336
+
337
+ } // namespace impl
338
+ } // namespace c10
339
+
340
+ #else
341
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
342
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/OperatorOptions.h ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <cstdint>
5
+
6
+ namespace c10 {
7
+
8
+ enum class AliasAnalysisKind : uint8_t {
9
+ INTERNAL_SPECIAL_CASE,
10
+ CONSERVATIVE, // The most conservative alias analysis type, assumes
11
+ // side-effects. This is the default analysis.
12
+ FROM_SCHEMA,
13
+ PURE_FUNCTION
14
+ };
15
+
16
+ #if !defined(_MSC_VER)
17
+ constexpr // Our current MSVC version has a bug that doesn't allow this to be
18
+ // constexpr.
19
+ #endif
20
+ inline const char*
21
+ toString(AliasAnalysisKind aliasAnalysisKind) {
22
+ return (aliasAnalysisKind == AliasAnalysisKind::CONSERVATIVE) ? "CONSERVATIVE"
23
+ : (aliasAnalysisKind == AliasAnalysisKind::FROM_SCHEMA) ? "FROM_SCHEMA"
24
+ : (aliasAnalysisKind == AliasAnalysisKind::PURE_FUNCTION)
25
+ ? "PURE_FUNCTION"
26
+ : (aliasAnalysisKind == AliasAnalysisKind::INTERNAL_SPECIAL_CASE)
27
+ ? "INTERNAL_SPECIAL_CASE"
28
+ : "UNKNOWN";
29
+ }
30
+
31
+ } // namespace c10
32
+
33
+ #else
34
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
35
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/RegistrationHandleRAII.h ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <functional>
5
+
6
+ namespace c10 {
7
+
8
+ class RegistrationHandleRAII final {
9
+ public:
10
+ explicit RegistrationHandleRAII(std::function<void()> onDestruction)
11
+ : onDestruction_(std::move(onDestruction)) {}
12
+
13
+ ~RegistrationHandleRAII() {
14
+ if (onDestruction_) {
15
+ onDestruction_();
16
+ }
17
+ }
18
+
19
+ RegistrationHandleRAII(const RegistrationHandleRAII&) = delete;
20
+ RegistrationHandleRAII& operator=(const RegistrationHandleRAII&) = delete;
21
+
22
+ RegistrationHandleRAII(RegistrationHandleRAII&& rhs) noexcept
23
+ : onDestruction_(std::move(rhs.onDestruction_)) {
24
+ rhs.onDestruction_ = nullptr;
25
+ }
26
+
27
+ RegistrationHandleRAII& operator=(RegistrationHandleRAII&& rhs) noexcept {
28
+ onDestruction_ = std::move(rhs.onDestruction_);
29
+ rhs.onDestruction_ = nullptr;
30
+ return *this;
31
+ }
32
+
33
+ private:
34
+ std::function<void()> onDestruction_;
35
+ };
36
+
37
+ } // namespace c10
38
+
39
+ #else
40
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
41
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/op_registration/adaption.h ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/Tensor.h>
5
+ #include <ATen/TensorUtils.h>
6
+ #include <ATen/core/List.h>
7
+ #include <c10/core/TensorOptions.h>
8
+
9
+ /*
10
+ * [Note: hacky wrapper removal for optional tensor]
11
+ *
12
+ * The kernel implementation takes an optional tensor marked in the schema as
13
+ * Tensor? but the C++ function takes Tensor instead of the std::optional<Tensor>
14
+ * expected by the dispatcher.
15
+ *
16
+ * To remove the hacky wrapper, the C++ function is changed to take
17
+ * std::optional<Tensor> and unwrap the Tensor value at the beginning of
18
+ * the function, e.g.:
19
+ * > c10::MaybeOwned<Tensor> weight_maybe_owned =
20
+ * > at::borrow_from_optional_tensor(weight_opt);
21
+ * > const Tensor& weight = *weight_maybe_owned;
22
+ *
23
+ * We may want to make the kernel handle optional directly without
24
+ * going through the creation of a default-constructed Tensor in
25
+ * at::borrow_from_optional_tensor.
26
+ */
27
+
28
+ /*
29
+ * [Note: hacky wrapper removal for TensorOptions]
30
+ *
31
+ * The kernel implementation takes a TensorOptions argument but the dispatcher
32
+ * expects separate arguments for dtype, layout, device, pin_memory.
33
+ *
34
+ * To remove the hacky wrapper, the kernel implementation is changed to take
35
+ * the 4 arguments (dtype, layout, device, pin_memory), and assemble the
36
+ * TensorOptions value at the beginning of the function, e.g.:
37
+ * > TensorOptions options = TensorOptions().dtype(dtype).layout(layout)
38
+ * > .device(device).pinned_memory(pin_memory);
39
+ *
40
+ * We may want make the kernel handle these parameters directly without going
41
+ * through the creation of a TensorOptions value.
42
+ */
43
+
44
+ namespace c10::impl {
45
+
46
+ TORCH_API void common_device_check_failure(Device common_device, const at::Tensor& tensor, at::CheckedFrom methodName, at::CheckedFrom argName);
47
+
48
+ inline void check_and_update_common_device(std::optional<Device>& common_device, const at::Tensor& tensor, at::CheckedFrom methodName, at::CheckedFrom argName) {
49
+ // TODO: Remove this once the following issue is addressed:
50
+ // https://github.com/pytorch/pytorch/issues/57380
51
+ if (!tensor.defined()) {
52
+ return;
53
+ }
54
+
55
+ if (!common_device.has_value()) {
56
+ common_device = tensor.device();
57
+ return;
58
+ }
59
+
60
+ if (C10_UNLIKELY(common_device != tensor.device())) {
61
+ common_device_check_failure(*common_device, tensor, methodName, argName);
62
+ }
63
+ }
64
+
65
+ inline void check_and_update_common_device(std::optional<Device>& common_device, const std::optional<at::Tensor>& tensor, at::CheckedFrom methodName, at::CheckedFrom argName) {
66
+ if (tensor.has_value()) {
67
+ check_and_update_common_device(common_device, tensor.value(), methodName, argName);
68
+ }
69
+ }
70
+
71
+ inline void check_and_update_common_device(std::optional<Device>& common_device, at::ITensorListRef tensors, at::CheckedFrom methodName, at::CheckedFrom argName) {
72
+ for (const auto& tensor : tensors) {
73
+ check_and_update_common_device(common_device, tensor, methodName, argName);
74
+ }
75
+ }
76
+
77
+ inline void check_and_update_common_device(std::optional<Device>& common_device, const List<std::optional<at::Tensor>>& tensors, at::CheckedFrom methodName, at::CheckedFrom argName) {
78
+ for (const auto& tensor : tensors) {
79
+ check_and_update_common_device(common_device, tensor, methodName, argName);
80
+ }
81
+ }
82
+ } // namespace c10::impl
83
+
84
+ #else
85
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
86
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/op_registration/infer_schema.h ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ /**
5
+ * This file contains functionality to take a C++ function and infer its
6
+ * c10::FunctionSchema.
7
+ */
8
+
9
+ #include <ATen/core/function_schema.h>
10
+ #include <c10/util/Metaprogramming.h>
11
+
12
+ namespace c10 {
13
+ namespace detail::infer_schema {
14
+
15
+ /// The templated inference code creates `ArgumentDef` instead of `Argument`,
16
+ /// because that can be constructed at compile time and has a much smaller
17
+ /// binary size than having calls to `Argument` constructors in the template.
18
+ /// Creating `Argument` objects from `ArgumentDef` can then be done at
19
+ /// runtime in a non-templated way.
20
+ struct ArgumentDef final {
21
+ using GetTypeFn = TypePtr();
22
+ GetTypeFn* getTypeFn;
23
+ GetTypeFn* getFakeTypeFn;
24
+ constexpr ArgumentDef(): getTypeFn(nullptr), getFakeTypeFn(nullptr) {}
25
+ explicit constexpr ArgumentDef(GetTypeFn *getTypeFn, GetTypeFn *getFakeTypeFn): getTypeFn(getTypeFn), getFakeTypeFn(getFakeTypeFn) {}
26
+ };
27
+
28
+ template<bool V>
29
+ struct bool_t {};
30
+ template<> struct bool_t<true> : std::true_type {};
31
+ template<> struct bool_t<false> : std::false_type {};
32
+
33
+ /// Checks the static C++ types `Types` for correctness to catch common error cases.
34
+ template <class... Types>
35
+ constexpr int checkStaticTypes() {
36
+ // Give nice error messages for some of the common error cases.
37
+ // Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT
38
+ static_assert(std::conjunction_v<
39
+ bool_t<!std::is_integral_v<Types> || std::is_same_v<Types, int8_t> || std::is_same_v<Types, int64_t> || std::is_same_v<Types, bool>>...
40
+ >, "INVALID TYPE: Only int8_t, int64_t and bool are supported as an integral argument type");
41
+ static_assert(std::conjunction_v<
42
+ bool_t<!std::is_same_v<Types, float>>...
43
+ >, "INVALID TYPE: float is not supported as an argument type, use double instead");
44
+ return 0;
45
+ }
46
+
47
+ template <typename... Ts, size_t... Is>
48
+ constexpr std::array<ArgumentDef, sizeof...(Ts)> createArgumentVectorFromTypes(std::index_sequence<Is...> /*unused*/) {
49
+ return (
50
+ // Check types for common errors
51
+ checkStaticTypes<Ts...>(),
52
+
53
+ // Create the return value
54
+ std::array<ArgumentDef, sizeof...(Ts)>{
55
+ ArgumentDef(&getTypePtrCopy<std::decay_t<Ts>>, &getFakeTypePtrCopy<std::decay_t<Ts>>)...}
56
+ );
57
+ }
58
+
59
+ /// Creates a vector of `ArgumentDef` from a list of C++ types that are specified
60
+ /// as template arguments.
61
+ template<class ParameterTypes> struct createArguments final {};
62
+ template<class... ParameterTypes>
63
+ struct createArguments<guts::typelist::typelist<ParameterTypes...>> final {
64
+ static constexpr std::array<ArgumentDef, sizeof...(ParameterTypes)> call() {
65
+ return createArgumentVectorFromTypes<ParameterTypes...>(
66
+ std::make_index_sequence<sizeof...(ParameterTypes)>()
67
+ );
68
+ }
69
+ };
70
+
71
+ /// Creates a vector of `ArgumentDef` from a list of C++ types that are specified
72
+ /// as a tuple (i.e. in the way c10 kernels return values).
73
+ /// It can be a tuple<A, B, C> if there's three output arguments with types A, B, C.
74
+ /// It can be an empty tuple<>, or void for kernels that don't return anything.
75
+ /// It can be a single type A (i.e. no tuple) for the case where a kernel just
76
+ /// returns one value.
77
+ template<class ReturnTypeTuple, class Enable = void> struct createReturns final {};
78
+
79
+ template<class... ReturnTypes>
80
+ struct createReturns<std::tuple<ReturnTypes...>, void> final {
81
+ static constexpr std::array<ArgumentDef, sizeof...(ReturnTypes)> call() {
82
+ return createArgumentVectorFromTypes<ReturnTypes...>(
83
+ std::make_index_sequence<sizeof...(ReturnTypes)>()
84
+ );
85
+ }
86
+ };
87
+
88
+ template<class ReturnType>
89
+ struct createReturns<ReturnType, std::enable_if_t<!std::is_same_v<void, ReturnType> && !guts::is_instantiation_of<std::tuple, ReturnType>::value>> final {
90
+ static constexpr std::array<ArgumentDef, 1> call() {
91
+ return createReturns<std::tuple<ReturnType>>::call();
92
+ }
93
+ };
94
+
95
+ template<>
96
+ struct createReturns<void, void> final {
97
+ static constexpr std::array<ArgumentDef, 0> call() {
98
+ return createReturns<std::tuple<>>::call();
99
+ }
100
+ };
101
+
102
+ template <typename ReturnType>
103
+ struct createSingleReturn {
104
+ static constexpr std::array<ArgumentDef, 1> call() {
105
+ return createArgumentVectorFromTypes<ReturnType>(std::make_index_sequence<1>());
106
+ }
107
+ };
108
+
109
+ TORCH_API FunctionSchema make_function_schema(std::string&& name, std::string&& overload_name, c10::ArrayRef<ArgumentDef> arguments, c10::ArrayRef<ArgumentDef> returns);
110
+ TORCH_API FunctionSchema make_function_schema(c10::ArrayRef<ArgumentDef> arguments, c10::ArrayRef<ArgumentDef> returns);
111
+
112
+ /// Creates a `FunctionSchema` object from a `FunctionTraits` type for a
113
+ /// function. Flattens std::tuple returns into multiple return types
114
+ template <typename FunctionTraits>
115
+ FunctionSchema createFunctionSchemaFromTraitsFlattenedReturns() {
116
+ using ReturnType = typename FunctionTraits::return_type;
117
+ using ParameterTypes = typename FunctionTraits::parameter_types;
118
+
119
+ // arguments and returns are computed into a std::array at compile time and embedded into the binary.
120
+ // The only code executed at runtime here is the one that creates a std::vector
121
+ // of the arguments/returns from the std::array.
122
+ constexpr auto arguments = createArguments<ParameterTypes>::call();
123
+ constexpr auto returns = createReturns<ReturnType>::call();
124
+
125
+ return make_function_schema(arguments, returns);
126
+ }
127
+
128
+ /// Creates a `FunctionSchema` object from a `FunctionTraits` type for a
129
+ /// function. Preserves std::tuple returns as a Tuple return type
130
+ template <typename FunctionTraits>
131
+ FunctionSchema createFunctionSchemaFromTraitsSingleReturn(std::string&& name, std::string&& overload_name) {
132
+ using ReturnType = typename FunctionTraits::return_type;
133
+ using ParameterTypes = typename FunctionTraits::parameter_types;
134
+
135
+ // arguments and returns are computed into a std::array at compile time and embedded into the binary.
136
+ // The only code executed at runtime here is the one that creates a std::vector
137
+ // of the arguments/returns from the std::array.
138
+ constexpr auto arguments = createArguments<ParameterTypes>::call();
139
+ constexpr auto returns = createSingleReturn<ReturnType>::call();
140
+
141
+ return make_function_schema(std::move(name), std::move(overload_name), arguments, returns);
142
+ }
143
+
144
+ }
145
+
146
+ template<class FuncType>
147
+ FunctionSchema inferFunctionSchemaFlattenedReturns() {
148
+ return detail::infer_schema::createFunctionSchemaFromTraitsFlattenedReturns<guts::infer_function_traits_t<FuncType>>();
149
+ }
150
+
151
+ template<class FuncType>
152
+ FunctionSchema inferFunctionSchemaSingleReturn(std::string&& name, std::string&& overload_name) {
153
+ return detail::infer_schema::createFunctionSchemaFromTraitsSingleReturn<guts::infer_function_traits_t<FuncType>>(std::move(name), std::move(overload_name));
154
+ }
155
+
156
+ TORCH_API std::optional<std::string> findSchemaDifferences(const FunctionSchema& inferred, const FunctionSchema& specified);
157
+
158
+ }
159
+
160
+ #else
161
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
162
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/op_registration/op_allowlist.h ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ // TODO: unify to C10_MOBILE. In theory this header could be used in OSS.
5
+ #ifdef TEMPLATE_SELECTIVE_BUILD
6
+ #include <ATen/selected_mobile_ops.h>
7
+ #endif
8
+
9
+ /**
10
+ * This header implements functionality to build PyTorch with only a certain
11
+ * set of operators (+ dependencies) included.
12
+ *
13
+ * - Build with -DTORCH_OPERATOR_WHITELIST="aten::add;aten::sub" and only these
14
+ * two ops will be included in your build. The allowlist records operators
15
+ * only, no overloads; if you include aten::add, all overloads of aten::add
16
+ * will be included.
17
+ *
18
+ * Internally, this is done by removing the operator registration calls
19
+ * using compile time programming, and the linker will then prune all
20
+ * operator functions that weren't registered.
21
+ * See Note [Selective build] for more details
22
+ *
23
+ * WARNING: The allowlist mechanism doesn't work for all ways you could go about
24
+ * registering an operator. If the dispatch key / operator name is not
25
+ * sufficiently obvious at compile time, then the allowlisting mechanism
26
+ * will fail (and the operator will be included in the binary anyway).
27
+ */
28
+
29
+ #include <string_view>
30
+ #include <c10/core/DispatchKey.h>
31
+ #include <c10/macros/Macros.h>
32
+
33
+
34
+ #if defined(ENABLE_RECORD_KERNEL_FUNCTION_DTYPE)
35
+ #include <ATen/record_function.h>
36
+ #endif
37
+
38
+ namespace c10::impl {
39
+
40
+ constexpr bool allowlist_contains(std::string_view allowlist, std::string_view item); // Forward Declare
41
+
42
+ /**
43
+ * In selective build mode returns true/false depending on whether a build
44
+ * feature is available or not.
45
+ *
46
+ * In instrumenting mode (tracing mode), always returns true, and doesn't
47
+ * trigger any side effects.
48
+ */
49
+ constexpr bool is_build_feature_available(const char* name) {
50
+ #if !defined(ENABLE_RECORD_KERNEL_FUNCTION_DTYPE)
51
+ // Selective Build mode.
52
+ #if !defined(TORCH_BUILD_FEATURE_ALLOWLIST)
53
+ (void)name;
54
+ return true;
55
+ #else
56
+ return allowlist_contains(
57
+ C10_STRINGIZE(TORCH_BUILD_FEATURE_ALLOWLIST),
58
+ name);
59
+ #endif
60
+
61
+ #else
62
+ // Instrumenting mode.
63
+ (void)name;
64
+ return true;
65
+ #endif
66
+ }
67
+
68
+ [[noreturn]] void build_feature_required_feature_not_available(const char* feature);
69
+
70
+ /**
71
+ * Use BUILD_FEATURE_REQUIRED macro in user-code.
72
+ *
73
+ * In selective build mode becomes a no-op if the build feature passed
74
+ * in is available. If not available, throws an exception (c10::Error).
75
+ * The compiler is able to perform dead code elimination for code
76
+ * following this method if the build feature is not available.
77
+ *
78
+ * In instrumenting mode (tracing mode), registers (as a side effect)
79
+ * the presence of this specific build feature being triggered.
80
+ */
81
+ #if !defined(ENABLE_RECORD_KERNEL_FUNCTION_DTYPE) // selective build mode
82
+
83
+ #if defined(TORCH_BUILD_FEATURE_ALLOWLIST)
84
+ #define BUILD_FEATURE_REQUIRED(NAME) \
85
+ if (!c10::impl::is_build_feature_available(NAME)) { \
86
+ ::c10::impl::build_feature_required_feature_not_available(NAME); \
87
+ }
88
+ #else // Everything trivially selected
89
+ #define BUILD_FEATURE_REQUIRED(NAME)
90
+
91
+ #endif
92
+
93
+ #else // trace mode
94
+ #define BUILD_FEATURE_REQUIRED(NAME) \
95
+ RECORD_FUNCTION_WITH_SCOPE( \
96
+ at::RecordScope::BUILD_FEATURE, \
97
+ std::string(NAME), \
98
+ {});
99
+ #endif
100
+
101
+ // Use this macro, and not is_build_feature_available
102
+ #define BUILD_FEATURE_AVAILABLE(NAME) ::c10::impl::is_build_feature_available(NAME)
103
+
104
+ // returns true iff allowlist contains item
105
+ // allowlist_contains("a;bc;d", "bc") == true
106
+ constexpr bool allowlist_contains(std::string_view allowlist, std::string_view item) {
107
+ //Choose a really big value for next so that if something goes wrong
108
+ //this code will blow up in a hopefully detectable way.
109
+ size_t next = std::numeric_limits<size_t>::max();
110
+ for (size_t cur = 0; cur <= allowlist.size(); cur = next) {
111
+ next = allowlist.find(';', cur);
112
+ if (next != std::string_view::npos) {
113
+ if (allowlist.substr(cur, next - cur) == item) {
114
+ return true;
115
+ }
116
+ next++;
117
+ } else {
118
+ if (allowlist.substr(cur) == item) {
119
+ return true;
120
+ }
121
+ break;
122
+ }
123
+ }
124
+ return false;
125
+ }
126
+
127
+ // Returns true iff the given op name is on the allowlist
128
+ // and should be registered
129
+ constexpr bool op_allowlist_check(std::string_view op_name [[maybe_unused]]) {
130
+ assert(op_name.find("::") != std::string_view::npos);
131
+ // Use assert() instead of throw() due to a gcc bug. See:
132
+ // https://stackoverflow.com/questions/34280729/throw-in-constexpr-function
133
+ // https://github.com/fmtlib/fmt/issues/682
134
+ assert(op_name.find('(') == std::string_view::npos);
135
+ #if !defined(TORCH_OPERATOR_WHITELIST)
136
+ // If the TORCH_OPERATOR_WHITELIST parameter is not defined,
137
+ // all ops are to be registered
138
+ return true;
139
+ #else
140
+ return allowlist_contains(
141
+ C10_STRINGIZE(TORCH_OPERATOR_WHITELIST),
142
+ // This function is majorly used for mobile selective build with
143
+ // root operators, where the overload is included in the allowlist.
144
+ op_name);
145
+ // // Strip overload name (as allowlist doesn't contain overloads)
146
+ // // Another function based on this may be added when there's usage
147
+ // // on op names without overload.
148
+ // OperatorNameView::parse(op_name).name);
149
+ #endif
150
+ }
151
+
152
+ // Returns true iff the given schema string is on the allowlist
153
+ // and should be registered
154
+ constexpr bool schema_allowlist_check(std::string_view schema) {
155
+ #if defined(TORCH_FORCE_SCHEMA_REGISTRATION)
156
+ return true;
157
+ #else
158
+ return op_allowlist_check(schema.substr(0, schema.find('(')));
159
+ #endif
160
+ }
161
+
162
+ // Returns true iff the given custom class name is on the allowlist
163
+ // and should be registered
164
+ constexpr bool custom_class_allowlist_check(std::string_view custom_class_name [[maybe_unused]]) {
165
+ #if !defined(TORCH_CUSTOM_CLASS_ALLOWLIST)
166
+ // If the TORCH_CUSTOM_CLASS_ALLOWLIST parameter is not defined,
167
+ // all custom classes are to be registered
168
+ return true;
169
+ #else
170
+ return allowlist_contains(
171
+ C10_STRINGIZE(TORCH_CUSTOM_CLASS_ALLOWLIST),
172
+ custom_class_name);
173
+ #endif
174
+ }
175
+
176
+ // schema_allowlist_check() implicitly depends on a macro, TORCH_OPERATOR_WHITELIST.
177
+ // Add this API to pass arbitrary allowlist.
178
+ constexpr bool op_allowlist_contains_name_in_schema(std::string_view allowlist, std::string_view schema) {
179
+ return allowlist_contains(allowlist, schema.substr(0, schema.find('(')));
180
+ }
181
+
182
+ } // namespace c10::impl
183
+
184
+ #else
185
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
186
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/core/op_registration/op_registration.h ADDED
@@ -0,0 +1,599 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ /**
5
+ * Include this file if you want to register operators. It includes all
6
+ * functionality needed to do so for you.
7
+ */
8
+
9
+ #include <c10/core/DispatchKey.h>
10
+ #include <c10/core/DispatchKeySet.h>
11
+ #include <c10/core/CompileTimeFunctionPointer.h>
12
+ #include <ATen/core/boxing/KernelFunction.h>
13
+ #include <ATen/core/dispatch/CppSignature.h>
14
+ #include <ATen/core/dispatch/RegistrationHandleRAII.h>
15
+ #include <ATen/core/op_registration/infer_schema.h>
16
+ #if defined(EXPOSE_C2_OPS) || !defined(CAFFE2_IS_XPLAT_BUILD)
17
+ #include <torch/csrc/jit/frontend/function_schema_parser.h>
18
+ #endif
19
+ #include <ATen/core/ATenOpList.h>
20
+
21
+ namespace c10 {
22
+
23
+ namespace detail {
24
+ // The first argument of the schema might be of type DispatchKeySet, in which case we remove it.
25
+ // We do this because every argument in a function schema is expected to be convertible
26
+ // to an ivalue, but DispatchKeySet is not a type we want the jit to be aware of.
27
+ // See Note [Plumbing Keys Through The Dispatcher]
28
+ template<class KernelFunctor>
29
+ std::unique_ptr<FunctionSchema> inferFunctionSchemaFromFunctor() {
30
+ using func_type = typename c10::remove_DispatchKeySet_arg_from_func<KernelFunctor>::func_type;
31
+ return std::make_unique<FunctionSchema>(inferFunctionSchemaFlattenedReturns<func_type>());
32
+ }
33
+ }
34
+
35
+ /**
36
+ * An instance of this class handles the registration for one or more operators.
37
+ * Make sure you keep the RegisterOperators instance around since it will
38
+ * deregister the operator it's responsible for in its destructor.
39
+ *
40
+ * Example:
41
+ *
42
+ * > namespace {
43
+ * > class my_kernel_cpu final : public c10::OperatorKernel {
44
+ * > public:
45
+ * > Tensor operator()(Tensor a, Tensor b) {...}
46
+ * > };
47
+ * > }
48
+ * >
49
+ * > static auto registry = c10::RegisterOperators()
50
+ * > .op(c10::RegisterOperators::options()
51
+ * > .schema("my_op")
52
+ * > .kernel<my_kernel_cpu>(DispatchKey::CPU));
53
+ */
54
+ class TORCH_API RegisterOperators final {
55
+ public:
56
+ RegisterOperators() = default;
57
+ ~RegisterOperators() = default;
58
+
59
+ RegisterOperators(const RegisterOperators&) = delete;
60
+ RegisterOperators& operator=(const RegisterOperators&) = delete;
61
+ RegisterOperators(RegisterOperators&&) noexcept = default;
62
+ RegisterOperators& operator=(RegisterOperators&&) noexcept = default;
63
+
64
+ class TORCH_API Options final {
65
+ public:
66
+ Options(const Options&) = delete;
67
+ Options(Options&&) noexcept = delete;
68
+ Options& operator=(const Options&) = delete;
69
+ Options& operator=(Options&&) noexcept = delete;
70
+
71
+ // internal-only for registering stack based kernels
72
+ template<KernelFunction::BoxedKernelFunction* kernel_func>
73
+ Options&& kernel(DispatchKey dispatch_key) && {
74
+ return std::move(*this).kernel(dispatch_key, KernelFunction::makeFromBoxedFunction<kernel_func>(), std::nullopt, nullptr);
75
+ }
76
+
77
+ // internal-only for registering stack based catch-all kernels
78
+ template<KernelFunction::BoxedKernelFunction* kernel_func>
79
+ Options&& catchAllKernel() && {
80
+ return std::move(*this).kernel(std::nullopt, KernelFunction::makeFromBoxedFunction<kernel_func>(), std::nullopt, nullptr);
81
+ }
82
+
83
+ // internal only for registering caffe2 ops
84
+ Options&& schema(FunctionSchema&& schema) {
85
+ TORCH_CHECK(!schemaOrName_.has_value(), "You can only specify the schema once per operator registration.");
86
+ schemaOrName_ = FunctionSchema(std::move(schema));
87
+ return std::move(*this);
88
+ }
89
+
90
+ /**
91
+ * Use this to specify the schema for an operator. You can also specify
92
+ * the operator name only to have the function signature part of the
93
+ * schema be inferred from the kernel function.
94
+ *
95
+ * Example:
96
+ *
97
+ * > // Infer function signature from my_kernel_cpu
98
+ * > static auto registry = c10::RegisterOperators()
99
+ * > .op(c10::RegisterOperators::options()
100
+ * > .schema("my_op")
101
+ * > .kernel<my_kernel_cpu>(DispatchKey::CPU));
102
+ * >
103
+ * >
104
+ * > // Explicitly specify full schema
105
+ * > static auto registry = c10::RegisterOperators()
106
+ * > .op(c10::RegisterOperators::options()
107
+ * > .schema("my_op(Tensor a) -> Tensor")
108
+ * > .kernel<my_kernel_cpu>(DispatchKey::CPU));
109
+ */
110
+ Options&& schema(const std::string& schemaOrName) {
111
+ TORCH_CHECK(!schemaOrName_.has_value(), "Tried to register operator ", schemaOrName," but specified schema multiple times. You can only specify the schema once per operator registration.");
112
+
113
+ #if !defined(EXPOSE_C2_OPS) && defined(CAFFE2_IS_XPLAT_BUILD)
114
+ throw std::logic_error("Tried to register operator " + schemaOrName + ". We don't support registering c10 ops on mobile yet because the function schema parser isn't present in the mobile build.");
115
+ #else
116
+ schemaOrName_ = torch::jit::parseSchemaOrName(schemaOrName);
117
+ #endif
118
+
119
+ return std::move(*this);
120
+ }
121
+
122
+ /**
123
+ * Use this to register an operator whose kernel is implemented as a functor.
124
+ * The kernel is only called for inputs matching the given dispatch key.
125
+ * You can register multiple kernels for different dispatch keys.
126
+ *
127
+ * Example:
128
+ *
129
+ * > namespace {
130
+ * > class my_kernel_cpu final : public c10::OperatorKernel {
131
+ * > public:
132
+ * > Tensor operator()(Tensor a, Tensor b) {...}
133
+ * > };
134
+ * > }
135
+ * >
136
+ * > static auto registry = c10::RegisterOperators()
137
+ * > .op(c10::RegisterOperators::options()
138
+ * > .schema("my_op")
139
+ * > .kernel<my_kernel_cpu>(DispatchKey::CPU));
140
+ *
141
+ * The functor constructor can take arguments to configure the kernel.
142
+ * The arguments are defined in the kernel registration.
143
+ * Example:
144
+ *
145
+ * > namespace {
146
+ * > class my_kernel_cpu final : public c10::OperatorKernel {
147
+ * > public:
148
+ * > explicit my_kernel_cpu(std::string some_configuration, int a, bool b)
149
+ * > : ... {...}
150
+ * >
151
+ * > Tensor operator()(Tensor a, Tensor b) {...}
152
+ * > };
153
+ * > }
154
+ * >
155
+ * > static auto registry = c10::RegisterOperators()
156
+ * > .op(c10::RegisterOperators::options()
157
+ * > .schema("my_op")
158
+ * > .kernel<my_kernel_cpu>(DispatchKey::CPU, "some_configuration", 3, true));
159
+ */
160
+ template<class KernelFunctor, class... ConstructorParameters>
161
+ // enable_if: only enable it if KernelFunctor is actually a functor
162
+ std::enable_if_t<guts::is_functor<KernelFunctor>::value, Options&&> kernel(DispatchKey dispatch_key, ConstructorParameters&&... constructorParameters) && {
163
+ static_assert(std::is_base_of_v<OperatorKernel, KernelFunctor>, "Tried to register a kernel functor using the kernel<Functor>() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
164
+ static_assert(std::is_constructible_v<KernelFunctor, ConstructorParameters...>, "Wrong argument list for constructor of kernel functor. The arguments to kernel<Functor>(arguments...) must match one of the constructors of Functor.");
165
+
166
+ return std::move(*this).kernel(
167
+ dispatch_key,
168
+ KernelFunction::makeFromUnboxedFunctor<false, KernelFunctor>(std::make_unique<KernelFunctor>(std::forward<ConstructorParameters>(constructorParameters)...)),
169
+ impl::CppSignature::make<KernelFunctor>(),
170
+ detail::inferFunctionSchemaFromFunctor<KernelFunctor>()
171
+ );
172
+ }
173
+
174
+ /**
175
+ * Use this to register an operator whose kernel is implemented as a functor.
176
+ * The kernel is a catch-all kernel, meaning it's called independent from
177
+ * the input. Dispatch is disabled for this operator.
178
+ *
179
+ * Example:
180
+ *
181
+ * > namespace {
182
+ * > class my_kernel_cpu final : public c10::OperatorKernel {
183
+ * > public:
184
+ * > Tensor operator()(Tensor a, Tensor b) {...}
185
+ * > };
186
+ * > }
187
+ * >
188
+ * > static auto registry = c10::RegisterOperators()
189
+ * > .op(c10::RegisterOperators::options()
190
+ * > .schema("my_op")
191
+ * > .catchAllKernel<my_kernel_cpu>());
192
+ *
193
+ * The functor constructor can take arguments to configure the kernel.
194
+ * The arguments are defined in the kernel registration.
195
+ * Example:
196
+ *
197
+ * > namespace {
198
+ * > class my_kernel_cpu final : public c10::OperatorKernel {
199
+ * > public:
200
+ * > explicit my_kernel_cpu(std::string some_configuration, int a, bool b)
201
+ * > : ... {...}
202
+ * >
203
+ * > Tensor operator()(Tensor a, Tensor b) {...}
204
+ * > };
205
+ * > }
206
+ * >
207
+ * > static auto registry = c10::RegisterOperators()
208
+ * > .op(c10::RegisterOperators::options()
209
+ * > .schema("my_op")
210
+ * > .catchAllKernel<my_kernel_cpu>("some_configuration", 3, true));
211
+ */
212
+ template<class KernelFunctor, class... ConstructorParameters>
213
+ // enable_if: only enable it if KernelFunctor is actually a functor
214
+ std::enable_if_t<guts::is_functor<KernelFunctor>::value, Options&&> catchAllKernel(ConstructorParameters&&... constructorParameters) && {
215
+ static_assert(std::is_base_of_v<OperatorKernel, KernelFunctor>, "Tried to register a kernel functor using the kernel<Functor>() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
216
+ static_assert(std::is_constructible_v<KernelFunctor, ConstructorParameters...>, "Wrong argument list for constructor of kernel functor. The arguments to kernel<Functor>(arguments...) must match one of the constructors of Functor.");
217
+
218
+ return std::move(*this).kernel(
219
+ std::nullopt,
220
+ KernelFunction::makeFromUnboxedFunctor<false, KernelFunctor>(std::make_unique<KernelFunctor>(std::forward<ConstructorParameters>(constructorParameters)...)),
221
+ impl::CppSignature::make<KernelFunctor>(),
222
+ detail::inferFunctionSchemaFromFunctor<KernelFunctor>()
223
+ );
224
+ }
225
+
226
+ /**
227
+ * Use this to register an operator whose kernel is implemented by a function.
228
+ * The kernel is only called for inputs matching the given dispatch key.
229
+ * You can register multiple kernels for different dispatch keys.
230
+ *
231
+ * Example:
232
+ *
233
+ * > namespace { Tensor my_kernel_cpu(Tensor a, Tensor b) {...} }
234
+ * >
235
+ * > static auto registry = c10::RegisterOperators()
236
+ * > .op(c10::RegisterOperators::options()
237
+ * > .schema("my_op")
238
+ * > .kernel<decltype(my_kernel_cpu), &my_kernel_cpu>(DispatchKey::CPU));
239
+ */
240
+ template<class FuncType, FuncType* kernel_func>
241
+ // enable_if: only enable it if FuncType is actually a function
242
+ std::enable_if_t<guts::is_function_type<FuncType>::value, Options&&> kernel(DispatchKey dispatch_key) && {
243
+ static_assert(!std::is_same_v<FuncType, KernelFunction::BoxedKernelFunction>, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API.");
244
+ static_assert(kernel_func != nullptr, "Kernel function cannot be nullptr");
245
+
246
+ return std::move(*this).kernel(
247
+ dispatch_key,
248
+ KernelFunction::makeFromUnboxedFunction(TORCH_FN(kernel_func)),
249
+ impl::CppSignature::make<FuncType>(),
250
+ // TODO Do schema inference without relying on WrapFunctionIntoFunctor
251
+ detail::inferFunctionSchemaFromFunctor<typename impl::WrapFunctionIntoFunctor<CompileTimeFunctionPointer<FuncType, kernel_func>>::type>()
252
+ );
253
+ }
254
+
255
+ /**
256
+ * Use this to register an operator whose kernel is implemented by a function.
257
+ * The kernel is a catch-all kernel, meaning it's called independent from
258
+ * the input. Dispatch is disabled for this operator.
259
+ *
260
+ * Example:
261
+ *
262
+ * > namespace { Tensor my_kernel_cpu(Tensor a, Tensor b) {...} }
263
+ * >
264
+ * > static auto registry = c10::RegisterOperators()
265
+ * > .op(c10::RegisterOperators::options()
266
+ * > .schema("my_op")
267
+ * > .catchAllKernel<decltype(my_kernel_cpu), &my_kernel_cpu>());
268
+ */
269
+ template<class FuncType, FuncType* kernel_func>
270
+ // enable_if: only enable it if FuncType is actually a function
271
+ std::enable_if_t<guts::is_function_type<FuncType>::value, Options&&> catchAllKernel() && {
272
+ static_assert(!std::is_same_v<FuncType, KernelFunction::BoxedKernelFunction>, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API.");
273
+ static_assert(kernel_func != nullptr, "Kernel function cannot be nullptr");
274
+
275
+ return std::move(*this).kernel(
276
+ std::nullopt,
277
+ KernelFunction::makeFromUnboxedFunction(TORCH_FN(kernel_func)),
278
+ impl::CppSignature::make<FuncType>(),
279
+ // TODO Do schema inference without relying on WrapFunctionIntoFunctor
280
+ detail::inferFunctionSchemaFromFunctor<typename impl::WrapFunctionIntoFunctor<CompileTimeFunctionPointer<FuncType, kernel_func>>::type>()
281
+ );
282
+ }
283
+
284
+ template<class FuncType>
285
+ // enable_if: only enable it if FuncType is actually a function
286
+ std::enable_if_t<guts::is_function_type<FuncType>::value, Options&&> kernel(DispatchKey dispatch_key, FuncType* kernel_func) && {
287
+ static_assert(!std::is_same_v<FuncType, KernelFunction::BoxedKernelFunction>, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API.");
288
+ TORCH_INTERNAL_ASSERT(kernel_func != nullptr, "Kernel function cannot be nullptr");
289
+
290
+ return std::move(*this).kernel(
291
+ dispatch_key,
292
+ KernelFunction::makeFromUnboxedRuntimeFunction(kernel_func),
293
+ impl::CppSignature::make<FuncType>(),
294
+ // TODO Do schema inference without relying on WrapFunctionIntoFunctor
295
+ detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<FuncType>>>()
296
+ );
297
+ }
298
+
299
+ template<class FuncType>
300
+ // enable_if: only enable it if FuncType is actually a function
301
+ std::enable_if_t<guts::is_function_type<FuncType>::value, Options&&> catchAllKernel(FuncType* kernel_func) && {
302
+ static_assert(!std::is_same_v<FuncType, KernelFunction::BoxedKernelFunction>, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API.");
303
+ TORCH_INTERNAL_ASSERT(kernel_func != nullptr, "Kernel function cannot be nullptr");
304
+
305
+ return std::move(*this).kernel(
306
+ std::nullopt,
307
+ KernelFunction::makeFromUnboxedRuntimeFunction(kernel_func),
308
+ impl::CppSignature::make<FuncType>(),
309
+ // TODO Do schema inference without relying on WrapFunctionIntoFunctor
310
+ detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<FuncType>>>()
311
+ );
312
+ }
313
+
314
+ /**
315
+ * Use this to register an operator whose kernel is implemented as a lambda.
316
+ * The kernel is only called for inputs matching the given dispatch key.
317
+ * You can register multiple kernels for different dispatch keys.
318
+ *
319
+ * The lambda must be stateless, i.e. not have a capture. If your kernel
320
+ * needs to store some configuration parameters, write the kernel as a
321
+ * functor instead.
322
+ *
323
+ * Example:
324
+ *
325
+ * > static auto registry = c10::RegisterOperators()
326
+ * > .op(c10::RegisterOperators::options()
327
+ * > .schema("my_op")
328
+ * > .kernel(DispatchKey::CPU, [] (Tensor a) -> Tensor {...}));
329
+ */
330
+ template<class Lambda>
331
+ // enable_if: only enable it if Lambda is a functor (note: lambdas are functors)
332
+ std::enable_if_t<
333
+ guts::is_functor<std::decay_t<Lambda>>::value
334
+ && !std::is_same_v<typename guts::infer_function_traits_t<std::decay_t<Lambda>>::func_type, KernelFunction::BoxedKernelFunction>,
335
+ Options&&> kernel(DispatchKey dispatch_key, Lambda&& functor) && {
336
+ static_assert(!std::is_base_of_v<OperatorKernel, std::decay_t<Lambda>>, "The kernel(x) API for registering a kernel is only meant to be used with lambdas. Your kernel is a functor. Please use the kernel<Functor>() API instead.");
337
+
338
+ // We don't support stateful lambdas (i.e. lambdas with a capture), because their
339
+ // behavior would be nonobvious. A functor kernel with cache gets a new instance of
340
+ // its cache each time the kernel is looked up from the dispatch table.
341
+ // A lambda with a capture would be global and share its capture between all kernel lookups.
342
+ // So, instead of making users having to think about it (including the thread-safety
343
+ // issues this causes), let's just forbid stateful lambdas altogether.
344
+ static_assert(guts::is_stateless_lambda<std::decay_t<Lambda>>::value, "The kernel(x) API for registering a kernel only works for stateless lambdas (i.e. lambdas without captures). If you need a cache, please use the functor based API kernel<Functor>() instead.");
345
+
346
+ return std::move(*this).kernel(
347
+ dispatch_key,
348
+ KernelFunction::makeFromUnboxedLambda(std::forward<Lambda>(functor)),
349
+ impl::CppSignature::make<Lambda>(),
350
+ // TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor
351
+ detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>()
352
+ );
353
+ }
354
+
355
+ /**
356
+ * Use this to register an operator whose kernel is implemented as a lambda.
357
+ * The kernel is a catch-all kernel, meaning it's called independent from
358
+ * the input. Dispatch is disabled for this operator.
359
+ *
360
+ * The lambda must be stateless, i.e. not have a capture. If your kernel
361
+ * needs to store some configuration parameters, write the kernel as a
362
+ * functor instead.
363
+ *
364
+ * Example:
365
+ *
366
+ * > static auto registry = c10::RegisterOperators()
367
+ * > .op(c10::RegisterOperators::options()
368
+ * > .schema("my_op")
369
+ * > .catchAllKernel([] (Tensor a) -> Tensor {...}));
370
+ */
371
+ template<class Lambda>
372
+ // enable_if: only enable it if Lambda is a functor (note: lambdas are functors)
373
+ std::enable_if_t<
374
+ guts::is_functor<std::decay_t<Lambda>>::value
375
+ && !std::is_same_v<typename guts::infer_function_traits_t<std::decay_t<Lambda>>::func_type, KernelFunction::BoxedKernelFunction>,
376
+ Options&&> catchAllKernel(Lambda&& lambda) && {
377
+ static_assert(!std::is_base_of_v<OperatorKernel, std::decay_t<Lambda>>, "The kernel(x) API for registering a kernel is only meant to be used with lambdas. Your kernel is a functor. Please use the kernel<Functor>() API instead.");
378
+
379
+ // We don't support stateful lambdas (i.e. lambdas with a capture), because their
380
+ // behavior would be nonobvious.
381
+ // A lambda with a capture would be global and share its capture between all kernel lookups.
382
+ // This would be a likely source for unexpected race conditions, so we forbid it.
383
+ // If a kernel really needs global state, they can just have regular global state
384
+ // in their .cpp file next to the kernel lambda.
385
+ static_assert(guts::is_stateless_lambda<std::decay_t<Lambda>>::value, "The kernel(x) API for registering a kernel only works for stateless lambdas (i.e. lambdas without captures). If you need a cache, please use the functor based API kernel<Functor>() instead.");
386
+
387
+ return std::move(*this).kernel(
388
+ std::nullopt,
389
+ KernelFunction::makeFromUnboxedLambda(std::forward<Lambda>(lambda)),
390
+ impl::CppSignature::make<Lambda>(),
391
+ // TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor
392
+ detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>()
393
+ );
394
+ }
395
+
396
+ Options&& aliasAnalysis(AliasAnalysisKind aliasAnalysisKind) && {
397
+ TORCH_CHECK(!aliasAnalysisKind_.has_value(), "You can only call aliasAnalysis() once per operator registration.");
398
+ aliasAnalysisKind_ = aliasAnalysisKind;
399
+ return std::move(*this);
400
+ }
401
+
402
+ private:
403
+ Options&& kernel(std::optional<DispatchKey> dispatch_key, KernelFunction&& func, std::optional<impl::CppSignature> cpp_signature, std::unique_ptr<FunctionSchema>&& inferred_function_schema) && {
404
+ KernelRegistrationConfig config;
405
+ config.dispatch_key = dispatch_key;
406
+ config.func = std::move(func);
407
+ config.cpp_signature = cpp_signature;
408
+ config.inferred_function_schema = std::move(inferred_function_schema);
409
+ kernels.push_back(std::move(config));
410
+ return std::move(*this);
411
+ }
412
+
413
+ Options()
414
+ : schemaOrName_(std::nullopt)
415
+ , aliasAnalysisKind_(std::nullopt)
416
+ {}
417
+
418
+ // KernelRegistrationConfig accumulates all information from the config
419
+ // parameters passed to a RegisterOperators::op() call into one object.
420
+ struct KernelRegistrationConfig final {
421
+ KernelRegistrationConfig()
422
+ : dispatch_key(std::nullopt)
423
+ , cpp_signature(std::nullopt)
424
+ , inferred_function_schema(nullptr)
425
+ {}
426
+
427
+ std::optional<DispatchKey> dispatch_key;
428
+ KernelFunction func;
429
+ std::optional<impl::CppSignature> cpp_signature;
430
+ std::unique_ptr<FunctionSchema> inferred_function_schema;
431
+ };
432
+
433
+ std::optional<std::variant<OperatorName, FunctionSchema>> schemaOrName_;
434
+
435
+ std::vector<KernelRegistrationConfig> kernels;
436
+ std::optional<AliasAnalysisKind> aliasAnalysisKind_;
437
+ friend class RegisterOperators;
438
+ friend class Library;
439
+ };
440
+
441
+ /**
442
+ * Call this to get an instance of registration options, which
443
+ * can be passed to a call to RegisterOperators::op() to specify
444
+ * these options for the operator registration.
445
+ * See class doc comment for examples.
446
+ */
447
+ static Options options() {
448
+ return {};
449
+ }
450
+
451
+ /**
452
+ * Call this to register an operator. See class doc comment for examples.
453
+ */
454
+ RegisterOperators&& op(Options&& options) && {
455
+ checkSchemaAndRegisterOp_(std::move(options));
456
+ return std::move(*this);
457
+ }
458
+
459
+ // Regular mutator version of the && version above
460
+ RegisterOperators& op(Options&& options) & {
461
+ checkSchemaAndRegisterOp_(std::move(options));
462
+ return *this;
463
+ }
464
+
465
+ /**
466
+ * This is a shorthand for RegisterOperators::op(Options) where you can
467
+ * specify the operator schema outside of the options parameter.
468
+ * See class doc comment for examples.
469
+ */
470
+ RegisterOperators&& op(const std::string& schemaOrName, Options&& options = RegisterOperators::options()) && {
471
+ return std::move(*this).op(std::move(options).schema(schemaOrName));
472
+ }
473
+
474
+ // internal only for registering caffe2 ops
475
+ RegisterOperators&& op(FunctionSchema schema, Options&& options) && {
476
+ return std::move(*this).op(std::move(options).schema(std::move(schema)));
477
+ }
478
+
479
+ template<class FuncType>
480
+ explicit RegisterOperators(const std::string& schemaOrName, FuncType&& func, Options&& options = RegisterOperators::options())
481
+ : RegisterOperators() {
482
+ std::move(*this).op(schemaOrName, std::forward<FuncType>(func), std::move(options));
483
+ }
484
+
485
+ /**
486
+ * This API registers an operator based on a kernel function pointer.
487
+ *
488
+ * Given a kernel
489
+ *
490
+ * > namespace { Tensor my_kernel_cpu(Tensor a, Tensor b) {...} }
491
+ *
492
+ * This API looks like:
493
+ *
494
+ * > static auto registry = c10::RegisterOperators()
495
+ * > .op("my_op", &my_kernel_cpu);
496
+ *
497
+ * If your kernel is small and the overhead of calling it matters,
498
+ * then this API might be the wrong choice since the following API
499
+ * has a slightly lower overhead for calling into the kernel:
500
+ *
501
+ * > static auto registry = c10::RegisterOperators()
502
+ * > .op("my_op", c10::RegisterOperators::options()
503
+ * > .kernel<decltype(my_kernel_cpu), &my_kernel_cpu>());
504
+ *
505
+ * Or, alternatively, write your kernel as a functor:
506
+ *
507
+ * > namespace {
508
+ * > class my_kernel_cpu final : public c10::OperatorKernel {
509
+ * > public:
510
+ * > Tensor operator()(Tensor a, Tensor b) {...}
511
+ * > };
512
+ * > }
513
+ * >
514
+ * > static auto registry = c10::RegisterOperators()
515
+ * > .op("my_op", c10::RegisterOperators::options()
516
+ * > .kernel<my_kernel_cpu>());
517
+ */
518
+ template<class FuncType>
519
+ // enable_if: only enable it if FuncType is actually a function, but not a stack based BoxedKernelFunction.
520
+ std::enable_if_t<guts::is_function_type<FuncType>::value && !std::is_same_v<FuncType, KernelFunction::BoxedKernelFunction>, RegisterOperators&&>
521
+ op(const std::string& schemaOrName, FuncType* func, Options&& options = RegisterOperators::options()) && {
522
+ constexpr bool AllowLegacyTypes = true;
523
+ return std::move(*this).op(std::move(options).schema(schemaOrName).kernel(
524
+ std::nullopt,
525
+ KernelFunction::makeFromUnboxedRuntimeFunction<AllowLegacyTypes>(func),
526
+ impl::CppSignature::make<FuncType>(),
527
+ // TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor
528
+ detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<FuncType>>>()
529
+ ));
530
+ }
531
+
532
+ /**
533
+ * This API registers an operator based on a kernel lambda.
534
+ *
535
+ * This API looks like:
536
+ *
537
+ * > static auto registry = c10::RegisterOperators()
538
+ * > .op("my_op", [] (Tensor a, Tensor b) {...});
539
+ *
540
+ * This is equivalent to:
541
+ *
542
+ * > static auto registry = c10::RegisterOperators()
543
+ * > .op("my_op", c10::RegisterOperators::options()
544
+ * > .catchAllKernel([] (Tensor a, Tensor b) {...}));
545
+ *
546
+ */
547
+ template<class Lambda>
548
+ // enable_if: only enable it if Lambda is actually a stateless lambda
549
+ std::enable_if_t<guts::is_functor<Lambda>::value && guts::is_stateless_lambda<std::decay_t<Lambda>>::value, RegisterOperators&&>
550
+ op(const std::string& schemaOrName, Lambda&& lambda, Options&& options = RegisterOperators::options()) && {
551
+ static_assert(!std::is_base_of_v<OperatorKernel, Lambda>, "c10::OperatorKernel is part of the new kernel registration API and shouldn't be used together with the deprecated registration API. Please use the new RegisterOperators::options().kernel() based API instead.");
552
+
553
+ constexpr bool AllowLegacyTypes = true;
554
+ return std::move(*this).op(std::move(options).schema(schemaOrName).kernel(
555
+ std::nullopt,
556
+ KernelFunction::makeFromUnboxedLambda<AllowLegacyTypes>(std::forward<Lambda>(lambda)),
557
+ impl::CppSignature::make<Lambda>(),
558
+ // TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor
559
+ detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>()
560
+ ));
561
+ }
562
+
563
+ template<class Lambda>
564
+ C10_DEPRECATED_MESSAGE("Registering operator kernels with stateful lambdas (i.e. lambdas with a capture) has non-obvious behavior. This is deprecated. Please use a lambda without a capture or a functor class instead.")
565
+ // enable_if: only enable it if Lambda is actually a functor but not a stateless lambda
566
+ std::enable_if_t<guts::is_functor<Lambda>::value && !guts::is_stateless_lambda<std::decay_t<Lambda>>::value, RegisterOperators&&>
567
+ op(const std::string& schemaOrName, Lambda&& lambda, Options&& options = RegisterOperators::options()) && {
568
+ static_assert(!std::is_base_of_v<OperatorKernel, Lambda>, "c10::OperatorKernel is part of the new kernel registration API and shouldn't be used together with the deprecated registration API. Please use the new RegisterOperators::options().kernel() based API instead.");
569
+
570
+ constexpr bool AllowLegacyTypes = true;
571
+ return std::move(*this).op(std::move(options).schema(schemaOrName).kernel(
572
+ std::nullopt,
573
+ KernelFunction::makeFromUnboxedLambda<AllowLegacyTypes>(std::forward<Lambda>(lambda)),
574
+ impl::CppSignature::make<Lambda>(),
575
+ // TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor
576
+ detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>()
577
+ ));
578
+ }
579
+
580
+ private:
581
+ void checkSchemaAndRegisterOp_(Options&& config);
582
+
583
+ static c10::FunctionSchema inferSchemaFromKernels_(const OperatorName& opNameStr, const Options& options);
584
+ void checkNoDuplicateKernels_(const Options& options);
585
+ void registerOp_(Options&& options);
586
+
587
+ std::vector<RegistrationHandleRAII> registrars_;
588
+ };
589
+
590
+ } // namespace c10
591
+
592
+ namespace torch {
593
+ // Old-style API
594
+ using RegisterOperators = c10::RegisterOperators;
595
+ }
596
+
597
+ #else
598
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
599
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/FlushDenormal.h ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ /// Flush-To-Zero and Denormals-Are-Zero mode
3
+ ///
4
+ /// Flush-To-Zero (FTZ) and Denormals-Are-Zero (DAZ) are modes that bypass
5
+ /// IEEE 754 methods of dealing with denormal floating-point numbers on x86-64
6
+ /// and some x86 CPUs. They result in reduced precision for values near zero,
7
+ /// but increased performance.
8
+ ///
9
+ /// See https://software.intel.com/en-us/articles/x87-and-sse-floating-point-assists-in-ia-32-flush-to-zero-ftz-and-denormals-are-zero-daz
10
+
11
+ namespace at::cpu {
12
+
13
+ bool set_flush_denormal(bool on);
14
+
15
+ } // namespace at::cpu
16
+
17
+ #else
18
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
19
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/Utils.h ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <cstdint>
5
+
6
+ #include <c10/macros/Export.h>
7
+
8
+ namespace at::cpu {
9
+
10
+ TORCH_API bool is_avx2_supported();
11
+ TORCH_API bool is_avx512_supported();
12
+
13
+ // Detect if CPU support Vector Neural Network Instruction.
14
+ TORCH_API bool is_avx512_vnni_supported();
15
+
16
+ // Detect if CPU supports AVX512_BF16 ISA
17
+ TORCH_API bool is_avx512_bf16_supported();
18
+
19
+ // Detect if CPU support Advanced Matrix Extension.
20
+ TORCH_API bool is_amx_tile_supported();
21
+
22
+ // Detect if CPU support Advanced Matrix Extension for fp16.
23
+ TORCH_API bool is_amx_fp16_supported();
24
+
25
+ // Enable the system to use AMX instructions.
26
+ TORCH_API bool init_amx();
27
+
28
+ // Get the L1 cache size per core in Byte
29
+ TORCH_API uint32_t L1d_cache_size();
30
+
31
+ // Get the L2 cache size per core in Byte
32
+ TORCH_API uint32_t L2_cache_size();
33
+
34
+ } // namespace at::cpu
35
+
36
+ #else
37
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
38
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vml.h ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/Config.h>
5
+ #include <ATen/Parallel.h>
6
+ #include <ATen/OpMathType.h>
7
+ #include <ATen/cpu/vec/functional.h>
8
+ #include <ATen/cpu/vec/vec.h>
9
+ #include <c10/util/complex.h>
10
+
11
+ // This header implements various unary operations using a MKL VML style
12
+ // interface.
13
+
14
+ // It implements various functions with a simple interface
15
+ // For example it enables the user to call vsin(float* out, const float* in,
16
+ // size) This functions takes a pointer to a continuous output array of floats and
17
+ // a constant input array. It will then apply sin to each value in the input
18
+ // array and write the result into the output array. out and in may point to the
19
+ // same memory, i.e. this fully supports in-place operations. These functions
20
+ // also implement their own parallelization, so take precautions when calling
21
+ // these from threaded functions.
22
+
23
+ // When MKL is available it will call into MKL's VML library similar to NumPy
24
+ // If MKL is not available it will use SLEEF.
25
+
26
+ // This file might be compiled under AVX or AVX2 when called from e.g.
27
+ // UnaryOpsKernel.cpp
28
+
29
+ #include <algorithm>
30
+ #include <cstddef>
31
+ #include <cstdint>
32
+ #include <cstring>
33
+ #include <type_traits>
34
+
35
+ #if AT_MKL_ENABLED() && !defined(__APPLE__)
36
+ #include <mkl.h>
37
+ #endif
38
+
39
+
40
+ namespace at::vml {
41
+ inline namespace CPU_CAPABILITY {
42
+
43
+ using namespace vec;
44
+
45
+ template <typename scalar_t>
46
+ inline void vrsqrt(scalar_t* out, scalar_t* in, int64_t size) {
47
+ parallel_for(0, size, 2048, [out, in](int64_t begin, int64_t end) {
48
+ map(
49
+ [](const Vectorized<scalar_t>& x) {
50
+ return Vectorized<scalar_t>((scalar_t)1) / x.sqrt();
51
+ },
52
+ out + begin,
53
+ in + begin,
54
+ end - begin);
55
+ });
56
+ }
57
+
58
+ // NB: We ignore numerical errors by convention and leave them to the user
59
+
60
+ #define IMPLEMENT_VML(op) \
61
+ template <typename scalar_t> \
62
+ inline void v##op(scalar_t* out, const scalar_t* in, int64_t size) { \
63
+ using vec_t = Vectorized<vec_scalar_t<scalar_t>>; \
64
+ vec::map([](vec_t x) { return x.op(); }, out, in, size); \
65
+ } \
66
+
67
+ IMPLEMENT_VML(abs)
68
+ IMPLEMENT_VML(acos)
69
+ IMPLEMENT_VML(asin)
70
+ IMPLEMENT_VML(atan)
71
+ IMPLEMENT_VML(atanh)
72
+ IMPLEMENT_VML(ceil)
73
+ IMPLEMENT_VML(cos)
74
+ // IMPLEMENT_VML(cosh)
75
+ IMPLEMENT_VML(erf)
76
+ IMPLEMENT_VML(erfc)
77
+ IMPLEMENT_VML(erfinv)
78
+ IMPLEMENT_VML(exp)
79
+ IMPLEMENT_VML(expm1)
80
+ IMPLEMENT_VML(floor)
81
+ IMPLEMENT_VML(i0)
82
+ IMPLEMENT_VML(i0e)
83
+ IMPLEMENT_VML(digamma)
84
+ IMPLEMENT_VML(reciprocal)
85
+ IMPLEMENT_VML(log)
86
+ IMPLEMENT_VML(log10)
87
+ IMPLEMENT_VML(log1p)
88
+ IMPLEMENT_VML(log2)
89
+ IMPLEMENT_VML(neg)
90
+ IMPLEMENT_VML(sin)
91
+ // IMPLEMENT_VML(sinh)
92
+ IMPLEMENT_VML(sqrt)
93
+ IMPLEMENT_VML(round)
94
+ IMPLEMENT_VML(rsqrt)
95
+ IMPLEMENT_VML(tan)
96
+ IMPLEMENT_VML(tanh)
97
+ IMPLEMENT_VML(trunc)
98
+ IMPLEMENT_VML(lgamma)
99
+
100
+
101
+ #if AT_MKL_ENABLED() && !defined(__APPLE__)
102
+
103
+ // NB: LP64 MKL is the most commonly used and thus we assume it here. That means
104
+ // we need to expect MKL_INT to be of type int, which implies int32_t or int64_t in most
105
+ // cases.
106
+ static_assert(
107
+ std::is_same_v<MKL_INT, int32_t> || std::is_same_v<MKL_INT, int64_t>,
108
+ "MKL_INT is assumed to be int32_t or int64_t");
109
+ #define IMPLEMENT_VML_MKL_STUB(op, mklop, type, mkltype) \
110
+ template <> \
111
+ inline void v##op(type * out, const type * in, int64_t size) { \
112
+ auto constexpr max_mkl_ind = std::numeric_limits<MKL_INT>::max(); \
113
+ if (size <= static_cast<int64_t>(max_mkl_ind)) { \
114
+ vm##mkltype##mklop( \
115
+ size, in, out, VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \
116
+ } else { \
117
+ int64_t ind = 0; \
118
+ int64_t chunks = size / max_mkl_ind; \
119
+ int64_t rest = size % max_mkl_ind; \
120
+ for (; ind < chunks; ind++) { \
121
+ vm##mkltype##mklop( \
122
+ max_mkl_ind, \
123
+ in + ind * max_mkl_ind, \
124
+ out + ind * max_mkl_ind, \
125
+ VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \
126
+ } \
127
+ vm##mkltype##mklop( \
128
+ rest, \
129
+ in + ind * max_mkl_ind, \
130
+ out + ind * max_mkl_ind, \
131
+ VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \
132
+ } \
133
+ }
134
+
135
+ #define IMPLEMENT_VML_MKL(op, mklop) \
136
+ IMPLEMENT_VML_MKL_STUB(op, mklop, float, s) \
137
+ IMPLEMENT_VML_MKL_STUB(op, mklop, double, d)
138
+
139
+ // NB: abs, cosh and sinh were temporarily disabled due to issues with Apple
140
+ // NB: expm1 is disabled because on some configs it produces expm1(nan)=-1
141
+ IMPLEMENT_VML_MKL(acos, Acos)
142
+ IMPLEMENT_VML_MKL(asin, Asin)
143
+ IMPLEMENT_VML_MKL(atan, Atan)
144
+ IMPLEMENT_VML_MKL(cos, Cos)
145
+ // IMPLEMENT_VML_MKL(cosh, Cosh)
146
+ IMPLEMENT_VML_MKL(erf, Erf)
147
+ IMPLEMENT_VML_MKL(erfc, Erfc)
148
+ IMPLEMENT_VML_MKL(erfinv, ErfInv)
149
+ IMPLEMENT_VML_MKL(exp, Exp)
150
+ // IMPLEMENT_VML_MKL(expm1, Expm1)
151
+ IMPLEMENT_VML_MKL(log, Ln)
152
+ IMPLEMENT_VML_MKL(log10, Log10)
153
+ IMPLEMENT_VML_MKL(sin, Sin)
154
+ // IMPLEMENT_VML_MKL(sinh, Sinh)
155
+ IMPLEMENT_VML_MKL(sqrt, Sqrt)
156
+ IMPLEMENT_VML_MKL(tan, Tan)
157
+ IMPLEMENT_VML_MKL(tanh, Tanh)
158
+ IMPLEMENT_VML_MKL(trunc, Trunc)
159
+
160
+ // Not vectorized in MKL version tested
161
+ // IMPLEMENT_VML_MKL(abs, Abs)
162
+ // IMPLEMENT_VML_MKL(log1p, Log1p)
163
+
164
+ #if INTEL_MKL_VERSION >= 20180406
165
+ IMPLEMENT_VML_MKL(log2, Log2)
166
+ #endif
167
+
168
+ #endif
169
+
170
+ } // namespace
171
+ } // namespace at::vml
172
+
173
+ #else
174
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
175
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/BLASConstants.h ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/core/TensorBase.h>
5
+
6
+ namespace at::cuda::detail {
7
+
8
+ float *get_cublas_device_one();
9
+ float *get_cublas_device_zero();
10
+ float *get_user_alpha_ptr();
11
+
12
+ } // namespace at::cuda::detail
13
+
14
+ #else
15
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
16
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/CUDAHooks.h ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/detail/CUDAHooksInterface.h>
5
+
6
+ #include <ATen/Generator.h>
7
+
8
+ // TODO: No need to have this whole header, we can just put it all in
9
+ // the cpp file
10
+
11
+ namespace at::cuda::detail {
12
+
13
+ // Set the callback to initialize Magma, which is set by
14
+ // torch_cuda_cu. This indirection is required so magma_init is called
15
+ // in the same library where Magma will be used.
16
+ TORCH_CUDA_CPP_API void set_magma_init_fn(void (*magma_init_fn)());
17
+
18
+
19
+ // The real implementation of CUDAHooksInterface
20
+ struct CUDAHooks : public at::CUDAHooksInterface {
21
+ CUDAHooks(at::CUDAHooksArgs /*unused*/) {}
22
+ void init() const override;
23
+ Device getDeviceFromPtr(void* data) const override;
24
+ bool isPinnedPtr(const void* data) const override;
25
+ const Generator& getDefaultGenerator(
26
+ DeviceIndex device_index = -1) const override;
27
+ Generator getNewGenerator(
28
+ DeviceIndex device_index = -1) const override;
29
+ bool hasCUDA() const override;
30
+ bool hasMAGMA() const override;
31
+ bool hasCuDNN() const override;
32
+ bool hasCuSOLVER() const override;
33
+ bool hasCuBLASLt() const override;
34
+ bool hasROCM() const override;
35
+ bool hasCKSDPA() const override;
36
+ bool hasCKGEMM() const override;
37
+ const at::cuda::NVRTC& nvrtc() const override;
38
+ DeviceIndex current_device() const override;
39
+ bool isBuilt() const override {return true;}
40
+ bool isAvailable() const override {return hasCUDA();}
41
+ bool hasPrimaryContext(DeviceIndex device_index) const override;
42
+ Allocator* getCUDADeviceAllocator() const override;
43
+ Allocator* getPinnedMemoryAllocator() const override;
44
+ bool compiledWithCuDNN() const override;
45
+ bool compiledWithMIOpen() const override;
46
+ bool supportsDilatedConvolutionWithCuDNN() const override;
47
+ bool supportsDepthwiseConvolutionWithCuDNN() const override;
48
+ bool supportsBFloat16ConvolutionWithCuDNNv8() const override;
49
+ bool supportsBFloat16RNNWithCuDNN() const override;
50
+ bool hasCUDART() const override;
51
+ long versionCUDART() const override;
52
+ long versionCuDNN() const override;
53
+ long versionRuntimeCuDNN() const override;
54
+ long versionCuDNNFrontend() const override;
55
+ long versionMIOpen() const override;
56
+ std::string showConfig() const override;
57
+ double batchnormMinEpsilonCuDNN() const override;
58
+ int64_t cuFFTGetPlanCacheMaxSize(DeviceIndex device_index) const override;
59
+ void cuFFTSetPlanCacheMaxSize(DeviceIndex device_index, int64_t max_size) const override;
60
+ int64_t cuFFTGetPlanCacheSize(DeviceIndex device_index) const override;
61
+ void cuFFTClearPlanCache(DeviceIndex device_index) const override;
62
+ int getNumGPUs() const override;
63
+ DeviceIndex deviceCount() const override;
64
+ DeviceIndex getCurrentDevice() const override;
65
+
66
+ #ifdef USE_ROCM
67
+ bool isGPUArch(const std::vector<std::string>& archs, DeviceIndex device_index = -1) const override;
68
+ #endif
69
+ void deviceSynchronize(DeviceIndex device_index) const override;
70
+ };
71
+
72
+ } // at::cuda::detail
73
+
74
+ #else
75
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
76
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/DeviceThreadHandles.h ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ // Some stateful GPU libraries, such as cuDNN, cuBLAS, use handles to store states.
3
+ // These handles are tied to device, and these libraries requires/recommends not to
4
+ // share handles across host threads.
5
+ //
6
+ // These libraries recommend using one handle per host thread. We may not want to do
7
+ // this because threads are relatively light-weight, but creating and destroying
8
+ // handles is expensive (destroying the handle causes synchronizations). DataParallel,
9
+ // for example, creates new threads for each forward pass.
10
+ //
11
+ // This file implements a handle pool mechanism. The handle pool returns handles on
12
+ // demand as threads request them. If all existing handles in the pool are in use,
13
+ // it creates a new one. As threads terminate, they release handles back into the pool.
14
+ // In this way, the handle pool never creates more handles than the high-water mark of
15
+ // active threads, so it's efficient with DataParallel.
16
+
17
+ #pragma once
18
+
19
+ #include <unordered_map>
20
+ #include <vector>
21
+ #include <utility>
22
+ #include <mutex>
23
+ #include <memory>
24
+
25
+ #include <c10/util/Exception.h>
26
+
27
+ namespace at::cuda { namespace {
28
+
29
+ template <typename Handle_t, void Create(Handle_t *), void Destroy(Handle_t)>
30
+ struct DeviceThreadHandlePool : public std::enable_shared_from_this<DeviceThreadHandlePool<Handle_t, Create, Destroy>> {
31
+
32
+ struct Handle {
33
+ Handle_t handle;
34
+ Handle(bool create = false) : handle(nullptr)
35
+ {
36
+ if(create) Create(&handle);
37
+ }
38
+ // std::vector.emplace() and push_back() may route through temporaries and call
39
+ // copy/move constructors along the way. If this is the case, we don't want
40
+ // the destructors of temporaries to call cudnnDestroy on the handle.
41
+ // We can achieve safety (for the narrow case of stashing within std::vectors)
42
+ // by making Handle moveable but not copyable, and transferring handle ownership
43
+ // to the latest constructed object. This is not a substitute for full-blown
44
+ // reference counting, but reference counting may be overkill here.
45
+ // Another alternative is to wrap the saved Handles in unique_ptrs, i.e.,
46
+ // unordered_map<int, vector<unique_ptr<Handle>>> created_handles;
47
+ Handle(const Handle& rhs) = delete;
48
+ // Following https://stackoverflow.com/questions/3279543/what-is-the-copy-and-swap-idiom
49
+ Handle(Handle&& rhs) noexcept : Handle() { std::swap(handle, rhs.handle); }
50
+ // operator= takes argument by value
51
+ Handle& operator=(Handle rhs) { std::swap(handle, rhs.handle); return *this; }
52
+ ~Handle() {
53
+ if(handle) Destroy(handle);
54
+ }
55
+ };
56
+
57
+ std::mutex mutex;
58
+
59
+ // Handles are lazily created as different threads request them,
60
+ // but are never destroyed until the end of the process.
61
+ // The maximum number of handles this process will create for each device is equal
62
+ // to the high-water mark of the number of concurrently active threads that request
63
+ // handles for that device.
64
+ // When threads terminate, they release their handles back into the pool for reuse.
65
+ // Otherwise, new handles would be created every time new threads were spawned,
66
+ // resulting in poor performance for Python modules that repeatedly or frequently
67
+ // spawned new sets of threads (like DataParallel, which creates a new set of threads
68
+ // for each forward pass).
69
+ //
70
+ // To prevent potential deadlocks, we explicitly choose not to cap the number
71
+ // of handles that are created per device.
72
+ // Example of danger: If we cap the max handles at 4, and 5 threads are sharing a device,
73
+ // only 4 can make forward progress at any time. The other 4 will not release their
74
+ // handles until they exit, so the fifth cannot make progress until then. This is
75
+ // not a problem...UNLESS all 5 threads attempt some sort of synchronization at an
76
+ // intermediate point (ie, before any of them have exited). We have no way to anticipate
77
+ // or enforce that user threads will not attempt such intermediate synchronization.
78
+ // The only way to ensure safety is to avoid imposing a cap on the number of handles.
79
+ std::unordered_map<int, std::vector<Handle>> created_handles;
80
+ std::unordered_map<int, std::vector<Handle_t>> available_handles;
81
+
82
+ // PoolWindow lazily creates and caches the handles that a particular thread is using,
83
+ // so in the common case handle access doesn't incur either handle creation or a mutex lock.
84
+ class PoolWindow
85
+ {
86
+ public:
87
+ PoolWindow(std::shared_ptr<DeviceThreadHandlePool> parent): weak_parent(std::move(parent)) {}
88
+ ~PoolWindow(){ release(); }
89
+
90
+ Handle_t reserve(int device)
91
+ {
92
+ // If this thread already has a handle for this device, return it
93
+ if(my_handles.find(device) != my_handles.end())
94
+ return my_handles[device];
95
+
96
+ // otherwise, either grab a handle from the pool if one is available,
97
+ // or if not, create a new one.
98
+ auto parent = weak_parent.lock();
99
+ TORCH_CHECK(parent, "Cannot create handle during program termination");
100
+ std::lock_guard<std::mutex> guard(parent->mutex);
101
+
102
+ if(parent->available_handles[device].size() > 0)
103
+ {
104
+ my_handles[device] = parent->available_handles[device].back();
105
+ parent->available_handles[device].pop_back();
106
+ }
107
+ else
108
+ {
109
+ // In local testing, I do observe that emplace_back sometimes routes through temporaries
110
+ // that incur move-constructor and destructor calls. See comments in Handle above.
111
+ parent->created_handles[device].emplace_back(true /*create*/);
112
+ my_handles[device] = parent->created_handles[device].back().handle;
113
+ }
114
+
115
+ return my_handles[device];
116
+ }
117
+
118
+ private:
119
+ // Stores the per-device handles currently owned by this thread
120
+ std::unordered_map<int, Handle_t> my_handles;
121
+
122
+ std::weak_ptr<DeviceThreadHandlePool> weak_parent;
123
+
124
+ // Called by the destructor. Releases this thread's handles back into the pool.
125
+ void release() {
126
+ if(!my_handles.empty()) {
127
+ auto parent = weak_parent.lock();
128
+ if (!parent) {
129
+ // If this thread exits after atexit handlers have completed, the
130
+ // cuda context itself may be invalid, so we must leak the handles.
131
+ return;
132
+ }
133
+
134
+ std::lock_guard<std::mutex> guard(parent->mutex);
135
+ for(auto d_h : my_handles)
136
+ parent->available_handles[d_h.first].push_back(d_h.second);
137
+ }
138
+ }
139
+ };
140
+
141
+ // Warning:
142
+ // If you want to change this function, be aware that this function will be called
143
+ // by multiple threads and there is no mutex guarding the call of this function, so
144
+ // make sure your implementation is thread-safe.
145
+ PoolWindow *newPoolWindow() {
146
+ // The returned pointer will be owned by a thread local variable
147
+ // so that different threads does not share the same PoolWindow.
148
+ return new PoolWindow(this->shared_from_this());
149
+ }
150
+ };
151
+
152
+ }} // namespace at::cuda::detail::<anonymous>
153
+
154
+ #else
155
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
156
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/IndexUtils.cuh ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/core/TensorBase.h>
5
+ #include <ATen/cuda/detail/TensorInfo.cuh>
6
+ #include <ATen/native/CanUse32BitIndexMath.h>
7
+
8
+ namespace at::cuda::detail {
9
+
10
+ TORCH_CUDA_CU_API bool maybeOverlappingIndices(const at::TensorBase &t);
11
+ using at::native::canUse32BitIndexMath;
12
+
13
+ template <typename scalar, typename IndexType>
14
+ TensorInfo<scalar, IndexType>
15
+ getTensorInfo(const at::TensorBase &t) {
16
+ IndexType sz[MAX_TENSORINFO_DIMS];
17
+ IndexType st[MAX_TENSORINFO_DIMS];
18
+
19
+ int dims = t.dim();
20
+ for (int i = 0; i < dims; ++i) {
21
+ sz[i] = t.size(i);
22
+ st[i] = t.stride(i);
23
+ }
24
+
25
+ scalar* data_ptr = nullptr;
26
+
27
+ if constexpr (std::is_const_v<scalar>) {
28
+ data_ptr = t.const_data_ptr<scalar>();
29
+ } else {
30
+ data_ptr = t.mutable_data_ptr<scalar>();
31
+ }
32
+
33
+ return TensorInfo<scalar, IndexType>(
34
+ data_ptr, dims, sz, st);
35
+ }
36
+
37
+ } // namespace at::cuda::detail
38
+
39
+ #else
40
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
41
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/IntegerDivider.cuh ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <assert.h>
5
+ #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
6
+ #include <cuda_runtime.h>
7
+ #endif
8
+
9
+ namespace at::cuda::detail {
10
+
11
+ // A utility class to implement integer division by multiplication, given a fixed
12
+ // divisor.
13
+ //
14
+ // WARNING: The fast divider algorithm is only implemented for unsigned int;
15
+ // otherwise we default to plain integer division. For unsigned int,
16
+ // we further assume that the dividend is at most INT32_MAX. Thus,
17
+ // IntDivider must NOT be used for general integer division.
18
+ //
19
+ // This reduced range is enough for our purpose, and it allows us to
20
+ // slightly simplify the computation.
21
+ //
22
+ // (NOTE: Below, "2^k" denotes exponentiation, i.e., 1<<k.)
23
+ //
24
+ // For any N-bit unsigned integer d (> 0), we can find a "magic number" m (2^N
25
+ // <= m < 2^(N+1)) and shift s such that:
26
+ //
27
+ // \floor(n / d) = \floor((m * n) / 2^(N+s)).
28
+ //
29
+ // Given such m and s, the integer division can be then implemented as:
30
+ //
31
+ // let m' = m - 2^N // 0 <= m' < 2^N
32
+ //
33
+ // fast_integer_division(n):
34
+ // // Multiply two N-bit unsigned integers: the result is a 2N-bit unsigned
35
+ // // integer. Then take the higher N bits.
36
+ // t = (m' * n) >> N
37
+ //
38
+ // // Here we use the fact that n is less than 2^(N-1): otherwise the value
39
+ // // of (t + n) may not fit in an N-bit integer.
40
+ // return (t + n) >> s
41
+ //
42
+ // Finding such a magic number is surprisingly easy:
43
+ //
44
+ // s = \ceil(\log_2 d)
45
+ // m' = \floor(2^N * (2^s - d) / d) + 1 // Need 2N-bit integer arithmetic.
46
+ //
47
+ // See also:
48
+ // - Division by Invariant Integers Using Multiplication,
49
+ // Torbjörn Granlund and Peter L. Montgomery, 1994.
50
+ //
51
+ // - http://www.hackersdelight.org/magic.htm
52
+ //
53
+ // - http://ridiculousfish.com/blog/posts/labor-of-division-episode-i.html
54
+
55
+ // Result of div/mod operation stored together.
56
+ template <typename Value>
57
+ struct DivMod {
58
+ Value div, mod;
59
+
60
+ C10_HOST_DEVICE DivMod(Value div, Value mod) : div(div), mod(mod) { }
61
+ };
62
+
63
+ // Base case: we only have an implementation for uint32_t for now. For
64
+ // everything else, we use plain division.
65
+ template <typename Value>
66
+ struct IntDivider {
67
+ IntDivider() = default;
68
+ IntDivider(Value d) : divisor(d) { }
69
+
70
+ C10_HOST_DEVICE inline Value div(Value n) const { return n / divisor; }
71
+ C10_HOST_DEVICE inline Value mod(Value n) const { return n % divisor; }
72
+ C10_HOST_DEVICE inline DivMod<Value> divmod(Value n) const {
73
+ return DivMod<Value>(n / divisor, n % divisor);
74
+ }
75
+
76
+ Value divisor;
77
+ };
78
+
79
+ // Implement fast integer division.
80
+ template <>
81
+ struct IntDivider<unsigned int> {
82
+ static_assert(sizeof(unsigned int) == 4, "Assumes 32-bit unsigned int.");
83
+
84
+ IntDivider() = default;
85
+
86
+ IntDivider(unsigned int d) : divisor(d) {
87
+ assert(divisor >= 1 && divisor <= INT32_MAX);
88
+
89
+ // TODO: gcc/clang has __builtin_clz() but it's not portable.
90
+ for (shift = 0; shift < 32; shift++) if ((1U << shift) >= divisor) break;
91
+
92
+ uint64_t one = 1;
93
+ uint64_t magic = ((one << 32) * ((one << shift) - divisor)) / divisor + 1;
94
+ m1 = magic;
95
+ assert(m1 > 0 && m1 == magic); // m1 must fit in 32 bits.
96
+ }
97
+
98
+ C10_HOST_DEVICE inline unsigned int div(unsigned int n) const {
99
+ #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
100
+ // 't' is the higher 32-bits of unsigned 32-bit multiplication of 'n' and
101
+ // 'm1'.
102
+ unsigned int t = __umulhi(n, m1);
103
+ return (t + n) >> shift;
104
+ #else
105
+ // Using uint64_t so that the addition does not overflow.
106
+ uint64_t t = ((uint64_t) n * m1) >> 32;
107
+ return (t + n) >> shift;
108
+ #endif
109
+ }
110
+
111
+ C10_HOST_DEVICE inline unsigned int mod(unsigned int n) const {
112
+ return n - div(n) * divisor;
113
+ }
114
+
115
+ C10_HOST_DEVICE inline DivMod<unsigned int> divmod(unsigned int n) const {
116
+ unsigned int q = div(n);
117
+ return DivMod<unsigned int>(q, n - q * divisor);
118
+ }
119
+
120
+ unsigned int divisor; // d above.
121
+ unsigned int m1; // Magic number: m' above.
122
+ unsigned int shift; // Shift amounts.
123
+ };
124
+
125
+ } // namespace at::cuda::detail
126
+
127
+ #else
128
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
129
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/KernelUtils.h ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <limits>
5
+ #include <c10/util/Exception.h>
6
+
7
+ namespace at::cuda::detail {
8
+
9
+ // CUDA: grid stride looping
10
+ //
11
+ // int64_t _i_n_d_e_x specifically prevents overflow in the loop increment.
12
+ // If input.numel() < INT_MAX, _i_n_d_e_x < INT_MAX, except after the final
13
+ // iteration of the loop where _i_n_d_e_x += blockDim.x * gridDim.x can be
14
+ // greater than INT_MAX. But in that case _i_n_d_e_x >= n, so there are no
15
+ // further iterations and the overflowed value in i=_i_n_d_e_x is not used.
16
+ #define CUDA_KERNEL_LOOP_TYPE(i, n, index_type) \
17
+ int64_t _i_n_d_e_x = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x; \
18
+ for (index_type i=_i_n_d_e_x; _i_n_d_e_x < (n); _i_n_d_e_x+=blockDim.x * gridDim.x, i=_i_n_d_e_x)
19
+
20
+ #define CUDA_KERNEL_LOOP(i, n) CUDA_KERNEL_LOOP_TYPE(i, n, int)
21
+
22
+
23
+ // Use 1024 threads per block, which requires cuda sm_2x or above
24
+ constexpr int CUDA_NUM_THREADS = 1024;
25
+
26
+ // CUDA: number of blocks for threads.
27
+ inline int GET_BLOCKS(const int64_t N, const int64_t max_threads_per_block=CUDA_NUM_THREADS) {
28
+ TORCH_INTERNAL_ASSERT(N > 0, "CUDA kernel launch blocks must be positive, but got N=", N);
29
+ constexpr int64_t max_int = std::numeric_limits<int>::max();
30
+
31
+ // Round up division for positive number that cannot cause integer overflow
32
+ auto block_num = (N - 1) / max_threads_per_block + 1;
33
+ TORCH_INTERNAL_ASSERT(block_num <= max_int, "Can't schedule too many blocks on CUDA device");
34
+
35
+ return static_cast<int>(block_num);
36
+ }
37
+
38
+ } // namespace at::cuda::detail
39
+
40
+ #else
41
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
42
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/LazyNVRTC.h ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+ #include <ATen/detail/CUDAHooksInterface.h>
4
+ namespace at::cuda {
5
+ // Forward-declares at::cuda::NVRTC
6
+ struct NVRTC;
7
+
8
+ namespace detail {
9
+ extern NVRTC lazyNVRTC;
10
+ } // namespace detail
11
+
12
+ } // namespace at::cuda
13
+
14
+ #else
15
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
16
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/OffsetCalculator.cuh ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <array>
5
+ #include <cstdint>
6
+ #include <type_traits>
7
+ #include <c10/macros/Macros.h>
8
+ #include <ATen/native/TensorIterator.h>
9
+ #include <ATen/cuda/detail/IntegerDivider.cuh>
10
+
11
+ // If element_sizes is nullptr, then the strides will be in bytes, otherwise
12
+ // the strides will be in # of elements.
13
+ // Operands that share the same shape, but may have different strides.
14
+ // OffsetCalculator iterates the tensor in a column-major order
15
+
16
+ #if defined(USE_ROCM)
17
+ constexpr int MAX_DIMS = 16;
18
+ #else
19
+ constexpr int MAX_DIMS = 25;
20
+ #endif
21
+
22
+ template <int NARGS, typename index_t = uint32_t, bool signed_strides = false>
23
+ struct OffsetCalculator {
24
+ // We allow having negative strides to implement some operations like torch.flip
25
+ using stride_t = std::conditional_t<signed_strides,
26
+ std::make_signed_t<index_t>,
27
+ index_t>;
28
+ // The offset for each argument. Wrapper around fixed-size array.
29
+ // On CUDA, zero sized array is not allowed, so when we are handling nullary
30
+ // operators, we need to create a size 1 offset to avoid compiler failure.
31
+ // This size 1 offset is just a placeholder, and we will not use it.
32
+ using offset_type = std::array<stride_t, std::max<int>(NARGS, 1)>;
33
+
34
+ // if element_sizes is nullptr, then the strides will be in bytes, otherwise
35
+ // the strides will be in # of elements.
36
+ OffsetCalculator(int dims, const int64_t* sizes, const int64_t* const* strides, const int64_t* element_sizes=nullptr) : dims(dims) {
37
+ TORCH_CHECK(dims <= MAX_DIMS, "tensor has too many (>", MAX_DIMS, ") dims");
38
+ for (int i=0; i < dims; i++){
39
+ sizes_[i] = at::cuda::detail::IntDivider<index_t>(sizes[i]);
40
+ for (int arg = 0; arg < NARGS; arg++) {
41
+ int64_t element_size = (element_sizes == nullptr ? 1LL : element_sizes[arg]);
42
+ strides_[i][arg] = strides[arg][i] / element_size;
43
+ }
44
+ }
45
+ }
46
+
47
+ C10_HOST_DEVICE offset_type get(index_t linear_idx) const {
48
+ offset_type offsets;
49
+
50
+ #if defined(USE_ROCM)
51
+ if ((dims > 0) && (dims <= 2)) {
52
+ auto divmod = sizes_[0].divmod(linear_idx);
53
+ #pragma unroll
54
+ for (int arg = 0; arg < NARGS; arg++)
55
+ offsets[arg] = divmod.mod * strides_[0][arg];
56
+ if (dims >= 2) {
57
+ divmod = sizes_[1].divmod(divmod.div);
58
+ #pragma unroll
59
+ for (int arg = 0; arg < NARGS; arg++)
60
+ offsets[arg] += divmod.mod * strides_[1][arg];
61
+ }
62
+ // [...]
63
+ return offsets;
64
+ }
65
+ #endif
66
+
67
+ #pragma unroll
68
+ for (int arg = 0; arg < NARGS; arg++) {
69
+ offsets[arg] = 0;
70
+ }
71
+
72
+ #pragma unroll
73
+ for (int dim = 0; dim < MAX_DIMS; ++dim) {
74
+ if (dim == dims) {
75
+ break;
76
+ }
77
+ auto divmod = sizes_[dim].divmod(linear_idx);
78
+ linear_idx = divmod.div;
79
+
80
+ #pragma unroll
81
+ for (int arg = 0; arg < NARGS; arg++) {
82
+ offsets[arg] += divmod.mod * strides_[dim][arg];
83
+ }
84
+
85
+ }
86
+ return offsets;
87
+ }
88
+
89
+ int dims;
90
+ at::cuda::detail::IntDivider<index_t> sizes_[MAX_DIMS];
91
+ stride_t strides_[MAX_DIMS][std::max<int>(NARGS, 1)];
92
+ };
93
+
94
+ template <int NARGS, typename index_t = uint32_t>
95
+ struct TrivialOffsetCalculator {
96
+ // The offset for each argument. Wrapper around fixed-size array.
97
+ // The offsets are in # of elements, not in bytes.
98
+ // On CUDA, zero sized array is not allowed, so when we are handling nullary
99
+ // operators, we need to create a size 1 offset to avoid compiler failure.
100
+ // This size 1 offset is just a placeholder, and we will not use it.
101
+ using offset_type = std::array<index_t, std::max<int>(NARGS, 1)>;
102
+
103
+ C10_HOST_DEVICE offset_type get(index_t linear_idx) const {
104
+ offset_type offsets;
105
+ #pragma unroll
106
+ for (int arg = 0; arg < NARGS; arg++) {
107
+ offsets[arg] = linear_idx;
108
+ }
109
+ return offsets;
110
+ }
111
+ };
112
+
113
+ // Make an OffsetCalculator with byte offsets
114
+ template<int N, bool signed_strides = false>
115
+ static OffsetCalculator<N, uint32_t, signed_strides> make_offset_calculator(const at::TensorIteratorBase& iter) {
116
+ TORCH_INTERNAL_ASSERT(N <= iter.ntensors());
117
+ std::array<const int64_t*, N> strides;
118
+ for (int i = 0; i < N; i++) {
119
+ strides[i] = iter.strides(i).data();
120
+ }
121
+ return OffsetCalculator<N, uint32_t, signed_strides>(iter.ndim(), iter.shape().data(), strides.data());
122
+ }
123
+
124
+ // Make an OffsetCalculator with element offsets
125
+ template<int N, bool signed_strides = false>
126
+ static OffsetCalculator<N, uint32_t, signed_strides> make_element_offset_calculator(
127
+ const at::TensorIteratorBase& iter) {
128
+ TORCH_INTERNAL_ASSERT(N <= iter.ntensors());
129
+ std::array<const int64_t*, N> strides;
130
+ std::array<int64_t, N> element_sizes;
131
+ for (int i = 0; i < N; i++) {
132
+ strides[i] = iter.strides(i).data();
133
+ element_sizes[i] = iter.element_size(i);
134
+ }
135
+ return OffsetCalculator<N, uint32_t, signed_strides>(
136
+ iter.ndim(), iter.shape().data(), strides.data(), element_sizes.data());
137
+ }
138
+
139
+ #else
140
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
141
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/PhiloxCudaStateRaw.cuh ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ // No "#pragma once" because this is a raw definition that can be copied by jit codegen.
3
+ // Eager mode clients should not include this file directly, instead,
4
+ // they should #include <ATen/cuda/PhiloxCudaState.h>, which has a #pragma once.
5
+
6
+ // Stores RNG state values. Passed as a kernel argument.
7
+ // See Note [CUDA Graph-safe RNG states].
8
+ //
9
+ // The raw definition lives in its own file so jit codegen can easily copy it.
10
+ namespace at {
11
+
12
+ struct PhiloxCudaState {
13
+ PhiloxCudaState() = default;
14
+ // Called if graph capture is not underway
15
+ PhiloxCudaState(uint64_t seed,
16
+ uint64_t offset) {
17
+ seed_.val = seed;
18
+ offset_.val = offset;
19
+ }
20
+ // Called if graph capture is underway
21
+ PhiloxCudaState(int64_t* seed,
22
+ int64_t* offset_extragraph,
23
+ uint64_t offset_intragraph) {
24
+ seed_.ptr = seed;
25
+ offset_.ptr = offset_extragraph;
26
+ offset_intragraph_ = offset_intragraph;
27
+ captured_ = true;
28
+ }
29
+
30
+ // Public members, directly accessible by at::cuda::philox::unpack.
31
+ // If we made them private with getters/setters, the getters/setters
32
+ // would have to be __device__, and we can't declare __device__ in ATen.
33
+ union Payload {
34
+ uint64_t val;
35
+ int64_t* ptr;
36
+ };
37
+
38
+ Payload seed_{};
39
+ Payload offset_{};
40
+ uint64_t offset_intragraph_ = 0;
41
+ bool captured_ = false;
42
+ };
43
+
44
+ } // namespace at
45
+
46
+ #else
47
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
48
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/TensorInfo.cuh ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/CollapseDims.h>
5
+
6
+ namespace at::cuda::detail {
7
+
8
+ #define MAX_TENSORINFO_DIMS 25
9
+
10
+ // CUDA kernel argument that defines tensor layout
11
+ template <typename T, typename IndexType>
12
+ struct TensorInfo {
13
+ TensorInfo();
14
+ TensorInfo(T* p,
15
+ int dim,
16
+ IndexType sz[MAX_TENSORINFO_DIMS],
17
+ IndexType st[MAX_TENSORINFO_DIMS]);
18
+
19
+ // Set the size of the given dimension to 1, as if it were a
20
+ // reduction dim (allows you to calculate offsets of the reduction
21
+ // slice)
22
+ void reduceDim(int dim);
23
+
24
+ // See note on [collapse dims].
25
+ int collapseDims(const int excludeDim = -1);
26
+
27
+ // Contiguous tensors of more than one dimension are collapsed down
28
+ // to one tensor
29
+ __host__ __device__ inline bool isContiguous() const {
30
+ return (dims == 1 && strides[0] == 1);
31
+ }
32
+
33
+ T* data;
34
+ IndexType sizes[MAX_TENSORINFO_DIMS];
35
+ IndexType strides[MAX_TENSORINFO_DIMS];
36
+ int dims;
37
+ };
38
+
39
+ template <typename T, typename IndexType>
40
+ TensorInfo<T, IndexType>::TensorInfo() {
41
+ data = nullptr;
42
+ dims = 0;
43
+ }
44
+
45
+ template <typename T, typename IndexType>
46
+ TensorInfo<T, IndexType>::TensorInfo(T* p,
47
+ int dim,
48
+ IndexType sz[MAX_TENSORINFO_DIMS],
49
+ IndexType st[MAX_TENSORINFO_DIMS]) {
50
+ data = p;
51
+ dims = dim;
52
+ TORCH_CHECK(dims < MAX_TENSORINFO_DIMS, "CUDA Tensors cannot have more than 25 dimensions");
53
+
54
+ for (int i = 0; i < dim; ++i) {
55
+ sizes[i] = sz[i];
56
+ strides[i] = st[i];
57
+ }
58
+ }
59
+
60
+ template <typename T, typename IndexType>
61
+ void
62
+ TensorInfo<T, IndexType>::reduceDim(int dim) {
63
+ TORCH_CHECK(dim < dims && dim >= 0, "expected dim between 0 and dims - 1");
64
+ sizes[dim] = 1;
65
+ }
66
+
67
+ template <typename T, typename IndexType>
68
+ int
69
+ TensorInfo<T, IndexType>::collapseDims(const int excludeDim) {
70
+ auto result = at::collapse_dims(sizes, strides, dims, excludeDim);
71
+ dims = std::get<1>(result);
72
+ return std::get<0>(result);
73
+ }
74
+
75
+ // Translate a linear index for the apply to a T* offset;
76
+ // specialized on `Dims` to reduce nvcc compilation time
77
+ template <typename T, typename IndexType, int Dims>
78
+ struct IndexToOffset {
79
+ static __host__ __device__ IndexType get(
80
+ IndexType linearId,
81
+ const TensorInfo<T, IndexType>& info) {
82
+
83
+ IndexType offset = 0;
84
+
85
+ // Uses static dims
86
+ for (int i = Dims - 1; i > 0; --i) {
87
+ IndexType curDimIndex = linearId % info.sizes[i];
88
+ IndexType curDimOffset = curDimIndex * info.strides[i];
89
+ offset += curDimOffset;
90
+ linearId /= info.sizes[i];
91
+ }
92
+
93
+ return offset + linearId * info.strides[0];
94
+ }
95
+ };
96
+
97
+ // Uses dynamic (runtime) instead of static (compile time) dims
98
+ template <typename T, typename IndexType>
99
+ struct IndexToOffset<T, IndexType, -1> {
100
+ static inline __host__ __device__ IndexType get(
101
+ IndexType linearId,
102
+ const TensorInfo<T, IndexType>& info) {
103
+
104
+ IndexType offset = 0;
105
+
106
+ for (int i = info.dims - 1; i > 0; --i) {
107
+ IndexType curDimIndex = linearId % info.sizes[i];
108
+ IndexType curDimOffset = curDimIndex * info.strides[i];
109
+ offset += curDimOffset;
110
+ linearId /= info.sizes[i];
111
+ }
112
+
113
+ return offset + linearId * info.strides[0];
114
+ }
115
+ };
116
+
117
+ } // namespace at::cuda::detail
118
+
119
+ #else
120
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
121
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/UnpackRaw.cuh ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ // No "#pragma once" because this is a raw definition that can be copied by jit codegen.
3
+ // Eager mode clients should not include this file directly, instead,
4
+ // they should #include <ATen/cuda/PhiloxUtils.cuh>, which has a #pragma once.
5
+
6
+ namespace at::cuda::philox {
7
+
8
+ // In-kernel call to retrieve philox seed and offset from a PhiloxCudaState instance whether
9
+ // that instance was created with graph capture underway or not.
10
+ // See Note [CUDA Graph-safe RNG states].
11
+ //
12
+ // We can't write a __device__ function in CUDAGeneratorImpl.h, because it's in ATen.
13
+ // Also, whatever call unpacks PhiloxCudaState in consumer kernels must be inlineable.
14
+ // Easiest thing that comes to mind is, define a __device__ unpack helper here, in ATen/cuda.
15
+ //
16
+ // The raw definition lives in its own file so jit codegen can easily copy it.
17
+ __host__ __device__ __forceinline__ std::tuple<uint64_t, uint64_t>
18
+ unpack(at::PhiloxCudaState arg) {
19
+ if (arg.captured_) {
20
+ // static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long".
21
+ // *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel.
22
+ // For most threads' reads it will hit in cache, so it shouldn't hurt performance.
23
+ return std::make_tuple(static_cast<uint64_t>(*arg.seed_.ptr), static_cast<uint64_t>(*(arg.offset_.ptr) + arg.offset_intragraph_));
24
+ } else {
25
+ return std::make_tuple(arg.seed_.val, arg.offset_.val);
26
+ }
27
+ }
28
+
29
+ // Adapted from TE
30
+ // extract seed and offset from PhiloxCudaState
31
+ __global__ void unpack_cudnn(at::PhiloxCudaState arg, int64_t* seed_ptr, int64_t* offset_ptr);
32
+
33
+ void unpack_cudnn_wrapper(at::PhiloxCudaState arg, int64_t* seed_ptr, int64_t* offset_ptr, cudaStream_t stream);
34
+
35
+ } // namespace at::cuda::philox
36
+
37
+ #else
38
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
39
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/GemmCommon.h ADDED
@@ -0,0 +1,705 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ // Original TunableOp is from onnxruntime.
3
+ // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
4
+ // https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
5
+ // Copyright (c) Microsoft Corporation.
6
+ // Licensed under the MIT license.
7
+ //
8
+ // Adapting TunableOp into PyTorch
9
+ // Copyright (c) Advanced Micro Devices, Inc.
10
+ //
11
+ #pragma once
12
+
13
+ #include <string>
14
+ #include <c10/core/ScalarType.h>
15
+
16
+ #include <ATen/cuda/tunable/TunableOp.h>
17
+ #include <ATen/cuda/tunable/Tunable.h>
18
+ #include <ATen/cuda/CUDABlas.h>
19
+ #include <ATen/cuda/Exceptions.h>
20
+ #include <c10/util/StringUtil.h>
21
+
22
+ #ifndef AT_PER_OPERATOR_HEADERS
23
+ #include <ATen/Functions.h>
24
+ #include <ATen/NativeFunctions.h>
25
+ #else
26
+ #include <ATen/ops/allclose.h>
27
+ #include <ATen/ops/from_blob.h>
28
+ #endif
29
+ #include <ATen/OpMathType.h>
30
+ #include <fmt/printf.h>
31
+
32
+ namespace at::cuda::tunable {
33
+
34
+ using at::blas::ScalingType;
35
+
36
+ enum class BlasOp {
37
+ N = 0,
38
+ T = 1
39
+ };
40
+
41
+ inline char BlasOpToString(BlasOp op) {
42
+ switch (op) {
43
+ case BlasOp::N:
44
+ return 'N';
45
+ case BlasOp::T:
46
+ return 'T';
47
+ }
48
+ TORCH_CHECK(false, "unrecognized BlasOp");
49
+ return 'N';
50
+ }
51
+
52
+ template <typename T>
53
+ inline const char* BLASTypeName(T v) {
54
+ return "unknown";
55
+ }
56
+
57
+ template <>
58
+ inline const char* BLASTypeName(float v) {
59
+ return "f32_r";
60
+ }
61
+
62
+ template <>
63
+ inline const char* BLASTypeName(double v) {
64
+ return "f64_r";
65
+ }
66
+
67
+ template <>
68
+ inline const char* BLASTypeName(BFloat16 v) {
69
+ return "bf16_r";
70
+ }
71
+
72
+ template <>
73
+ inline const char* BLASTypeName(Half v) {
74
+ return "f16_r";
75
+ }
76
+
77
+ //https://github.com/ROCm/hipBLASLt/blob/develop/library/src/include/auxiliary.hpp#L175
78
+ template <>
79
+ inline const char* BLASTypeName(Float8_e4m3fn v) {
80
+ return "f8_r";
81
+ }
82
+
83
+ template <>
84
+ inline const char* BLASTypeName(Float8_e5m2 v) {
85
+ return "bf8_r";
86
+ }
87
+
88
+ template <>
89
+ inline const char* BLASTypeName(Float8_e4m3fnuz v) {
90
+ return "f8_fnuz_r";
91
+ }
92
+
93
+ template <>
94
+ inline const char* BLASTypeName(Float8_e5m2fnuz v) {
95
+ return "bf8_fnuz_r";
96
+ }
97
+
98
+ template <>
99
+ inline const char* BLASTypeName(c10::complex<double> v) {
100
+ return "f64_r";
101
+ }
102
+
103
+ template <>
104
+ inline const char* BLASTypeName(c10::complex<float> v) {
105
+ return "f32_r";
106
+ }
107
+
108
+ inline std::string ScalarTypeToBLASType(c10::ScalarType scalar_type) {
109
+ std::string BLASType;
110
+ switch (scalar_type) {
111
+ case c10::ScalarType::Float:{
112
+ BLASType = "f32_r";
113
+ break;
114
+ }
115
+ case c10::ScalarType::Double:{
116
+ BLASType = "f64_r";
117
+ break;
118
+ }
119
+ case c10::ScalarType::BFloat16:{
120
+ BLASType = "bf16_r";
121
+ break;
122
+ }
123
+ case c10::ScalarType::Half: {
124
+ BLASType = "f16_r";
125
+ break;
126
+ }
127
+ case c10::ScalarType::Float8_e4m3fn: {
128
+ BLASType = "f8_r";
129
+ break;
130
+ }
131
+ case c10::ScalarType::Float8_e5m2: {
132
+ BLASType = "bf8_r";
133
+ break;
134
+ }
135
+ case c10::ScalarType::Float8_e4m3fnuz: {
136
+ BLASType = "f8_fnuz_r";
137
+ break;
138
+ }
139
+ case c10::ScalarType::Float8_e5m2fnuz: {
140
+ BLASType = "bf8_fnuz_r";
141
+ break;
142
+ }
143
+ case c10::ScalarType::ComplexFloat:{
144
+ BLASType = "f32_c";
145
+ break;
146
+ }
147
+ case c10::ScalarType::ComplexDouble:{
148
+ BLASType = "f64_c";
149
+ break;
150
+ }
151
+ default:
152
+ BLASType = "unknown";
153
+ }
154
+ return BLASType;
155
+
156
+ }
157
+
158
+ // Similar to Compute Type in GemmRocblas.h
159
+ template <typename T>
160
+ inline std::string ComputeTypeFor() {
161
+ return "Unknown ComputeType";
162
+ }
163
+
164
+ // This is a union of the compute types for
165
+ // ROCBLAS and hipBLASLt.
166
+ template <>
167
+ inline std::string ComputeTypeFor<float>() {
168
+ if (at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) != at::Float32Precision::TF32) {
169
+ return "f32_r";
170
+ } else {
171
+ return "xf32_r";
172
+ }
173
+ }
174
+
175
+ template <>
176
+ inline std::string ComputeTypeFor<double>() {
177
+ return "f64_r";
178
+ }
179
+
180
+ template <>
181
+ inline std::string ComputeTypeFor<Half>() {
182
+ return "f32_r";
183
+ }
184
+
185
+ template <>
186
+ inline std::string ComputeTypeFor<BFloat16>() {
187
+ return "f32_r";
188
+ }
189
+
190
+ template <>
191
+ inline std::string ComputeTypeFor<c10::complex<float>>() {
192
+ return "f32_c";
193
+ }
194
+
195
+ template <>
196
+ inline std::string ComputeTypeFor<c10::complex<double>>() {
197
+ return "f64_c";
198
+ }
199
+
200
+ template <>
201
+ inline std::string ComputeTypeFor<Float8_e4m3fn>() {
202
+ return "f32_r";
203
+ }
204
+
205
+ template <>
206
+ inline std::string ComputeTypeFor<Float8_e5m2>() {
207
+ return "f32_r";
208
+ }
209
+
210
+ template <>
211
+ inline std::string ComputeTypeFor<Float8_e4m3fnuz>() {
212
+ return "f32_r";
213
+ }
214
+
215
+ template <>
216
+ inline std::string ComputeTypeFor<Float8_e5m2fnuz>() {
217
+ return "f32_r";
218
+ }
219
+
220
+ // Convert opmath_type<T> to string
221
+ template <typename T>
222
+ inline std::string to_string_opmath(const at::opmath_type<T>& value) {
223
+ if constexpr (std::is_same_v<at::opmath_type<T>, c10::complex<float>> ||
224
+ std::is_same_v<at::opmath_type<T>, c10::complex<double>>) {
225
+ return fmt::format("({:.4f}, {:.4f})", value.real(), value.imag());
226
+ } else {
227
+ return fmt::format("{:.4f}", value);
228
+ }
229
+ }
230
+
231
+ // convert activation epilogue to string
232
+ inline std::string to_string_epilogue(const at::cuda::blas::GEMMAndBiasActivationEpilogue& value) {
233
+ switch (value) {
234
+ case at::cuda::blas::GEMMAndBiasActivationEpilogue::None:
235
+ return std::string("None");
236
+ break;
237
+ case at::cuda::blas::GEMMAndBiasActivationEpilogue::RELU:
238
+ return std::string("RELU");
239
+ break;
240
+ case cuda::blas::GEMMAndBiasActivationEpilogue::GELU:
241
+ return std::string("GELU");
242
+ break;
243
+ default:
244
+ return std::string("unknown");
245
+ }
246
+ }
247
+
248
+ namespace detail {
249
+
250
+ static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size, const NumericalCheckConfig& config) {
251
+
252
+ if (!config.enabled) {
253
+ return true; // skip when disabled
254
+ }
255
+
256
+ auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA);
257
+ at::Tensor ref = at::from_blob(c, {size}, options);
258
+ at::Tensor oth = at::from_blob(other_c, {size}, options);
259
+ at::Tensor ref_float = ref.to(at::kFloat);
260
+ at::Tensor oth_float = oth.to(at::kFloat);
261
+
262
+ const bool ok = at::allclose(ref_float, oth_float, config.rtol, config.atol);
263
+ if (ok) {
264
+ TUNABLE_LOG3("├──verify numerics: PASSED with atol=", config.atol, ", rtol=", config.rtol);
265
+ } else {
266
+ TUNABLE_LOG3("├──verify numerics: FAILED with atol=", config.atol, ", rtol=", config.rtol);
267
+ }
268
+ return ok;
269
+ }
270
+
271
+ }
272
+
273
+ // Note on GetSizeA et al.
274
+ // Tensors can be dense or arbitrarily strided. We only need our copies to be large enough.
275
+ // Our copies must be at least as large as the m n k shapes dictate, but could be larger
276
+ // depending on the lda ldb ldc values. Similarly for the batched case.
277
+
278
+ template <typename T>
279
+ struct GemmParams : OpParams {
280
+ GemmParams() = default;
281
+
282
+ std::string BLASSignature() const override {
283
+ std::string alpha_str = to_string_opmath<T>(alpha);
284
+ std::string beta_str = to_string_opmath<T>(beta);
285
+ return fmt::sprintf("- { function: matmul, M: %ld, N: %ld, K: %ld, lda: %ld, ldb: %ld, ldc: %ld, ldd: %ld, stride_a: 0, stride_b: 0, stride_c: 0, stride_d: 0, "
286
+ "alpha: %s, beta: %s, transA: %c, transB: %c, batch_count: 1, a_type: %s, b_type: %s, c_type: %s, d_type: %s, scale_type: %s, bias_type: %s, compute_type: %s }",
287
+ m, n, k, lda, ldb, ldc, ldc, alpha_str, beta_str, transa, transb,
288
+ BLASTypeName<T>(T{}), BLASTypeName<T>(T{}), BLASTypeName<T>(T{}), BLASTypeName<T>(T{}), ComputeTypeFor<T>(), ComputeTypeFor<T>(), ComputeTypeFor<T>());
289
+ }
290
+
291
+ std::string Signature() const override {
292
+ return fmt::sprintf("%c%c_%ld_%ld_%ld_ld_%ld_%ld_%ld", transa, transb, m, n, k, lda, ldb, ldc);
293
+ }
294
+
295
+ size_t GetSizeA() const {
296
+ size_t size_stride = lda * ((transa == 'n' || transa == 'N') ? k : m);
297
+ size_t size_dense = m * k;
298
+ return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense);
299
+ }
300
+
301
+ size_t GetSizeB() const {
302
+ size_t size_stride = ldb * ((transb == 'n' || transb == 'N') ? n : k);
303
+ size_t size_dense = k * n;
304
+ return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense);
305
+ }
306
+
307
+ size_t GetSizeC() const {
308
+ size_t size_stride = ldc * n;
309
+ size_t size_dense = m * n;
310
+ return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense);
311
+ }
312
+
313
+ size_t GetSize(bool duplicate_inputs) const {
314
+ size_t size = GetSizeC();
315
+ if (duplicate_inputs) {
316
+ size += GetSizeA();
317
+ size += GetSizeB();
318
+ }
319
+ return size;
320
+ }
321
+
322
+ GemmParams* DeepCopy(bool duplicate_inputs) const {
323
+ GemmParams* copy = new GemmParams;
324
+ *copy = *this;
325
+ c10::DeviceIndex device = 0;
326
+ AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
327
+ size_t c_size = GetSizeC();
328
+ copy->c = static_cast<T*>(c10::cuda::CUDACachingAllocator::raw_alloc(c_size));
329
+ AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
330
+ copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
331
+ if (duplicate_inputs) {
332
+ size_t a_size = GetSizeA();
333
+ size_t b_size = GetSizeB();
334
+ copy->a = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(a_size));
335
+ copy->b = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(b_size));
336
+ copy->duplicate_inputs_ = true;
337
+ }
338
+ return copy;
339
+ }
340
+
341
+ // only call on object returned by DeepCopy
342
+ void Delete() {
343
+ c10::cuda::CUDACachingAllocator::raw_delete(c);
344
+ if (duplicate_inputs_) {
345
+ // NOLINTNEXTLINE(*const-cast*)
346
+ c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(a));
347
+ // NOLINTNEXTLINE(*const-cast*)
348
+ c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(b));
349
+ }
350
+ }
351
+
352
+ TuningStatus NumericalCheck(GemmParams<T> *other) {
353
+ auto* ctx = getTuningContext();
354
+ auto cfg = ctx->GetNumericalCheckConfig();
355
+ auto c_dtype = c10::CppTypeToScalarType<T>::value;
356
+ return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
357
+ }
358
+
359
+ char transa{};
360
+ char transb{};
361
+ int64_t m{};
362
+ int64_t n{};
363
+ int64_t k{};
364
+ at::opmath_type<T> alpha;
365
+ const T* a{};
366
+ int64_t lda{};
367
+ const T* b{};
368
+ int64_t ldb{};
369
+ at::opmath_type<T> beta;
370
+ T* c{};
371
+ int64_t ldc{};
372
+ private:
373
+ bool duplicate_inputs_{false};
374
+ };
375
+
376
+ template <typename T>
377
+ struct GemmAndBiasParams : OpParams {
378
+ std::string BLASSignature() const override {
379
+ std::string alpha_str = to_string_opmath<T>(alpha);
380
+ std::string activation_str = to_string_epilogue(activation);
381
+ return fmt::sprintf("- { function: matmul, M: %ld, N: %ld, K: %ld, lda: %ld, ldb: %ld, ldc: %ld, ldd: %ld, stride_a: 0, stride_b: 0, stride_c: 0, stride_d: 0, "
382
+ "alpha: %s, transA: %c, transB: %c, batch_count: 1, a_type: %s, b_type: %s, c_type: %s, d_type: %s, activation: %s, bias_type: %s, scale_type: %s, compute_type: %s }",
383
+ m, n, k, lda, ldb, ldc, ldc, alpha_str, transa, transb,
384
+ BLASTypeName<T>(T{}), BLASTypeName<T>(T{}), BLASTypeName<T>(T{}), BLASTypeName<T>(T{}), activation_str, BLASTypeName<T>(T{}), ComputeTypeFor<T>(), ComputeTypeFor<T>(), ComputeTypeFor<T>());
385
+ }
386
+
387
+ std::string Signature() const override {
388
+ return fmt::sprintf("%c%c_%ld_%ld_%ld_ld_%ld_%ld_%ld", transa, transb, m, n, k, lda, ldb, ldc);
389
+ }
390
+
391
+ size_t GetSizeA() const {
392
+ size_t size_stride = lda * ((transa == 'n' || transa == 'N') ? k : m);
393
+ size_t size_dense = m * k;
394
+ return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense);
395
+ }
396
+
397
+ size_t GetSizeB() const {
398
+ size_t size_stride = ldb * ((transb == 'n' || transb == 'N') ? n : k);
399
+ size_t size_dense = k * n;
400
+ return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense);
401
+ }
402
+
403
+ size_t GetSizeC() const {
404
+ size_t size_stride = ldc * n;
405
+ size_t size_dense = m * n;
406
+ return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense);
407
+ }
408
+
409
+ size_t GetSize(bool duplicate_inputs) const {
410
+ size_t size = GetSizeC();
411
+ if (duplicate_inputs) {
412
+ size += GetSizeA();
413
+ size += GetSizeB();
414
+ }
415
+ return size;
416
+ }
417
+
418
+ GemmAndBiasParams* DeepCopy(bool duplicate_inputs) const {
419
+ GemmAndBiasParams* copy = new GemmAndBiasParams;
420
+ *copy = *this;
421
+ c10::DeviceIndex device = 0;
422
+ AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
423
+ size_t c_size = GetSizeC();
424
+ copy->c = static_cast<T*>(c10::cuda::CUDACachingAllocator::raw_alloc(c_size));
425
+ AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
426
+ copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
427
+ if (duplicate_inputs) {
428
+ size_t a_size = GetSizeA();
429
+ size_t b_size = GetSizeB();
430
+ copy->a = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(a_size));
431
+ copy->b = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(b_size));
432
+ copy->duplicate_inputs_ = true;
433
+ }
434
+ return copy;
435
+ }
436
+
437
+ // only call on object returned by DeepCopy
438
+ void Delete() {
439
+ c10::cuda::CUDACachingAllocator::raw_delete(c);
440
+ if (duplicate_inputs_) {
441
+ // NOLINTNEXTLINE(*const-cast)
442
+ c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(a));
443
+ // NOLINTNEXTLINE(*const-cast)
444
+ c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(b));
445
+ }
446
+ }
447
+
448
+ TuningStatus NumericalCheck(GemmAndBiasParams<T> *other) {
449
+ auto* ctx = getTuningContext();
450
+ auto cfg = ctx->GetNumericalCheckConfig();
451
+ auto c_dtype = c10::CppTypeToScalarType<T>::value;
452
+ return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
453
+ }
454
+
455
+ char transa{};
456
+ char transb{};
457
+ int64_t m{};
458
+ int64_t n{};
459
+ int64_t k{};
460
+ at::opmath_type<T> alpha{};
461
+ const T* a{};
462
+ int64_t lda{};
463
+ const T* b{};
464
+ int64_t ldb{};
465
+ T* c{};
466
+ int64_t ldc{};
467
+ const T* bias{};
468
+ at::cuda::blas::GEMMAndBiasActivationEpilogue activation{};
469
+ private:
470
+ bool duplicate_inputs_{false};
471
+ };
472
+
473
+ template <typename T, typename C_Dtype = T>
474
+ struct GemmStridedBatchedParams : OpParams {
475
+ std::string BLASSignature() const override {
476
+ std::string alpha_str = to_string_opmath<T>(alpha);
477
+ std::string beta_str = to_string_opmath<T>(beta);
478
+ return fmt::sprintf("- { function: matmul, M: %ld, N: %ld, K: %ld, lda: %ld, ldb: %ld, ldc: %ld, ldd: %ld, stride_a: %ld, stride_b: %ld, stride_c: %ld, stride_d: %ld, "
479
+ "alpha: %s, beta: %s, transA: %c, transB: %c, batch_count: %ld, a_type: %s, b_type: %s, c_type: %s, d_type: %s, scale_type: %s, compute_type: %s }",
480
+ m, n, k, lda, ldb, ldc, ldc, stride_a, stride_b, stride_c, stride_c, alpha_str, beta_str, transa, transb, batch,
481
+ BLASTypeName<T>(T{}), BLASTypeName<T>(T{}), BLASTypeName<C_Dtype>(C_Dtype{}), BLASTypeName<T>(T{}), ComputeTypeFor<T>(), ComputeTypeFor<T>());
482
+ }
483
+
484
+ std::string Signature() const override {
485
+ return fmt::sprintf("%c%c_%ld_%ld_%ld_B_%ld_ld_%ld_%ld_%ld", transa, transb, m, n, k, batch, lda, ldb, ldc);
486
+ }
487
+
488
+ size_t GetSizeA() const {
489
+ size_t size_stride = stride_a * batch;
490
+ size_t size_dense = m * k * batch;
491
+ return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense);
492
+ }
493
+
494
+ size_t GetSizeB() const {
495
+ size_t size_stride = stride_b * batch;
496
+ size_t size_dense = k * n * batch;
497
+ return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense);
498
+ }
499
+
500
+ size_t GetSizeC() const {
501
+ size_t size_stride = stride_c * batch;
502
+ size_t size_dense = m * n * batch;
503
+ return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense);
504
+ }
505
+
506
+ size_t GetSize(bool duplicate_inputs) const {
507
+ size_t size = GetSizeC();
508
+ if (duplicate_inputs) {
509
+ size += GetSizeA();
510
+ size += GetSizeB();
511
+ }
512
+ return size;
513
+ }
514
+
515
+ GemmStridedBatchedParams* DeepCopy(bool duplicate_inputs) const {
516
+ GemmStridedBatchedParams* copy = new GemmStridedBatchedParams;
517
+ *copy = *this;
518
+ c10::DeviceIndex device = 0;
519
+ AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
520
+ size_t c_size = GetSizeC();
521
+ copy->c = static_cast<C_Dtype*>(c10::cuda::CUDACachingAllocator::raw_alloc(c_size));
522
+ AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
523
+ copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
524
+ if (duplicate_inputs) {
525
+ size_t a_size = GetSizeA();
526
+ size_t b_size = GetSizeB();
527
+ // NOLINTNEXTLINE(*const-cast*)
528
+ copy->a = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(a_size));
529
+ // NOLINTNEXTLINE(*const-cast*)
530
+ copy->b = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(b_size));
531
+ copy->duplicate_inputs_ = true;
532
+ }
533
+ return copy;
534
+ }
535
+
536
+ // only call on object returned by DeepCopy
537
+ void Delete() {
538
+ c10::cuda::CUDACachingAllocator::raw_delete(c);
539
+ if (duplicate_inputs_) {
540
+ // NOLINTNEXTLINE(*const-cast*)
541
+ c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(a));
542
+ // NOLINTNEXTLINE(*const-cast*)
543
+ c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(b));
544
+ }
545
+ }
546
+
547
+ TuningStatus NumericalCheck(GemmStridedBatchedParams<T> *other) {
548
+ auto* ctx = getTuningContext();
549
+ auto cfg = ctx->GetNumericalCheckConfig();
550
+ auto c_dtype = c10::CppTypeToScalarType<C_Dtype>::value;
551
+ return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
552
+ }
553
+
554
+ char transa{};
555
+ char transb{};
556
+ int64_t m{};
557
+ int64_t n{};
558
+ int64_t k{};
559
+ at::opmath_type<T> alpha{};
560
+ const T* a{};
561
+ int64_t lda{};
562
+ int64_t stride_a{};
563
+ const T* b{};
564
+ int64_t ldb{};
565
+ int64_t stride_b{};
566
+ at::opmath_type<T> beta;
567
+ C_Dtype* c{};
568
+ int64_t ldc{};
569
+ int64_t stride_c{};
570
+ int64_t batch{};
571
+ private:
572
+ bool duplicate_inputs_{false};
573
+ };
574
+
575
+ template <typename T>
576
+ struct ScaledGemmParams : OpParams {
577
+ ScaledGemmParams() = default;
578
+
579
+ std::string BLASSignature() const override {
580
+ // Excluding use_fast_accum and use_rowise booleans for now
581
+ if (bias_ptr == nullptr) {
582
+ return fmt::sprintf("- { function: matmul, M: %ld, N: %ld, K: %ld, lda: %ld, ldb: %ld, ldc: %ld, ldd: %ld, stride_a: 0, stride_b: 0, stride_c: 0, stride_d: 0, "
583
+ "transA: %c, transB: %c, batch_count: 1, scaleA: f32_r, scaleB: f32_r, a_type: %s, b_type: %s, c_type: %s, d_type: %s, scale_type: %s, compute_type: %s }",
584
+ m, n, k, lda, ldb, ldc, ldc, transa, transb,
585
+ ScalarTypeToBLASType(a_dtype), ScalarTypeToBLASType(b_dtype), ScalarTypeToBLASType(c_dtype), ScalarTypeToBLASType(c_dtype),
586
+ ComputeTypeFor<T>(), ComputeTypeFor<T>());
587
+ }
588
+ else {
589
+ return fmt::sprintf("- { function: matmul, M: %ld, N: %ld, K: %ld, lda: %ld, ldb: %ld, ldc: %ld, ldd: %ld, stride_a: 0, stride_b: 0, stride_c: 0, stride_d: 0, "
590
+ "transA: %c, transB: %c, batch_count: 1, scaleA: f32_r, scaleB: f32_r, a_type: %s, b_type: %s, c_type: %s, d_type: %s, bias_type: %s, scale_type: %s, compute_type: %s }",
591
+ m, n, k, lda, ldb, ldc, ldc, transa, transb,
592
+ ScalarTypeToBLASType(a_dtype), ScalarTypeToBLASType(b_dtype), ScalarTypeToBLASType(c_dtype), ScalarTypeToBLASType(c_dtype), ScalarTypeToBLASType(bias_dtype),
593
+ ComputeTypeFor<T>(), ComputeTypeFor<T>());
594
+ }
595
+ }
596
+
597
+ std::string Signature() const override {
598
+ // In Blas.cpp, code defaults to a bias_dtype of Half even when there is no bias vector.
599
+ // Search for this line::
600
+ // params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_;
601
+ //
602
+ // In TunableOp, we must distinguish in param signature these two cases: with and without a bias vector.
603
+ return fmt::sprintf("%c%c_%ld_%ld_%ld_ld_%ld_%ld_%ld_rw_%d_bias_%s",
604
+ transa, transb, m, n, k, lda, ldb, ldc,
605
+ a_scaling_type == ScalingType::RowWise && b_scaling_type == ScalingType::RowWise,
606
+ bias_ptr == nullptr ? "None" : at::toString(bias_dtype));
607
+ }
608
+
609
+ size_t GetSizeA() const {
610
+ size_t size_stride = lda * ((transa == 'n' || transa == 'N') ? k : m);
611
+ size_t size_dense = m * k;
612
+ return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense);
613
+ }
614
+
615
+ size_t GetSizeB() const {
616
+ size_t size_stride = ldb * ((transb == 'n' || transb == 'N') ? n : k);
617
+ size_t size_dense = k * n;
618
+ return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense);
619
+ }
620
+
621
+ size_t GetSizeC() const {
622
+ size_t size_stride = ldc * n;
623
+ size_t size_dense = m * n;
624
+ return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense);
625
+ }
626
+
627
+ size_t GetSize(bool duplicate_inputs) const {
628
+ size_t size = GetSizeC();
629
+ if (duplicate_inputs) {
630
+ size += GetSizeA();
631
+ size += GetSizeB();
632
+ }
633
+ return size;
634
+ }
635
+
636
+ ScaledGemmParams* DeepCopy(bool duplicate_inputs) const {
637
+ ScaledGemmParams* copy = new ScaledGemmParams;
638
+ *copy = *this;
639
+ c10::DeviceIndex device = 0;
640
+ AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
641
+ size_t c_size = GetSizeC();
642
+ copy->c = c10::cuda::CUDACachingAllocator::raw_alloc(c_size);
643
+ AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
644
+ copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
645
+ if (duplicate_inputs) {
646
+ size_t a_size = GetSizeA();
647
+ size_t b_size = GetSizeB();
648
+ copy->a = c10::cuda::CUDACachingAllocator::raw_alloc(a_size);
649
+ copy->b = c10::cuda::CUDACachingAllocator::raw_alloc(b_size);
650
+ copy->duplicate_inputs_ = true;
651
+ }
652
+ return copy;
653
+ }
654
+
655
+ // only call on object returned by DeepCopy
656
+ void Delete() {
657
+ c10::cuda::CUDACachingAllocator::raw_delete(c);
658
+ if (duplicate_inputs_) {
659
+ // NOLINTNEXTLINE(*const-cast*)
660
+ c10::cuda::CUDACachingAllocator::raw_delete(const_cast<void*>(a));
661
+ // NOLINTNEXTLINE(*const-cast*)
662
+ c10::cuda::CUDACachingAllocator::raw_delete(const_cast<void*>(b));
663
+ }
664
+ }
665
+
666
+ TuningStatus NumericalCheck(ScaledGemmParams<T> *other) {
667
+ auto* ctx = getTuningContext();
668
+ auto cfg = ctx->GetNumericalCheckConfig();
669
+ return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
670
+ }
671
+
672
+ char transa{};
673
+ char transb{};
674
+ int64_t m{};
675
+ int64_t n{};
676
+ int64_t k{};
677
+ const void* a{};
678
+ const void* a_scale_ptr{};
679
+ int64_t lda{};
680
+ ScalarType a_dtype{};
681
+ ScalarType a_scale_dtype{};
682
+ ScalingType a_scaling_type{};
683
+ const void* b{};
684
+ const void* b_scale_ptr{};
685
+ int64_t ldb{};
686
+ ScalarType b_dtype{};
687
+ ScalarType b_scale_dtype{};
688
+ ScalingType b_scaling_type{};
689
+ const void* bias_ptr{};
690
+ ScalarType bias_dtype{};
691
+ void* c{};
692
+ const void* c_scale_ptr{};
693
+ int64_t ldc{};
694
+ ScalarType c_dtype{};
695
+ void* amax_ptr{};
696
+ bool use_fast_accum{};
697
+ private:
698
+ bool duplicate_inputs_{false};
699
+ };
700
+
701
+ } // namespace at::cuda::tunable
702
+
703
+ #else
704
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
705
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/GemmHipblaslt.h ADDED
@@ -0,0 +1,692 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ // Copyright (c) Microsoft Corporation. All rights reserved.
3
+ // Licensed under the MIT License.
4
+
5
+ #pragma once
6
+
7
+ #include <ATen/cuda/CUDAContext.h>
8
+ #include <ATen/cuda/CUDADataType.h>
9
+ #include <ATen/cuda/tunable/TunableOp.h>
10
+ #include <ATen/cuda/tunable/GemmCommon.h>
11
+ #include <c10/cuda/CUDACachingAllocator.h>
12
+ #include <c10/util/StringUtil.h>
13
+ #include <fmt/printf.h>
14
+
15
+ #include <hipblaslt/hipblaslt.h>
16
+ #include <hipblaslt/hipblaslt-ext.hpp>
17
+
18
+ #define TORCH_HIPBLASLT_CHECK(EXPR) \
19
+ do { \
20
+ hipblasStatus_t __err = EXPR; \
21
+ TORCH_CHECK(__err == HIPBLAS_STATUS_SUCCESS, \
22
+ "hipblaslt error: ", \
23
+ hipblasStatusToString(__err), \
24
+ " when calling `" #EXPR "`"); \
25
+ } while (0)
26
+
27
+ namespace at::cuda::tunable {
28
+
29
+ template <typename T>
30
+ constexpr hipDataType HipDataTypeFor();
31
+
32
+ template <>
33
+ constexpr hipDataType HipDataTypeFor<float>() {
34
+ return HIP_R_32F;
35
+ }
36
+
37
+ template <>
38
+ constexpr hipDataType HipDataTypeFor<Half>() {
39
+ return HIP_R_16F;
40
+ }
41
+
42
+ template <>
43
+ constexpr hipDataType HipDataTypeFor<BFloat16>() {
44
+ return HIP_R_16BF;
45
+ }
46
+
47
+ template <>
48
+ constexpr hipDataType HipDataTypeFor<double>() {
49
+ return HIP_R_64F;
50
+ }
51
+
52
+ template <>
53
+ constexpr hipDataType HipDataTypeFor<c10::Float8_e4m3fnuz>() {
54
+ return HIP_R_8F_E4M3_FNUZ;
55
+ }
56
+
57
+ template <>
58
+ constexpr hipDataType HipDataTypeFor<c10::Float8_e5m2fnuz>() {
59
+ return HIP_R_8F_E5M2_FNUZ;
60
+ }
61
+
62
+ // This code is instantiated regardless of ROCm version.
63
+ // Prior to ROCm 6.3, we hard-code the known enum values.
64
+ template <>
65
+ constexpr hipDataType HipDataTypeFor<c10::Float8_e4m3fn>() {
66
+ #if ROCM_VERSION >= 60300
67
+ return HIP_R_8F_E4M3;
68
+ #else
69
+ return static_cast<hipDataType>(28);
70
+ #endif
71
+ }
72
+
73
+ template <>
74
+ constexpr hipDataType HipDataTypeFor<c10::Float8_e5m2>() {
75
+ #if ROCM_VERSION >= 60300
76
+ return HIP_R_8F_E5M2;
77
+ #else
78
+ return static_cast<hipDataType>(29);
79
+ #endif
80
+ }
81
+
82
+ // This type is not intended for matrix types but rather a scale factor.
83
+ // Return a dummy value to satisfy linker.
84
+ template <>
85
+ constexpr hipDataType HipDataTypeFor<c10::Float8_e8m0fnu>() {
86
+ return static_cast<hipDataType>(500);
87
+ }
88
+
89
+ template <>
90
+ constexpr hipDataType HipDataTypeFor<c10::Float4_e2m1fn_x2>() {
91
+ #if ROCM_VERSION >= 70000
92
+ return HIP_R_4F_E2M1;
93
+ #else
94
+ return static_cast<hipDataType>(33);
95
+ #endif
96
+ }
97
+
98
+ template <typename T>
99
+ int GetBatchFromParams(const GemmParams<T>* params) {
100
+ return 1;
101
+ }
102
+
103
+ template <typename T>
104
+ int GetBatchFromParams(const GemmAndBiasParams<T>* params) {
105
+ return 1;
106
+ }
107
+
108
+ template <typename T>
109
+ int GetBatchFromParams(const GemmStridedBatchedParams<T>* params) {
110
+ return params->batch;
111
+ }
112
+
113
+ template <typename T>
114
+ int GetBatchFromParams(const ScaledGemmParams<T>* params) {
115
+ return 1;
116
+ }
117
+
118
+ template <typename T>
119
+ int GetStrideAFromParams(const GemmParams<T>* params) {
120
+ return 1;
121
+ }
122
+
123
+ template <typename T>
124
+ int GetStrideAFromParams(const GemmAndBiasParams<T>* params) {
125
+ return 1;
126
+ }
127
+
128
+ template <typename T>
129
+ int GetStrideAFromParams(const GemmStridedBatchedParams<T>* params) {
130
+ return params->stride_a;
131
+ }
132
+
133
+ template <typename T>
134
+ int GetStrideAFromParams(const ScaledGemmParams<T>* params) {
135
+ return 1;
136
+ }
137
+
138
+ template <typename T>
139
+ int GetStrideBFromParams(const GemmParams<T>* params) {
140
+ return 1;
141
+ }
142
+
143
+ template <typename T>
144
+ int GetStrideBFromParams(const GemmAndBiasParams<T>* params) {
145
+ return 1;
146
+ }
147
+
148
+ template <typename T>
149
+ int GetStrideBFromParams(const GemmStridedBatchedParams<T>* params) {
150
+ return params->stride_b;
151
+ }
152
+
153
+ template <typename T>
154
+ int GetStrideBFromParams(const ScaledGemmParams<T>* params) {
155
+ return 1;
156
+ }
157
+
158
+ template <typename T>
159
+ int GetStrideCFromParams(const GemmParams<T>* params) {
160
+ return 1;
161
+ }
162
+
163
+ template <typename T>
164
+ int GetStrideCFromParams(const GemmAndBiasParams<T>* params) {
165
+ return 1;
166
+ }
167
+
168
+ template <typename T>
169
+ int GetStrideCFromParams(const GemmStridedBatchedParams<T>* params) {
170
+ return params->stride_c;
171
+ }
172
+
173
+ template <typename T>
174
+ int GetStrideCFromParams(const ScaledGemmParams<T>* params) {
175
+ return 1;
176
+ }
177
+
178
+ template <typename T>
179
+ float GetAlphaFromParams(const GemmParams<T>* params) {
180
+ return params->alpha;
181
+ }
182
+
183
+ template <typename T>
184
+ float GetAlphaFromParams(const GemmAndBiasParams<T>* params) {
185
+ return params->alpha;
186
+ }
187
+
188
+ template <typename T>
189
+ float GetAlphaFromParams(const GemmStridedBatchedParams<T>* params) {
190
+ return params->alpha;
191
+ }
192
+
193
+ template <typename T>
194
+ float GetAlphaFromParams(const ScaledGemmParams<T>* params) {
195
+ return 1.0;
196
+ }
197
+
198
+ template <typename T>
199
+ float GetBetaFromParams(const GemmParams<T>* params) {
200
+ return params->beta;
201
+ }
202
+
203
+ template <typename T>
204
+ float GetBetaFromParams(const GemmAndBiasParams<T>* params) {
205
+ return 0.0;
206
+ }
207
+
208
+ template <typename T>
209
+ float GetBetaFromParams(const GemmStridedBatchedParams<T>* params) {
210
+ return params->beta;
211
+ }
212
+
213
+ template <typename T>
214
+ float GetBetaFromParams(const ScaledGemmParams<T>* params) {
215
+ return 0.0;
216
+ }
217
+
218
+ template <typename T>
219
+ ScalingType GetAScalingTypeFromParams(const GemmParams<T>* params) {
220
+ return ScalingType::TensorWise;
221
+ }
222
+
223
+ template <typename T>
224
+ ScalingType GetBScalingTypeFromParams(const GemmParams<T>* params) {
225
+ return ScalingType::TensorWise;
226
+ }
227
+
228
+ template <typename T>
229
+ ScalingType GetAScalingTypeFromParams(const GemmAndBiasParams<T>* params) {
230
+ return ScalingType::TensorWise;
231
+ }
232
+
233
+ template <typename T>
234
+ ScalingType GetBScalingTypeFromParams(const GemmAndBiasParams<T>* params) {
235
+ return ScalingType::TensorWise;
236
+ }
237
+
238
+ template <typename T>
239
+ ScalingType GetAScalingTypeFromParams(const GemmStridedBatchedParams<T>* params) {
240
+ return ScalingType::TensorWise;
241
+ }
242
+
243
+ template <typename T>
244
+ ScalingType GetBScalingTypeFromParams(const GemmStridedBatchedParams<T>* params) {
245
+ return ScalingType::TensorWise;
246
+ }
247
+
248
+ template <typename T>
249
+ ScalingType GetAScalingTypeFromParams(const ScaledGemmParams<T>* params) {
250
+ return params->a_scaling_type;
251
+ }
252
+
253
+ template <typename T>
254
+ ScalingType GetBScalingTypeFromParams(const ScaledGemmParams<T>* params) {
255
+ return params->b_scaling_type;
256
+ }
257
+
258
+ template <typename T>
259
+ const void* GetAScalePointerFromParams(const GemmParams<T>* params) {
260
+ return nullptr;
261
+ }
262
+
263
+ template <typename T>
264
+ const void* GetAScalePointerFromParams(const GemmAndBiasParams<T>* params) {
265
+ return nullptr;
266
+ }
267
+
268
+ template <typename T>
269
+ const void* GetAScalePointerFromParams(const GemmStridedBatchedParams<T>* params) {
270
+ return nullptr;
271
+ }
272
+
273
+ template <typename T>
274
+ const void* GetAScalePointerFromParams(const ScaledGemmParams<T>* params) {
275
+ return params->a_scale_ptr;
276
+ }
277
+
278
+ template <typename T>
279
+ const void* GetBScalePointerFromParams(const GemmParams<T>* params) {
280
+ return nullptr;
281
+ }
282
+
283
+ template <typename T>
284
+ const void* GetBScalePointerFromParams(const GemmAndBiasParams<T>* params) {
285
+ return nullptr;
286
+ }
287
+
288
+ template <typename T>
289
+ const void* GetBScalePointerFromParams(const GemmStridedBatchedParams<T>* params) {
290
+ return nullptr;
291
+ }
292
+
293
+ template <typename T>
294
+ const void* GetBScalePointerFromParams(const ScaledGemmParams<T>* params) {
295
+ return params->b_scale_ptr;
296
+ }
297
+
298
+ template <typename T>
299
+ const void* GetDScalePointerFromParams(const GemmParams<T>* params) {
300
+ return nullptr;
301
+ }
302
+
303
+ template <typename T>
304
+ const void* GetDScalePointerFromParams(const GemmAndBiasParams<T>* params) {
305
+ return nullptr;
306
+ }
307
+
308
+ template <typename T>
309
+ const void* GetDScalePointerFromParams(const GemmStridedBatchedParams<T>* params) {
310
+ return nullptr;
311
+ }
312
+
313
+ template <typename T>
314
+ const void* GetDScalePointerFromParams(const ScaledGemmParams<T>* params) {
315
+ return params->c_scale_ptr;
316
+ }
317
+
318
+ template <typename T>
319
+ const void* GetBiasPointerFromParams(const GemmParams<T>* params) {
320
+ return nullptr;
321
+ }
322
+
323
+ template <typename T>
324
+ const void* GetBiasPointerFromParams(const GemmAndBiasParams<T>* params) {
325
+ return params->bias;
326
+ }
327
+
328
+ template <typename T>
329
+ const void* GetBiasPointerFromParams(const GemmStridedBatchedParams<T>* params) {
330
+ return nullptr;
331
+ }
332
+
333
+ template <typename T>
334
+ const void* GetBiasPointerFromParams(const ScaledGemmParams<T>* params) {
335
+ return params->bias_ptr;
336
+ }
337
+
338
+ template <typename T>
339
+ hipDataType GetBiasTypeFromParams(const GemmParams<T>* params) {
340
+ return HIP_R_32F;
341
+ }
342
+
343
+ template <typename T>
344
+ hipDataType GetBiasTypeFromParams(const GemmAndBiasParams<T>* params) {
345
+ return HipDataTypeFor<T>();
346
+ }
347
+
348
+ template <typename T>
349
+ hipDataType GetBiasTypeFromParams(const GemmStridedBatchedParams<T>* params) {
350
+ return HIP_R_32F;
351
+ }
352
+
353
+ template <typename T>
354
+ hipDataType GetBiasTypeFromParams(const ScaledGemmParams<T>* params) {
355
+ return at::cuda::ScalarTypeToCudaDataType(params->bias_dtype);
356
+ }
357
+
358
+ template <typename T>
359
+ at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmParams<T>* params) {
360
+ return at::cuda::blas::GEMMAndBiasActivationEpilogue::None;
361
+ }
362
+
363
+ template <typename T>
364
+ at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmAndBiasParams<T>* params) {
365
+ return params->activation;
366
+ }
367
+
368
+ template <typename T>
369
+ at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmStridedBatchedParams<T>* params) {
370
+ return at::cuda::blas::GEMMAndBiasActivationEpilogue::None;
371
+ }
372
+
373
+ template <typename T>
374
+ at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const ScaledGemmParams<T>* params) {
375
+ return at::cuda::blas::GEMMAndBiasActivationEpilogue::None;
376
+ }
377
+
378
+ static hipblasOperation_t _hipblasOpFromChar(char op) {
379
+ switch (op) {
380
+ case 'n':
381
+ case 'N':
382
+ return HIPBLAS_OP_N;
383
+ case 't':
384
+ case 'T':
385
+ return HIPBLAS_OP_T;
386
+ case 'c':
387
+ case 'C':
388
+ return HIPBLAS_OP_C;
389
+ }
390
+ TORCH_CHECK(false,
391
+ "_hipblasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`");
392
+ }
393
+
394
+ static char _charFromhipblasOp(hipblasOperation_t op) {
395
+ switch (op) {
396
+ case HIPBLAS_OP_N:
397
+ return 'N';
398
+ case HIPBLAS_OP_T:
399
+ return 'T';
400
+ case HIPBLAS_OP_C:
401
+ return 'C';
402
+ }
403
+ TORCH_CHECK(false,
404
+ "_charFromhipblasOp input should be HIPBLAS_OP_N/T/C but got `", op, "`");
405
+ }
406
+
407
+ static hipblasOperation_t MapLayoutToHipBlasLt(BlasOp layout) {
408
+ if (layout == BlasOp::N) {
409
+ return HIPBLAS_OP_N;
410
+ }
411
+ return HIPBLAS_OP_T;
412
+ }
413
+
414
+ template <typename T, cublasStatus_t (*destructor)(T*)>
415
+ struct HipBlasLtDeleter {
416
+ void operator()(T* x) {
417
+ if (x != nullptr) {
418
+ TORCH_CUDABLAS_CHECK(destructor(x));
419
+ }
420
+ }
421
+ };
422
+
423
+ template <typename T, hipblasStatus_t (*destructor)(T*)>
424
+ class HipBlasLtDescriptor {
425
+ public:
426
+ T* descriptor() const {
427
+ return descriptor_.get();
428
+ }
429
+ T* descriptor() {
430
+ return descriptor_.get();
431
+ }
432
+
433
+ protected:
434
+ std::unique_ptr<T, HipBlasLtDeleter<T, destructor>> descriptor_;
435
+ };
436
+
437
+ class HipBlasLtMatmulDescriptor : public HipBlasLtDescriptor<
438
+ hipblasLtMatmulDescOpaque_t,
439
+ &hipblasLtMatmulDescDestroy> {
440
+ public:
441
+ HipBlasLtMatmulDescriptor(
442
+ hipblasComputeType_t compute_type,
443
+ hipDataType scale_type) {
444
+ hipblasLtMatmulDesc_t raw_descriptor = nullptr;
445
+ TORCH_HIPBLASLT_CHECK(
446
+ hipblasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type));
447
+ descriptor_.reset(raw_descriptor);
448
+ }
449
+ template <typename T>
450
+ inline void setAttribute(hipblasLtMatmulDescAttributes_t attr, const T value) {
451
+ TORCH_HIPBLASLT_CHECK(::hipblasLtMatmulDescSetAttribute(descriptor(), attr, &value, sizeof(T)));
452
+ }
453
+ };
454
+
455
+ template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout, typename ParamsT>
456
+ class HipblasltGemmOp : public Callable<ParamsT> {
457
+ public:
458
+ HipblasltGemmOp(hipblasLtMatmulAlgo_t algo) : algo_{algo} {}
459
+
460
+ TuningStatus Call(const ParamsT* params) override {
461
+ hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout);
462
+ hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout);
463
+ auto a_datatype = HipDataTypeFor<AT>();
464
+ auto b_datatype = HipDataTypeFor<BT>();
465
+ auto in_out_datatype = HipDataTypeFor<CT>();
466
+ auto opa = _hipblasOpFromChar(params->transa);
467
+ auto opb = _hipblasOpFromChar(params->transb);
468
+
469
+ TORCH_CHECK(transa_outer == opa && transb_outer == opb, "trans mismatch, shouldn't happen");
470
+
471
+ float alpha = GetAlphaFromParams<CT>(params);
472
+ float beta = GetBetaFromParams<CT>(params);
473
+
474
+ hipblasLtMatrixLayout_t mat_a, mat_b, mat_c;
475
+ if (opa == HIPBLAS_OP_N) {
476
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, a_datatype, params->m, params->k, params->lda));
477
+ }
478
+ else {
479
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, a_datatype, params->k, params->m, params->lda));
480
+ }
481
+ if (opb == HIPBLAS_OP_N) {
482
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, b_datatype, params->k, params->n, params->ldb));
483
+ }
484
+ else {
485
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, b_datatype, params->n, params->k, params->ldb));
486
+ }
487
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_c, in_out_datatype, params->m, params->n, params->ldc));
488
+
489
+ // specific to batched gemmm
490
+ int batch = GetBatchFromParams<CT>(params);
491
+ if (batch > 1) {
492
+ int64_t stride_a = GetStrideAFromParams<CT>(params);
493
+ int64_t stride_b = GetStrideBFromParams<CT>(params);
494
+ int64_t stride_c = GetStrideCFromParams<CT>(params);
495
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
496
+ mat_a, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
497
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
498
+ mat_a, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_a, sizeof(stride_a)));
499
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
500
+ mat_b, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
501
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
502
+ mat_b, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_b, sizeof(stride_b)));
503
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
504
+ mat_c, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
505
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
506
+ mat_c, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_c, sizeof(stride_c)));
507
+ }
508
+
509
+ hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F;
510
+ if (at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) == at::Float32Precision::TF32) {
511
+ computeType = HIPBLAS_COMPUTE_32F_FAST_TF32;
512
+ }
513
+ HipBlasLtMatmulDescriptor matmul(computeType, HIP_R_32F);
514
+ matmul.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSA, opa);
515
+ matmul.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSB, opb);
516
+
517
+ // specific to scaled gemm
518
+ const void* mat1_scale_ptr = GetAScalePointerFromParams<CT>(params);
519
+ const void* mat2_scale_ptr = GetBScalePointerFromParams<CT>(params);
520
+ const void* result_scale_ptr = GetDScalePointerFromParams<CT>(params);
521
+ if (mat1_scale_ptr && mat2_scale_ptr) {
522
+ hipblasLtMatmulDescAttributes_t a_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER;
523
+ hipblasLtMatmulDescAttributes_t b_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER;
524
+ if (GetAScalingTypeFromParams<CT>(params) == ScalingType::RowWise) {
525
+ #if defined(HIPBLASLT_OUTER_VEC)
526
+ matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F);
527
+ #elif defined(HIPBLASLT_VEC_EXT)
528
+ a_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT;
529
+ #endif
530
+ }
531
+ if (GetBScalingTypeFromParams<CT>(params) == ScalingType::RowWise) {
532
+ #if defined(HIPBLASLT_OUTER_VEC)
533
+ matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F);
534
+ #elif defined(HIPBLASLT_VEC_EXT)
535
+ b_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT;
536
+ #endif
537
+ }
538
+ matmul.setAttribute(a_scale_ptr_desc, mat1_scale_ptr);
539
+ matmul.setAttribute(b_scale_ptr_desc, mat2_scale_ptr);
540
+ }
541
+ if (result_scale_ptr) {
542
+ matmul.setAttribute(HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
543
+ }
544
+
545
+ const void* bias_ptr = GetBiasPointerFromParams<CT>(params);
546
+ auto bias_datatype = GetBiasTypeFromParams<CT>(params);
547
+ if (bias_ptr) {
548
+ matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_POINTER, bias_ptr);
549
+ matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, bias_datatype);
550
+ auto activation = GetActivationFromParams<CT>(params);
551
+ if (activation == at::cuda::blas::GEMMAndBiasActivationEpilogue::RELU) {
552
+ matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_RELU_BIAS);
553
+ }
554
+ else if (activation == at::cuda::blas::GEMMAndBiasActivationEpilogue::GELU) {
555
+ matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_GELU_BIAS);
556
+ }
557
+ else {
558
+ matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_BIAS);
559
+ }
560
+ }
561
+
562
+ size_t workspace_size = at::cuda::getCUDABlasLtWorkspaceSize();
563
+
564
+ auto op_handle = at::cuda::getCurrentCUDABlasLtHandle();
565
+
566
+ size_t ret_workspace_size = 0;
567
+ auto status = hipblaslt_ext::matmulIsAlgoSupported(op_handle,
568
+ matmul.descriptor(),
569
+ &alpha,
570
+ mat_a,
571
+ mat_b,
572
+ &beta,
573
+ mat_c,
574
+ mat_c,
575
+ algo_,
576
+ ret_workspace_size);
577
+
578
+ if (status == HIPBLAS_STATUS_SUCCESS) {
579
+ if (ret_workspace_size >= workspace_size) {
580
+ return FAIL;
581
+ }
582
+ }
583
+ else {
584
+ return FAIL;
585
+ }
586
+
587
+ void* workspace_buffer = at::cuda::getCUDABlasLtWorkspace();
588
+
589
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatmul(op_handle,
590
+ matmul.descriptor(),
591
+ &alpha,
592
+ params->a,
593
+ mat_a,
594
+ params->b,
595
+ mat_b,
596
+ &beta,
597
+ params->c,
598
+ mat_c,
599
+ params->c,
600
+ mat_c,
601
+ &algo_,
602
+ workspace_buffer,
603
+ workspace_size,
604
+ at::cuda::getCurrentCUDAStream()));
605
+
606
+ //TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescDestroy(matmul));
607
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_a));
608
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_b));
609
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_c));
610
+ return OK;
611
+ }
612
+
613
+ private:
614
+ hipblasLtMatmulAlgo_t algo_;
615
+ };
616
+
617
+ template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout, typename ParamsT>
618
+ auto GetHipBlasLtTypeStringAndOps() {
619
+ hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout);
620
+ hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout);
621
+ auto a_datatype = HipDataTypeFor<AT>();
622
+ auto b_datatype = HipDataTypeFor<BT>();
623
+ auto in_out_datatype = HipDataTypeFor<CT>();
624
+ std::vector<hipblasLtMatmulHeuristicResult_t> heuristic_result;
625
+ #if ROCM_VERSION == 60400
626
+ // hipblaslt TT fp32 regression on ROCm 6.4, cannot use
627
+ if ((a_datatype == HIP_R_32F || b_datatype == HIP_R_32F || in_out_datatype == HIP_R_32F)
628
+ && (transa_outer == HIPBLAS_OP_T && transb_outer == HIPBLAS_OP_T)) {
629
+ std::vector<std::pair<std::string, std::unique_ptr<Callable<ParamsT>>>> ignore;
630
+ return ignore;
631
+ }
632
+ #endif
633
+
634
+ hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F;
635
+ if (at::globalContext().allowTF32CuBLAS()) {
636
+ computeType = HIPBLAS_COMPUTE_32F_FAST_TF32;
637
+ }
638
+
639
+ hipblasLtHandle_t handle;
640
+ TORCH_HIPBLASLT_CHECK(hipblasLtCreate(&handle));
641
+ TORCH_HIPBLASLT_CHECK(hipblaslt_ext::getAllAlgos(handle,
642
+ hipblaslt_ext::GemmType::HIPBLASLT_GEMM,
643
+ transa_outer,
644
+ transb_outer,
645
+ a_datatype,
646
+ b_datatype,
647
+ in_out_datatype,
648
+ in_out_datatype,
649
+ computeType,
650
+ heuristic_result));
651
+ TORCH_HIPBLASLT_CHECK(hipblasLtDestroy(handle));
652
+
653
+ int returned_algo_count = heuristic_result.size();
654
+ std::vector<std::pair<std::string, std::unique_ptr<Callable<ParamsT>>>> ret;
655
+ for (int i = 0; i < returned_algo_count; i++) {
656
+ auto algo = heuristic_result[i].algo;
657
+ int algo_index = hipblaslt_ext::getIndexFromAlgo(algo);
658
+ auto callable = std::make_unique<HipblasltGemmOp<AT, BT, CT, ALayout, BLayout, ParamsT>>(algo);
659
+ std::string type_string = fmt::sprintf("Gemm_Hipblaslt_%d", algo_index);
660
+ ret.emplace_back(type_string, std::move(callable));
661
+ }
662
+
663
+ return ret;
664
+ }
665
+
666
+ template <typename T, BlasOp ALayout, BlasOp BLayout>
667
+ auto GetHipBlasLtGemmTypeStringAndOps() {
668
+ return GetHipBlasLtTypeStringAndOps<T, T, T, ALayout, BLayout, GemmParams<T>>();
669
+ }
670
+
671
+ template <typename T, BlasOp ALayout, BlasOp BLayout>
672
+ auto GetHipBlasLtGemmAndBiasTypeStringAndOps() {
673
+ return GetHipBlasLtTypeStringAndOps<T, T, T, ALayout, BLayout, GemmAndBiasParams<T>>();
674
+ }
675
+
676
+ template <typename T, BlasOp ALayout, BlasOp BLayout>
677
+ auto GetHipBlasLtGemmStridedBatchedTypeStringAndOps() {
678
+ return GetHipBlasLtTypeStringAndOps<T, T, T, ALayout, BLayout, GemmStridedBatchedParams<T>>();
679
+ }
680
+
681
+ template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout>
682
+ auto GetHipBlasLtScaledGemmTypeStringAndOps() {
683
+ return GetHipBlasLtTypeStringAndOps<AT, BT, CT, ALayout, BLayout, ScaledGemmParams<CT>>();
684
+ }
685
+
686
+ #undef TORCH_HIPBLASLT_CHECK
687
+
688
+ } // namespace at::cuda::tunable
689
+
690
+ #else
691
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
692
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/GemmRocblas.h ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ // Copyright (c) Microsoft Corporation. All rights reserved.
3
+ // Licensed under the MIT License.
4
+
5
+ #pragma once
6
+
7
+ #include <ATen/cuda/CUDAContext.h>
8
+ #include <ATen/cuda/tunable/TunableOp.h>
9
+ #include <ATen/cuda/tunable/GemmCommon.h>
10
+ #include <c10/util/StringUtil.h>
11
+ #include <fmt/printf.h>
12
+
13
+ #define ROCBLAS_BETA_FEATURES_API
14
+ #include <rocblas/rocblas.h>
15
+
16
+ #define TORCH_ROCBLAS_CHECK(EXPR) \
17
+ do { \
18
+ rocblas_status __err = EXPR; \
19
+ TORCH_CHECK(__err == rocblas_status_success, \
20
+ "rocblas error: ", \
21
+ rocblas_status_to_string(__err), \
22
+ " when calling `" #EXPR "`"); \
23
+ } while (0)
24
+
25
+ namespace at::cuda::tunable {
26
+
27
+ template <typename T>
28
+ constexpr rocblas_datatype RocBlasDataTypeFor();
29
+
30
+ template <>
31
+ constexpr rocblas_datatype RocBlasDataTypeFor<float>() {
32
+ return rocblas_datatype_f32_r;
33
+ }
34
+
35
+ template <>
36
+ constexpr rocblas_datatype RocBlasDataTypeFor<double>() {
37
+ return rocblas_datatype_f64_r;
38
+ }
39
+
40
+ template <>
41
+ constexpr rocblas_datatype RocBlasDataTypeFor<Half>() {
42
+ return rocblas_datatype_f16_r;
43
+ }
44
+
45
+ template <>
46
+ constexpr rocblas_datatype RocBlasDataTypeFor<BFloat16>() {
47
+ return rocblas_datatype_bf16_r;
48
+ }
49
+
50
+ template <>
51
+ constexpr rocblas_datatype RocBlasDataTypeFor<c10::complex<float>>() {
52
+ return rocblas_datatype_f32_c;
53
+ }
54
+
55
+ template <>
56
+ constexpr rocblas_datatype RocBlasDataTypeFor<c10::complex<double>>() {
57
+ return rocblas_datatype_f64_c;
58
+ }
59
+
60
+ template <typename T>
61
+ constexpr rocblas_datatype RocBlasComputeTypeFor();
62
+
63
+ template <>
64
+ constexpr rocblas_datatype RocBlasComputeTypeFor<float>() {
65
+ return rocblas_datatype_f32_r;
66
+ }
67
+
68
+ template <>
69
+ constexpr rocblas_datatype RocBlasComputeTypeFor<double>() {
70
+ return rocblas_datatype_f64_r;
71
+ }
72
+
73
+ template <>
74
+ constexpr rocblas_datatype RocBlasComputeTypeFor<Half>() {
75
+ // Note that we're returning the _compute_ type for a given datatype.
76
+ // As of 12/2022, using compute type FP16 for 16-bit floats was much
77
+ // slower than using compute type FP32. So we use FP32 compute even for
78
+ // FP16 datatypes. This is how GEMM is implemented even in the function
79
+ // rocblasGemmHelper (see fpgeneric.h)
80
+ return rocblas_datatype_f32_r;
81
+ }
82
+
83
+ template <>
84
+ constexpr rocblas_datatype RocBlasComputeTypeFor<BFloat16>() {
85
+ // Note that we're returning the _compute_ type for a given datatype.
86
+ // As of 12/2022, using compute type FP16 for 16-bit floats was much
87
+ // slower than using compute type FP32. So we use FP32 compute even for
88
+ // BF16 datatypes. This is how GEMM is implemented even in the function
89
+ // rocblasGemmHelper (see fpgeneric.h)
90
+ return rocblas_datatype_f32_r;
91
+ }
92
+
93
+ template <>
94
+ constexpr rocblas_datatype RocBlasComputeTypeFor<c10::complex<float>>() {
95
+ return rocblas_datatype_f32_c;
96
+ }
97
+
98
+ template <>
99
+ constexpr rocblas_datatype RocBlasComputeTypeFor<c10::complex<double>>() {
100
+ return rocblas_datatype_f64_c;
101
+ }
102
+
103
+ template <typename T>
104
+ auto DoCastForHalfOrBfloat16(const T fp) {
105
+ return fp;
106
+ }
107
+
108
+ template <>
109
+ inline auto DoCastForHalfOrBfloat16<Half>(const Half fp) {
110
+ // alpha and beta should be the same as compute_type, in Half case it is float.
111
+ float h = fp;
112
+ return h;
113
+ }
114
+
115
+ template <>
116
+ inline auto DoCastForHalfOrBfloat16<BFloat16>(const BFloat16 fp) {
117
+ // alpha and beta should be the same as compute_type, in bfloat16 case it is float.
118
+ float h = fp;
119
+ return h;
120
+ }
121
+
122
+ static rocblas_operation _rocblasOpFromChar(char op) {
123
+ switch (op) {
124
+ case 'n':
125
+ case 'N':
126
+ return rocblas_operation_none;
127
+ case 't':
128
+ case 'T':
129
+ return rocblas_operation_transpose;
130
+ case 'c':
131
+ case 'C':
132
+ return rocblas_operation_conjugate_transpose;
133
+ }
134
+ TORCH_CHECK(false,
135
+ "_rocblasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`");
136
+ }
137
+
138
+ template <typename T>
139
+ class RocblasGemmOp : public Callable<GemmParams<T>> {
140
+ public:
141
+ RocblasGemmOp(int solution) : solution_{solution} {}
142
+
143
+ TuningStatus Call(const GemmParams<T>* params) override {
144
+ auto input_output_type = RocBlasDataTypeFor<T>();
145
+ if (at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) == at::Float32Precision::TF32 && input_output_type == rocblas_datatype_f32_r)
146
+ return FAIL; // no support for TF32 in rocBLAS
147
+ auto compute_type = RocBlasComputeTypeFor<T>();
148
+ auto h_a = DoCastForHalfOrBfloat16(params->alpha);
149
+ auto h_b = DoCastForHalfOrBfloat16(params->beta);
150
+ auto status = rocblas_gemm_ex(
151
+ (rocblas_handle)at::cuda::getCurrentCUDABlasHandle(),
152
+ _rocblasOpFromChar(params->transa),
153
+ _rocblasOpFromChar(params->transb),
154
+ params->m, params->n, params->k,
155
+ &h_a,
156
+ params->a, input_output_type, params->lda,
157
+ params->b, input_output_type, params->ldb,
158
+ &h_b,
159
+ params->c, input_output_type, params->ldc,
160
+ params->c, input_output_type, params->ldc,
161
+ compute_type,
162
+ rocblas_gemm_algo_solution_index,
163
+ solution_,
164
+ rocblas_gemm_flags_none);
165
+ if (status != rocblas_status_success) {
166
+ return FAIL;
167
+ }
168
+ return OK;
169
+ }
170
+
171
+ private:
172
+ int solution_;
173
+ };
174
+
175
+ template <typename T>
176
+ auto GetRocBlasGemmTypeStringAndOps() {
177
+ rocblas_handle handle = (rocblas_handle)at::cuda::getCurrentCUDABlasHandle();
178
+ int solution_size;
179
+ auto input_output_type = RocBlasDataTypeFor<T>();
180
+ auto compute_type = RocBlasComputeTypeFor<T>();
181
+ // Get the number of available solutions
182
+ TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
183
+ input_output_type,
184
+ input_output_type,
185
+ compute_type,
186
+ rocblas_gemm_flags_none,
187
+ nullptr,
188
+ &solution_size));
189
+ std::vector<int> solutions(solution_size);
190
+ // Get the list of available solutions
191
+ TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
192
+ input_output_type,
193
+ input_output_type,
194
+ compute_type,
195
+ rocblas_gemm_flags_none,
196
+ solutions.data(),
197
+ &solution_size));
198
+ std::vector<std::pair<std::string, std::unique_ptr<Callable<GemmParams<T>>>>> ret;
199
+ for (size_t i = 0; i < solutions.size(); ++i) {
200
+ auto callable = std::make_unique<RocblasGemmOp<T>>(solutions[i]);
201
+ ret.emplace_back(std::make_pair(fmt::sprintf("Gemm_Rocblas_%d", solutions[i]), std::move(callable)));
202
+ }
203
+ return ret;
204
+ }
205
+
206
+ template <typename T>
207
+ class RocblasGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>> {
208
+ public:
209
+ RocblasGemmStridedBatchedOp(int solution) : solution_{solution} {}
210
+
211
+ TuningStatus Call(const GemmStridedBatchedParams<T>* params) override {
212
+ auto input_output_type = RocBlasDataTypeFor<T>();
213
+ if (at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) == at::Float32Precision::TF32 && input_output_type == rocblas_datatype_f32_r)
214
+ return FAIL; // no support for TF32 in rocBLAS
215
+ auto compute_type = RocBlasComputeTypeFor<T>();
216
+ auto h_a = DoCastForHalfOrBfloat16(params->alpha);
217
+ auto h_b = DoCastForHalfOrBfloat16(params->beta);
218
+ auto status = rocblas_gemm_strided_batched_ex(
219
+ (rocblas_handle)at::cuda::getCurrentCUDABlasHandle(),
220
+ _rocblasOpFromChar(params->transa),
221
+ _rocblasOpFromChar(params->transb),
222
+ params->m, params->n, params->k,
223
+ &h_a,
224
+ params->a, input_output_type, params->lda, params->stride_a,
225
+ params->b, input_output_type, params->ldb, params->stride_b,
226
+ &h_b,
227
+ params->c, input_output_type, params->ldc, params->stride_c,
228
+ params->c, input_output_type, params->ldc, params->stride_c,
229
+ params->batch,
230
+ compute_type,
231
+ rocblas_gemm_algo_solution_index,
232
+ solution_,
233
+ rocblas_gemm_flags_none);
234
+ if (status != rocblas_status_success) {
235
+ return FAIL;
236
+ }
237
+ return OK;
238
+ }
239
+
240
+ private:
241
+ int solution_;
242
+ };
243
+
244
+ template <typename T>
245
+ auto GetRocBlasGemmStridedBatchedTypeStringAndOps() {
246
+ rocblas_handle handle = (rocblas_handle)at::cuda::getCurrentCUDABlasHandle();
247
+ int solution_size;
248
+ auto input_output_type = RocBlasDataTypeFor<T>();
249
+ auto compute_type = RocBlasComputeTypeFor<T>();
250
+ // Get the number of available solutions
251
+ TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
252
+ input_output_type,
253
+ input_output_type,
254
+ compute_type,
255
+ rocblas_gemm_flags_none,
256
+ nullptr,
257
+ &solution_size));
258
+ std::vector<int> solutions(solution_size);
259
+ // Get the list of available solutions
260
+ TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
261
+ input_output_type,
262
+ input_output_type,
263
+ compute_type,
264
+ rocblas_gemm_flags_none,
265
+ solutions.data(),
266
+ &solution_size));
267
+ // Sort the solutions in ascending order to make the solution vector deterministic across runs
268
+ std::sort(solutions.begin(), solutions.end());
269
+
270
+ std::vector<std::pair<std::string, std::unique_ptr<Callable<GemmStridedBatchedParams<T>>>>> ret;
271
+ for (size_t i = 0; i < solutions.size(); ++i) {
272
+ auto callable = std::make_unique<RocblasGemmStridedBatchedOp<T>>(solutions[i]);
273
+ ret.emplace_back(std::make_pair(c10::str("Gemm_Rocblas_", solutions[i]), std::move(callable)));
274
+ }
275
+ return ret;
276
+ }
277
+
278
+ } // namespace at::cuda::tunable
279
+
280
+ #else
281
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
282
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/StreamTimer.h ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ // Original TunableOp is from onnxruntime.
3
+ // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
4
+ // https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
5
+ // Copyright (c) Microsoft Corporation.
6
+ // Licensed under the MIT license.
7
+ //
8
+ // Adapting TunableOp into PyTorch
9
+ // Copyright (c) Advanced Micro Devices, Inc.
10
+ //
11
+ #pragma once
12
+
13
+ #include <cuda_runtime.h>
14
+
15
+ #include <ATen/cuda/tunable/Tunable.h>
16
+
17
+ namespace at::cuda::tunable {
18
+
19
+ class StreamTimer : public ITimer {
20
+ public:
21
+ StreamTimer();
22
+ ~StreamTimer() override;
23
+
24
+ void Start() override;
25
+
26
+ void End() override;
27
+
28
+ float Duration() override;
29
+
30
+ private:
31
+ cudaEvent_t start_{};
32
+ cudaEvent_t end_{};
33
+ };
34
+
35
+ class StreamTimerNoSync : public ITimer {
36
+ public:
37
+ StreamTimerNoSync();
38
+ ~StreamTimerNoSync() override;
39
+
40
+ void Start() override;
41
+
42
+ void End() override;
43
+
44
+ float Duration() override;
45
+
46
+ private:
47
+ cudaEvent_t start_{};
48
+ cudaEvent_t end_{};
49
+ };
50
+
51
+ } // namespace at::cuda::tunable
52
+
53
+ #else
54
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
55
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/Tunable.h ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ // Original TunableOp is from onnxruntime.
3
+ // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
4
+ // https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
5
+ // Copyright (c) Microsoft Corporation.
6
+ // Licensed under the MIT license.
7
+ //
8
+ // Adapting TunableOp into PyTorch
9
+ // Copyright (c) Advanced Micro Devices, Inc.
10
+ //
11
+ #pragma once
12
+
13
+ #include <c10/util/CallOnce.h>
14
+ #include <c10/util/StringUtil.h>
15
+ #include <c10/util/env.h>
16
+
17
+ #include <fstream>
18
+ #include <functional>
19
+ #include <iostream>
20
+ #include <memory>
21
+ #include <mutex>
22
+ #include <string>
23
+ #include <unordered_map>
24
+ #include <unordered_set>
25
+ #include <utility>
26
+
27
+ #define TUNABLE_LOGV(LEVEL, ...) getTuningContext()->Log(LEVEL, __VA_ARGS__)
28
+ #define TUNABLE_LOG1(...) TUNABLE_LOGV(1, __VA_ARGS__)
29
+ #define TUNABLE_LOG2(...) TUNABLE_LOGV(2, __VA_ARGS__)
30
+ #define TUNABLE_LOG3(...) TUNABLE_LOGV(3, __VA_ARGS__)
31
+
32
+ namespace at::cuda::tunable {
33
+
34
+ enum TORCH_CUDA_CPP_API TuningStatus {
35
+ OK = 0,
36
+ FAIL = 1,
37
+ UNSUPPORTED = 2,
38
+ };
39
+
40
+ // Mapping from params signature to kernel id
41
+ class TORCH_CUDA_CPP_API ResultEntry {
42
+ public:
43
+ explicit ResultEntry(std::string key, double time) : key_(std::move(key)), time_(time) {}
44
+ explicit ResultEntry(std::string key, double time, std::string blas_sig ) : key_(std::move(key)), time_(time), blas_sig_(std::move(blas_sig)) {}
45
+ bool operator==(const ResultEntry& other) const { return key_ == other.key_; }
46
+ bool operator!=(const ResultEntry& other) const { return key_ != other.key_; }
47
+ operator std::string () { return key_; }
48
+ std::string GetKey() const { return key_; }
49
+ double GetTime() const { return time_; }
50
+ friend std::ostream& operator<<(std::ostream& stream, const ResultEntry& entry);
51
+ static ResultEntry Null() { return ResultEntry("Null", 0.0); }
52
+ static ResultEntry Default() { return ResultEntry("Default", 0.0); }
53
+
54
+ private:
55
+ std::string key_;
56
+ double time_;
57
+ std::string blas_sig_;
58
+ };
59
+
60
+ typedef std::unordered_map<std::string, ResultEntry> KernelMap;
61
+ typedef std::unordered_map<std::string, KernelMap> ResultsMap;
62
+ typedef std::unordered_map<std::string, std::unordered_set<std::string>> UntunedMap;
63
+
64
+ struct TORCH_CUDA_CPP_API TuningResults {
65
+ // Validates if these results are compatible with the libraries
66
+ std::unordered_map<std::string, std::string> validators;
67
+
68
+ // Mapping from Callable signature to Callable's tuning result
69
+ ResultsMap results;
70
+ };
71
+
72
+ class TORCH_CUDA_CPP_API TuningResultsManager {
73
+ public:
74
+ TuningResultsManager() = default;
75
+ ~TuningResultsManager() = default;
76
+
77
+ KernelMap Lookup(const std::string& op_signature);
78
+
79
+ ResultEntry Lookup(const std::string& op_signature, const std::string& params_signature);
80
+
81
+ void AddImpl(const std::string& op_signature,
82
+ const std::string& params_signature,
83
+ ResultEntry best,
84
+ KernelMap& kernel_map);
85
+
86
+ void Add(const std::string& op_signature,
87
+ const std::string& params_signature,
88
+ ResultEntry best);
89
+
90
+ void Delete(const std::string& op_signature, const std::string& params_signature);
91
+
92
+ void DisjointMergeImpl(
93
+ const std::string& op_signature,
94
+ const KernelMap& kernel_map,
95
+ /*out*/ ResultsMap& results);
96
+
97
+ void Load(const ResultsMap& results_to_load);
98
+
99
+ ResultsMap Dump();
100
+
101
+ void DisjointMerge(const std::string& op_signature, const KernelMap& kernel_map);
102
+
103
+ size_t GetSize();
104
+
105
+ void RecordUntuned( std::ofstream& untuned_file, const std::string& op_signature,
106
+ const std::string& params_signature, const std::string& blas_signature);
107
+
108
+ void InitRealtimeAppend(
109
+ const std::string& filename,
110
+ const std::unordered_map<std::string, std::string>& validators);
111
+
112
+ void AppendResultLine(const std::string& op_sig,
113
+ const std::string& param_sig,
114
+ const ResultEntry& result);
115
+
116
+ void CloseRealtimeAppend(); // For clean shutdown
117
+ private:
118
+ std::mutex lock_;
119
+ std::mutex realtime_file_mutex_;
120
+ std::unique_ptr<std::ofstream> realtime_out_;
121
+ std::string realtime_filename_;
122
+ ResultsMap results_;
123
+ UntunedMap untuned_results_;
124
+ bool validators_written_ = false;
125
+
126
+ };
127
+
128
+ class TORCH_CUDA_CPP_API TuningResultsValidator {
129
+ public:
130
+ using GetFunc = std::function<std::string()>;
131
+ using ValidateFunc = std::function<TuningStatus(const std::string&)>;
132
+ using GetValidateFuncs = std::unordered_map<std::string, std::pair<GetFunc, ValidateFunc>>;
133
+
134
+ TuningResultsValidator();
135
+ ~TuningResultsValidator() = default;
136
+
137
+ std::unordered_map<std::string, std::string> GetAllValidators() const;
138
+ TuningStatus ValidateAll(const std::unordered_map<std::string, std::string>& to_validate) const;
139
+ void RegisterValidator(const std::string& key, const GetFunc& gf, const ValidateFunc& vf);
140
+
141
+ protected:
142
+ static std::string GetPyTorchVersion() ;
143
+ TuningStatus ValidatePyTorchVersion(const std::string& value) const;
144
+
145
+ public:
146
+ static constexpr const std::array mandatory_keys{"PT_VERSION"};
147
+
148
+ private:
149
+ GetValidateFuncs validators_;
150
+ };
151
+
152
+ struct NumericalCheckConfig {
153
+ bool enabled{false};
154
+ double atol{1e-5};
155
+ double rtol{1e-5};
156
+
157
+ NumericalCheckConfig() = default;
158
+ NumericalCheckConfig(bool e, double a, double r) : enabled(e), atol(a), rtol(r) {}
159
+ };
160
+
161
+
162
+ class TORCH_CUDA_CPP_API TuningContext {
163
+ public:
164
+ TuningContext();
165
+ ~TuningContext();
166
+ TuningContext(TuningContext &) = delete;
167
+ TuningContext(TuningContext &&) = delete;
168
+ TuningContext &operator=(TuningContext &) = delete;
169
+ TuningContext &operator=(TuningContext &&) = delete;
170
+
171
+ void EnableTunableOp(bool value);
172
+ bool IsTunableOpEnabled() const;
173
+
174
+ void EnableTuning(bool value);
175
+ bool IsTuningEnabled() const;
176
+
177
+ void EnableRecordUntuned(bool value);
178
+ bool IsRecordUntunedEnabled() const;
179
+ std::ofstream& GetUntunedFile();
180
+
181
+ void EnableNumericsCheck(bool value);
182
+ bool IsNumericsCheckEnabled() const;
183
+ void SetNumericalCheckConfig(bool enabled, double atol, double rtol);
184
+ NumericalCheckConfig GetNumericalCheckConfig() const;
185
+
186
+ void SetMaxTuningDurationMs(int max_duration_ms);
187
+ int GetMaxTuningDurationMs() const;
188
+
189
+ void SetMaxTuningIterations(int max_iter);
190
+ int GetMaxTuningIterations() const;
191
+
192
+ void SetMaxWarmupDurationMs(int max_duration_ms);
193
+ int GetMaxWarmupDurationMs() const;
194
+
195
+ void SetMaxWarmupIterations(int max_iter);
196
+ int GetMaxWarmupIterations() const;
197
+
198
+ void EnableICacheFlush(bool value);
199
+ bool IsICacheFlushEnabled() const;
200
+
201
+ void SetRotatingBufferSize(int size);
202
+ int GetRotatingBufferSize() const;
203
+
204
+ TuningResultsManager& GetTuningResultsManager();
205
+
206
+ TuningResultsValidator& GetTuningResultsValidator();
207
+
208
+ TuningResults GetTuningResults();
209
+
210
+ TuningStatus LoadTuningResults(const TuningResults& tr);
211
+
212
+ void SetFilename(const std::string& filename, bool insert_device_ordinal=false);
213
+ std::string GetFilename() const;
214
+
215
+ bool ReadFile(const std::string& filename={});
216
+
217
+ template<class... Types>
218
+ void Log(int level, Types... args) {
219
+ if (GetLogOkay() && GetLogLevel() >= level) {
220
+ GetLog() << c10::str(args...) << std::endl;
221
+ }
222
+ }
223
+
224
+ private:
225
+ std::string GetLogFilename() const;
226
+ int GetLogLevel() const;
227
+ bool GetLogOkay() const;
228
+ std::ostream& GetLog() const;
229
+
230
+ bool enable_;
231
+ bool tuning_enable_;
232
+ bool record_untuned_enable_;
233
+ bool manager_initialized_;
234
+ bool numerics_check_enable_;
235
+ int max_tuning_duration_ms_;
236
+ int max_tuning_iterations_;
237
+ int max_warmup_duration_ms_;
238
+ int max_warmup_iterations_;
239
+ bool icache_flush_;
240
+ int rotating_buffer_size_;
241
+ mutable TuningResultsManager manager_;
242
+ mutable c10::once_flag manager_init_once_;
243
+ TuningResultsValidator validator_;
244
+ std::string filename_;
245
+ std::ofstream untuned_file_;
246
+ size_t results_count_from_input_file_;
247
+ bool is_shutting_down_;
248
+
249
+ NumericalCheckConfig numerics_cfg_{};
250
+ };
251
+
252
+ TORCH_CUDA_CPP_API TuningContext* getTuningContext();
253
+
254
+ class ITimer {
255
+ public:
256
+ ITimer() = default;
257
+ virtual ~ITimer() = default;
258
+
259
+ virtual void Start() = 0;
260
+ virtual void End() = 0;
261
+
262
+ /// Computes the elapsed time in milliseconds between Start() and End()
263
+ virtual float Duration() = 0;
264
+ };
265
+
266
+ } // namespace at::cuda::tunable
267
+
268
+ #else
269
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
270
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/TunableGemm.h ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ // Original TunableOp is from onnxruntime.
3
+ // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
4
+ // https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
5
+ // Copyright (c) Microsoft Corporation.
6
+ // Licensed under the MIT license.
7
+ //
8
+ // Adapting TunableOp into PyTorch
9
+ // Copyright (c) Advanced Micro Devices, Inc.
10
+ //
11
+ #pragma once
12
+
13
+ #include <ATen/cuda/tunable/GemmCommon.h>
14
+ #ifdef USE_ROCM
15
+ #include <ATen/cuda/tunable/GemmHipblaslt.h>
16
+ #include <ATen/cuda/tunable/GemmRocblas.h>
17
+ #endif
18
+ #include <ATen/cuda/tunable/TunableOp.h>
19
+ #include <c10/cuda/CUDACachingAllocator.h>
20
+ #include <c10/util/Float8_e4m3fn.h>
21
+ #include <c10/util/Float8_e4m3fnuz.h>
22
+ #include <c10/util/Float8_e5m2.h>
23
+ #include <c10/util/Float8_e5m2fnuz.h>
24
+ #include <c10/util/Float8_e8m0fnu.h>
25
+ #include <c10/util/StringUtil.h>
26
+ #include <fmt/printf.h>
27
+
28
+ namespace at::cuda::tunable {
29
+
30
+ template <typename T>
31
+ class DefaultGemmOp : public Callable<GemmParams<T>> {
32
+ public:
33
+ TuningStatus Call(const GemmParams<T>* params) override {
34
+ at::cuda::blas::gemm_internal<T>(
35
+ params->transa, params->transb,
36
+ params->m, params->n, params->k,
37
+ params->alpha,
38
+ params->a, params->lda,
39
+ params->b, params->ldb,
40
+ params->beta,
41
+ params->c, params->ldc);
42
+ return OK;
43
+ }
44
+ };
45
+
46
+ static bool _transposeBoolFromChar(char op) {
47
+ return op == 't' || op == 'T';
48
+ }
49
+
50
+ template <typename T>
51
+ class DefaultGemmAndBiasOp : public Callable<GemmAndBiasParams<T>> {
52
+ public:
53
+ TuningStatus Call(const GemmAndBiasParams<T>* params) override {
54
+ at::cuda::blas::gemm_and_bias<T>(
55
+ _transposeBoolFromChar(params->transa),
56
+ _transposeBoolFromChar(params->transb),
57
+ params->m, params->n, params->k,
58
+ params->alpha,
59
+ params->a, params->lda,
60
+ params->b, params->ldb,
61
+ params->bias,
62
+ params->c, params->ldc,
63
+ params->activation);
64
+ return OK;
65
+ }
66
+ };
67
+
68
+ template <typename T>
69
+ class DefaultGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>> {
70
+ public:
71
+ TuningStatus Call(const GemmStridedBatchedParams<T>* params) override {
72
+ at::cuda::blas::bgemm_internal<T>(
73
+ params->transa, params->transb,
74
+ params->m, params->n, params->k,
75
+ params->alpha,
76
+ params->a, params->lda, params->stride_a,
77
+ params->b, params->ldb, params->stride_b,
78
+ params->beta,
79
+ params->c, params->ldc, params->stride_c,
80
+ params->batch);
81
+ return OK;
82
+ }
83
+ };
84
+
85
+ template <typename T>
86
+ class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
87
+ public:
88
+ TuningStatus Call(const ScaledGemmParams<T>* params) override {
89
+ at::cuda::blas::scaled_gemm(
90
+ params->transa,
91
+ params->transb,
92
+ params->m,
93
+ params->n,
94
+ params->k,
95
+ params->a,
96
+ params->a_scale_ptr,
97
+ params->lda,
98
+ params->a_dtype,
99
+ params->a_scale_dtype,
100
+ params->a_scaling_type,
101
+ params->b,
102
+ params->b_scale_ptr,
103
+ params->ldb,
104
+ params->b_dtype,
105
+ params->b_scale_dtype,
106
+ params->b_scaling_type,
107
+ params->bias_ptr,
108
+ params->bias_dtype,
109
+ params->c,
110
+ params->c_scale_ptr,
111
+ params->ldc,
112
+ params->c_dtype,
113
+ params->use_fast_accum,
114
+ std::nullopt /* alpha */);
115
+ return OK;
116
+ }
117
+ };
118
+
119
+ template <typename T>
120
+ inline bool IsZero(T v) {
121
+ return v == 0.0f;
122
+ }
123
+
124
+ template <>
125
+ inline bool IsZero(BFloat16 v) {
126
+ return v.x == 0;
127
+ }
128
+
129
+ template <>
130
+ inline bool IsZero(Half v) {
131
+ return float(v) == 0.0f;
132
+ }
133
+
134
+ template <>
135
+ inline bool IsZero(c10::complex<double> v) {
136
+ return v == 0.0;
137
+ }
138
+
139
+ template <>
140
+ inline bool IsZero(c10::complex<float> v) {
141
+ return v == 0.0f;
142
+ }
143
+
144
+ template <typename T>
145
+ inline const char* TypeName(T v) {
146
+ return "unknown";
147
+ }
148
+
149
+ template <>
150
+ inline const char* TypeName(float v) {
151
+ if (at::globalContext().allowTF32CuBLAS()) {
152
+ return "tf32";
153
+ } else {
154
+ return "float";
155
+ }
156
+ }
157
+
158
+ template <>
159
+ inline const char* TypeName(double v) {
160
+ return "double";
161
+ }
162
+
163
+ template <>
164
+ inline const char* TypeName(BFloat16 v) {
165
+ return "BFloat16";
166
+ }
167
+
168
+ template <>
169
+ inline const char* TypeName(Half v) {
170
+ return "Half";
171
+ }
172
+
173
+ template <>
174
+ inline const char* TypeName(Float8_e4m3fn v) {
175
+ return "Float8_e4m3fn";
176
+ }
177
+
178
+ template <>
179
+ inline const char* TypeName(Float8_e5m2 v) {
180
+ return "Float8_e5m2";
181
+ }
182
+
183
+ template <>
184
+ inline const char* TypeName(Float8_e4m3fnuz v) {
185
+ return "Float8_e4m3fnuz";
186
+ }
187
+
188
+ template <>
189
+ inline const char* TypeName(Float8_e5m2fnuz v) {
190
+ return "Float8_e5m2fnuz";
191
+ }
192
+
193
+ template <>
194
+ inline const char* TypeName(Float8_e8m0fnu v) {
195
+ return "Float8_e8m0fnu";
196
+ }
197
+
198
+ template <>
199
+ inline const char* TypeName(c10::complex<double> v) {
200
+ return "c10::complex<double>";
201
+ }
202
+
203
+ template <>
204
+ inline const char* TypeName(c10::complex<float> v) {
205
+ return "c10::complex<float>";
206
+ }
207
+
208
+ template <typename T, BlasOp ALayout, BlasOp BLayout>
209
+ class GemmTunableOp : public TunableOp<GemmParams<T>> {
210
+ public:
211
+ GemmTunableOp() {
212
+ this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmOp<T>>());
213
+
214
+ #ifdef USE_ROCM
215
+ static const auto env_rocblas = c10::utils::check_env("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED");
216
+ if (!env_rocblas.has_value() || env_rocblas.value()) {
217
+ for (auto&& [name, op] : GetRocBlasGemmTypeStringAndOps<T>()) {
218
+ this->RegisterOp(std::move(name), std::move(op));
219
+ }
220
+ }
221
+
222
+ static const auto env_hipblaslt = c10::utils::check_env("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
223
+ if (!env_hipblaslt.has_value() || env_hipblaslt.value()) {
224
+ // disallow tuning of hipblaslt with c10::complex
225
+ if constexpr (
226
+ !std::is_same_v<T, c10::complex<float>> &&
227
+ !std::is_same_v<T, c10::complex<double>>) {
228
+ for (auto&& [name, op] : GetHipBlasLtGemmTypeStringAndOps<T, ALayout, BLayout>()) {
229
+ this->RegisterOp(std::move(name), std::move(op));
230
+ }
231
+ }
232
+ }
233
+ #endif
234
+
235
+ this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmOp<T>>());
236
+ }
237
+
238
+ std::string Signature() override {
239
+ return fmt::sprintf("GemmTunableOp_%s_%c%c", TypeName<T>(T{}), BlasOpToString(ALayout), BlasOpToString(BLayout));
240
+ }
241
+ };
242
+
243
+ template <typename T, BlasOp ALayout, BlasOp BLayout>
244
+ class GemmAndBiasTunableOp : public TunableOp<GemmAndBiasParams<T>> {
245
+ public:
246
+ GemmAndBiasTunableOp() {
247
+ this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmAndBiasOp<T>>());
248
+
249
+ #ifdef USE_ROCM
250
+ static const auto env_hipblaslt = c10::utils::check_env("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
251
+ if (!env_hipblaslt.has_value() || env_hipblaslt.value()) {
252
+ // disallow tuning of hipblaslt with c10::complex
253
+ if constexpr (
254
+ !std::is_same_v<T, c10::complex<float>> &&
255
+ !std::is_same_v<T, c10::complex<double>>) {
256
+ for (auto&& [name, op] : GetHipBlasLtGemmAndBiasTypeStringAndOps<T, ALayout, BLayout>()) {
257
+ this->RegisterOp(std::move(name), std::move(op));
258
+ }
259
+ }
260
+ }
261
+ #endif
262
+
263
+ this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmAndBiasOp<T>>());
264
+ }
265
+
266
+ std::string Signature() override {
267
+ return fmt::sprintf("GemmAndBiasTunableOp_%s_%c%c", TypeName<T>(T{}), BlasOpToString(ALayout), BlasOpToString(BLayout));
268
+ }
269
+ };
270
+
271
+ template <typename T, BlasOp ALayout, BlasOp BLayout>
272
+ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>> {
273
+ public:
274
+ GemmStridedBatchedTunableOp() {
275
+ this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmStridedBatchedOp<T>>());
276
+
277
+ #ifdef USE_ROCM
278
+ static const auto env_rocblas = c10::utils::check_env("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED");
279
+ if (!env_rocblas.has_value() || env_rocblas.value()) {
280
+ for (auto&& [name, op] : GetRocBlasGemmStridedBatchedTypeStringAndOps<T>()) {
281
+ this->RegisterOp(std::move(name), std::move(op));
282
+ }
283
+ }
284
+
285
+ static const auto env_hipblaslt = c10::utils::check_env("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
286
+ if (!env_hipblaslt.has_value() || env_hipblaslt.value()) {
287
+ // disallow tuning of hipblaslt with c10::complex
288
+ if constexpr (
289
+ !std::is_same_v<T, c10::complex<float>> &&
290
+ !std::is_same_v<T, c10::complex<double>>) {
291
+ for (auto&& [name, op] : GetHipBlasLtGemmStridedBatchedTypeStringAndOps<T, ALayout, BLayout>()) {
292
+ this->RegisterOp(std::move(name), std::move(op));
293
+ }
294
+ }
295
+ }
296
+ #endif
297
+
298
+ this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmStridedBatchedOp<T>>());
299
+ }
300
+
301
+ std::string Signature() override {
302
+ return fmt::sprintf("GemmStridedBatchedTunableOp_%s_%c%c", TypeName<T>(T{}), BlasOpToString(ALayout), BlasOpToString(BLayout));
303
+ }
304
+ };
305
+
306
+ template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout>
307
+ class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>> {
308
+ public:
309
+ ScaledGemmTunableOp() {
310
+ this->RegisterOp(std::string("Default"), std::make_unique<DefaultScaledGemmOp<CT>>());
311
+
312
+ #ifdef USE_ROCM
313
+ for (auto&& [name, op] : GetHipBlasLtScaledGemmTypeStringAndOps<AT, BT, CT, ALayout, BLayout>()) {
314
+ this->RegisterOp(std::move(name), std::move(op));
315
+ }
316
+ #endif
317
+
318
+ this->RegisterOp(std::string("Default"), std::make_unique<DefaultScaledGemmOp<CT>>());
319
+ }
320
+
321
+ std::string Signature() override {
322
+ return fmt::sprintf("ScaledGemmTunableOp_%s_%s_%s_%c%c",
323
+ TypeName<AT>(AT{}),
324
+ TypeName<BT>(BT{}),
325
+ TypeName<CT>(CT{}),
326
+ BlasOpToString(ALayout), BlasOpToString(BLayout));
327
+ }
328
+ };
329
+
330
+ } // namespace at::cuda::tunable
331
+
332
+ #else
333
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
334
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/TunableOp.h ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ // Original TunableOp is from onnxruntime.
3
+ // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
4
+ // https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
5
+ // Copyright (c) Microsoft Corporation.
6
+ // Licensed under the MIT license.
7
+ //
8
+ // Adapting TunableOp into PyTorch
9
+ // Copyright (c) Advanced Micro Devices, Inc.
10
+ //
11
+ #pragma once
12
+
13
+ #include <ATen/cuda/tunable/Tunable.h>
14
+ #include <ATen/cuda/tunable/StreamTimer.h>
15
+ #include <ATen/cuda/Sleep.h>
16
+ #include <c10/cuda/CUDACachingAllocator.h>
17
+
18
+ #ifndef _WIN32
19
+ #include <cxxabi.h>
20
+ #endif
21
+
22
+ #include <string>
23
+ #include <unordered_map>
24
+ #include <vector>
25
+ #include <deque>
26
+
27
+ namespace at::cuda::tunable {
28
+
29
+ template <typename ParamsT>
30
+ class Callable {
31
+ public:
32
+ virtual ~Callable() = default;
33
+ virtual TuningStatus Call(const ParamsT* /*unused*/) {
34
+ return FAIL;
35
+ }
36
+ virtual TuningStatus IsSupported(const ParamsT* params) {
37
+ return Call(params);
38
+ }
39
+ };
40
+
41
+ namespace {
42
+
43
+ /** http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance */
44
+
45
+ class Stats {
46
+ public:
47
+ Stats() {
48
+ _n = 0UL;
49
+ _mean = 0.0;
50
+ _M2 = 0.0;
51
+ _sum = 0.0;
52
+ _min = 0.0;
53
+ _max = 0.0;
54
+ }
55
+
56
+ void sample_value(const double x) {
57
+ double delta = 0;
58
+ _sum = _sum + x;
59
+ if (0UL == _n) {
60
+ _min = x;
61
+ _max = x;
62
+ }
63
+ else {
64
+ _min = _min < x ? _min : x;
65
+ _max = _max > x ? _max : x;
66
+ }
67
+ _n = _n + 1UL;
68
+ delta = x - _mean;
69
+ _mean = _mean + delta/_n;
70
+ _M2 = _M2 + delta * (x - _mean);
71
+ }
72
+
73
+ double variance() const {
74
+ return _M2/(_n-1);
75
+ }
76
+
77
+ double stddev() const {
78
+ return std::sqrt(variance());
79
+ }
80
+
81
+ unsigned long _n;
82
+ double _mean;
83
+ double _M2;
84
+ double _sum;
85
+ double _min;
86
+ double _max;
87
+ };
88
+
89
+ class FixedSizeStack {
90
+ private:
91
+ std::deque<std::string> stack;
92
+ const size_t max_size;
93
+
94
+ public:
95
+ FixedSizeStack(size_t size) : max_size(size) {}
96
+
97
+ void push(const std::string& value) {
98
+ if (stack.size() >= max_size) {
99
+ stack.pop_front(); // Remove the oldest entry
100
+ }
101
+ stack.push_back(value); // Add new entry
102
+ }
103
+
104
+ auto rbegin() { return stack.rbegin(); }
105
+ auto rend() { return stack.rend(); }
106
+ };
107
+
108
+ } // anonymous namespace
109
+
110
+ template <typename ParamsT>
111
+ class TunableOp {
112
+ public:
113
+ virtual ~TunableOp() = default;
114
+
115
+ TuningStatus operator()(const ParamsT* params) {
116
+ ResultEntry result = ResultEntry::Null();
117
+ TuningContext* ctx = getTuningContext();
118
+ if (ctx->IsTunableOpEnabled()) {
119
+ auto& mgr = ctx->GetTuningResultsManager();
120
+ auto op_sig = Signature();
121
+ auto params_sig = params->Signature();
122
+ auto blas_sig = params->BLASSignature();
123
+ result = mgr.Lookup(op_sig, params_sig);
124
+ // If there is not previous tuning result been found, we do the tuning iff tuning is enabled
125
+ if (result == ResultEntry::Null()) {
126
+ if (ctx->IsTuningEnabled()) {
127
+ result = FindFastest(params);
128
+ mgr.Add(op_sig, params_sig, result);
129
+ }
130
+ else if (ctx->IsRecordUntunedEnabled()) {
131
+ // or record the gemm into file
132
+ mgr.RecordUntuned(ctx->GetUntunedFile(), op_sig, params_sig, blas_sig);
133
+ }
134
+ }
135
+ }
136
+ else {
137
+ result = ResultEntry::Default();
138
+ }
139
+ if (result == ResultEntry::Null()) {
140
+ TUNABLE_LOG2("no result, using default");
141
+ result = ResultEntry::Default();
142
+ }
143
+ auto iter = ops_.find(result);
144
+ TORCH_CHECK(iter != ops_.end());
145
+ return iter->second->Call(params);
146
+ }
147
+
148
+ virtual std::string Signature() {
149
+ // According to C++17 standard https://wg21.link/n4659 section 15.7.4
150
+ // > if the operand of typeid refers to the
151
+ // > object under construction or destruction, typeid yields the std::type_info object representing the constructor
152
+ // > or destructor’s class.
153
+ // So delay the op signature generation.
154
+ c10::call_once(signature_init_once_, [this]() { signature_ = CreateSignature(); });
155
+ return signature_;
156
+ }
157
+
158
+ protected:
159
+ void RegisterOp(const std::string& name, std::unique_ptr<Callable<ParamsT>> op) {
160
+ this->op_names_.emplace_back(name);
161
+ this->ops_.emplace(name, std::move(op));
162
+ }
163
+
164
+ private:
165
+ static void WarmUp(Callable<ParamsT> *op, const std::vector<ParamsT*> &param, size_t num_iter, size_t &offset) {
166
+ TuningContext* ctx = getTuningContext();
167
+ bool do_flush = ctx->IsICacheFlushEnabled();
168
+ for (size_t i = 0; i < num_iter; i++) {
169
+ if (do_flush) {
170
+ at::cuda::flush_icache();
171
+ }
172
+ TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK);
173
+ }
174
+ }
175
+
176
+ static double ProfileSimple(Callable<ParamsT> *op, const std::vector<ParamsT*> &param, size_t num_iter, size_t &offset) {
177
+ TuningContext* ctx = getTuningContext();
178
+ bool do_flush = ctx->IsICacheFlushEnabled();
179
+ StreamTimerNoSync timer{};
180
+
181
+ // Small Mandatory Warmup
182
+ // Reduces outliers
183
+ for (size_t i = 0; i < 2; i++) {
184
+ TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK);
185
+ }
186
+
187
+ timer.Start();
188
+ for (size_t i = 0; i < num_iter; i++) {
189
+ if (do_flush) {
190
+ at::cuda::flush_icache();
191
+ }
192
+ TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK);
193
+ }
194
+ timer.End();
195
+ return timer.Duration() / num_iter;
196
+ }
197
+
198
+ static Stats ProfileStats(Callable<ParamsT> *op, const std::vector<ParamsT*> &param, size_t num_iter, size_t &offset) {
199
+ TuningContext* ctx = getTuningContext();
200
+ bool do_flush = ctx->IsICacheFlushEnabled();
201
+ std::vector<StreamTimerNoSync> timer(num_iter);
202
+
203
+ // Small Mandatory Warmup
204
+ // Reduces outliers
205
+ for (size_t i = 0; i < 2; i++) {
206
+ TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK);
207
+ }
208
+
209
+ for (size_t i = 0; i < num_iter; i++) {
210
+ timer[i].Start();
211
+ TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK);
212
+ timer[i].End();
213
+ if (do_flush) {
214
+ at::cuda::flush_icache();
215
+ }
216
+ }
217
+ Stats s;
218
+ for (size_t i = 0; i < num_iter; i++) {
219
+ s.sample_value(timer[i].Duration());
220
+ }
221
+ return s;
222
+ }
223
+
224
+ protected:
225
+ virtual ResultEntry FindFastest(const ParamsT* params) {
226
+ TuningContext* ctx = getTuningContext();
227
+ auto op_sig = Signature();
228
+ auto params_sig = params->Signature();
229
+ auto blas_sig = params->BLASSignature();
230
+ TUNABLE_LOG2("finding fastest for ", op_sig, '(', params_sig, ')', " out of ", op_names_.size(), " candidates");
231
+ auto min_duration_ms = std::numeric_limits<double>::infinity();
232
+ std::string id_name = "Default";
233
+ ParamsT* reference_params = nullptr;
234
+ auto top_solns = FixedSizeStack(5);
235
+
236
+ // numeric check option is controlled by non-static env var, so check it once per tuned operator
237
+ bool do_numerics_check = ctx->IsNumericsCheckEnabled();
238
+
239
+ // calculate a reference answer for numerical check
240
+ if (do_numerics_check) {
241
+ reference_params = params->DeepCopy(false);
242
+ TORCH_CHECK(ops_[ResultEntry::Default()]->Call(reference_params) == OK);
243
+ }
244
+
245
+ // need copies of params to reuse
246
+ // make as many copies as will fill the requested rotating buffer size, if requested
247
+ // rotating_size guaranteed to be >= 0 even though GetRotatingBufferSize() returns int
248
+ size_t rotating_size = ctx->GetRotatingBufferSize();
249
+ bool use_buffer_rotation = (rotating_size > 0);
250
+ size_t param_size = params->GetSize(use_buffer_rotation);
251
+ size_t param_count = (rotating_size / param_size) + 1;
252
+ constexpr size_t MB = 1024ull*1024;
253
+ if (use_buffer_rotation) {
254
+ TUNABLE_LOG2("Rotating buffer ", rotating_size/MB, " MiB. ",
255
+ "Needed Size: ", param_size/MB, " MiB. ",
256
+ "Needed number of param copies: ", param_count);
257
+ }
258
+ TORCH_CHECK(param_count > 0);
259
+
260
+ std::vector<ParamsT*> reusable_params(param_count);
261
+ for (size_t i = 0; i < param_count; i++) {
262
+ reusable_params[i] = params->DeepCopy(use_buffer_rotation);
263
+ }
264
+
265
+ // for rotating buffer
266
+ size_t offset = 0;
267
+
268
+ for (size_t i = 0; i < op_names_.size(); i++) {
269
+ auto* candidate = ops_[op_names_[i]].get(); // borrow pointer
270
+
271
+ auto status = candidate->Call(reusable_params[0]);
272
+ if (status != OK) {
273
+ TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
274
+ continue;
275
+ }
276
+
277
+ // collect a small profile
278
+ int approx_num_iter = 3;
279
+ auto s = ProfileStats(candidate, reusable_params, approx_num_iter, offset);
280
+ double approx_duration = s._mean;
281
+ // bail if too slow
282
+ if (approx_duration > 1.5 * min_duration_ms) {
283
+ TUNABLE_LOG3("├──skip slow instance id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
284
+ continue;
285
+ }
286
+
287
+ // 2nd phase skip, more aggressive
288
+ approx_num_iter = 10;
289
+ s = ProfileStats(candidate, reusable_params, approx_num_iter, offset);
290
+ approx_duration = s._mean;
291
+ // bail if too slow
292
+ if (approx_duration > 1.15 * min_duration_ms) {
293
+ TUNABLE_LOG3("├──2nd skip slow instance id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
294
+ continue;
295
+ }
296
+
297
+ if (do_numerics_check) {
298
+ ParamsT* numerical_params = params->DeepCopy(false);
299
+ auto status = candidate->Call(numerical_params);
300
+ if (status != OK) {
301
+ numerical_params->Delete();
302
+ TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
303
+ continue;
304
+ }
305
+ status = reference_params->NumericalCheck(numerical_params);
306
+ numerical_params->Delete();
307
+ if (status != OK) {
308
+ TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
309
+ continue;
310
+ }
311
+ }
312
+
313
+ // for warmup does user set max duration, max iters, or both?
314
+ // warmup is skipped by default, i.e. warmup_iter = 0
315
+ // warmup will be set to the non-zero value of max_warmup_duration
316
+ // or max_warmup_iter
317
+ // if both are non-zero, we take the smaller of the two.
318
+ double max_warmup_duration = ctx->GetMaxWarmupDurationMs();
319
+ int max_warmup_iter = ctx->GetMaxWarmupIterations();
320
+ int warmup_iter = 0; // default
321
+ if (max_warmup_duration > 0) {
322
+ int duration_iters = max_warmup_duration / approx_duration;
323
+ if (max_warmup_iter > 0) {
324
+ warmup_iter = std::min(max_warmup_iter, duration_iters);
325
+ }
326
+ else {
327
+ warmup_iter = duration_iters;
328
+ }
329
+ }
330
+ else if (max_warmup_iter > 0) {
331
+ warmup_iter = max_warmup_iter;
332
+ }
333
+
334
+ // for tuning does user set max duration, max iters, or both?
335
+ double max_tuning_duration = ctx->GetMaxTuningDurationMs();
336
+ int max_tuning_iter = ctx->GetMaxTuningIterations();
337
+ int tuning_iter = 100; // default
338
+ if (max_tuning_duration > 0) {
339
+ int duration_iters = max_tuning_duration / approx_duration;
340
+ if (max_tuning_iter > 0) {
341
+ tuning_iter = std::min(max_tuning_iter, duration_iters);
342
+ }
343
+ else {
344
+ tuning_iter = duration_iters;
345
+ }
346
+ }
347
+ else if (max_tuning_iter > 0) {
348
+ tuning_iter = max_tuning_iter;
349
+ }
350
+ // tuning must run at least 1 iteration
351
+ tuning_iter = std::max(1, tuning_iter);
352
+
353
+ // do the full warmup followed by tuning
354
+ double warmup_ms = warmup_iter * approx_duration;
355
+ double tuning_ms = tuning_iter * approx_duration;
356
+ TUNABLE_LOG3("├──tuning using "
357
+ "warmup iters ", warmup_iter, " [", warmup_ms, " ms] "
358
+ "and tuning iters ", tuning_iter, " [", tuning_ms, " ms] ",
359
+ "instance id=", i, ", ", op_sig, "(", params_sig, ") ", op_names_[i]);
360
+ TUNABLE_LOG3("├──offset at ", offset);
361
+ WarmUp(candidate, reusable_params, warmup_iter, offset);
362
+ s = ProfileStats(candidate, reusable_params, tuning_iter, offset);
363
+ auto s_stddev = s.stddev();
364
+ // Assume normal distribution.
365
+ // Solution with smallest mean + 2*sigma will be a better solution?
366
+ // if ((s._mean + 2*s_stddev) < (min_duration_ms + 2*min_stddev_ms)) {
367
+ if (s._mean < min_duration_ms) {
368
+ TUNABLE_LOG3("├──found better instance id=", i, ". " , s._mean, "ms. ", op_names_[i],
369
+ " min ", s._min,
370
+ " max ", s._max,
371
+ " mean ", s._mean,
372
+ " std ", s_stddev);
373
+ min_duration_ms = s._mean;
374
+ id_name = op_names_[i];
375
+ std::string current_soln = std::to_string(s._mean) + " " + op_names_[i];
376
+ top_solns.push(current_soln);
377
+ }
378
+ else {
379
+ TUNABLE_LOG3("├──found slower instance id=", i, ". " , s._mean, "ms. ", op_names_[i],
380
+ " min ", s._min,
381
+ " max ", s._max,
382
+ " mean ", s._mean,
383
+ " std ", s_stddev);
384
+ }
385
+ }
386
+
387
+ for (size_t i = 0; i < reusable_params.size(); i++) {
388
+ reusable_params[i]->Delete();
389
+ }
390
+ if (reference_params) {
391
+ reference_params->Delete();
392
+ }
393
+
394
+ TUNABLE_LOG2("└──found fastest for ", op_sig, '(', params_sig, ") ", id_name);
395
+ TUNABLE_LOG2("└──top five solutions for ", op_sig, '(', params_sig, ") ");
396
+ for (auto it = top_solns.rbegin(); it != top_solns.rend(); ++it) {
397
+ TUNABLE_LOG2(" ", *it);
398
+ }
399
+ return ResultEntry(id_name, min_duration_ms, blas_sig);
400
+ }
401
+
402
+ private:
403
+ std::string CreateSignature() {
404
+ #ifndef _WIN32
405
+ const auto* name = typeid(*this).name();
406
+ // NOLINTNEXTLINE(*array*)
407
+ char buf[256];
408
+ size_t buf_len = 256;
409
+ abi::__cxa_demangle(name, buf, &buf_len, nullptr);
410
+ buf[255] = '\0';
411
+ return buf;
412
+ #else
413
+ return typeid(*this).name();
414
+ #endif
415
+ }
416
+
417
+ mutable c10::once_flag signature_init_once_;
418
+ std::string signature_;
419
+
420
+ std::unordered_map<std::string, std::unique_ptr<Callable<ParamsT>>> ops_;
421
+ std::vector<std::string> op_names_;
422
+ };
423
+
424
+ struct OpParams {
425
+ virtual ~OpParams() = default;
426
+ virtual std::string Signature() const = 0;
427
+ virtual std::string BLASSignature() const = 0;
428
+ };
429
+
430
+ } // namespace at::cuda::tunable
431
+
432
+ #else
433
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
434
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/functorch/ADInterpreters.h ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+ #include <ATen/functorch/Interpreter.h>
4
+
5
+ namespace at::functorch {
6
+
7
+ // These are the interpreters for our AD transforms
8
+ // (grad, vjp and jvp).
9
+ // See NOTE: [functorch interpreter stack] for more details.
10
+
11
+ struct TORCH_API GradInterpreterPtr {
12
+ explicit GradInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Grad); }
13
+ TransformType key() const { return base_->key(); }
14
+ int64_t level() const { return base_->level(); }
15
+ void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
16
+ void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
17
+ bool prevGradMode() const {
18
+ return std::get<GradInterpreterMeta>(base_->meta()).prevGradMode_;
19
+ }
20
+ Tensor lift(const Tensor& tensor) const;
21
+ private:
22
+ const Interpreter* base_;
23
+ };
24
+
25
+ struct TORCH_API JvpInterpreterPtr {
26
+ explicit JvpInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Jvp); }
27
+ TransformType key() const { return base_->key(); }
28
+ int64_t level() const { return base_->level(); }
29
+ void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
30
+ void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
31
+ bool prevFwdGradMode() const {
32
+ return std::get<JvpInterpreterMeta>(base_->meta()).prevFwdGradMode_;
33
+ }
34
+ Tensor lift(const Tensor& tensor) const;
35
+ private:
36
+ const Interpreter* base_;
37
+ };
38
+
39
+ } // namespace at::functorch
40
+
41
+ #else
42
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
43
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchRulesHelper.h ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ // Copyright (c) Facebook, Inc. and its affiliates.
3
+ // All rights reserved.
4
+ //
5
+ // This source code is licensed under the BSD-style license found in the
6
+ // LICENSE file in the root directory of this source tree.
7
+ #pragma once
8
+
9
+ #include <c10/util/TypeList.h>
10
+
11
+ #include <ATen/ATen.h>
12
+ #include <ATen/Operators.h>
13
+
14
+ #include <ATen/functorch/DynamicLayer.h>
15
+ #include <ATen/functorch/TensorWrapper.h>
16
+ #include <ATen/functorch/BatchingMetaprogramming.h>
17
+ #include <ATen/functorch/LegacyVmapTransforms.h>
18
+ #include <ATen/functorch/BatchedFallback.h>
19
+ #include <ATen/functorch/PlumbingHelper.h>
20
+ #include <ATen/core/dispatch/Dispatcher.h>
21
+ #include <ATen/VmapGeneratedPlumbing.h>
22
+
23
+ #include <utility>
24
+
25
+ // This file contains helper functions for batching rules.
26
+
27
+ namespace at::functorch {
28
+
29
+ TORCH_API Tensor reshape_dim_into(int64_t src, int64_t dst, const Tensor& x);
30
+ TORCH_API Tensor reshape_dim_outof(int64_t src, int64_t size1, const Tensor& x);
31
+
32
+ TORCH_API Tensor reshape_dim_outof_symint(int64_t src, const c10::SymInt& size1, const Tensor& x);
33
+
34
+ Tensor moveBatchDimToFront(Tensor tensor, std::optional<int64_t> maybe_batch_dim);
35
+ int64_t rankWithoutBatchDim(const Tensor& tensor, std::optional<int64_t> maybe_batch_dim);
36
+ int64_t numelWithoutBatchDim(const Tensor& tensor, std::optional<int64_t> maybe_batch_dim);
37
+ std::optional<int64_t> valIfNonempty(std::optional<int64_t> maybe_empty, int64_t new_val);
38
+ int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical_dim);
39
+ VmapDimVector getPhysicalDims(const Tensor& tensor, bool has_batch_dim, IntArrayRef logical_dims);
40
+
41
+ void vmapIncompatibleInplaceError(const char* schema_name);
42
+
43
+ Tensor maybePadToLogicalRank(const Tensor& tensor, std::optional<int64_t> has_bdim, int64_t logical_rank);
44
+
45
+ void check_randomness(RandomnessType randomness);
46
+ void check_randomness(RandomnessType randomness, bool any_tensor_bdim);
47
+
48
+ inline Tensor ensure_has_bdim(const Tensor& tensor, bool has_bdim, c10::SymInt batch_size) {
49
+ if (has_bdim) {
50
+ return tensor;
51
+ }
52
+ const auto sizes = tensor.sym_sizes();
53
+ SymDimVector expanded_shape;
54
+ expanded_shape.reserve(sizes.size());
55
+ expanded_shape.emplace_back(std::move(batch_size));
56
+ expanded_shape.insert(expanded_shape.end(), sizes.begin(), sizes.end());
57
+ return tensor.expand_symint(expanded_shape);
58
+ }
59
+
60
+ #define VMAP_SUPPORT(op, batch_rule) \
61
+ m.impl(#op, op ## _generated_plumbing<decltype(&batch_rule), &batch_rule>);
62
+
63
+ #define VMAP_SUPPORT2(op, overload, batch_rule) \
64
+ m.impl(#op "." #overload, op ## _ ## overload ## _generated_plumbing<decltype(&batch_rule), &batch_rule>);
65
+
66
+ #define OP_DECOMPOSE(op) m.impl(#op, static_cast<decltype(&ATEN_FN(op))>(native::op));
67
+ #define OP_DECOMPOSE2(op, overload) m.impl(#op"."#overload, static_cast<decltype(&ATEN_FN2(op, overload))>(native::op));
68
+
69
+ // DO NOT USE ME DIRECTLY! Use BASIC_UNARY_BATCH_RULE to save yourself some pain
70
+ template <typename A, A a, typename C>
71
+ struct BasicUnaryBatchRuleHelper;
72
+
73
+ template <typename F, F Func, typename A, typename... T>
74
+ struct BasicUnaryBatchRuleHelper<F, Func, c10::guts::typelist::typelist<A, T...>> {
75
+ static std::tuple<Tensor, std::optional<int64_t>> apply(
76
+ const Tensor& tensor,
77
+ std::optional<int64_t> batch_dim,
78
+ T... extra_args) {
79
+ return std::make_tuple(Func(tensor, std::forward<T>(extra_args)...), batch_dim);
80
+ }
81
+ };
82
+
83
+ // USAGE: BASIC_UNARY_BATCH_RULE(at::sin)
84
+ // INCORRECT USAGE: BASIC_UNARY_BATCH_RULE(&at::sin)
85
+ // It is important that this macro is not passed a function pointer!!
86
+ #define BASIC_UNARY_BATCH_RULE(fn) SINGLE_ARG(\
87
+ BasicUnaryBatchRuleHelper<\
88
+ decltype(&fn),\
89
+ &fn,\
90
+ c10::guts::function_traits<decltype(fn)>::parameter_types>::apply)
91
+
92
+ #define UNARY_POINTWISE(op) \
93
+ VMAP_SUPPORT(op, BASIC_UNARY_BATCH_RULE(ATEN_FN(op)));
94
+
95
+ template <typename A, A a, typename C>
96
+ struct VariadicBdimsBatchRuleHelper;
97
+
98
+ template <typename F, F Func, typename A, typename... T>
99
+ struct VariadicBdimsBatchRuleHelper<F, Func, c10::guts::typelist::typelist<A, T...>> {
100
+ static std::tuple<Tensor, std::optional<int64_t>> apply(
101
+ const Tensor& tensor,
102
+ std::optional<int64_t> batch_dim,
103
+ T... extra_args) {
104
+ auto tensor_ = moveBatchDimToFront(tensor, batch_dim);
105
+ return std::make_tuple(Func(tensor_, std::forward<T>(extra_args)...), 0);
106
+ }
107
+ };
108
+
109
+ // USAGE: VARIADIC_BDIMS_BATCH_RULE(at::cholesky_inverse)
110
+ // INCORRECT USAGE: VARIADIC_BDIMS_BATCH_RULE(&at::cholesky_inverse)
111
+ // It is important that this macro is not passed a function pointer!!
112
+ #define VARIADIC_BDIMS_BATCH_RULE(fn) SINGLE_ARG(\
113
+ VariadicBdimsBatchRuleHelper<\
114
+ decltype(&fn),\
115
+ &fn,\
116
+ c10::guts::function_traits<decltype(fn)>::parameter_types>::apply)
117
+
118
+ #define VARIADIC_BDIMS(op) \
119
+ VMAP_SUPPORT(op, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(op)));
120
+
121
+ #define VARIADIC_BDIMS2(op, overload) \
122
+ VMAP_SUPPORT2(op, overload, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN2(op, overload)));
123
+
124
+ template<class F, F Func>
125
+ void boxed_tensor_inputs_batch_rule(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
126
+ const auto& schema = op.schema();
127
+ const auto num_returns = schema.returns().size();
128
+ const auto num_arguments = schema.arguments().size();
129
+
130
+ c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
131
+ auto maybe_layer = maybeCurrentDynamicLayer();
132
+ vmap_check_escaped(maybe_layer, "boxed_tensor_inputs_batch_rule");
133
+
134
+ int64_t cur_level = maybe_layer->layerId();
135
+
136
+ auto orig_arguments = torch::jit::last(*stack, num_arguments);
137
+ if (std::none_of(orig_arguments.begin(), orig_arguments.end(), ivalueParticipatesInCurrentLevel)) {
138
+ op.callBoxed(stack);
139
+ return;
140
+ }
141
+
142
+ auto arguments = torch::jit::pop(*stack, num_arguments);
143
+ std::vector<std::pair<Tensor, std::optional<int64_t>>> tensor_inputs;
144
+ std::vector<int64_t> tensor_pos;
145
+ for (const auto idx : c10::irange(0, num_arguments)) {
146
+ const auto& ivalue = arguments[idx];
147
+ if (ivalue.isTensor()) {
148
+ auto [tensor_value, tensor_bdim] = unwrapTensorAtLevel(ivalue.toTensor(), cur_level);
149
+ tensor_inputs.emplace_back(std::move(tensor_value), tensor_bdim);
150
+ tensor_pos.push_back(static_cast<int64_t>(idx));
151
+ }
152
+ }
153
+ Func(tensor_inputs);
154
+
155
+ size_t tensor_idx = 0;
156
+ TORCH_INTERNAL_ASSERT(!tensor_pos.empty());
157
+ for (const auto arg_idx : c10::irange(0, num_arguments)) {
158
+ if (tensor_idx >= tensor_pos.size() || (int64_t)arg_idx != tensor_pos[tensor_idx]) {
159
+ torch::jit::push(stack, arguments[arg_idx]);
160
+ } else {
161
+ TORCH_INTERNAL_ASSERT(tensor_idx < tensor_inputs.size());
162
+ torch::jit::push(stack, tensor_inputs[tensor_idx].first);
163
+ tensor_idx++;
164
+ }
165
+ }
166
+
167
+ op.callBoxed(stack);
168
+ const auto returns = torch::jit::pop(*stack, num_returns);
169
+ for (const auto& ret : returns) {
170
+ if (ret.isTensor()) {
171
+ torch::jit::push(stack, makeBatched(ret.toTensor(), 0, cur_level));
172
+ } else {
173
+ TORCH_INTERNAL_ASSERT(false, "This boxed batching rule does not currently support ops that return non-tensor values");
174
+ }
175
+ }
176
+ }
177
+
178
+ inline void handle_pointwise_ops(std::vector<std::pair<Tensor, std::optional<int64_t>>> &tensor_inputs) {
179
+ int64_t out_logical_rank = 0;
180
+ for (auto& tensor_input : tensor_inputs) {
181
+ int64_t cur_logical_rank = rankWithoutBatchDim(tensor_input.first, tensor_input.second);
182
+ out_logical_rank = std::max(out_logical_rank, cur_logical_rank);
183
+ }
184
+ for (auto& tensor_input: tensor_inputs) {
185
+ tensor_input.first = moveBatchDimToFront(tensor_input.first, tensor_input.second);
186
+ tensor_input.first = maybePadToLogicalRank(tensor_input.first, tensor_input.second, out_logical_rank);
187
+ }
188
+ }
189
+
190
+ #define POINTWISE_BOXED(op) \
191
+ m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_tensor_inputs_batch_rule<decltype(&handle_pointwise_ops), &handle_pointwise_ops>>());
192
+
193
+ #define POINTWISE_BOXED2(op, overload) \
194
+ m.impl(#op "." #overload, torch::CppFunction::makeFromBoxedFunction<boxed_tensor_inputs_batch_rule<decltype(&handle_pointwise_ops), &handle_pointwise_ops>>());
195
+
196
+ inline void handle_variadic_bdims(std::vector<std::pair<Tensor, std::optional<int64_t>>> &tensor_inputs) {
197
+ for (auto & tensor_input : tensor_inputs) {
198
+ tensor_input.first = moveBatchDimToFront(tensor_input.first, tensor_input.second);
199
+ }
200
+ }
201
+
202
+ #define VARIADIC_BDIMS_BOXED(op) \
203
+ m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_tensor_inputs_batch_rule<decltype(&handle_variadic_bdims), &handle_variadic_bdims>>());
204
+
205
+ using UnpackedBatchedTensor = std::tuple<Tensor, std::optional<int64_t>>;
206
+
207
+ inline void find_and_unpack_tensors(
208
+ const torch::jit::Stack* stack,
209
+ int64_t num_args,
210
+ int64_t cur_level,
211
+ SmallVector<UnpackedBatchedTensor, 5>* tensors,
212
+ SmallVector<int64_t, 5>* tensors_pos,
213
+ int64_t* batch_size) {
214
+
215
+ int64_t computed_batch_size = -1;
216
+ int64_t args_begin = static_cast<int64_t>(stack->size()) - num_args;
217
+
218
+ for (const auto idx : c10::irange(0, num_args)) {
219
+ const auto& ivalue = (*stack)[args_begin + idx];
220
+ if (!ivalue.isTensor()) {
221
+ continue;
222
+ }
223
+ auto unpacked = unwrapTensorAtLevel(ivalue.toTensor(), cur_level);
224
+ const auto& [tensor_value, tensor_bdim] = unpacked;
225
+ if (tensor_bdim.has_value()) {
226
+ auto candidate_batch_size = tensor_value.size(*tensor_bdim);
227
+ if (computed_batch_size == -1) {
228
+ computed_batch_size = candidate_batch_size;
229
+ }
230
+ TORCH_INTERNAL_ASSERT(candidate_batch_size == computed_batch_size);
231
+ }
232
+
233
+ tensors->push_back(std::move(unpacked));
234
+ tensors_pos->push_back(idx);
235
+ }
236
+ TORCH_INTERNAL_ASSERT(computed_batch_size > -1);
237
+ *batch_size = computed_batch_size;
238
+ }
239
+
240
+ inline void boxed_existing_bdim_all_batch_rule(
241
+ const c10::OperatorHandle& op, torch::jit::Stack* stack) {
242
+ const auto& schema = op.schema();
243
+ const auto num_returns = schema.returns().size();
244
+ const auto num_arguments = static_cast<int64_t>(schema.arguments().size());
245
+
246
+ c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
247
+ const auto maybe_layer = maybeCurrentDynamicLayer();
248
+ vmap_check_escaped(maybe_layer, "boxed_existing_bdim_all_batch_rule");
249
+
250
+ const auto arguments = torch::jit::last(stack, num_arguments);
251
+ if (std::none_of(arguments.begin(), arguments.end(), ivalueParticipatesInCurrentLevel)) {
252
+ op.callBoxed(stack);
253
+ return;
254
+ }
255
+
256
+ int64_t args_begin = static_cast<int64_t>(stack->size()) - num_arguments;
257
+ SmallVector<UnpackedBatchedTensor, 5> tensor_inputs;
258
+ SmallVector<int64_t, 5> tensor_pos;
259
+ int64_t batch_size = 0;
260
+ // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
261
+ int64_t cur_level = maybe_layer->layerId();
262
+
263
+ find_and_unpack_tensors(
264
+ stack, num_arguments, cur_level,
265
+ &tensor_inputs, &tensor_pos, &batch_size);
266
+
267
+ // for each tensor, ensure it has a bdim and reshape it.
268
+ for (const auto tensor_idx : c10::irange(0, tensor_inputs.size())) {
269
+ const auto& [value, bdim] = tensor_inputs[tensor_idx];
270
+ auto value_ = ensure_has_bdim(value, bdim.has_value(), batch_size);
271
+ (*stack)[args_begin + tensor_pos[tensor_idx]] = reshape_dim_into(bdim.value_or(0), 0, value_);
272
+ }
273
+
274
+ op.callBoxed(stack);
275
+
276
+ for (const auto idx : c10::irange(args_begin, args_begin + num_returns)) {
277
+ const auto& ret = (*stack)[idx];
278
+ TORCH_INTERNAL_ASSERT(ret.isTensor(),
279
+ "This boxed batching rule does not currently support ops that return non-tensor values");
280
+ (*stack)[idx] = makeBatched(reshape_dim_outof(0, batch_size, ret.toTensor()), 0, cur_level);
281
+ }
282
+ }
283
+
284
+ // Use when all tensors arguments accept one (normal) batch dim.
285
+ // This batching rule expands the batch dim on all Tensors, reshapes it into
286
+ // dim 0, calls the op, and then reshapes the batch dim out of dim 0.
287
+ // This is not the most efficient thing; if there are alternatives, please try
288
+ // to use them. Use this only as a last resort.
289
+ #define EXISTING_BDIM_ALL_BOXED(op) \
290
+ m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_existing_bdim_all_batch_rule>());
291
+
292
+ template <int64_t feature_rank, int64_t contig_tensor_index=-1>
293
+ inline void boxed_all_tensors_have_optional_bdim(
294
+ const c10::OperatorHandle& op, torch::jit::Stack* stack) {
295
+ const auto& schema = op.schema();
296
+ const auto num_returns = schema.returns().size();
297
+ const auto num_arguments = schema.arguments().size();
298
+
299
+ c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
300
+ auto maybe_layer = maybeCurrentDynamicLayer();
301
+ vmap_check_escaped(maybe_layer, "boxed_all_tensors_have_optional_bdim");
302
+ int64_t cur_level = maybe_layer->layerId();
303
+
304
+ const auto arguments = torch::jit::last(stack, num_arguments);
305
+ if (std::none_of(arguments.begin(), arguments.end(), ivalueParticipatesInCurrentLevel)) {
306
+ op.callBoxed(stack);
307
+ return;
308
+ }
309
+
310
+ int64_t args_begin = static_cast<int64_t>(stack->size() - num_arguments);
311
+ SmallVector<UnpackedBatchedTensor, 5> tensor_inputs;
312
+ SmallVector<int64_t, 5> tensor_pos;
313
+ int64_t batch_size = 0;
314
+
315
+ find_and_unpack_tensors(
316
+ stack, static_cast<int64_t>(num_arguments), cur_level,
317
+ &tensor_inputs, &tensor_pos, &batch_size);
318
+
319
+ std::optional<bool> is_no_batch_dim_case;
320
+
321
+ for (const auto tensor_idx : c10::irange(0, tensor_inputs.size())) {
322
+ const auto& value = std::get<0>(tensor_inputs[tensor_idx]);
323
+ auto bdim = std::get<1>(tensor_inputs[tensor_idx]);
324
+ const auto logical_rank = rankWithoutBatchDim(value, bdim);
325
+
326
+ if (!is_no_batch_dim_case.has_value()) {
327
+ is_no_batch_dim_case = (logical_rank == feature_rank);
328
+ }
329
+ auto value_ = ensure_has_bdim(value, bdim.has_value(), batch_size);
330
+ if (!bdim.has_value()) {
331
+ bdim = 0;
332
+ }
333
+ if (*is_no_batch_dim_case) {
334
+ TORCH_INTERNAL_ASSERT(logical_rank == feature_rank);
335
+ value_ = moveBatchDimToFront(value_, bdim);
336
+ if (tensor_idx == contig_tensor_index) {
337
+ value_ = value_.contiguous();
338
+ }
339
+ (*stack)[args_begin + tensor_pos[tensor_idx]] = std::move(value_);
340
+ continue;
341
+ }
342
+ TORCH_INTERNAL_ASSERT(logical_rank == feature_rank + 1);
343
+ value_ = reshape_dim_into(*bdim, 0, value_);
344
+ if (tensor_idx == contig_tensor_index) {
345
+ value_ = value_.contiguous();
346
+ }
347
+ (*stack)[args_begin + tensor_pos[tensor_idx]] = std::move(value_);
348
+ }
349
+
350
+ op.callBoxed(stack);
351
+
352
+ for (const auto idx : c10::irange(args_begin, args_begin + num_returns)) {
353
+ const auto& ret = (*stack)[idx];
354
+ TORCH_INTERNAL_ASSERT(ret.isTensor(),
355
+ "This boxed batching rule does not currently support ops that return non-tensor values");
356
+ if (*is_no_batch_dim_case) {
357
+ (*stack)[idx] = makeBatched(ret.toTensor(), 0, cur_level);
358
+ } else {
359
+ (*stack)[idx] = makeBatched(reshape_dim_outof(0, batch_size, ret.toTensor()), 0, cur_level);
360
+ }
361
+ }
362
+ }
363
+
364
+ // Useful for many NN operators.
365
+ // The operator must satisfy the following:
366
+ // - All arguments must accept an optional batch dim.
367
+ // - All arguments must be the same rank
368
+ #define ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED(feature_rank, op) \
369
+ m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_all_tensors_have_optional_bdim<feature_rank>>());
370
+
371
+ #define ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(feature_rank, op, contig_tensor_index) \
372
+ m.impl(#op, \
373
+ torch::CppFunction::makeFromBoxedFunction<\
374
+ boxed_all_tensors_have_optional_bdim<\
375
+ feature_rank, \
376
+ contig_tensor_index>\
377
+ >());
378
+
379
+ template <typename A, A a, typename C>
380
+ struct ExistingBdimBatchRuleHelper;
381
+
382
+ template <typename F, F Func, typename A, typename... T>
383
+ struct ExistingBdimBatchRuleHelper<F, Func, c10::guts::typelist::typelist<A, T...>> {
384
+ static std::tuple<Tensor, std::optional<int64_t>> apply(
385
+ const Tensor& self,
386
+ std::optional<int64_t> self_bdim,
387
+ T... extra_args) {
388
+ auto self_ = reshape_dim_into(*self_bdim, 0, self);
389
+ auto out = Func(self_, std::forward<T>(extra_args)...);
390
+ return std::make_tuple(reshape_dim_outof_symint(0, self.sym_sizes()[*self_bdim], out), 0);
391
+ }
392
+ };
393
+
394
+ // USAGE: EXISTING_BDIM_BATCH_RULE(at::cholesky_inverse)
395
+ // INCORRECT USAGE: EXISTING_BDIM_BATCH_RULE(&at::cholesky_inverse)
396
+ // It is important that this macro is not passed a function pointer!!
397
+ #define EXISTING_BDIM_BATCH_RULE(fn) SINGLE_ARG(\
398
+ ExistingBdimBatchRuleHelper<\
399
+ decltype(&fn),\
400
+ &fn,\
401
+ c10::guts::function_traits<decltype(fn)>::parameter_types>::apply)
402
+
403
+
404
+ #define EXISTING_BDIM(op) \
405
+ VMAP_SUPPORT(op, EXISTING_BDIM_BATCH_RULE(ATEN_FN(op)));
406
+
407
+ #define EXISTING_BDIM2(op, overload) \
408
+ VMAP_SUPPORT2(op, overload, EXISTING_BDIM_BATCH_RULE(ATEN_FN2(op, overload)));
409
+
410
+ #define INVOKE(object,ptrToMember) ((object).*(ptrToMember))
411
+
412
+
413
+ template <typename F, F Method, typename... ExtraArgs>
414
+ Tensor& unary_inplace_batch_rule(Tensor& self, std::optional<int64_t> /*unused*/, ExtraArgs... extra_args) {
415
+ INVOKE(self, Method)(std::forward<ExtraArgs>(extra_args)...);
416
+ return self;
417
+ }
418
+
419
+ inline int64_t get_bdim_size4(
420
+ const Tensor& a_value, std::optional<int64_t> a_bdim,
421
+ const Tensor& b_value, std::optional<int64_t> b_bdim,
422
+ const Tensor& c_value, std::optional<int64_t> c_bdim,
423
+ const Tensor& d_value, std::optional<int64_t> d_bdim) {
424
+ if (a_bdim)
425
+ return a_value.size(*a_bdim);
426
+ if (b_bdim)
427
+ return b_value.size(*b_bdim);
428
+ if (c_bdim)
429
+ return c_value.size(*c_bdim);
430
+ if (d_bdim)
431
+ return d_value.size(*d_bdim);
432
+ TORCH_INTERNAL_ASSERT(false);
433
+ }
434
+
435
+ inline int64_t get_bdim_size3(
436
+ const Tensor& a_value, std::optional<int64_t> a_bdim,
437
+ const Tensor& b_value, std::optional<int64_t> b_bdim,
438
+ const Tensor& c_value, std::optional<int64_t> c_bdim) {
439
+ if (a_bdim)
440
+ return a_value.size(*a_bdim);
441
+ if (b_bdim)
442
+ return b_value.size(*b_bdim);
443
+ if (c_bdim)
444
+ return c_value.size(*c_bdim);
445
+ TORCH_INTERNAL_ASSERT(false);
446
+ }
447
+
448
+ inline int64_t get_bdim_size2(
449
+ const Tensor& a_value, std::optional<int64_t> a_bdim,
450
+ const Tensor& b_value, std::optional<int64_t> b_bdim) {
451
+ if (a_bdim)
452
+ return a_value.size(*a_bdim);
453
+ if (b_bdim)
454
+ return b_value.size(*b_bdim);
455
+ TORCH_INTERNAL_ASSERT(false);
456
+ }
457
+
458
+ inline c10::SymInt get_bdim_size2_symint(
459
+ const Tensor& a_value, std::optional<int64_t> a_bdim,
460
+ const Tensor& b_value, std::optional<int64_t> b_bdim) {
461
+ if (a_bdim)
462
+ return a_value.sym_size(*a_bdim);
463
+ if (b_bdim)
464
+ return b_value.sym_size(*b_bdim);
465
+ TORCH_INTERNAL_ASSERT(false);
466
+ }
467
+
468
+ // [start, start + 1, ..., stop - 1]
469
+ inline VmapDimVector range(int64_t start, int64_t stop) {
470
+ TORCH_INTERNAL_ASSERT(stop >= start);
471
+ VmapDimVector dims;
472
+ dims.reserve(stop - start);
473
+ for (int64_t i = start; i < stop; i++) {
474
+ dims.emplace_back(i);
475
+ }
476
+ return dims;
477
+ }
478
+ std::tuple<Tensor, Tensor> _binary_pointwise_helper(
479
+ const Tensor& tensor, std::optional<int64_t> tensor_batch_dim, const Tensor& other, std::optional<int64_t> other_batch_dim,
480
+ bool do_type_promotion=true);
481
+
482
+ } // namespace at::functorch
483
+
484
+ #else
485
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
486
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchedFallback.h ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ // Copyright (c) Facebook, Inc. and its affiliates.
3
+ // All rights reserved.
4
+ //
5
+ // This source code is licensed under the BSD-style license found in the
6
+ // LICENSE file in the root directory of this source tree.
7
+
8
+ #pragma once
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/core/op_registration/op_registration.h>
11
+ #include <torch/library.h>
12
+
13
+ namespace at::functorch {
14
+
15
+ // This file contains code for the vmap fallback (also known as the
16
+ // BatchedTensor fallback or the Batched fallback). This code runs
17
+ // when an operation doesn't have a batching rule implemented.
18
+
19
+ // If an operator doesn't have a batching rule implemented then we fallback
20
+ // to this implementation. The fallback doesn't work on out= variants or
21
+ // view operations; that is, it works for out-of-place operations and
22
+ // in-place non-view operations.
23
+ //
24
+ // For out-of-place operations, the fallback effectively takes all of the
25
+ // BatchedTensors in `stack`, slices them, and runs `op` on all of the
26
+ // corresponding slices to produce slices of the outputs. The output slices
27
+ // then get `torch.stack`ed to create the
28
+ // final returns.
29
+ //
30
+ // The performance of the fallback is not very good because it introduces an
31
+ // extra copy from stacking the sliced outputs. Because of this, we prefer to
32
+ // write batching rules for operators whenever possible.
33
+ void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
34
+ void batchedNestedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
35
+
36
+ void vmapErrorFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
37
+
38
+ // The vmap fallback emits a warning by default, but it may be disabled if
39
+ // the user finds it to be too annoying.
40
+ TORCH_API bool isVmapFallbackWarningEnabled();
41
+ TORCH_API void setVmapFallbackWarningEnabled(bool enabled);
42
+
43
+ // Used for testing. The vmap fallback is enabled by default. When it is disabled,
44
+ // it raises an error.
45
+ TORCH_API bool isVmapFallbackEnabled();
46
+ TORCH_API void setVmapFallbackEnabled(bool enabled);
47
+
48
+ template <typename A> A vector_to_result(const std::vector<IValue>& buffer) {
49
+ return buffer[0].to<A>();
50
+ }
51
+ template <typename A, typename B> std::tuple<A, B> vector_to_result(const std::vector<IValue>& buffer) {
52
+ return std::make_tuple(buffer[0].to<A>(), buffer[1].to<B>());
53
+ }
54
+ template <typename A, typename B, typename C> std::tuple<A, B, C> vector_to_result(const std::vector<IValue>& buffer) {
55
+ return std::make_tuple(buffer[0].to<A>(), buffer[1].to<B>(), buffer[2].to<B>());
56
+ }
57
+
58
+ // slow_fallback is a way to call the vmap fallback inside some boxed kernel.
59
+ // There is probably some better way to metaprogram this.
60
+ template <typename Ret>
61
+ Ret slow_fallback(const c10::OperatorHandle& op, ArrayRef<IValue> args) {
62
+ std::vector<IValue> stack(args.begin(), args.end());
63
+ batchedTensorForLoopFallback(op, &stack);
64
+ return vector_to_result<Ret>(stack);
65
+ }
66
+
67
+ template <typename A, typename B>
68
+ std::tuple<A, B> slow_fallback(const c10::OperatorHandle& op, ArrayRef<IValue> args) {
69
+ std::vector<IValue> stack(args.begin(), args.end());
70
+ batchedTensorForLoopFallback(op, &stack);
71
+ return vector_to_result<A, B>(stack);
72
+ }
73
+
74
+ template <typename A, typename B, typename C>
75
+ std::tuple<A, B, C> slow_fallback(const c10::OperatorHandle& op, ArrayRef<IValue> args) {
76
+ std::vector<IValue> stack(args.begin(), args.end());
77
+ batchedTensorForLoopFallback(op, &stack);
78
+ return vector_to_result<A, B, C>(stack);
79
+ }
80
+
81
+
82
+ } // namespace at::functorch
83
+
84
+ #else
85
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
86
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchedTensorImpl.h ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ // Copyright (c) Facebook, Inc. and its affiliates.
3
+ // All rights reserved.
4
+ //
5
+ // This source code is licensed under the BSD-style license found in the
6
+ // LICENSE file in the root directory of this source tree.
7
+
8
+ #pragma once
9
+
10
+ #include <bitset>
11
+
12
+ #include <ATen/ArrayRef.h>
13
+ #include <ATen/SmallVector.h>
14
+ #include <ATen/Tensor.h>
15
+
16
+ namespace at::functorch {
17
+
18
+ using Tensor = at::Tensor;
19
+
20
+ // We assume this in a few other places in the codebase,
21
+ // but there isn't a centralized definition.
22
+ constexpr int64_t kVmapMaxTensorDims = 64;
23
+
24
+ // The valid vmap levels range from [0, 64). This effectively means that we
25
+ // support a maximum of 64 nested vmaps.
26
+ constexpr int64_t kVmapNumLevels = 64;
27
+
28
+ // Store this number of elements of BatchDims on the stack. Most people will
29
+ // probably use <= 5 nested vmaps, but adjust this number as necessary.
30
+ constexpr int64_t kBatchDimsStackSize = 5;
31
+
32
+ // A BatchedTensorImpl holds an underlying Tensor and a single batch dim
33
+ // NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
34
+ // BatchedTensorImpl.
35
+ //
36
+ // The batch dimensions are treated as being "private"; they are not user-visible.
37
+ // For example, in the following Tensor,
38
+ // bt = BatchedTensorImpl(ones(2, 3, 5, 7), lvl=1, dim=0)
39
+ // dimension 0 is batch dimension.
40
+ //
41
+ // bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public)
42
+ // dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7) tensor.
43
+ struct TORCH_API BatchedTensorImpl : public c10::TensorImpl {
44
+ explicit BatchedTensorImpl(at::DispatchKeySet key_set, Tensor value, int64_t dim, int64_t level);
45
+
46
+ // Returns batch dimension of this tensor
47
+ int64_t bdim() const { return bdim_; }
48
+
49
+ // Returns batch dimension of this tensor
50
+ int64_t level() const { return level_; }
51
+
52
+ // BatchedTensorImpl wraps a Tensor
53
+ const Tensor& value() const { return value_; }
54
+
55
+ // Given a public dimension index, return the dimension index in the underlying
56
+ // value() tensor.
57
+ // For example, if we have
58
+ // bt = BatchedTensorImpl(ones(2, 3, 5, 7), lvl=1, dim=0)
59
+ // bt.actualDim(0) -> 1
60
+ // bt.actualDim(1) -> 2
61
+ // bt.actualDim(2) -> 3
62
+ // bt.actualDim(3) -> Error
63
+ int64_t actualDim(int64_t dim, bool wrap_dim = true) const;
64
+
65
+ IntArrayRef sizes_custom() const override;
66
+ SymIntArrayRef sym_sizes_custom() const override;
67
+ int64_t size_custom(int64_t d) const override;
68
+ c10::SymInt sym_size_custom(int64_t d) const override;
69
+ // We have to override this because we opted into CustomStrides
70
+ IntArrayRef strides_custom() const override;
71
+ SymIntArrayRef sym_strides_custom() const override;
72
+ // Override a bunch of methods inherited from TensorImpl to return error messages.
73
+ c10::SymBool sym_is_contiguous_custom(at::MemoryFormat memory_format) const override;
74
+ void set_size(int64_t dim, int64_t new_size) override;
75
+ void set_stride(int64_t dim, int64_t new_stride) override;
76
+ c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
77
+ const c10::VariableVersion& version_counter,
78
+ bool allow_tensor_metadata_change) const override;
79
+ c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
80
+ c10::VariableVersion&& version_counter,
81
+ bool allow_tensor_metadata_change) const override;
82
+ void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;
83
+ #ifdef DEBUG
84
+ bool has_storage() const override;
85
+ #endif
86
+
87
+ void refreshTensorMetadata();
88
+
89
+ // Used in torchdim. torchdim uses non-lexical BatchedTensor; the way it
90
+ // accomplishes this is a hack where it is able to modify the levels of
91
+ // BatchedTensor to match the level of the current vmap transform.
92
+ void _unsafe_set_level(int64_t level) {
93
+ level_ = level;
94
+ }
95
+
96
+ // Used in batching rule for in-place view operations that can change
97
+ // the index of the bdim (think squeeze_, unsqueeze_)
98
+ void unsafe_set_bdim(int64_t bdim) {
99
+ // NB: you MUST call refreshTensorMetadata after doing this.
100
+ bdim_ = bdim;
101
+ }
102
+ private:
103
+ // see NOTE: [BatchedTensorImpl levels invariant]
104
+ void checkInvariants() const;
105
+ const char* tensorimpl_type_name() const override;
106
+
107
+ Tensor value_;
108
+
109
+ int64_t level_;
110
+ int64_t bdim_;
111
+ };
112
+
113
+ // NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
114
+ // BatchedTensorImpl.
115
+ inline bool isBatchedTensor(const Tensor& tensor) {
116
+ return tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::FuncTorchBatched) ||
117
+ tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::BatchedNestedTensor);
118
+ }
119
+
120
+ // It is unsafe to call this on a Tensor that is not backed by a
121
+ // BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible.
122
+ inline BatchedTensorImpl* unsafeGetBatchedImpl(const Tensor& tensor) {
123
+ return static_cast<BatchedTensorImpl*>(tensor.unsafeGetTensorImpl());
124
+ }
125
+
126
+ inline BatchedTensorImpl* maybeGetBatchedImpl(const Tensor& tensor) {
127
+ if (!isBatchedTensor(tensor)) {
128
+ return nullptr;
129
+ }
130
+ return unsafeGetBatchedImpl(tensor);
131
+ }
132
+
133
+ // Returns a bitset. If bit i is set, then that means dim i is a batchdim.
134
+ inline std::bitset<kVmapMaxTensorDims> createBatchDimBitset(int64_t dim) {
135
+ std::bitset<kVmapMaxTensorDims> is_bdim;
136
+ is_bdim.set(dim);
137
+ return is_bdim;
138
+ }
139
+
140
+ // Creates a bitset for the given level
141
+ inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(int64_t level) {
142
+ std::bitset<kVmapNumLevels> result;
143
+ result.set(level);
144
+ return result;
145
+ }
146
+
147
+ // Use this to construct a BatchedTensor from a regular Tensor
148
+ TORCH_API Tensor makeBatched(Tensor tensor, int64_t dim, int64_t level);
149
+
150
+ // Adds a batch dim to `tensor`, returning a BatchedTensor
151
+ TORCH_API Tensor addBatchDim(Tensor tensor, int64_t dim, int64_t level);
152
+
153
+ // Certain dispatch keys must be propagated to the BatchedTensor (or, in general,
154
+ // any wrapper Tensor subclasses). This is because there are methods on Tensor
155
+ // that skip dispatch and check for the presence of a dispatch key (e.g. is_cpu()).
156
+ // TODO: should probably contain more (or all?) backend keys
157
+ constexpr DispatchKeySet kKeysToPropagateToWrapper({
158
+ DispatchKey::Negative,
159
+ DispatchKey::Conjugate,
160
+ DispatchKey::XLA,
161
+ DispatchKey::XPU,
162
+ DispatchKey::HPU,
163
+ DispatchKey::CUDA,
164
+ DispatchKey::CPU,
165
+ DispatchKey::PrivateUse1,
166
+ DispatchKey::SparseCPU,
167
+ DispatchKey::SparseCUDA,
168
+ DispatchKey::SparseCsrCPU,
169
+ DispatchKey::SparseCsrCUDA,
170
+ });
171
+
172
+ inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) {
173
+ auto key_set = tensor.unsafeGetTensorImpl()->key_set();
174
+ return key_set & kKeysToPropagateToWrapper;
175
+ }
176
+
177
+ } // namespace at::functorch
178
+
179
+ #else
180
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
181
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchingMetaprogramming.h ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ // Copyright (c) Facebook, Inc. and its affiliates.
3
+ // All rights reserved.
4
+ //
5
+ // This source code is licensed under the BSD-style license found in the
6
+ // LICENSE file in the root directory of this source tree.
7
+
8
+ #pragma once
9
+ #include <ATen/Tensor.h>
10
+ #include <ATen/VmapGeneratedPlumbing.h>
11
+
12
+ // This file contains template metaprogramming things that are used for our
13
+ // batching rules.
14
+ //
15
+ // See NOTE: [vmap plumbing] for more details on why this is necessary.
16
+ // The plumbing has a bunch of metaprogramming hacks for determining the signature
17
+ // of a batching rule from the signature of the operator, many of which use the
18
+ // helper functions in this file.
19
+
20
+ namespace at::functorch {
21
+
22
+ // Metaprogramming things
23
+ template <class... Items> using typelist = c10::guts::typelist::typelist<Items...>;
24
+ template <class TypeList> using head_t = c10::guts::typelist::head_t<TypeList>;
25
+ template <class TL1, class TL2> using concat_t = c10::guts::typelist::concat_t<TL1, TL2>;
26
+ template <typename T> class debug_t;
27
+
28
+ // tail operation
29
+ template<class TypeList>
30
+ struct tail final {
31
+ static_assert(c10::guts::false_t<TypeList>::value,
32
+ "In typelist::tail<T>, the T argument must be typelist<...>.");
33
+ };
34
+ template<class Head, class... Tail>
35
+ struct tail<typelist<Head, Tail...>> final {
36
+ using type = typelist<Tail...>;
37
+ };
38
+ template<class TypeList> using tail_t = typename tail<TypeList>::type;
39
+
40
+ template <class First, class Second, class Next, class Tail>
41
+ struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext {
42
+ using type = Next;
43
+ };
44
+ template <class Next, class Tail>
45
+ struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<Tensor, std::optional<int64_t>, Next, Tail> {
46
+ using type = Tail;
47
+ };
48
+ template <class Next, class Tail>
49
+ struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<const Tensor&, std::optional<int64_t>, Next, Tail> {
50
+ using type = Tail;
51
+ };
52
+ template <class Next, class Tail>
53
+ struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<Tensor&, std::optional<int64_t>, Next, Tail> {
54
+ using type = Tail;
55
+ };
56
+ template <class Next, class Tail>
57
+ struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<std::optional<Tensor>, std::optional<int64_t>, Next, Tail> {
58
+ using type = Tail;
59
+ };
60
+ template <class Next, class Tail>
61
+ struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<const std::optional<Tensor>&, std::optional<int64_t>, Next, Tail> {
62
+ using type = Tail;
63
+ };
64
+ template <class Next, class Tail>
65
+ struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<std::optional<Tensor>&, std::optional<int64_t>, Next, Tail> {
66
+ using type = Tail;
67
+ };
68
+ template <class Next, class Tail>
69
+ struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<std::vector<Tensor>, std::optional<int64_t>, Next, Tail> {
70
+ using type = Tail;
71
+ };
72
+ template <class TypeList> struct RemoveBatchDimAfterTensor {
73
+ using first = head_t<TypeList>;
74
+ using next = tail_t<TypeList>;
75
+ using second = head_t<next>;
76
+ using tail = tail_t<next>;
77
+
78
+ using type = concat_t<
79
+ typelist<first>,
80
+ typename RemoveBatchDimAfterTensor<
81
+ typename IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<first, second, next, tail>::type
82
+ >::type
83
+ >;
84
+ };
85
+ template <class Type> struct RemoveBatchDimAfterTensor<typelist<Type>> {
86
+ using type = typelist<Type>;
87
+ };
88
+ template <> struct RemoveBatchDimAfterTensor<typelist<>> {
89
+ using type = typelist<>;
90
+ };
91
+ template<class TypeList> using remove_batch_dim_after_tensor_t = typename RemoveBatchDimAfterTensor<TypeList>::type;
92
+
93
+ template <typename T> struct UnpackSingleItemTuple {
94
+ using type = T;
95
+ };
96
+ template <typename T> struct UnpackSingleItemTuple<std::tuple<T>> {
97
+ using type = T;
98
+ };
99
+ template <typename T> using unpack_single_item_tuple_t = typename UnpackSingleItemTuple<T>::type;
100
+
101
+ template <typename Return, typename TupleArgs> struct BuildFunctionHelper;
102
+ template <typename Return, typename... Args> struct BuildFunctionHelper<Return, std::tuple<Args...>> {
103
+ using type = Return(Args...);
104
+ };
105
+ template <typename Return, typename TL>
106
+ struct BuildFunction {
107
+ using type = typename BuildFunctionHelper<Return, c10::guts::typelist::to_tuple_t<TL>>::type;
108
+ };
109
+ template <typename Return, typename TL> using build_function_t = typename BuildFunction<Return, TL>::type;
110
+
111
+
112
+ template <typename batch_rule_t> struct ToOperatorType {
113
+ using batch_rule_return_type = typename c10::guts::function_traits<batch_rule_t>::return_type;
114
+ using batch_rule_parameter_types = typename c10::guts::function_traits<batch_rule_t>::parameter_types;
115
+
116
+ using operator_parameter_types = remove_batch_dim_after_tensor_t<batch_rule_parameter_types>;
117
+ using operator_return_type =
118
+ unpack_single_item_tuple_t<
119
+ c10::guts::typelist::to_tuple_t<
120
+ remove_batch_dim_after_tensor_t<
121
+ c10::guts::typelist::from_tuple_t<batch_rule_return_type>>>>;
122
+
123
+ using type = build_function_t<operator_return_type, operator_parameter_types>;
124
+ };
125
+ template <typename batch_rule_t> using to_operator_t = typename ToOperatorType<batch_rule_t>::type;
126
+
127
+ } // namespace at::functorch
128
+
129
+ #else
130
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
131
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/functorch/DynamicLayer.h ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ // Copyright (c) Facebook, Inc. and its affiliates.
3
+ // All rights reserved.
4
+ //
5
+ // This source code is licensed under the BSD-style license found in the
6
+ // LICENSE file in the root directory of this source tree.
7
+
8
+ #pragma once
9
+ #include <ATen/functorch/Macros.h>
10
+ #include <c10/core/DispatchKey.h>
11
+ #include <ATen/core/function_schema.h>
12
+ #include <optional>
13
+ #include <c10/core/impl/LocalDispatchKeySet.h>
14
+ #include <ATen/functorch/Interpreter.h>
15
+ #include <ATen/functorch/VmapInterpreter.h>
16
+ #include <ATen/functorch/ADInterpreters.h>
17
+ #include <ATen/functorch/FunctionalizeInterpreter.h>
18
+
19
+ // Forward declared
20
+ namespace c10 { struct AutogradMetaInterface; }
21
+
22
+ namespace at::functorch {
23
+
24
+ // This file contains the implementation of functorch's interpreter stack.
25
+ // See NOTE: [functorch interpreter stack] first before reading on.
26
+ //
27
+ // NB: the functorch interpreter stack is also referred to as:
28
+ // - the "dynamic layer stack" -- an older name for "interpreter" was
29
+ // "dynamic layer".
30
+ // - the "functorch mode stack". You can think of each functorch transform as a
31
+ // "mode" (in the same sense as torch_dispatch mode or torch_function mode),
32
+ // and functorch being an implementation of a "mode stack" where the modes
33
+ // may be arbitrary composed.
34
+
35
+ // DynamicLayer is basically the same thing as an Interpreter.
36
+ // It represents a functorch transform and it holds an Interpreter,
37
+ // which contains metadata related to the transform and instructions on
38
+ // how to perform the transform.
39
+ //
40
+ // TODO: we can excise DynamicLayer in favor of Interpreter,
41
+ // But I am going to leave it for now as a compatibility shim to avoid
42
+ // needing to refactor a lot of callsites...
43
+ struct TORCH_API DynamicLayer {
44
+ explicit DynamicLayer(
45
+ TransformType transform_type,
46
+ int64_t layerId,
47
+ std::optional<c10::SymInt> batchSize = std::nullopt,
48
+ std::optional<RandomnessType> randomness = std::nullopt,
49
+ std::optional<bool> prev_grad_mode = std::nullopt,
50
+ std::optional<bool> pre_fwd_grad_mode = std::nullopt,
51
+ std::optional<bool> functionalize_add_back_views = std::nullopt);
52
+
53
+ TransformType key() const;
54
+ int64_t layerId() const;
55
+
56
+ const Interpreter& interpreter() const { return interpreter_; }
57
+ Interpreter& interpreter() { return interpreter_; }
58
+
59
+ // Only valid for vmap
60
+ c10::SymInt batchSize() const;
61
+ RandomnessType randomness() const;
62
+
63
+ private:
64
+ Interpreter interpreter_;
65
+ };
66
+
67
+ TORCH_API int64_t initAndPushDynamicLayer(
68
+ TransformType transform_type,
69
+ std::optional<c10::SymInt> batch_size = std::nullopt,
70
+ std::optional<RandomnessType> randomness = std::nullopt,
71
+ std::optional<bool> prev_grad_mode = std::nullopt,
72
+ std::optional<bool> prev_fwd_grad_mode = std::nullopt,
73
+ std::optional<bool> functionalize_add_back_views = std::nullopt);
74
+ TORCH_API DynamicLayer popDynamicLayerAndDeleteMetadata();
75
+ TORCH_API std::optional<DynamicLayer> maybeCurrentDynamicLayer();
76
+ TORCH_API const std::vector<DynamicLayer>& getDynamicLayerStack();
77
+ TORCH_API void setDynamicLayerStack(const std::vector<DynamicLayer>& stack);
78
+ TORCH_API void setDynamicLayerFrontBackKeysIncluded(bool included);
79
+
80
+ // NOTE: [Life handles and lexically scoped transforms]
81
+ // functorch transforms are lexically scoped.
82
+ // Given a level, we store a "life handle" that is a boolean that tells us if the
83
+ // transform with that level is active or not.
84
+ //
85
+ // functorch's TensorWrapper (for grad transforms) stores a life handle.
86
+ // If a TensorWrapper escapes from the scope of the transform, then somehow
87
+ // it must know it escaped; it can tell by querying the life handle.
88
+ TORCH_API const std::shared_ptr<bool>& getLifeHandleForLevel(int64_t level);
89
+
90
+ // Returns if an operator is in-place. An operator is inplace if:
91
+ // 1. The first argument is a Tensor and it is being written to
92
+ // 2. The first argument is being returned
93
+ // 3. No other arguments are aliased
94
+ // Here is an example of an in-place operator:
95
+ // add_(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
96
+ TORCH_API bool isInplaceOp(const c10::FunctionSchema& schema);
97
+
98
+ // Given the indices of unwrapped inputs and the schema, this returns the indices of any outputs that should remain unwrapped
99
+ TORCH_API std::optional<size_t> findAliasedOutput(const FunctionSchema& schema, const int64_t immutable_input);
100
+
101
+ TORCH_API Tensor unwrapIfDead(const Tensor& tensor);
102
+ TORCH_API bool isDeadTensorWrapper(const Tensor& tensor);
103
+
104
+ // Pretty printers
105
+ TORCH_API std::ostream& operator<<(std::ostream& os, const DynamicLayer& layer);
106
+ TORCH_API std::ostream& operator<<(std::ostream& os, const std::vector<DynamicLayer>& dynamicLayerStack);
107
+
108
+ // While a functorch transform is active, torch.autograd.function._SingleLevelFunction
109
+ // is disabled by default. The following two APIs are APIs for enabling
110
+ // it. These are not user-facing APIs. We can delete this in the future, but
111
+ // it is useful for debugging when something goes wrong with the
112
+ // autograd.Function <> functorch interaction, which uses _SingleLevelFunction,
113
+ // because it leads to loud errors if something is incorrect.
114
+ TORCH_API void setSingleLevelAutogradFunctionAllowed(bool allowed);
115
+ TORCH_API bool getSingleLevelAutogradFunctionAllowed();
116
+
117
+ // While a functorch grad transform is active, Tensor.requires_grad_() gets
118
+ // disabled. These two functions are the mechanism to controlling that.
119
+ TORCH_API void setInplaceRequiresGradAllowed(bool allowed);
120
+ TORCH_API bool getInplaceRequiresGradAllowed();
121
+
122
+ TORCH_API DynamicLayer popDynamicLayer();
123
+ TORCH_API int64_t pushDynamicLayer(DynamicLayer&& layer);
124
+
125
+ } // namespace at::functorch
126
+
127
+ #else
128
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
129
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/functorch/FunctionalizeInterpreter.h ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+ #include <ATen/functorch/Interpreter.h>
4
+
5
+ namespace at::functorch {
6
+
7
+ // This is the interpreter that handles the functionalize() transform.
8
+ // See NOTE: [functorch interpreter stack] for more details.
9
+
10
+ struct FunctionalizeInterpreterPtr {
11
+ explicit FunctionalizeInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Functionalize); }
12
+ TransformType key() const { return base_->key(); }
13
+ int64_t level() const { return base_->level(); }
14
+ void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
15
+ void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
16
+ bool functionalizeAddBackViews() const {
17
+ return std::get<FunctionalizeInterpreterMeta>(base_->meta()).functionalizeAddBackViews_;
18
+ }
19
+ private:
20
+ const Interpreter* base_;
21
+ };
22
+
23
+ } // namespace at::functorch
24
+
25
+ #else
26
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
27
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/functorch/Interpreter.h ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/functorch/Macros.h>
5
+ #include <ATen/core/dispatch/Dispatcher.h>
6
+ #include <c10/core/impl/LocalDispatchKeySet.h>
7
+ #include <c10/util/Exception.h>
8
+ #include <optional>
9
+ #include <bitset>
10
+ #include <utility>
11
+ #include <variant>
12
+
13
+ #include <nlohmann/json.hpp>
14
+
15
+ namespace at::functorch {
16
+
17
+ // NOTE: [functorch interpreter stack]
18
+ //
19
+ // functorch's dispatching system uses a stack of interpreters.
20
+ // Historically we've referred to this as the "DynamicLayerStack".
21
+ //
22
+ // An interpreter is something that reads in the code it is passed
23
+ // and then executes it. We have a different interpreter per-transform:
24
+ // the "VmapInterpreter" is responsible for reading in operators (like aten::mv)
25
+ // and executing the batched version of it (the batching rule for aten::mv).
26
+ //
27
+ // Concretely, each interpreter is responsible for two things:
28
+ //
29
+ // 1) process(ophandle, stack)
30
+ // Given an operator handle and a stack of arguments, the interpreter is
31
+ // responsible for figuring out how to execute the operation under the semantics
32
+ // of the interpreter. For e.g. VmapInterpreter, this is figuring out how to call
33
+ // the batching rule.
34
+ //
35
+ // The batching rules are stored as kernels on the FuncTorchBatched key, so the way
36
+ // VmapInterpreter calls the batching rule is roughly: (A) exclude all
37
+ // dispatch keys aside from the Batched key, (B) redispatch so we get to the
38
+ // Batched key.
39
+ //
40
+ // 2) sendToNextInterpreter(ophandle, stack)
41
+ // The VmapInterpreter, when it sees aten::mv, will process it into a call to
42
+ // aten::mm. It then needs to send the call to aten::mm to the next interpreter
43
+ // in the interpreter stack.
44
+ //
45
+ // The VmapInterpreter just does this via a call to ophandle.callBoxed(stack)
46
+ // and most Interpreters will implement it this way.
47
+
48
+ enum class RandomnessType {
49
+ Error, // always errors when calling a random function
50
+ Same, // randomness appears the same across batches
51
+ Different, // randomness appears different across batches
52
+ END
53
+ };
54
+
55
+ enum class TransformType {
56
+ Torch, // Unused
57
+ Vmap,
58
+ Grad, // reverse-mode AD, aka vjp
59
+ Jvp, // forward-mode AD
60
+ Functionalize,
61
+ };
62
+
63
+ std::ostream& operator<<(std::ostream& os, const TransformType& t);
64
+
65
+ // NOTE: [Interpreter "subclassing" design]
66
+ //
67
+ // How are various Interpreters for different transforms (vmap, grad, ...)
68
+ // implemented?
69
+ //
70
+ // Accessing interpreters is in the hot-path of functorch so we have a constraint
71
+ // that this code must be as fast as possible.
72
+ //
73
+ // As a result, we stay away from virtual methods and this causes our code
74
+ // to look a little funny.
75
+ //
76
+ // `Interpreter` is the struct for Interpreters. It holds ALL of the
77
+ // relevant information (what type of interpreter it is and the metadata).
78
+ // Metadata for each interpreter is represented as a Union (std::variant)
79
+ // of all possible metadata (VmapInterpreterMeta, GradInterpreterMeta, ...).
80
+ //
81
+ // Given an Interpreter, how do I get a "VmapInterpreter"? You may wish to do this
82
+ // if you want to access the metadata fields (like batchSize and randomness).
83
+ //
84
+ // Each type of interpreter (e.g. Vmap) has a convenience struct
85
+ // (e.g. VmapInterpreterPtr) associated with it.
86
+ //
87
+ // Construct the convenience struct with VmapInterpreterPtr(Interpreter*),
88
+ // and then one can access methods on VmapInterpreterPtr like so:
89
+ // >>> VmapInterpreterPtr(&interpreter).batchSize()
90
+ //
91
+ // Finally, Interpreter::process switches on the type of the interpreter
92
+ // and calls one of {Transform}Interpreter::processImpl under the hood.
93
+ // Same for Interpreter::sendToNextInterpreter :)
94
+
95
+ struct VmapInterpreterMeta {
96
+ explicit VmapInterpreterMeta(c10::SymInt batchSize, RandomnessType randomness) :
97
+ batchSize_(std::move(batchSize)), randomness_(randomness) {}
98
+
99
+ c10::SymInt batchSize_;
100
+ RandomnessType randomness_;
101
+
102
+ VmapInterpreterMeta() = default;
103
+ VmapInterpreterMeta(const VmapInterpreterMeta&) = default;
104
+ VmapInterpreterMeta(VmapInterpreterMeta&&) = default;
105
+ VmapInterpreterMeta& operator=(const VmapInterpreterMeta&) = default;
106
+ VmapInterpreterMeta& operator=(VmapInterpreterMeta&&) = default;
107
+ ~VmapInterpreterMeta() = default;
108
+
109
+ template <typename T>
110
+ friend void to_json(T& json_j, const VmapInterpreterMeta& json_t) {
111
+ TORCH_CHECK(
112
+ !json_t.batchSize_.is_heap_allocated(),
113
+ "Serialization for heap-allocated SymInt is not implemented yet"
114
+ );
115
+ json_j["batchSize"] = json_t.batchSize_.as_int_unchecked();
116
+ json_j["randomness"] = static_cast<int64_t>(json_t.randomness_);
117
+ }
118
+
119
+ template <typename T>
120
+ friend void from_json(const T& json_j, VmapInterpreterMeta& json_t) {
121
+ json_t.batchSize_ = c10::SymInt(SymInt::Unchecked::UNCHECKED, json_j["batchSize"]);
122
+ json_t.randomness_ = static_cast<RandomnessType>(json_j["randomness"]);
123
+ }
124
+ };
125
+
126
+ struct GradInterpreterMeta {
127
+ explicit GradInterpreterMeta(bool prevGradMode): prevGradMode_(prevGradMode) {}
128
+ GradInterpreterMeta() = default;
129
+ GradInterpreterMeta(const GradInterpreterMeta&) = default;
130
+ GradInterpreterMeta(GradInterpreterMeta&&) = default;
131
+ GradInterpreterMeta& operator=(const GradInterpreterMeta&) = default;
132
+ GradInterpreterMeta& operator=(GradInterpreterMeta&&) = default;
133
+ ~GradInterpreterMeta() = default;
134
+
135
+ bool prevGradMode_;
136
+ template <typename T>
137
+ friend void to_json(T& json_j, const GradInterpreterMeta& json_t) {
138
+ json_j["prevGradMode"] = json_t.prevGradMode_;
139
+ }
140
+
141
+ template <typename T>
142
+ friend void from_json(const T& json_j, GradInterpreterMeta& json_t) {
143
+ json_t.prevGradMode_ = json_j["prevGradMode"];
144
+ }
145
+ };
146
+
147
+ struct JvpInterpreterMeta {
148
+ explicit JvpInterpreterMeta(bool prevFwdGradMode) : prevFwdGradMode_(prevFwdGradMode) {}
149
+ JvpInterpreterMeta() = default;
150
+ JvpInterpreterMeta(const JvpInterpreterMeta&) = default;
151
+ JvpInterpreterMeta(JvpInterpreterMeta&&) = default;
152
+ JvpInterpreterMeta& operator=(const JvpInterpreterMeta&) = default;
153
+ JvpInterpreterMeta& operator=(JvpInterpreterMeta&&) = default;
154
+ ~JvpInterpreterMeta() = default;
155
+
156
+ bool prevFwdGradMode_;
157
+ template <typename T>
158
+ friend void to_json(T& json_j, const JvpInterpreterMeta& json_t) {
159
+ json_j["prevFwdGradMode"] = json_t.prevFwdGradMode_;
160
+ }
161
+
162
+ template <typename T>
163
+ friend void from_json(const T& json_j, JvpInterpreterMeta& json_t) {
164
+ json_t.prevFwdGradMode_ = json_j["prevFwdGradMode"];
165
+ }
166
+ };
167
+
168
+ struct FunctionalizeInterpreterMeta {
169
+ explicit FunctionalizeInterpreterMeta(bool functionalizeAddBackViews) :
170
+ functionalizeAddBackViews_(functionalizeAddBackViews) {}
171
+ FunctionalizeInterpreterMeta() = default;
172
+ FunctionalizeInterpreterMeta(const FunctionalizeInterpreterMeta&) = default;
173
+ FunctionalizeInterpreterMeta(FunctionalizeInterpreterMeta&&) = default;
174
+ FunctionalizeInterpreterMeta& operator=(const FunctionalizeInterpreterMeta&) = default;
175
+ FunctionalizeInterpreterMeta& operator=(FunctionalizeInterpreterMeta&&) = default;
176
+ ~FunctionalizeInterpreterMeta() = default;
177
+
178
+ bool functionalizeAddBackViews_;
179
+ template <typename T>
180
+ friend void to_json(T& json_j, const FunctionalizeInterpreterMeta& json_t) {
181
+ json_j["functionalizeAddBackViews"] = json_t.functionalizeAddBackViews_;
182
+ }
183
+
184
+ template <typename T>
185
+ friend void from_json(const T& json_j, FunctionalizeInterpreterMeta& json_t) {
186
+ json_t.functionalizeAddBackViews_ = json_j["functionalizeAddBackViews"];
187
+ }
188
+ };
189
+
190
+ typedef std::variant<
191
+ int64_t,
192
+ GradInterpreterMeta,
193
+ JvpInterpreterMeta,
194
+ VmapInterpreterMeta,
195
+ FunctionalizeInterpreterMeta
196
+ > InterpreterMeta;
197
+
198
+
199
+ struct Interpreter {
200
+ // factory functions
201
+ static Interpreter Vmap(int64_t level, c10::SymInt batchSize, RandomnessType randomness) {
202
+ return Interpreter(TransformType::Vmap, level, VmapInterpreterMeta(std::move(batchSize), randomness));
203
+ }
204
+ static Interpreter Grad(int64_t level, bool prevGradMode) {
205
+ return Interpreter(TransformType::Grad, level, GradInterpreterMeta(prevGradMode));
206
+ }
207
+ static Interpreter Jvp(int64_t level, bool prevFwdGradMode) {
208
+ return Interpreter(TransformType::Jvp, level, JvpInterpreterMeta(prevFwdGradMode));
209
+ }
210
+ static Interpreter Functionalize(int64_t level, bool functionalizeAddBackViews) {
211
+ return Interpreter(TransformType::Functionalize, level, FunctionalizeInterpreterMeta(functionalizeAddBackViews));
212
+ }
213
+
214
+ // methods
215
+ TransformType key() const { return type_; }
216
+ int64_t level() const { return level_; }
217
+ const InterpreterMeta& meta() const { return meta_; }
218
+
219
+ void process(const c10::OperatorHandle& op, torch::jit::Stack* stack);
220
+ void sendToNextInterpreter(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
221
+
222
+ void saveLocalDispatchKeySet(c10::impl::LocalDispatchKeySet keyset) {
223
+ TORCH_INTERNAL_ASSERT(!savedLocalDispatchKeySet_.has_value());
224
+ savedLocalDispatchKeySet_ = keyset;
225
+ }
226
+ void clearSavedLocalDispatchKeySet() {
227
+ TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value());
228
+ savedLocalDispatchKeySet_ = std::nullopt;
229
+ }
230
+ c10::impl::LocalDispatchKeySet getSavedLocalDispatchKeySet() const {
231
+ TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value());
232
+ return *savedLocalDispatchKeySet_;
233
+ }
234
+
235
+ // An Interpreter is alive if we are currently inside the ongoing transform
236
+ // for the interpreter. For example, vmap(f)(x); inside of f, the vmap's
237
+ // corresponding Interpreter is alive, even when it is not on the DynamicLayerStack.
238
+ bool is_alive() const {
239
+ return *is_alive_;
240
+ }
241
+ const std::shared_ptr<bool>& is_alive_ptr() const {
242
+ return is_alive_;
243
+ }
244
+ void set_is_alive(bool alive) {
245
+ *is_alive_ = alive;
246
+ }
247
+
248
+ // Please don't use this
249
+ explicit Interpreter() = default;
250
+
251
+ template <typename T>
252
+ friend void to_json(T& json_j, const Interpreter& json_t) {
253
+ json_j["type"] = static_cast<int64_t>(json_t.type_);
254
+ json_j["level"] = json_t.level_;
255
+ if (json_t.savedLocalDispatchKeySet_) {
256
+ json_j["savedLocalDispatchKeySet"] = {
257
+ {"included", json_t.savedLocalDispatchKeySet_->included_.raw_repr()},
258
+ {"excluded", json_t.savedLocalDispatchKeySet_->excluded_.raw_repr()}
259
+ };
260
+ } else {
261
+ json_j["savedLocalDispatchKeySet"] = nlohmann::json();
262
+ }
263
+ json_j["is_alive"] = *json_t.is_alive_;
264
+ std::visit([&](auto&& arg) {
265
+ using V = std::decay_t<decltype(arg)>;
266
+ if constexpr (std::is_same_v<V, int64_t>) {
267
+ json_j["meta"] = {{"Torch", arg}};
268
+ } else if constexpr (std::is_same_v<V, GradInterpreterMeta>) {
269
+ json_j["meta"] = {{"Grad", arg}};
270
+ } else if constexpr (std::is_same_v<V, JvpInterpreterMeta>) {
271
+ json_j["meta"] = {{"Jvp", arg}};
272
+ } else if constexpr (std::is_same_v<V, VmapInterpreterMeta>) {
273
+ json_j["meta"] = {{"Vmap", arg}};
274
+ } else if constexpr (std::is_same_v<V, FunctionalizeInterpreterMeta>) {
275
+ json_j["meta"] = {{"Functionalize", arg}};
276
+ } else {
277
+ static_assert(false && sizeof(V), "unknown variant case");
278
+ }
279
+ }, json_t.meta_);
280
+ }
281
+
282
+ template <typename T>
283
+ friend void from_json(const T& json_j, Interpreter& json_t) {
284
+ json_t.type_ = static_cast<TransformType>(json_j["type"]);
285
+ json_t.level_ = json_j["level"];
286
+ auto savedLocalDispatchKeySet = json_j["savedLocalDispatchKeySet"];
287
+ if (savedLocalDispatchKeySet.is_null()) {
288
+ json_t.savedLocalDispatchKeySet_ = std::nullopt;
289
+ } else {
290
+ c10::impl::PODLocalDispatchKeySet pod;
291
+ pod.set_included(DispatchKeySet::from_raw_repr(savedLocalDispatchKeySet["included"].template get<uint64_t>()));
292
+ pod.set_excluded(DispatchKeySet::from_raw_repr(savedLocalDispatchKeySet["excluded"].template get<uint64_t>()));
293
+ json_t.savedLocalDispatchKeySet_ = c10::impl::LocalDispatchKeySet(pod);
294
+ }
295
+ json_t.is_alive_ = std::make_shared<bool>(json_j["is_alive"]);
296
+ auto meta = json_j["meta"];
297
+ if (meta.contains("Torch")) {
298
+ json_t.meta_.emplace<int64_t>(meta["Torch"].template get<int64_t>());
299
+ } else if (meta.contains("Grad")) {
300
+ json_t.meta_.emplace<GradInterpreterMeta>(meta["Grad"].template get<GradInterpreterMeta>());
301
+ } else if (meta.contains("Jvp")) {
302
+ json_t.meta_.emplace<JvpInterpreterMeta>(meta["Jvp"].template get<JvpInterpreterMeta>());
303
+ } else if (meta.contains("Vmap")) {
304
+ json_t.meta_.emplace<VmapInterpreterMeta>(meta["Vmap"].template get<VmapInterpreterMeta>());
305
+ } else if (meta.contains("Functionalize")) {
306
+ json_t.meta_.emplace<FunctionalizeInterpreterMeta>(meta["Functionalize"].template get<FunctionalizeInterpreterMeta>());
307
+ } else {
308
+ TORCH_CHECK(false, "unknown interpreter metadata type");
309
+ }
310
+ }
311
+
312
+ std::string serialize() const {
313
+ return nlohmann::json(*this).dump();
314
+ }
315
+
316
+ static Interpreter deserialize(const std::string& serialized) {
317
+ return nlohmann::json::parse(serialized).get<Interpreter>();
318
+ }
319
+
320
+ private:
321
+ explicit Interpreter(TransformType type, int64_t level, InterpreterMeta meta):
322
+ type_(type), level_(level), is_alive_(std::make_shared<bool>(false)), meta_(std::move(meta)) {}
323
+
324
+ // fields
325
+ TransformType type_{};
326
+ int64_t level_{};
327
+ std::optional<c10::impl::LocalDispatchKeySet> savedLocalDispatchKeySet_;
328
+ std::shared_ptr<bool> is_alive_;
329
+ InterpreterMeta meta_;
330
+ };
331
+
332
+ // Applies the following for-loop:
333
+ // for i in range(begin, end):
334
+ // args[i] = func(args[i])
335
+ void foreachTensorInplace(std::vector<IValue>& args, int64_t begin, int64_t end,
336
+ std::function<Tensor(const Tensor&)> func);
337
+
338
+ // Applies the following for-loop:
339
+ // for i in range(begin, end):
340
+ // if use_flag_relative[i] == 1: <-- treats use_flag_relative as a bitset
341
+ // args[i] = func(args[i], i - begin, true)
342
+ // args[i] = func(args[i], i - begin)
343
+ void foreachTensorInplaceWithFlag(std::vector<IValue>& args, int64_t begin, int64_t end,
344
+ const std::bitset<64> use_flag_relative, const std::function<Tensor(const Tensor&, bool)>& func);
345
+
346
+ std::vector<int64_t> findUnwrappedInputs(std::vector<IValue>& args, int64_t begin, int64_t end);
347
+
348
+ DispatchKeySet keysToExcludeWhenEnteringDynamicLayer(TransformType key);
349
+
350
+ void setup_dispatch_key_tls(TransformType key, DispatchKeySet include);
351
+
352
+ void sanityCheckStack(const c10::OperatorHandle& op, torch::jit::Stack* stack);
353
+
354
+ } // namespace at::functorch
355
+
356
+ #else
357
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
358
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)