koichi12 commited on
Commit
8509ad7
·
verified ·
1 Parent(s): 587c1a9

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/torchgen/dest/__init__.py +19 -0
  2. .venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/__init__.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/lazy_ir.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/lazy_ts_lowering.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/native_functions.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/register_dispatch_key.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/ufunc.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/torchgen/dest/lazy_ir.py +707 -0
  9. .venv/lib/python3.11/site-packages/torchgen/dest/native_functions.py +63 -0
  10. .venv/lib/python3.11/site-packages/torchgen/dest/register_dispatch_key.py +1005 -0
  11. .venv/lib/python3.11/site-packages/torchgen/dest/ufunc.py +551 -0
  12. .venv/lib/python3.11/site-packages/torchgen/executorch/__init__.py +0 -0
  13. .venv/lib/python3.11/site-packages/torchgen/executorch/__pycache__/__init__.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/torchgen/executorch/__pycache__/model.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/torchgen/executorch/__pycache__/parse.cpython-311.pyc +0 -0
  16. .venv/lib/python3.11/site-packages/torchgen/executorch/api/__init__.py +0 -0
  17. .venv/lib/python3.11/site-packages/torchgen/executorch/api/__pycache__/__init__.cpython-311.pyc +0 -0
  18. .venv/lib/python3.11/site-packages/torchgen/executorch/api/__pycache__/custom_ops.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/torchgen/executorch/api/__pycache__/et_cpp.cpython-311.pyc +0 -0
  20. .venv/lib/python3.11/site-packages/torchgen/executorch/api/__pycache__/unboxing.cpython-311.pyc +0 -0
  21. .venv/lib/python3.11/site-packages/torchgen/executorch/api/custom_ops.py +149 -0
  22. .venv/lib/python3.11/site-packages/torchgen/executorch/api/et_cpp.py +370 -0
  23. .venv/lib/python3.11/site-packages/torchgen/executorch/api/types/__init__.py +4 -0
  24. .venv/lib/python3.11/site-packages/torchgen/executorch/api/types/__pycache__/__init__.cpython-311.pyc +0 -0
  25. .venv/lib/python3.11/site-packages/torchgen/executorch/api/types/__pycache__/signatures.cpython-311.pyc +0 -0
  26. .venv/lib/python3.11/site-packages/torchgen/executorch/api/types/__pycache__/types.cpython-311.pyc +0 -0
  27. .venv/lib/python3.11/site-packages/torchgen/executorch/api/types/signatures.py +76 -0
  28. .venv/lib/python3.11/site-packages/torchgen/executorch/api/types/types.py +83 -0
  29. .venv/lib/python3.11/site-packages/torchgen/executorch/api/unboxing.py +230 -0
  30. .venv/lib/python3.11/site-packages/torchgen/executorch/model.py +220 -0
  31. .venv/lib/python3.11/site-packages/torchgen/executorch/parse.py +153 -0
  32. .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/native/native_functions.yaml +0 -0
  33. .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/native/tags.yaml +74 -0
  34. .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/ATenOpList.cpp +36 -0
  35. .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/CompositeViewCopyKernels.cpp +73 -0
  36. .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/DispatchKeyFunction.h +23 -0
  37. .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/DispatchKeyFunctions.h +29 -0
  38. .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/DispatchKeyFunctions_inl.h +22 -0
  39. .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.cpp +13 -0
  40. .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.h +19 -0
  41. .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/Function.h +26 -0
  42. .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/FunctionalInverses.h +33 -0
  43. .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/Functions.cpp +103 -0
  44. .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/Functions.h +143 -0
  45. .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/LazyIr.h +19 -0
  46. .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/LazyNonNativeIr.h +11 -0
  47. .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/MethodOperators.h +24 -0
  48. .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/NativeFunction.h +17 -0
  49. .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/NativeFunctions.h +33 -0
  50. .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/NativeMetaFunction.h +23 -0
.venv/lib/python3.11/site-packages/torchgen/dest/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchgen.dest.lazy_ir import (
2
+ generate_non_native_lazy_ir_nodes as generate_non_native_lazy_ir_nodes,
3
+ GenLazyIR as GenLazyIR,
4
+ GenLazyNativeFuncDefinition as GenLazyNativeFuncDefinition,
5
+ GenLazyShapeInferenceDefinition as GenLazyShapeInferenceDefinition,
6
+ )
7
+ from torchgen.dest.native_functions import (
8
+ compute_native_function_declaration as compute_native_function_declaration,
9
+ )
10
+ from torchgen.dest.register_dispatch_key import (
11
+ gen_registration_headers as gen_registration_headers,
12
+ gen_registration_helpers as gen_registration_helpers,
13
+ RegisterDispatchKey as RegisterDispatchKey,
14
+ )
15
+ from torchgen.dest.ufunc import (
16
+ compute_ufunc_cpu as compute_ufunc_cpu,
17
+ compute_ufunc_cpu_kernel as compute_ufunc_cpu_kernel,
18
+ compute_ufunc_cuda as compute_ufunc_cuda,
19
+ )
.venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (913 Bytes). View file
 
.venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/lazy_ir.cpython-311.pyc ADDED
Binary file (40.8 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/lazy_ts_lowering.cpython-311.pyc ADDED
Binary file (3.52 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/native_functions.cpython-311.pyc ADDED
Binary file (3.58 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/register_dispatch_key.cpython-311.pyc ADDED
Binary file (44.9 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/ufunc.cpython-311.pyc ADDED
Binary file (27.6 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/dest/lazy_ir.py ADDED
@@ -0,0 +1,707 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import itertools
4
+ from abc import ABC
5
+ from dataclasses import dataclass
6
+ from typing import Any
7
+
8
+ import torchgen.api.dispatcher as dispatcher
9
+ from torchgen.api.lazy import (
10
+ getValueT,
11
+ isValueType,
12
+ LazyArgument,
13
+ LazyIrProperties,
14
+ LazyIrSchema,
15
+ tensorListValueT,
16
+ )
17
+ from torchgen.api.translate import translate
18
+ from torchgen.api.types import (
19
+ BaseCType,
20
+ Binding,
21
+ deviceT,
22
+ DispatcherSignature,
23
+ kernel_signature,
24
+ NativeSignature,
25
+ OptionalCType,
26
+ VectorCType,
27
+ )
28
+ from torchgen.context import method_with_native_function
29
+ from torchgen.dest.lazy_ts_lowering import ts_lowering_body
30
+ from torchgen.model import (
31
+ Argument,
32
+ BackendIndex,
33
+ BackendMetadata,
34
+ BaseTy,
35
+ BaseType,
36
+ FunctionSchema,
37
+ ListType,
38
+ NativeFunction,
39
+ NativeFunctionsGroup,
40
+ )
41
+
42
+
43
+ def node_ctor_arg_rvalue_string(arg: LazyArgument) -> str:
44
+ """
45
+ Given a LazyArgument,
46
+ generate a c++ string for materializing an rvalue of that arg for passing into
47
+ a lazy Node constructor.
48
+ """
49
+
50
+ # TODO: Matching on CType seems wrong; should be matching on Type
51
+ if isValueType(arg.lazy_type):
52
+ if isinstance(arg.lazy_type, BaseCType):
53
+ if arg.is_wrapped_scalar:
54
+ return f"node_{arg.name}"
55
+ elif arg.lazy_type.type is tensorListValueT:
56
+ return f"lazy_{arg.name}_tensorlist"
57
+ elif arg.is_symint_or_list:
58
+ return f"GetSymIntValue({arg.name})"
59
+ return f"lazy_{arg.name}->GetIrValue()"
60
+ elif isinstance(arg.lazy_type, OptionalCType):
61
+ if arg.is_symint_or_list:
62
+ # TODO: I don't understand when you should put lazy_ in the name
63
+ # or not
64
+ return f"{arg.name} ? std::make_optional(GetSymIntValue(*{arg.name})) : ::std::nullopt"
65
+ elif arg.is_wrapped_scalar:
66
+ return f"node_{arg.name}"
67
+ return (
68
+ f"lazy_{arg.name} ? "
69
+ f"std::make_optional(lazy_{arg.name}->GetIrValue()) : "
70
+ "::std::nullopt"
71
+ )
72
+ else:
73
+ raise AssertionError(
74
+ f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})"
75
+ )
76
+ else:
77
+ # NB: this is here because right now we aren't treating SymInt[] as a
78
+ # value type; when we do this needs to move above
79
+ # NB: we cannot test arg.lazy_type as we've already specified it is an
80
+ # int64_t and so we cannot distinguish between SymInt and int64_t
81
+ if isinstance(arg.orig_type, ListType) and arg.orig_type.elem == BaseType(
82
+ BaseTy.SymInt
83
+ ):
84
+ if arg.symint:
85
+ return f"GetSymIntArrayRefValue({arg.name})"
86
+ else:
87
+ return f"std::vector<int64_t>({arg.name}.begin(), {arg.name}.end())"
88
+ elif isinstance(arg.lazy_type, VectorCType) and isinstance(
89
+ arg.lazy_type.elem, BaseCType
90
+ ):
91
+ return f"std::vector<{arg.lazy_type.elem.type}>({arg.name}.begin(), {arg.name}.end())"
92
+ elif (
93
+ isinstance(arg.lazy_type, OptionalCType)
94
+ and isinstance(arg.lazy_type.elem, VectorCType)
95
+ and isinstance(arg.lazy_type.elem.elem, BaseCType)
96
+ ):
97
+ return f"torch::lazy::ToOptionalVector<{arg.lazy_type.elem.elem.type}>({arg.name})"
98
+ else:
99
+ return f"{arg.name}"
100
+
101
+
102
+ def node_ctor_inputs(schema: LazyIrSchema) -> str:
103
+ """
104
+ Produce a formatted string with the arguments as passed into the constructor of a node class.
105
+ """
106
+ node_ctor_values = [
107
+ node_ctor_arg_rvalue_string(arg) for arg in schema.filtered_args()
108
+ ]
109
+ return ", ".join(node_ctor_values)
110
+
111
+
112
+ def gen_fallback_code(
113
+ schema: LazyIrSchema,
114
+ sig: DispatcherSignature | NativeSignature,
115
+ overload_name: str,
116
+ ) -> str:
117
+ """
118
+ Generate code that falls back to eager conditioned on a predicate
119
+ """
120
+ dispatcher_sig = DispatcherSignature.from_schema(schema.func)
121
+ exprs = translate(sig.arguments(), dispatcher_sig.arguments())
122
+ fallback_args = ",\n ".join([a.expr for a in exprs])
123
+ if len(overload_name):
124
+ aten_op_str = f"ATEN_OP2({schema.aten_name}, {overload_name})"
125
+ else:
126
+ aten_op_str = f"ATEN_OP({schema.aten_name})"
127
+ return f"""
128
+ if (force_eager_fallback({aten_symbol(schema)})) {{
129
+ return at::native::call_fallback_fn_symint<&ltc_eager_fallback, {aten_op_str}>::call(
130
+ {fallback_args}
131
+ );
132
+ }}
133
+ """
134
+
135
+
136
+ def aten_symbol(schema: LazyIrSchema) -> str:
137
+ missing_interned_strings = {
138
+ "sigmoid_backward",
139
+ }
140
+ if schema.aten_name in missing_interned_strings:
141
+ return f'c10::Symbol::fromQualString("aten::{schema.aten_name}")'
142
+
143
+ if not schema.aten_name.startswith("at::"):
144
+ return f"at::aten::{schema.aten_name}"
145
+ else:
146
+ return schema.aten_name
147
+
148
+
149
+ # converts all tensor-like arguments to meta tensors. Returns:
150
+ # (1) a string containing all of the logic that does the conversions.
151
+ # (2) a context, to be used by translate(), with all of the relevant bindings.
152
+ def convert_to_meta_tensors(sig: DispatcherSignature) -> tuple[str, list[Binding]]:
153
+ context: list[Binding] = []
154
+ unwrapped_tensor_args: list[str] = []
155
+ for arg in sig.arguments():
156
+ if isinstance(arg.argument, Argument) and arg.argument.type.is_tensor_like():
157
+ unwrapped_name = f"{arg.name}_meta"
158
+ unwrapped_tensor_args.append(
159
+ f"auto {unwrapped_name} = to_meta({arg.name});"
160
+ )
161
+ context.append(arg.with_name(unwrapped_name))
162
+ else:
163
+ context.append(arg)
164
+ unwrap_tensor_args_str = "\n ".join(unwrapped_tensor_args)
165
+ return unwrap_tensor_args_str, context
166
+
167
+
168
+ @dataclass(frozen=True)
169
+ class GenLazyIR(ABC):
170
+ backend_index: BackendIndex
171
+ backend_name: str
172
+ node_base: str
173
+ use_lazy_shape: bool
174
+
175
+ @method_with_native_function
176
+ def __call__(self, f: NativeFunctionsGroup | NativeFunction) -> list[str]:
177
+ func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
178
+ metadata = self.backend_index.get_kernel(
179
+ f.functional if isinstance(f, NativeFunctionsGroup) else f
180
+ )
181
+ schema = LazyIrSchema(
182
+ func, symint=metadata is not None and metadata.supports_symint()
183
+ )
184
+ return self.gen(schema)
185
+
186
+ # there is no lowering functionality generated unless this IR base class is subclassed and
187
+ # implemented as a backend-specific node
188
+ def lowering_function(self, schema: LazyIrSchema) -> str:
189
+ return ""
190
+
191
+ def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
192
+ return ""
193
+
194
+ def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
195
+ return f"""bool CanBeReused({node_ctor_args}) const {{
196
+ return false;
197
+ }}"""
198
+
199
+ def node_base_ctor_call(self, schema: LazyIrSchema) -> str:
200
+ value_args = schema.filtered_args(values=True, scalars=False)
201
+ # backends can customize the way the node base class constructor is called,
202
+ # as long as all of its arguments can be generated from information available from the schema
203
+ base_ctor_value_args_list = []
204
+ for arg in value_args:
205
+ if isinstance(arg.lazy_type, (BaseCType, VectorCType)):
206
+ base_ctor_value_args_list.append(f"{arg.name}")
207
+ elif isinstance(arg.lazy_type, OptionalCType):
208
+ base_ctor_value_args_list.append(f"{arg.name}.value_or(kNullValue)")
209
+ else:
210
+ raise AssertionError(
211
+ f"Unsupported type ({arg.lazy_type}) - add support if necessary"
212
+ )
213
+ base_ctor_value_args = ", ".join(base_ctor_value_args_list)
214
+
215
+ scalar_args = schema.filtered_args(values=False, scalars=True)
216
+
217
+ # Shape construction.
218
+ # Conditionally build shape depending on specified shape property
219
+ if schema.properties.ShapePrecompute:
220
+ shape_ctor_arg = "std::move(shapes),"
221
+ elif schema.properties.ShapeCompute:
222
+ shape_args = [a.name for a in value_args]
223
+ shape_args.extend(a.name for a in scalar_args)
224
+ shape_ctor_arg = f"compute_shape_{schema.name}({', '.join(shape_args)}),"
225
+ elif schema.properties.ShapeCache:
226
+ shape_args = [f"operand({i})" for i in range(len(value_args))]
227
+ shape_args.extend(a.name for a in scalar_args)
228
+ shape_ctor_arg = f"[&](){{ return compute_shape_{schema.name}({', '.join(shape_args)})[0]; }},"
229
+ else:
230
+ shape_ctor_arg = ""
231
+
232
+ scalar_hashes = ", ".join(f"{a.name}" for a in scalar_args)
233
+
234
+ return f"""{self.node_base}(
235
+ {schema.node_name}::ClassOpKind(),
236
+ OpList{{{base_ctor_value_args}}},
237
+ {shape_ctor_arg}
238
+ /* num_outputs */ {len(schema.returns)},
239
+ torch::lazy::MHash({scalar_hashes}))"""
240
+
241
+ def gen(self, schema: LazyIrSchema) -> list[str]:
242
+ opkind = schema.opkind or aten_symbol(schema)
243
+
244
+ # for now, we just want one IR class decl and soon after also the method defs
245
+ # and we use the functional version not out/inplace.
246
+ all_args = schema.filtered_args()
247
+ scalar_args = schema.filtered_args(values=False, scalars=True)
248
+
249
+ ctor_args = [f"const {i.lazy_type.cpp_type()}& {i.name}" for i in all_args]
250
+ reuse_ctor_args = ", ".join(ctor_args)
251
+ if self.use_lazy_shape and schema.properties.ShapePrecompute:
252
+ ctor_args.append("std::vector<torch::lazy::Shape>&& shapes")
253
+ node_ctor_args = ", ".join(ctor_args)
254
+
255
+ scalar_initializers = ",\n ".join(
256
+ [
257
+ # This code is just special casing the mapping from string_view -> strings
258
+ f"{a.name}({a.name}.has_value() ? ::std::make_optional(std::string(*{a.name})) : ::std::nullopt)"
259
+ if a.lazy_type.cpp_type() == "::std::optional<c10::string_view>"
260
+ else f"{a.name}({a.name})"
261
+ for a in scalar_args
262
+ ]
263
+ )
264
+ if len(scalar_initializers):
265
+ scalar_initializers = f",\n {scalar_initializers}"
266
+ scalar_decls = "\n ".join(
267
+ [
268
+ f"std::string {a.name};"
269
+ if a.lazy_type.cpp_type() == "c10::string_view"
270
+ else f"::std::optional<std::string> {a.name};"
271
+ if a.lazy_type.cpp_type() == "::std::optional<c10::string_view>"
272
+ else f"{a.lazy_type.cpp_type()} {a.name};"
273
+ for a in scalar_args
274
+ ]
275
+ )
276
+ optional_values = [
277
+ arg.name
278
+ for arg in schema.filtered_args(values=True, scalars=False)
279
+ if isinstance(arg.lazy_type, OptionalCType)
280
+ ]
281
+ has_optional_decls = "\n ".join(
282
+ [f"bool has_{value}: 1;" for value in optional_values]
283
+ )
284
+ has_optional_defs = "\n ".join(
285
+ [f"has_{value} = !!{value};" for value in optional_values]
286
+ )
287
+ members_to_string = []
288
+ for arg in scalar_args:
289
+ if isinstance(arg.lazy_type, OptionalCType):
290
+ value = f"{arg.name}.value()"
291
+ if arg.is_generator:
292
+ value = '"torch.Generator()"'
293
+ members_to_string.append(
294
+ f"""if ({arg.name}.has_value()) {{
295
+ ss << ", {arg.name}=" << {value};
296
+ }} else {{
297
+ ss << ", {arg.name}=null";
298
+ }}"""
299
+ )
300
+ else:
301
+ members_to_string.append(f'ss << ", {arg.name}=" << {arg.name};')
302
+ members_to_string_str = "\n ".join(members_to_string)
303
+
304
+ return [
305
+ f"""\
306
+ class {schema.node_name} : public {self.node_base} {{
307
+ public:
308
+ static torch::lazy::OpKind ClassOpKind() {{
309
+ return torch::lazy::OpKind({opkind});
310
+ }}
311
+
312
+ {schema.node_name}({node_ctor_args})
313
+ : {self.node_base_ctor_call(schema)}{scalar_initializers}
314
+ {{
315
+ {has_optional_defs}
316
+ }}
317
+
318
+ std::string ToString() const override {{
319
+ std::stringstream ss;
320
+ ss << {self.node_base}::ToString();
321
+ {members_to_string_str}
322
+ return ss.str();
323
+ }}
324
+
325
+ {self.create_function(schema, reuse_ctor_args)}
326
+
327
+ {self.can_be_reused_function(schema, reuse_ctor_args)}
328
+
329
+ {self.lowering_function(schema)}
330
+
331
+ {scalar_decls}
332
+ {has_optional_decls}
333
+
334
+ }};
335
+
336
+ """,
337
+ ]
338
+
339
+
340
+ @dataclass(frozen=True)
341
+ class GenTSLazyIR(GenLazyIR):
342
+ def lowering_function(self, schema: LazyIrSchema) -> str:
343
+ signature = """
344
+ torch::lazy::TSOpVector Lower(
345
+ std::shared_ptr<torch::jit::GraphFunction> function,
346
+ torch::lazy::TSLoweringContext* loctx) const override"""
347
+
348
+ if schema.properties.LowerDeclOnly:
349
+ return f"{signature};"
350
+ elif schema.properties.Lower:
351
+ return f"""{signature} {{
352
+ {ts_lowering_body(schema)}
353
+ }}
354
+ """
355
+ else:
356
+ return ""
357
+
358
+ def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
359
+ signature = f"static NodePtr Create({node_ctor_args})"
360
+ if schema.properties.CreateFnDeclOnly:
361
+ return f"{signature};"
362
+ elif not schema.properties.CreateFn:
363
+ return ""
364
+ return f"""{signature} {{
365
+ return ReuseOrMakeNode<{schema.node_name}>(data);
366
+ }}"""
367
+
368
+ def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
369
+ signature = f"bool CanBeReused({node_ctor_args}) const"
370
+ if schema.properties.CanBeReusedDeclOnly:
371
+ return f"{signature};"
372
+ elif not schema.properties.CanBeReused:
373
+ return ""
374
+ value_comparison = []
375
+ for arg in itertools.chain(schema.positional_values, schema.keyword_values):
376
+ if isinstance(arg.lazy_type, OptionalCType):
377
+ value_comparison.append(
378
+ f"nullable_operand(i++) == {arg.name}.value_or(kNullValue)"
379
+ )
380
+ else:
381
+ value_comparison.append(f"operand(i++) == {arg.name}")
382
+ for arg in itertools.chain(schema.positional_scalars, schema.keyword_scalars):
383
+ if isinstance(arg.lazy_type, OptionalCType):
384
+ value_comparison.append(
385
+ f"((!this->{arg.name}&&!{arg.name}) || (this->{arg.name}&&{arg.name} && *(this->{arg.name}) == *{arg.name}))"
386
+ )
387
+ else:
388
+ value_comparison.append(f"this->{arg.name} == {arg.name}")
389
+ value_comparison_str = " &&\n ".join(value_comparison)
390
+
391
+ return f"""{signature} {{
392
+ size_t i = 0;
393
+ return ({value_comparison_str});
394
+ }}"""
395
+
396
+
397
+ @dataclass(frozen=True)
398
+ class GenLazyNativeFuncDefinition:
399
+ class_method_name: str
400
+ backend_index: BackendIndex
401
+ tensor_class: str
402
+ gen_forced_fallback_code: bool
403
+ backend_namespace: str
404
+ get_tensorlist: str
405
+ get_tensor_or_wrap_number: str
406
+ try_get_tensor: str
407
+ metrics_counter: str
408
+ create_tensor: str
409
+ create_from_first_tensor: bool
410
+ create_aten_from_ltc_tensor: str
411
+ tuple_aten_from_ltc_tensors: str
412
+ lazy_tensor_ptr: str
413
+ get_device_fn: str
414
+
415
+ def lazy_tensor_decls(self, func: NativeFunction, schema: LazyIrSchema) -> str:
416
+ value_args = schema.filtered_args(values=True, scalars=False)
417
+ # Generates lazy_{name} variables for LazyTensors wrapping input tensors
418
+ lazy_tensor_decls: list[str] = []
419
+ for arg in value_args:
420
+ if arg.is_wrapped_scalar:
421
+ if isinstance(arg.lazy_type, OptionalCType):
422
+ lazy_tensor_decls.append(
423
+ f"""auto node_{arg.name} = {arg.name} ?
424
+ std::make_optional(torch::lazy::LazyGraphExecutor::Get()->
425
+ GetIrValueForScalarFromCodegen(*{arg.name}, *common_device)):
426
+ ::std::nullopt;"""
427
+ )
428
+ else:
429
+ lazy_tensor_decls.append(
430
+ f"""auto node_{arg.name} = torch::lazy::LazyGraphExecutor::Get()->
431
+ GetIrValueForScalarFromCodegen({arg.name}, *common_device);"""
432
+ )
433
+ elif arg.is_symint_or_list:
434
+ continue # values are extracted in isValueType
435
+ elif isinstance(arg.lazy_type, BaseCType):
436
+ if arg.lazy_type.type is tensorListValueT:
437
+ lazy_tensor_decls.append(
438
+ f"auto lazy_{arg.name}_tensorlist = "
439
+ f"{self.backend_namespace}::{self.get_tensorlist}({arg.name});"
440
+ )
441
+ else:
442
+ lazy_tensor_decls.append(
443
+ f"{self.lazy_tensor_ptr} lazy_{arg.name} = "
444
+ f"{self.backend_namespace}::{self.get_tensor_or_wrap_number}({arg.name}, *common_device);"
445
+ )
446
+ elif isinstance(arg.lazy_type, OptionalCType):
447
+ assert arg.lazy_type.elem == BaseCType(getValueT()), arg.lazy_type.elem
448
+ # TODO(alanwaketan): Maybe we want to apply GetLtcTensorOrCreateForWrappedNumber here, but hold it
449
+ # until we encounter a real world example.
450
+ lazy_tensor_decls.append(
451
+ f"{self.lazy_tensor_ptr} lazy_{arg.name} = "
452
+ f"{self.backend_namespace}::{self.try_get_tensor}({arg.name}.value_or(at::Tensor()));"
453
+ )
454
+ else:
455
+ raise AssertionError(
456
+ f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})"
457
+ )
458
+ return ("\n ").join(lazy_tensor_decls)
459
+
460
+ def force_eager_fallback(
461
+ self,
462
+ func: NativeFunction,
463
+ schema: LazyIrSchema,
464
+ metadata: BackendMetadata,
465
+ sig: DispatcherSignature | NativeSignature,
466
+ ) -> str:
467
+ if self.gen_forced_fallback_code:
468
+ return gen_fallback_code(
469
+ schema, sig, overload_name=func.func.name.overload_name
470
+ )
471
+ return ""
472
+
473
+ def metrics(self, func: NativeFunction, schema: LazyIrSchema) -> str:
474
+ return f"{self.metrics_counter};"
475
+
476
+ def get_device(self, func: NativeFunction, schema: LazyIrSchema) -> str:
477
+ value_args = schema.filtered_args(values=True, scalars=False)
478
+ scalar_args = schema.filtered_args(values=False, scalars=True)
479
+ value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar]
480
+ optional_device = OptionalCType(BaseCType(deviceT))
481
+ optional_devices = [
482
+ a.name for a in scalar_args if a.lazy_type == optional_device
483
+ ]
484
+ assert (
485
+ len(value_types_names) > 0 or len(optional_devices) > 0
486
+ ), "Expected at least one Value or Device type"
487
+ get_device_str = (
488
+ f"{self.get_device_fn}({', '.join(value_types_names + optional_devices)})"
489
+ )
490
+ return f"""auto common_device = {get_device_str};
491
+ TORCH_INTERNAL_ASSERT(common_device);
492
+ """
493
+
494
+ def shape_inference(self, func: NativeFunction, schema: LazyIrSchema) -> str:
495
+ metadata = self.backend_index.get_kernel(func)
496
+ assert metadata is not None
497
+ all_args = schema.filtered_args()
498
+ returns_length = len(schema.returns)
499
+ # call the meta kernel if it exists, to compute output shape/dtype for our IR
500
+ # Note [Generated LTC Shape Functions]
501
+ # LTC uses meta tensors from core to do shape inference when possible, and otherwise
502
+ # we generate a shape function declaration that needs to be manually implemented.
503
+ # How do we detect which ops are eligible to use meta tensors?
504
+ # In general we should be able to use meta tensors not just on structured operators,
505
+ # but also on composite operators that are implemented in terms of structured kernels.
506
+ # We don't currently have a way of knowing at codegen time which ops are implemented that way.
507
+ # This is the case for all view and view_copy operators however, so we're going to
508
+ # use them specifically for all of the view_copy ops (instead of manually writing shape rules for all of them).
509
+ is_view_copy_op = "view_copy" in func.tags
510
+ is_structured = func.structured or func.structured_delegate is not None
511
+ if is_structured or is_view_copy_op:
512
+ meta_out = """
513
+ std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};"""
514
+ if returns_length > 1:
515
+
516
+ def this_shape(i: int) -> str:
517
+ return f"torch::lazy::Shape(std::get<{i}>(out_meta).scalar_type(), std::get<{i}>(out_meta).sizes().vec())"
518
+
519
+ shapes_str = ",".join([this_shape(i) for i in range(returns_length)])
520
+ meta_out = "std::vector<torch::lazy::Shape> shapes{" + shapes_str + "};"
521
+
522
+ # Convert tensor args to the meta device and call it.
523
+ # (We can't pass in the input tensors directly, because they are "functional wrappers".
524
+ # If any of the meta kernels call a tensor op and redispatch, we don't want to hit the functionalize kernels.)
525
+ # Even at::meta:: functions might redispatch, e.g. if they call into view ops.
526
+ dispatcher_sig = DispatcherSignature.from_schema(func.func)
527
+ meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig)
528
+ meta_call_args = [
529
+ e.expr
530
+ for e in translate(
531
+ meta_call_ctx, dispatcher_sig.arguments(), method=False
532
+ )
533
+ ]
534
+ if is_view_copy_op:
535
+ # view_copy ops always have a CompositeExplicitAutogradNonFunctional kernel
536
+ assert func.has_composite_explicit_autograd_non_functional_kernel
537
+ dispatch_ns = "compositeexplicitautogradnonfunctional"
538
+ else:
539
+ dispatch_ns = "meta"
540
+ aten_name = schema.aten_name
541
+ # TODO: this is trolling
542
+ if func.func.has_symint() and metadata.supports_symint():
543
+ aten_name += "_symint"
544
+ shape_str = f"""\
545
+ {meta_conversion_str}
546
+ auto out_meta = at::{dispatch_ns}::{aten_name}({', '.join(meta_call_args)});
547
+ {meta_out}"""
548
+ else:
549
+ shape_sig = ComputeShapeSignature(
550
+ metadata.kernel, func, symint=metadata.supports_symint()
551
+ )
552
+ shape_str = f"""
553
+ auto shapes = {shape_sig.shape_call};"""
554
+
555
+ shape_str += f"""
556
+ TORCH_INTERNAL_ASSERT(shapes.size() == {returns_length});"""
557
+
558
+ # Calculating which dimensions are symbolic
559
+ func_schema_str = "aten::" + str(func.func)
560
+ shape_str += f"""
561
+ if(torch::lazy::symbolicShapeEnabled()){{
562
+ std::vector<torch::jit::IValue> inputs = {{ {', '.join(str(a.name) for a in all_args)} }};
563
+ const char* schema_str = "{func_schema_str}";
564
+ applySymbolicShapesOnLT(schema_str, inputs, shapes);
565
+ }}
566
+ """
567
+ return shape_str
568
+
569
+ def build_ir_node(self, func: NativeFunction, schema: LazyIrSchema) -> str:
570
+ node_ctor_input_str = node_ctor_inputs(schema)
571
+ return f"""torch::lazy::NodePtr node = torch::lazy::ReuseNode<{schema.node_name}>({node_ctor_input_str});
572
+ if (!node) {{
573
+ {self.shape_inference(func, schema)}
574
+ node = torch::lazy::MakeNode<{schema.node_name}>({node_ctor_input_str}, std::move(shapes));
575
+ CacheNode(node);
576
+ }}
577
+ """
578
+
579
+ def create_lazy_tensor(self, first_tensor_name: str | None = None) -> str:
580
+ # xla uses an instance method for tensor creation, for the time being
581
+ if self.create_from_first_tensor:
582
+ # TODO(whc) remove this if XLA switches to using static method for creation
583
+ assert (
584
+ first_tensor_name is not None
585
+ ), "Requires first tensor to create lazy tensor"
586
+ return f"{first_tensor_name}.{self.create_tensor}"
587
+ return f"{self.backend_namespace}::{self.create_tensor}"
588
+
589
+ def return_aten_tensor(self, func: NativeFunction, schema: LazyIrSchema) -> str:
590
+ returns_length = len(schema.returns)
591
+ value_args = schema.filtered_args(values=True, scalars=False)
592
+ value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar]
593
+ first_tensor_name = value_types_names[0] if len(value_types_names) > 0 else None
594
+ bridge_str = f"""auto result = {self.create_aten_from_ltc_tensor}(
595
+ {self.create_lazy_tensor(first_tensor_name)}(std::move(node), *common_device));"""
596
+
597
+ if returns_length > 1:
598
+ assert (
599
+ len(value_types_names) > 0
600
+ ), "Code below assumes there is at least one tensor arg"
601
+ bridge_str = f"""std::vector<{self.lazy_tensor_ptr}> lazy_tensors;
602
+ for (int i = 0; i < {returns_length}; i++) {{
603
+ lazy_tensors.push_back({self.create_lazy_tensor(first_tensor_name)}({getValueT()}(node, i), *common_device));
604
+ }}
605
+ auto result = {self.tuple_aten_from_ltc_tensors}<{returns_length}>(lazy_tensors);"""
606
+
607
+ if schema.name.name.inplace or func.func.is_out_fn():
608
+ assert returns_length == 1, (
609
+ "We assumed there was no such case where an op is an in-place variant "
610
+ f"and has tuple outputs, but got tuple of len {returns_length}."
611
+ )
612
+ bridge_str = f"""lazy_{first_tensor_name}->SetInPlaceIrValue(node);
613
+ auto& result = {first_tensor_name};"""
614
+
615
+ bridge_str += """
616
+ return result;"""
617
+ return bridge_str
618
+
619
+ @method_with_native_function
620
+ def __call__(self, func: NativeFunction) -> list[str]:
621
+ sig = kernel_signature(func, self.backend_index)
622
+ metadata = self.backend_index.get_kernel(func)
623
+ assert metadata is not None
624
+ schema = LazyIrSchema(func.func, symint=metadata.supports_symint())
625
+ return [
626
+ f"""\
627
+ {sig.decl(name=f"{self.class_method_name}::{metadata.kernel}")} {{
628
+ {self.force_eager_fallback(func, schema, metadata, sig)}
629
+ {self.metrics(func, schema)}
630
+ {self.get_device(func, schema)}
631
+ {self.lazy_tensor_decls(func, schema)}
632
+ {self.build_ir_node(func, schema)}
633
+ {self.return_aten_tensor(func, schema)}
634
+ }}\n
635
+ """
636
+ ]
637
+
638
+
639
+ class ComputeShapeSignature:
640
+ """
641
+ Here we use the base name as the suffix of the signature to avoid generating for in-place variants.
642
+ """
643
+
644
+ def __init__(self, kernel_name: str, f: NativeFunction, *, symint: bool) -> None:
645
+ self.__schema = LazyIrSchema(f.func, symint=symint)
646
+ self.__dispatch_args = ", ".join(
647
+ [a.decl() for a in dispatcher.arguments(f.func, symint=symint)]
648
+ )
649
+ self.__call_args = ", ".join(
650
+ [f"{arg.name}" for arg in self.__schema.filtered_args(generator=True)]
651
+ )
652
+ self.__kernel_name = kernel_name
653
+
654
+ def __decl_suffix(self) -> str:
655
+ return f"{self.__kernel_name}({self.__dispatch_args})"
656
+
657
+ def __call_suffix(self) -> str:
658
+ return f"{self.__kernel_name}({self.__call_args})"
659
+
660
+ @property
661
+ def shape_decl(self) -> str:
662
+ return f"TORCH_API std::vector<torch::lazy::Shape> compute_shape_{self.__decl_suffix()}"
663
+
664
+ @property
665
+ def shape_call(self) -> str:
666
+ return f"torch::lazy::compute_shape_{self.__call_suffix()}"
667
+
668
+
669
+ @dataclass(frozen=True)
670
+ class GenLazyShapeInferenceDefinition:
671
+ backend_index: BackendIndex
672
+ tensor_class: str
673
+
674
+ @method_with_native_function
675
+ def __call__(self, f: NativeFunction) -> list[str]:
676
+ metadata = self.backend_index.get_kernel(f)
677
+ assert metadata is not None
678
+
679
+ # See Note [Generated LTC Shape Functions]
680
+ is_view_copy_op = "view_copy" in f.tags
681
+ is_structured = f.structured or f.structured_delegate is not None
682
+ if is_structured or is_view_copy_op:
683
+ return []
684
+ else:
685
+ shape_sig = ComputeShapeSignature(
686
+ metadata.kernel, f, symint=metadata.supports_symint()
687
+ )
688
+ return ["\n".join([f"{shape_sig.shape_decl};"])]
689
+
690
+
691
+ def generate_non_native_lazy_ir_nodes(
692
+ non_native: list[dict[str, Any]], gen_lazy_ir: GenLazyIR
693
+ ) -> list[str]:
694
+ """Generate the non-native lazy IR node classes"""
695
+ nodes = []
696
+ for op in non_native:
697
+ # Set default properties for Non-Native IRs
698
+ properties = LazyIrProperties("ShapeCache", "CanBeReused", "LowerDeclOnly")
699
+ for p in op.get("properties", []):
700
+ setattr(properties, p, True)
701
+
702
+ # non-native is assumed to want symint bindings if you wrote symint
703
+ schema = LazyIrSchema(FunctionSchema.parse(op["func"]), properties, symint=True)
704
+ schema.opkind = op.get("opkind")
705
+ nodes.append(gen_lazy_ir.gen(schema)[0])
706
+
707
+ return nodes
.venv/lib/python3.11/site-packages/torchgen/dest/native_functions.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torchgen.api.meta as meta
4
+ import torchgen.api.structured as structured
5
+ from torchgen.api.types import kernel_signature
6
+ from torchgen.context import with_native_function_and_index
7
+ from torchgen.model import BackendIndex, NativeFunction, NativeFunctionsGroup
8
+ from torchgen.utils import mapMaybe
9
+
10
+
11
+ @with_native_function_and_index
12
+ def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> str | None:
13
+ sig = kernel_signature(f, backend_index)
14
+ metadata = backend_index.get_kernel(f)
15
+ if metadata is None:
16
+ return None
17
+ if "legacy::" in metadata.kernel:
18
+ return None
19
+ else:
20
+ prefix = "static" if backend_index.external else "TORCH_API"
21
+ return f"{prefix} {sig.decl(name=metadata.kernel)};"
22
+
23
+
24
+ @with_native_function_and_index
25
+ def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> list[str]:
26
+ meta_name = meta.name(g)
27
+ out_args = structured.impl_arguments(g)
28
+ metadata = backend_index.get_kernel(g)
29
+ if metadata is None:
30
+ return []
31
+ prefix = "" if backend_index.external else "TORCH_API "
32
+ return [
33
+ f"""\
34
+ struct {prefix}structured_{metadata.kernel} : public at::meta::structured_{meta_name} {{
35
+ void impl({', '.join(a.decl() for a in out_args)});
36
+ }};
37
+ """
38
+ ]
39
+
40
+
41
+ # Generates NativeFunctions.h, a list of forward declarations of all
42
+ # actual kernel definitions we keep in aten/src/ATen/native/
43
+ @with_native_function_and_index
44
+ def compute_native_function_declaration(
45
+ g: NativeFunctionsGroup | NativeFunction, backend_index: BackendIndex
46
+ ) -> list[str]:
47
+ metadata = backend_index.get_kernel(g)
48
+ if isinstance(g, NativeFunctionsGroup):
49
+ if metadata is not None and metadata.structured:
50
+ if backend_index.external:
51
+ # Structured hasn't been tested with external backends yet.
52
+ raise AssertionError(
53
+ "Structured external backend functions are not implemented yet."
54
+ )
55
+ else:
56
+ return gen_structured(g, backend_index)
57
+ else:
58
+ return list(
59
+ mapMaybe(lambda f: gen_unstructured(f, backend_index), g.functions())
60
+ )
61
+ else:
62
+ x = gen_unstructured(g, backend_index)
63
+ return [] if x is None else [x]
.venv/lib/python3.11/site-packages/torchgen/dest/register_dispatch_key.py ADDED
@@ -0,0 +1,1005 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import itertools
4
+ import textwrap
5
+ from dataclasses import dataclass
6
+ from typing import Literal, TYPE_CHECKING
7
+
8
+ import torchgen.api.cpp as cpp
9
+ import torchgen.api.meta as meta
10
+ import torchgen.api.structured as structured
11
+ from torchgen.api.translate import translate
12
+ from torchgen.api.types import (
13
+ BaseCType,
14
+ Binding,
15
+ ConstRefCType,
16
+ CppSignature,
17
+ CppSignatureGroup,
18
+ DispatcherSignature,
19
+ Expr,
20
+ kernel_signature,
21
+ MutRefCType,
22
+ NamedCType,
23
+ NativeSignature,
24
+ tensorT,
25
+ )
26
+ from torchgen.context import method_with_native_function, native_function_manager
27
+ from torchgen.model import (
28
+ Argument,
29
+ BackendIndex,
30
+ DeviceCheckType,
31
+ DispatchKey,
32
+ gets_generated_out_inplace_wrapper,
33
+ is_cuda_dispatch_key,
34
+ NativeFunction,
35
+ NativeFunctionsGroup,
36
+ SchemaKind,
37
+ TensorOptionsArguments,
38
+ )
39
+ from torchgen.utils import assert_never, mapMaybe, Target
40
+
41
+
42
+ if TYPE_CHECKING:
43
+ from torchgen.selective_build.selector import SelectiveBuilder
44
+
45
+
46
+ def gen_registration_headers(
47
+ backend_index: BackendIndex,
48
+ per_operator_headers: bool,
49
+ rocm: bool,
50
+ ) -> list[str]:
51
+ if per_operator_headers:
52
+ headers = ["#include <ATen/ops/as_strided_native.h>"]
53
+ else:
54
+ headers = ["#include <ATen/NativeFunctions.h>"]
55
+
56
+ if backend_index.dispatch_key in (DispatchKey.CPU, DispatchKey.Meta):
57
+ headers.append("#include <ATen/EmptyTensor.h>")
58
+ elif backend_index.dispatch_key == DispatchKey.CUDA:
59
+ if rocm:
60
+ headers.append("#include <ATen/hip/EmptyTensor.h>")
61
+ else:
62
+ headers.append("#include <ATen/cuda/EmptyTensor.h>")
63
+ elif backend_index.dispatch_key == DispatchKey.MPS:
64
+ headers.append("#include <ATen/mps/EmptyTensor.h>")
65
+ elif backend_index.dispatch_key == DispatchKey.XPU:
66
+ # XPU specific, this header resides in third_party/torch-xpu-ops
67
+ headers.append("#include <ATen/xpu/EmptyTensor.h>")
68
+ elif per_operator_headers:
69
+ headers += [
70
+ "#include <ATen/ops/empty.h>",
71
+ "#include <ATen/ops/empty_strided.h>",
72
+ "#include <ATen/ops/_copy_from_and_resize.h>",
73
+ "#include <ATen/ops/_copy_from.h>",
74
+ ]
75
+ else:
76
+ headers.append("#include <ATen/Functions.h>")
77
+
78
+ headers.append("#include <c10/macros/Macros.h>")
79
+ return headers
80
+
81
+
82
+ def gen_empty_impl_names(
83
+ backend_index: BackendIndex,
84
+ ) -> tuple[str | None, str | None]:
85
+ empty_impl = None
86
+ empty_strided_impl = None
87
+
88
+ if backend_index.dispatch_key in (
89
+ DispatchKey.Meta,
90
+ DispatchKey.CPU,
91
+ DispatchKey.CUDA,
92
+ DispatchKey.MPS,
93
+ DispatchKey.XPU,
94
+ ):
95
+ dispatch = str(backend_index.dispatch_key).lower()
96
+ empty_impl = f"at::detail::empty_{dispatch}"
97
+ empty_strided_impl = f"at::detail::empty_strided_{dispatch}"
98
+ elif backend_index.dispatch_key in (
99
+ DispatchKey.CompositeExplicitAutogradNonFunctional,
100
+ DispatchKey.QuantizedCPU,
101
+ DispatchKey.QuantizedCUDA,
102
+ DispatchKey.XPU,
103
+ ):
104
+ empty_impl = "at::empty"
105
+ empty_strided_impl = "at::empty_strided"
106
+
107
+ return empty_impl, empty_strided_impl
108
+
109
+
110
+ def gen_create_out_helper(backend_index: BackendIndex) -> list[str]:
111
+ if backend_index.dispatch_key == DispatchKey.Meta:
112
+ empty_options = "options.device(at::kMeta)"
113
+ else:
114
+ empty_options = "options"
115
+
116
+ empty_impl, empty_strided_impl = gen_empty_impl_names(backend_index)
117
+ if empty_impl is None:
118
+ return []
119
+
120
+ return [
121
+ f"""
122
+ Tensor create_out(IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{
123
+ if (strides.empty()) {{
124
+ return {empty_impl}(sizes, {empty_options});
125
+ }} else {{
126
+ return {empty_strided_impl}(sizes, strides, {empty_options});
127
+ }}
128
+ }}
129
+ """
130
+ ]
131
+
132
+
133
+ def gen_maybe_create_proxy_helper(backend_index: BackendIndex) -> list[str]:
134
+ _, empty_strided_impl = gen_empty_impl_names(backend_index)
135
+ return (
136
+ []
137
+ if empty_strided_impl is None
138
+ else [
139
+ f"""
140
+ std::optional<Tensor> maybe_create_proxy(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{
141
+ if (out.strides() != strides) {{
142
+ return {empty_strided_impl}(sizes, strides, options);
143
+ }}
144
+ return std::nullopt;
145
+ }}
146
+ """
147
+ ]
148
+ )
149
+
150
+
151
+ def gen_resize_out_helper(backend_index: BackendIndex) -> list[str]:
152
+ if backend_index.dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional:
153
+ # The function isn't used by this key (since only functional ops have a kernel for this key),
154
+ # so we need to not include it to avoid a defined-but-not-used error.
155
+ return []
156
+ return [
157
+ """
158
+ void resize_out(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {
159
+ TORCH_CHECK(options.dtype() == out.dtype(),
160
+ "Expected out tensor to have dtype ", options.dtype(), ", but got ", out.dtype(), " instead");
161
+ TORCH_CHECK(options.device() == out.device(),
162
+ "Expected out tensor to have device ", options.device(), ", but got ", out.device(), " instead");
163
+ const bool resized = at::native::resize_output(out, sizes);
164
+ // Only restride if a resize occurred; otherwise we ignore the (advisory)
165
+ // strides from the meta function and directly use the output tensor's
166
+ // preexisting strides
167
+ if (resized) {
168
+ if (!strides.empty()) {
169
+ TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value());
170
+ // TODO: avoid the redispatch here
171
+ out.as_strided_(sizes, strides);
172
+ } else if (options.memory_format_opt().has_value()) {
173
+ out.unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt());
174
+ }
175
+ }
176
+ }
177
+ """
178
+ ]
179
+
180
+
181
+ def gen_check_inplace_helper(backend_index: BackendIndex) -> list[str]:
182
+ return [
183
+ """
184
+ void check_inplace(const Tensor &self, IntArrayRef sizes, const TensorOptions &options) {
185
+ // These checks are needed on those operators that:
186
+ // 1) don't use 'TensorIterator' (e.g. 'addmm' and 'baddbmm')
187
+ // 2) have particular typing rules (e.g. 'cumsum' and 'cumprod')
188
+ // For other operators (e.g. 'add'), 'TensorIterator' already checks
189
+ // these things separately.
190
+ TORCH_CHECK(options.dtype() == self.dtype(),
191
+ "Bad in-place call: ",
192
+ "input tensor dtype ", self.dtype(), " and output tensor dtype ", options.dtype(), " should match");
193
+ TORCH_CHECK(options.device() == self.device(),
194
+ "Bad in-place call: ",
195
+ "input tensor device ", self.device(), " and output tensor device ", options.device(), " should match");
196
+ TORCH_CHECK(sizes == self.sizes(),
197
+ "Bad in-place call: ",
198
+ "input tensor size ", self.sizes(), " and output tensor size ", sizes, " should match");
199
+ }
200
+ """
201
+ ]
202
+
203
+
204
+ def gen_registration_helpers(backend_index: BackendIndex) -> list[str]:
205
+ return [
206
+ 'C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function")',
207
+ *gen_create_out_helper(backend_index),
208
+ *gen_resize_out_helper(backend_index),
209
+ *gen_check_inplace_helper(backend_index),
210
+ *gen_maybe_create_proxy_helper(backend_index),
211
+ "C10_DIAGNOSTIC_POP()",
212
+ ]
213
+
214
+
215
+ # Generates Register{dispatch}.cpp (e.g., RegisterCPU.cpp).
216
+ #
217
+ # - The primary function of this file is to register all of the
218
+ # implementations for the given dispatch key to the dispatcher,
219
+ # so they are available for use in PyTorch. If dispatch is
220
+ # None, we generate schema (def) registrations and catchall
221
+ # registrations.
222
+ # - The secondary function of this file is to generate a wrapper
223
+ # around functions. In CPUType these wrappers do nothing
224
+ # (and should be removed), but in other cases they handle
225
+ # DeviceGuard. A small extra benefit of wrappers is they
226
+ # are not overloaded, so they can be used in the registration
227
+ # API without having to disambiguate which overload you want
228
+ # (as would be the case if you directly registered native::
229
+ # functions).
230
+ # - The tertiary function of this file is to generate *static*
231
+ # cpp API bindings which can be used to bypass dispatcher
232
+ # directly to kernels, but with user-friendly cpp-style API
233
+ @dataclass(frozen=True)
234
+ class RegisterDispatchKey:
235
+ backend_index: BackendIndex
236
+
237
+ target: Literal[
238
+ Target.ANONYMOUS_DEFINITION,
239
+ Target.NAMESPACED_DEFINITION,
240
+ Target.NAMESPACED_DECLARATION,
241
+ Target.REGISTRATION,
242
+ ]
243
+
244
+ # Selector object to determine which operators to generate
245
+ # registration code for.
246
+ selector: SelectiveBuilder
247
+
248
+ # Whether or not we are actually code-genning for ROCm
249
+ rocm: bool
250
+
251
+ # Whether or not to generate symint registrations or not. External users
252
+ # of codegen who don't care about symints can set this to false to get
253
+ # non-SymInt codegen
254
+ symint: bool
255
+
256
+ # The class that all unstructured native functions live under. This is used to improve
257
+ # compiler error messages when a kernel writer adds a native function with the wrong signature.
258
+ # This is only used in unstructured kernels, since structured kernels already live in a class.
259
+ # Finally, this field is currently Optional because it is only used by external backends.
260
+ # It would be nice if we can add the same logic to in-tree kernels too, but that requires updating
261
+ # all of the existing kernel signatures scattered across aten/src/ATen/native.
262
+ class_method_name: str | None
263
+
264
+ # Only set to true in lightweight dispatch. If lightweight dispatch is enabled we are registering
265
+ # operators into JIT op registry, thus we need to avoid generating code to register into the dispatcher.
266
+ skip_dispatcher_op_registration: bool
267
+
268
+ @staticmethod
269
+ def gen_device_check(
270
+ type: DeviceCheckType, args: list[Argument], method_name: str
271
+ ) -> str:
272
+ if type == DeviceCheckType.NoCheck:
273
+ return " // No device check\n"
274
+
275
+ device_check = "std::optional<Device> common_device = std::nullopt;\n"
276
+ device_check += "(void)common_device; // Suppress unused variable warning\n"
277
+ for arg in args:
278
+ # Only tensor like arguments are eligible
279
+ if arg.type.is_tensor_like():
280
+ device_check += f"""
281
+ c10::impl::check_and_update_common_device(common_device, {arg.name}, "{method_name}", "{arg.name}");"""
282
+ return device_check
283
+
284
+ @method_with_native_function
285
+ def __call__(self, f: NativeFunctionsGroup | NativeFunction) -> list[str]:
286
+ if isinstance(f, NativeFunctionsGroup):
287
+ g: NativeFunctionsGroup = f
288
+ # Note: We call gen_structured() if the operator is marked structured, regardless of the backend.
289
+ # gen_structured() has special logic to handle auto-generated kernels.
290
+ if g.structured:
291
+ return self.gen_structured(g)
292
+ else:
293
+ return list(
294
+ mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions())
295
+ )
296
+ elif isinstance(f, NativeFunction):
297
+ r = self.gen_unstructured(f)
298
+ return [] if r is None else [r]
299
+ else:
300
+ assert_never(f)
301
+
302
+ def wrapper_kernel_sig(
303
+ self, f: NativeFunction
304
+ ) -> NativeSignature | DispatcherSignature:
305
+ # The prefix is just to ensure uniqueness. The Dispatcher API doesn't guarantee unique kernel names.
306
+ return DispatcherSignature.from_schema(
307
+ f.func,
308
+ prefix=f"wrapper_{self.backend_index.dispatch_key}_{f.func.name.overload_name}_",
309
+ symint=self.symint,
310
+ )
311
+
312
+ def gen_out_inplace_wrapper(
313
+ self, f: NativeFunction, g: NativeFunctionsGroup | None
314
+ ) -> str | None:
315
+ if g is None:
316
+ return None
317
+ k = f.func.kind()
318
+ if k is SchemaKind.inplace:
319
+ copy_op = "at::_copy_from"
320
+ elif k is SchemaKind.out:
321
+ copy_op = "at::_copy_from_and_resize"
322
+ else:
323
+ raise AssertionError("gen_out_inplace_wrapper called on a functional op")
324
+
325
+ sig = self.wrapper_kernel_sig(f)
326
+ name = sig.name()
327
+
328
+ func_res = f"{name}_tmp"
329
+ return_names = cpp.return_names(f)
330
+ if len(return_names) > 1:
331
+ updates = "\n ".join(
332
+ f"{copy_op}(std::get<{i}>({func_res}), {ret_name});"
333
+ for i, ret_name in enumerate(return_names)
334
+ )
335
+ returns = f'{sig.returns_type().cpp_type()}({", ".join(return_names)})'
336
+ elif len(return_names) == 1:
337
+ ret_name = return_names[0]
338
+ updates = f"{copy_op}({func_res}, {ret_name});"
339
+ returns = ret_name
340
+ else:
341
+ assert len(f.func.arguments.out) == 1
342
+ returns = ""
343
+ out_arg = f.func.arguments.out[0]
344
+ if out_arg.type.is_list_like():
345
+ updates = f"""\
346
+ for (int64_t i = 0; i < {func_res}.size(); ++i) {{
347
+ {copy_op}({func_res}[i], {out_arg.name}[i]);
348
+ }}"""
349
+ else:
350
+ updates = f"{copy_op}({func_res}, {out_arg.name});"
351
+
352
+ functional_sig = self.wrapper_kernel_sig(g.functional)
353
+ wrapper_name = sig.name()
354
+
355
+ return f"""\
356
+ {sig.defn(name=wrapper_name)} {{
357
+ auto {func_res} = {functional_sig.name()}({", ".join(e.expr for e in translate(sig.arguments(), functional_sig.arguments()))});
358
+ {updates}
359
+ return {returns};
360
+ }}
361
+ """
362
+
363
+ def gen_structured(self, g: NativeFunctionsGroup) -> list[str]:
364
+ metadata = self.backend_index.get_kernel(g)
365
+ if self.backend_index.dispatch_key == DispatchKey.Meta:
366
+ assert not self.backend_index.has_kernel(g.out), (
367
+ "Do not explicitly specify Meta dispatch key on structured "
368
+ "functions, they will be automatically generated for you"
369
+ )
370
+ elif (
371
+ self.backend_index.dispatch_key
372
+ == DispatchKey.CompositeExplicitAutogradNonFunctional
373
+ ):
374
+ assert not self.backend_index.has_kernel(g.out), (
375
+ "Do not explicitly specify CompositeExplicitAutograd dispatch key on structured "
376
+ "functions, they will be automatically generated for you"
377
+ )
378
+ elif metadata is None or not metadata.structured:
379
+ return list(mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions()))
380
+ structured_gen = StructuredRegisterDispatchKey(
381
+ self.backend_index,
382
+ self.target,
383
+ self.selector,
384
+ self.rocm,
385
+ self.symint,
386
+ self.class_method_name,
387
+ self.skip_dispatcher_op_registration,
388
+ g,
389
+ )
390
+ return list(mapMaybe(structured_gen.gen_one, g.functions()))
391
+
392
+ def gen_unstructured(
393
+ self, f: NativeFunction, g: NativeFunctionsGroup | None = None
394
+ ) -> str | None:
395
+ with native_function_manager(f):
396
+ inplace_meta = False
397
+ gets_out_inplace_wrapper = False
398
+ if not self.backend_index.has_kernel(f):
399
+ if (
400
+ self.backend_index.dispatch_key == DispatchKey.Meta
401
+ and f.func.kind() is SchemaKind.inplace
402
+ and
403
+ # Defer to composites for meta implementation
404
+ not f.has_composite_kernel
405
+ and
406
+ # Inplace list operations are not supported
407
+ len(f.func.returns) == 1
408
+ ):
409
+ inplace_meta = True
410
+ elif (
411
+ not self.backend_index.use_out_as_primary
412
+ and g is not None
413
+ and gets_generated_out_inplace_wrapper(f, g, self.backend_index)
414
+ ):
415
+ # We want to generate inplace/out wrappers, that don't have a kernel for the backend.
416
+ gets_out_inplace_wrapper = True
417
+ else:
418
+ return None
419
+ if f.manual_kernel_registration:
420
+ return None
421
+
422
+ if (
423
+ self.target is Target.REGISTRATION
424
+ and not self.selector.is_native_function_selected(f)
425
+ ):
426
+ return None
427
+
428
+ sig = self.wrapper_kernel_sig(f)
429
+
430
+ name = sig.name()
431
+ returns_type = sig.returns_type().cpp_type()
432
+ args = sig.arguments()
433
+ args_str = ", ".join(a.defn() for a in args)
434
+
435
+ # See Note [Direct dispatch bindings]
436
+ cpp_sig_group = CppSignatureGroup.from_native_function(
437
+ f, method=False, fallback_binding=False
438
+ )
439
+
440
+ # TODO: dedupe this with the structured codegen
441
+ if self.target is Target.NAMESPACED_DECLARATION:
442
+ result = ""
443
+ for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
444
+ result += f"TORCH_API {cpp_sig.decl()};\n"
445
+ return result
446
+ elif self.target is Target.NAMESPACED_DEFINITION:
447
+
448
+ def generate_defn(cpp_sig: CppSignature) -> str:
449
+ return f"""
450
+ {cpp_sig.defn()} {{
451
+ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
452
+ }}
453
+ """
454
+
455
+ result = ""
456
+ for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
457
+ result += generate_defn(cpp_sig)
458
+ return result
459
+
460
+ elif self.target is Target.ANONYMOUS_DEFINITION:
461
+ # short circuit for inplace_meta
462
+ if inplace_meta:
463
+ assert f.func.arguments.self_arg is not None
464
+ self_arg_name = f.func.arguments.self_arg.argument.name
465
+ # TODO: handle in place on tensor list
466
+ return f"""
467
+ {returns_type} {name}({args_str}) {{
468
+ TORCH_CHECK_NOT_IMPLEMENTED({self_arg_name}.is_meta(),
469
+ "Cannot inplace into non-meta tensor with meta tensor argument");
470
+ return {self_arg_name};
471
+ }}
472
+ """
473
+
474
+ # short circuit for generated inplace/out wrappers
475
+ if gets_out_inplace_wrapper:
476
+ return self.gen_out_inplace_wrapper(f, g)
477
+
478
+ metadata = self.backend_index.get_kernel(f)
479
+ if metadata is None:
480
+ return None
481
+ if self.class_method_name is None:
482
+ impl_name = f"{metadata.cpp_namespace}::{metadata.kernel}"
483
+ else:
484
+ impl_name = f"{metadata.cpp_namespace}::{self.class_method_name}::{metadata.kernel}"
485
+
486
+ kernel_sig = kernel_signature(f, self.backend_index)
487
+
488
+ args_exprs_str = ", ".join(
489
+ e.expr
490
+ for e in translate(
491
+ sig.arguments(), kernel_sig.arguments(), method=False
492
+ )
493
+ )
494
+
495
+ device_check = " // No device check\n"
496
+ # Backends that require device guards presumably also require device checks.
497
+ if self.backend_index.device_guard:
498
+ device_check_args = itertools.chain(
499
+ f.func.arguments.out, f.func.arguments.flat_positional
500
+ )
501
+ device_check = RegisterDispatchKey.gen_device_check(
502
+ f.device_check, list(device_check_args), name
503
+ )
504
+
505
+ device_guard = "// DeviceGuard omitted" # default
506
+ if f.device_guard and self.backend_index.device_guard:
507
+ has_tensor_options = any(
508
+ isinstance(a, TensorOptionsArguments)
509
+ for a in f.func.arguments.non_out
510
+ )
511
+ if has_tensor_options:
512
+ # kernel is creating a tensor
513
+ device_guard = """
514
+ const DeviceGuard device_guard(device_or_default(device));"""
515
+
516
+ # CUDA requires special handling
517
+ if is_cuda_dispatch_key(self.backend_index.dispatch_key):
518
+ device_guard = (
519
+ f"globalContext().lazyInitCUDA();\n{device_guard}"
520
+ )
521
+ else:
522
+ # kernel is operating on existing tensors
523
+
524
+ # There is precedence for which argument we use to do
525
+ # device guard. This describes the precedence order.
526
+ self_arg = (
527
+ [f.func.arguments.self_arg.argument]
528
+ if f.func.arguments.self_arg is not None
529
+ else []
530
+ )
531
+ candidate_args = itertools.chain(
532
+ self_arg,
533
+ f.func.arguments.out,
534
+ f.func.arguments.flat_positional,
535
+ )
536
+
537
+ # Only tensor like arguments are eligible
538
+ device_of = next(
539
+ (
540
+ f"{a.name}"
541
+ for a in candidate_args
542
+ if a.type.is_tensor_like()
543
+ ),
544
+ None,
545
+ )
546
+ if device_of is not None:
547
+ device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));"
548
+
549
+ return f"""\
550
+ namespace {{
551
+
552
+ {returns_type} {name}({args_str}) {{
553
+ {device_check}
554
+
555
+ {device_guard}
556
+ return {impl_name}({args_exprs_str});
557
+ }}
558
+
559
+ }} // anonymous namespace
560
+ """
561
+
562
+ elif self.target is Target.REGISTRATION:
563
+ if f.manual_kernel_registration or self.skip_dispatcher_op_registration:
564
+ return None
565
+ else:
566
+ payload = f"TORCH_FN({name})"
567
+ return f'm.impl("{f.func.name}",\n{payload});\n'
568
+ else:
569
+ assert_never(self.target)
570
+
571
+
572
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
573
+ #
574
+ # STRUCTURED
575
+ #
576
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
577
+
578
+
579
+ @dataclass(frozen=True)
580
+ class StructuredRegisterDispatchKey(RegisterDispatchKey):
581
+ g: NativeFunctionsGroup
582
+
583
+ def gen_class_set_output_functions(
584
+ self, k: SchemaKind, parent_class: str, generate_super: bool
585
+ ) -> str:
586
+ if generate_super:
587
+ set_output_super = f"{parent_class}::set_output_raw_strided(output_idx, sizes, strides, options, names);"
588
+ else:
589
+ set_output_super = ""
590
+
591
+ def gen_set_output_function(name: str, maybe_create_proxy: bool) -> str:
592
+ return f"""
593
+ void set_output_{name}(
594
+ int64_t output_idx, IntArrayRef sizes, IntArrayRef strides,
595
+ TensorOptions options, DimnameList names
596
+ ) override {{
597
+ {textwrap.indent(self.gen_class_set_output_body(k, maybe_create_proxy), " ")}
598
+ if (!names.empty()) {{
599
+ namedinference::propagate_names(outputs_[output_idx], names);
600
+ }}
601
+ // super must happen after, so that downstream can use maybe_get_output
602
+ // to retrieve the output
603
+ {textwrap.indent(set_output_super, " ")}
604
+ }}
605
+ """
606
+
607
+ return f"""
608
+ {gen_set_output_function("strided", maybe_create_proxy=True)}
609
+ {gen_set_output_function("raw_strided", maybe_create_proxy=False)}
610
+ """
611
+
612
+ def gen_class_set_output_body(self, k: SchemaKind, maybe_create_proxy: bool) -> str:
613
+ if self.backend_index.dispatch_key in [
614
+ DispatchKey.CUDA,
615
+ DispatchKey.MPS,
616
+ DispatchKey.CompositeExplicitAutogradNonFunctional,
617
+ ]:
618
+ maybe_set_guard = """
619
+ auto current_device = guard_.current_device();
620
+ if (C10_UNLIKELY(current_device.has_value())) {
621
+ TORCH_INTERNAL_ASSERT(*current_device == options.device(),
622
+ "structured kernels don't support multi-device outputs");
623
+ } else {
624
+ guard_.reset_device(options.device());
625
+ }
626
+ """
627
+ maybe_set_guard_line = maybe_set_guard + "\n"
628
+ else:
629
+ maybe_set_guard_line = maybe_set_guard = ""
630
+
631
+ if maybe_create_proxy:
632
+ create_proxy = """
633
+ auto maybe_proxy = maybe_create_proxy(out, sizes, strides, options);
634
+ if (C10_UNLIKELY(maybe_proxy.has_value())) {
635
+ proxy_outputs_[output_idx] = std::move(maybe_proxy).value();
636
+ }
637
+ """
638
+ else:
639
+ create_proxy = ""
640
+
641
+ if k is SchemaKind.functional:
642
+ assert self.backend_index.dispatch_key in (
643
+ DispatchKey.Meta,
644
+ DispatchKey.CPU,
645
+ DispatchKey.CUDA,
646
+ DispatchKey.MPS,
647
+ DispatchKey.XPU,
648
+ DispatchKey.CompositeExplicitAutogradNonFunctional,
649
+ )
650
+ return f"""{maybe_set_guard_line}
651
+ outputs_[output_idx] = create_out(sizes, strides, options);"""
652
+ elif k is SchemaKind.inplace:
653
+ return f"""{maybe_set_guard_line}
654
+ const auto& out = outputs_[output_idx].get();
655
+ check_inplace(out, sizes, options);
656
+ {create_proxy}"""
657
+ elif k is SchemaKind.out:
658
+ return f"""{maybe_set_guard_line}
659
+ const auto& out = outputs_[output_idx].get();
660
+ resize_out(out, sizes, strides, options);
661
+ {create_proxy}"""
662
+ elif k is SchemaKind.mutable or k is SchemaKind.scratch:
663
+ raise AssertionError(
664
+ f"{k} structured operators are currently not supported"
665
+ )
666
+ else:
667
+ assert_never(k)
668
+
669
+ # returns the definition of a ctor, as well as how to construct
670
+ # this class to a variable named op
671
+ def gen_class_ctor(self, k: SchemaKind, class_name: str, returns: int) -> str:
672
+ if k is SchemaKind.functional:
673
+ return ""
674
+ elif k is SchemaKind.inplace:
675
+ # TODO: Make sure out argument is guaranteed to be self
676
+ return f"{class_name}(Tensor& self) : outputs_{{std::ref(self)}} {{}}"
677
+ elif k is SchemaKind.out:
678
+ out_args = ", ".join(f"Tensor& out{i}" for i in range(returns))
679
+ out_refs = ", ".join(f"std::ref(out{i})" for i in range(returns))
680
+ return f"{class_name}({out_args}) : outputs_{{ {out_refs} }} {{}}"
681
+ elif k is SchemaKind.mutable or k is SchemaKind.scratch:
682
+ raise AssertionError(
683
+ f"{k} structured operators are currently not supported"
684
+ )
685
+ else:
686
+ assert_never(k)
687
+
688
+ def gen_class(
689
+ self,
690
+ f: NativeFunction,
691
+ k: SchemaKind,
692
+ *,
693
+ class_name: str,
694
+ parent_class: str,
695
+ generate_super: bool,
696
+ ) -> str:
697
+ if k is SchemaKind.functional:
698
+ output_type = "Tensor"
699
+ output_value = "outputs_[output_idx]"
700
+ proxy_field = ""
701
+ elif k is SchemaKind.inplace:
702
+ output_type = "std::reference_wrapper<Tensor>"
703
+ output_value = "proxy_outputs_[output_idx].has_value() ? *proxy_outputs_[output_idx] : outputs_[output_idx].get()"
704
+ proxy_field = f"std::array<::std::optional<Tensor>, {len(f.func.returns)}> proxy_outputs_;"
705
+ elif k is SchemaKind.out:
706
+ output_type = "std::reference_wrapper<Tensor>"
707
+ output_value = "proxy_outputs_[output_idx].has_value() ? *proxy_outputs_[output_idx] : outputs_[output_idx].get()"
708
+ proxy_field = f"std::array<::std::optional<Tensor>, {len(f.func.returns)}> proxy_outputs_;"
709
+ else:
710
+ raise RuntimeError(f"Unsupported SchemaKind {k}")
711
+
712
+ if self.backend_index.dispatch_key == DispatchKey.CUDA:
713
+ if self.rocm:
714
+ guard_field = "c10::hip::OptionalHIPGuardMasqueradingAsCUDA guard_;"
715
+ else:
716
+ guard_field = "c10::cuda::OptionalCUDAGuard guard_;"
717
+ elif (
718
+ self.backend_index.dispatch_key
719
+ == DispatchKey.CompositeExplicitAutogradNonFunctional
720
+ ):
721
+ guard_field = "c10::OptionalDeviceGuard guard_;"
722
+ elif self.backend_index.dispatch_key == DispatchKey.MPS:
723
+ # TODO: Move to OptionalMPSGuard.
724
+ guard_field = "c10::OptionalDeviceGuard guard_;"
725
+ else:
726
+ guard_field = ""
727
+
728
+ indent = " " * 4
729
+ class_ctor_str = self.gen_class_ctor(k, class_name, len(f.func.returns))
730
+ lines = (
731
+ f"struct {class_name} final : public {parent_class} {{",
732
+ f"{textwrap.indent(class_ctor_str, indent)}",
733
+ f"{textwrap.indent(self.gen_class_set_output_functions(k, parent_class, generate_super), indent)}",
734
+ " const Tensor& maybe_get_output(int64_t output_idx) override {",
735
+ f" return {output_value};\n", # type: ignore[possibly-undefined] # TODO: audit
736
+ " }",
737
+ # type: ignore[possibly-undefined] # TODO: audit
738
+ f" std::array<{output_type}, {len(f.func.returns)}> outputs_;",
739
+ f"{textwrap.indent(proxy_field, indent)}", # type: ignore[possibly-undefined] # TODO: audit
740
+ f"{textwrap.indent(guard_field, indent)}",
741
+ "};",
742
+ )
743
+ return "\n".join(line for line in lines if line)
744
+
745
+ @method_with_native_function
746
+ def gen_one(self, f: NativeFunction) -> str | None:
747
+ assert not f.manual_kernel_registration
748
+
749
+ if (
750
+ self.target is Target.REGISTRATION
751
+ and not self.selector.is_native_function_selected(f)
752
+ ):
753
+ return None
754
+
755
+ # TODO: Now, there is something interesting going on here. In the code below,
756
+ # we generate CompositeExplicitAutogradNonFunctional implementations of functional and inplace
757
+ # based on the out implementation. But in fact, out is definable by
758
+ # functional too (just not very efficiently), and this is honestly the
759
+ # MORE likely situation for a backend implementor. How do we pick?
760
+ # Well, taking a page from Haskell type classes and default methods,
761
+ # we could conceivably register a circular definition (out in terms
762
+ # of functional, and functional in terms of out) and just require
763
+ # someone to implement one or the other. We'd have to do a little bit
764
+ # of work to not register one of these "weak" definitions unless there
765
+ # is a strong definition somewhere in the DAG! So it's not implemented yet.
766
+ if (
767
+ self.backend_index.dispatch_key
768
+ == DispatchKey.CompositeExplicitAutogradNonFunctional
769
+ and f.func.kind() is SchemaKind.out
770
+ ):
771
+ # Never generate a default implementation for out, that's what you
772
+ # have to define as a backend implementor
773
+ return None
774
+
775
+ # Note [Direct dispatch bindings]
776
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
777
+ # Signature of the non-dispatched function we'll expose in a header
778
+ # (e.g., at::cpu::add). We don't generate methods (TODO: do this
779
+ # when CPUTensor class is a thing); nor do we generate fallback
780
+ # bindings for manual_cpp_binding functions.
781
+ cpp_sig_group = CppSignatureGroup.from_native_function(
782
+ f, method=False, fallback_binding=False
783
+ )
784
+
785
+ # Signature of the wrapper function we'll register to the dispatcher
786
+ kern = self.backend_index.get_kernel(f)
787
+ sig = NativeSignature(
788
+ f.func,
789
+ prefix=f"wrapper_{self.backend_index.dispatch_key}_",
790
+ symint=kern is not None and kern.supports_symint(),
791
+ )
792
+
793
+ if self.target is Target.NAMESPACED_DECLARATION:
794
+ result = ""
795
+ for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
796
+ result += f"TORCH_API {cpp_sig.decl()};\n"
797
+ return result
798
+
799
+ elif self.target is Target.NAMESPACED_DEFINITION:
800
+
801
+ def generate_defn(cpp_sig: CppSignature) -> str:
802
+ return f"""
803
+ {cpp_sig.defn()} {{
804
+ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
805
+ }}
806
+ """
807
+
808
+ result = ""
809
+ for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
810
+ result += generate_defn(cpp_sig)
811
+ return result
812
+
813
+ elif self.target is Target.ANONYMOUS_DEFINITION:
814
+ k = f.func.kind()
815
+
816
+ # Construct the body of the wrapper function with signature sig
817
+ sig_body = []
818
+ # We'll use context to keep track of any variables we've brought
819
+ # into scope while generating code
820
+ context: list[Binding | Expr] = list(sig.arguments())
821
+
822
+ # Initialize the class corresponding to this structured
823
+ # operator; feeding it the output argument(s) if it is known
824
+ if self.backend_index.dispatch_key is DispatchKey.Meta:
825
+ class_name = f"structured_{meta.name(self.g)}_meta_{k.name}"
826
+ parent_class = f"at::meta::structured_{meta.name(self.g)}"
827
+ elif (
828
+ self.backend_index.dispatch_key
829
+ is DispatchKey.CompositeExplicitAutogradNonFunctional
830
+ ):
831
+ # TODO: dedup this branch
832
+ class_name = f"structured_{meta.name(self.g)}_default_backend_{k.name}"
833
+ parent_class = f"at::meta::structured_{meta.name(self.g)}"
834
+ else:
835
+ metadata = self.backend_index.get_kernel(self.g)
836
+ assert metadata is not None
837
+ class_name = f"structured_{metadata.kernel}_{k.name}"
838
+ parent_class = f"{metadata.cpp_namespace}::structured_{metadata.kernel}"
839
+
840
+ if self.backend_index.device_guard:
841
+ device_check_args = itertools.chain(
842
+ f.func.arguments.out, f.func.arguments.flat_positional
843
+ )
844
+ sig_body.append(
845
+ RegisterDispatchKey.gen_device_check(
846
+ f.device_check, list(device_check_args), sig.name()
847
+ )
848
+ )
849
+
850
+ if k is SchemaKind.functional:
851
+ sig_body.append(f"{class_name} op;")
852
+ elif k is SchemaKind.inplace:
853
+ sig_body.append(f"{class_name} op(self);")
854
+ elif k is SchemaKind.out:
855
+ out_args_str = ", ".join(a.name for a in f.func.arguments.out)
856
+ sig_body.append(f"{class_name} op({out_args_str});")
857
+
858
+ # Translate the input native arguments into structured
859
+ # arguments for the meta call
860
+ meta_exprs = ", ".join(
861
+ e.expr
862
+ for e in translate(
863
+ context, structured.meta_arguments(self.g), method=False
864
+ )
865
+ )
866
+
867
+ if self.g.out.precomputed:
868
+ # If this function group has precomputed elements, the meta function
869
+ # returns a struct containing them which must be saved so that it
870
+ # can be unpacked when generating code to call the impl.
871
+ sig_body.append(f"auto precompute = op.meta({meta_exprs});")
872
+
873
+ # Put all of the contents of the precompute struct into the context
874
+ # so that translate will be able to return the correct args for the
875
+ # call to the impl.
876
+ precomputed_values = [
877
+ *self.g.out.precomputed.replace.values(),
878
+ self.g.out.precomputed.add,
879
+ ]
880
+ for precomputed_elems in precomputed_values:
881
+ for arg in precomputed_elems:
882
+ context.append(
883
+ Expr(
884
+ expr=f"precompute.{arg.name}",
885
+ type=structured.argument_type(arg, binds=arg.name),
886
+ )
887
+ )
888
+
889
+ # Add a use of the precompute struct so FB internal compilers don't
890
+ # complain that there is an unused variable.
891
+ sig_body.append("(void)precompute;")
892
+ else:
893
+ sig_body.append(f"op.meta({meta_exprs});")
894
+
895
+ # After running meta, op.outputs_ is guaranteed to be valid;
896
+ # add it to the context
897
+ out_args = structured.out_arguments(self.g)
898
+ for i, out_arg in enumerate(out_args):
899
+ assert ConstRefCType(BaseCType(tensorT)) == out_arg.nctype.type
900
+
901
+ if k is SchemaKind.out:
902
+ expr = f"op.maybe_get_output({i})"
903
+ else:
904
+ expr = f"op.outputs_[{i}]"
905
+
906
+ context.append(
907
+ Expr(
908
+ expr=expr,
909
+ # TODO: Stop hardcoding that the output type is a Tensor. Note
910
+ # that for the codegen here this is fine because outputs_ is
911
+ # hardcoded to be tensor already
912
+ type=NamedCType(
913
+ out_arg.nctype.name, MutRefCType(BaseCType(tensorT))
914
+ ),
915
+ )
916
+ )
917
+
918
+ # With the expanded context, do the impl call (if not a meta
919
+ # function)
920
+ if (
921
+ self.backend_index.dispatch_key
922
+ == DispatchKey.CompositeExplicitAutogradNonFunctional
923
+ ):
924
+ # TODO: https://github.com/pytorch/pytorch/issues/53023
925
+ out_sig_group = CppSignatureGroup.from_native_function(
926
+ self.g.out, method=False, fallback_binding=f.manual_cpp_binding
927
+ )
928
+ out_sig = out_sig_group.most_faithful_signature()
929
+ api_name = out_sig.name()
930
+ out_exprs = ", ".join(
931
+ e.expr
932
+ for e in translate(context, out_sig.arguments(), method=False)
933
+ )
934
+ # TODO: I think this means structured won't work with method
935
+ # only functions (but maybe you're saved by faithful? iunno.)
936
+ # NB: Originally I wrote this as an at::redispatch call, but
937
+ # I got in trouble because that meant I needed a DispatchKeySet
938
+ # in the wrapper function, which meant I needed a DispatchKeySet
939
+ # in the DispatchKeyFunctions declarations, but the defined API
940
+ # there does NOT permit a dispatch key set. I think you can
941
+ # probably unwind this by calling some function to do the TLS
942
+ # fetch and get the DispatchKeySet when you don't have it, but
943
+ # I didn't do it for this version
944
+ sig_body.append(f"at::{api_name}({out_exprs});")
945
+ elif self.backend_index.dispatch_key != DispatchKey.Meta:
946
+ impl_exprs = ", ".join(
947
+ e.expr
948
+ for e in translate(
949
+ context, structured.impl_arguments(self.g), method=False
950
+ )
951
+ )
952
+ sig_body.append(f"op.impl({impl_exprs});")
953
+
954
+ # Go over each output, and check if there is a proxy created for it.
955
+ # If so, copy it over to the original output.
956
+ if k is SchemaKind.out or k is SchemaKind.inplace:
957
+ for i in range(len(f.func.returns)):
958
+ sig_body.append(
959
+ f"if (op.proxy_outputs_[{i}].has_value()) op.outputs_[{i}].get().copy_(*op.proxy_outputs_[{i}]);"
960
+ )
961
+
962
+ # Destructively return the final tensors
963
+ # TODO: Do this in translate instead
964
+ if k is SchemaKind.functional:
965
+ if len(f.func.returns) == 1:
966
+ ret_expr = "std::move(op.outputs_[0])" # small optimization
967
+ else:
968
+ moved = ", ".join(
969
+ f"std::move(op.outputs_[{i}])"
970
+ for i in range(len(f.func.returns))
971
+ )
972
+ ret_expr = f"std::make_tuple({moved})"
973
+ elif k is SchemaKind.inplace:
974
+ ret_expr = "self"
975
+ elif k is SchemaKind.out:
976
+ if len(f.func.returns) == 1:
977
+ ret_expr = f.func.arguments.out[0].name
978
+ else:
979
+ refs = ", ".join(a.name for a in f.func.arguments.out)
980
+ ret_expr = f"std::forward_as_tuple({refs})"
981
+ sig_body.append(f"return {ret_expr};") # type: ignore[possibly-undefined] # TODO: audit
982
+
983
+ sig_body_str = "\n".join(sig_body)
984
+
985
+ # For an overview of what this template code looks like, see
986
+ # https://github.com/pytorch/rfcs/pull/9
987
+ return f"""\
988
+ {self.gen_class(
989
+ f, k,
990
+ class_name=class_name,
991
+ parent_class=parent_class,
992
+ generate_super=self.g.out.structured_inherits is not None
993
+ )}
994
+
995
+ {sig.defn()} {{
996
+ {sig_body_str}
997
+ }}
998
+ """
999
+
1000
+ elif self.target is Target.REGISTRATION:
1001
+ return f'm.impl("{f.func.name}", TORCH_FN({sig.name()}));'
1002
+ else:
1003
+ assert_never(self.target)
1004
+ # Silence mypy's "Missing return statement" error
1005
+ return None
.venv/lib/python3.11/site-packages/torchgen/dest/ufunc.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Sequence, TYPE_CHECKING
5
+
6
+ import torchgen.api.ufunc as ufunc
7
+ from torchgen.api.translate import translate
8
+ from torchgen.api.types import (
9
+ BaseCType,
10
+ Binding,
11
+ CType,
12
+ Expr,
13
+ NamedCType,
14
+ opmath_t,
15
+ scalar_t,
16
+ StructuredImplSignature,
17
+ VectorizedCType,
18
+ )
19
+ from torchgen.context import with_native_function
20
+ from torchgen.model import (
21
+ Argument,
22
+ BaseTy,
23
+ BaseType,
24
+ DispatchKey,
25
+ NativeFunctionsGroup,
26
+ ScalarType,
27
+ UfuncKey,
28
+ )
29
+ from torchgen.utils import OrderedSet
30
+
31
+
32
+ if TYPE_CHECKING:
33
+ from torchgen.api.ufunc import UfunctorBindings
34
+
35
+
36
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
37
+ #
38
+ # CUDA STUFF
39
+ #
40
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
41
+
42
+ # NB: not bothering to generate dispatch stub forward declaration in header,
43
+ # we can just paste it whereever necessary
44
+
45
+ # TODO: use BackendIndex
46
+ # dispatch_key: DispatchKey # only CPU/CUDA right now
47
+
48
+
49
+ # Represents functors for implementing CUDA ufuncs.
50
+ # Functors are templated by scalar_t because when USERS instantiate functors
51
+ # they are templated. A functor looks something like this:
52
+ #
53
+ # template <typename scalar_t>
54
+ # struct CUDAFunctorOnSelf_add {
55
+ # using opmath_t = at::opmath_type<scalar_t>;
56
+ # opmath_t other_;
57
+ # opmath_t alpha_;
58
+ # CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha)
59
+ # : other_(other), alpha_(alpha) {}
60
+ # __device__ scalar_t operator()(scalar_t self) {
61
+ # return ufunc::add(static_cast<opmath_t>(self), other_, alpha_);
62
+ # }
63
+ # };
64
+ #
65
+ @dataclass(frozen=True)
66
+ class UfunctorSignature:
67
+ g: NativeFunctionsGroup
68
+ scalar_tensor_idx: int | None
69
+ name: str
70
+
71
+ def arguments(self) -> UfunctorBindings:
72
+ return ufunc.ufunctor_arguments(
73
+ self.g, scalar_tensor_idx=self.scalar_tensor_idx, scalar_t=scalar_t
74
+ )
75
+
76
+ def fields(self) -> list[Binding]:
77
+ # fields are renamed to have a trailing underscore, as is conventional
78
+ return [b.rename(f"{b.name}_") for b in self.arguments().ctor]
79
+
80
+ def returns_type(self) -> CType:
81
+ # TODO: don't hardcode; return type will be inferred based on tags on
82
+ # the native function
83
+ return BaseCType(scalar_t)
84
+
85
+ def decl_fields(self) -> str:
86
+ return "\n".join(f"{f.type} {f.name};" for f in self.fields())
87
+
88
+ def inline_defn_ctor(self) -> str:
89
+ args_str = ", ".join(a.decl() for a in self.arguments().ctor)
90
+ # NB: hypothetically could do this with translate but the
91
+ # transition here is very regular
92
+ init_str = ", ".join(f"{a.name}_({a.name})" for a in self.arguments().ctor)
93
+ return f"{self.name}({args_str}) : {init_str} {{}}"
94
+
95
+ def decl_apply(self) -> str:
96
+ args_str = ", ".join(a.decl() for a in self.arguments().apply)
97
+ return f"{self.returns_type().cpp_type()} operator()({args_str}) const"
98
+
99
+
100
+ @dataclass(frozen=True)
101
+ class UfuncSignature:
102
+ g: NativeFunctionsGroup
103
+ name: str
104
+ compute_t: CType
105
+
106
+ def arguments(self) -> list[Binding]:
107
+ return ufunc.ufunc_arguments(self.g, compute_t=self.compute_t)
108
+
109
+ def call(self, ctx: Sequence[Binding | Expr]) -> str:
110
+ return f"{self.name}({', '.join(a.expr for a in translate(ctx, self.arguments()))})"
111
+
112
+
113
+ # steps:
114
+ # 1. take the functional signature
115
+ # 2. use api.ufunc to convert it to template signature. this establishes
116
+ # the type of the template function
117
+ # 3. use api.ufunc (II) to generate a split struct / operator() signature.
118
+ # this establish context in which we call the template signature
119
+ #
120
+ # StructuredImplSignature context
121
+ # ~> functor constructor sig
122
+ #
123
+ # Functor constructor context
124
+ # ~> functor fields sig
125
+ #
126
+ # Functor apply context (functor fields + functor apply sig)
127
+ # ~> template sig
128
+ #
129
+
130
+
131
+ def eligible_for_binary_scalar_specialization(g: NativeFunctionsGroup) -> bool:
132
+ num_tensors = sum(
133
+ 1 for a in g.functional.func.arguments.flat_non_out if a.type.is_tensor_like()
134
+ )
135
+ return num_tensors == 2
136
+
137
+
138
+ def compute_ufunc_cuda_functors(
139
+ g: NativeFunctionsGroup,
140
+ ) -> tuple[dict[ScalarType, dict[UfuncKey, UfunctorSignature]], str]:
141
+ # First, build the functors.
142
+ ufunctor_sigs: dict[ScalarType, dict[UfuncKey, UfunctorSignature]] = {}
143
+ ufunctors: list[str] = []
144
+ loops = g.out.ufunc_inner_loop
145
+ scalar_tensor_idx_lookup = {
146
+ UfuncKey.CUDAFunctorOnSelf: 1,
147
+ UfuncKey.CUDAFunctorOnOther: 0,
148
+ UfuncKey.CUDAFunctor: None,
149
+ }
150
+ if eligible_for_binary_scalar_specialization(g):
151
+ keys = [
152
+ UfuncKey.CUDAFunctorOnSelf,
153
+ UfuncKey.CUDAFunctorOnOther,
154
+ UfuncKey.CUDAFunctor,
155
+ ]
156
+ else:
157
+ keys = [UfuncKey.CUDAFunctor]
158
+ for k in [UfuncKey.CUDAFunctorOnSelf, UfuncKey.CUDAFunctorOnOther]:
159
+ assert k not in loops, f"cannot use {k} on non-binary function"
160
+ for k in keys:
161
+ # If the key was directly defined, skip functor codegen; we assume the
162
+ # user already done it for us
163
+ if k in loops:
164
+ ufunctor_sig = UfunctorSignature(
165
+ g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=loops[k].name
166
+ )
167
+ for dtype in loops[k].supported_dtypes:
168
+ ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig
169
+ continue
170
+
171
+ # Note [ScalarOnly and Generic must match names for CUDA]
172
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
173
+ # Otherwise, look in ANY of the generic entries. For simplicity of
174
+ # codegen, both ScalarOnly and Generic are defined, the ufunc name
175
+ # must match (if they didn't match, we'd have to generate distinct
176
+ # functors per dtype, which is awful, so we're not going to do it unless
177
+ # someone really forces us to)
178
+ ufunc_name = None
179
+ supported_dtypes: OrderedSet[ScalarType] = OrderedSet()
180
+ for lk in [UfuncKey.ScalarOnly, UfuncKey.Generic]:
181
+ if lk not in loops:
182
+ continue
183
+ if ufunc_name is None:
184
+ ufunc_name = loops[lk].name
185
+ else:
186
+ # See Note [ScalarOnly and Generic must match names for CUDA]
187
+ assert (
188
+ ufunc_name == loops[lk].name
189
+ ), "ScalarOnly and Generic must have same ufunc name"
190
+ supported_dtypes |= loops[lk].supported_dtypes
191
+ assert ufunc_name is not None
192
+
193
+ name = f"{k}_{ufunc_name}"
194
+ ufunctor_sig = UfunctorSignature(
195
+ g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=name
196
+ )
197
+ for dtype in supported_dtypes:
198
+ ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig
199
+
200
+ ufunc_sig = UfuncSignature(
201
+ g, name=f"ufunc::{ufunc_name}", compute_t=BaseCType(opmath_t)
202
+ )
203
+ apply_ctx = ufunctor_sig.fields() + ufunctor_sig.arguments().apply
204
+ ufunctors.append(
205
+ f"""
206
+ template <typename scalar_t>
207
+ struct {ufunctor_sig.name} {{
208
+ using opmath_t = at::opmath_type<scalar_t>;
209
+ {ufunctor_sig.decl_fields()}
210
+ {ufunctor_sig.inline_defn_ctor()}
211
+ __device__ {ufunctor_sig.decl_apply()} {{
212
+ return {ufunc_sig.call(apply_ctx)};
213
+ }}
214
+ }};
215
+ """
216
+ )
217
+
218
+ return ufunctor_sigs, "\n".join(ufunctors)
219
+
220
+
221
+ @dataclass(frozen=True)
222
+ class BinaryScalarSpecializationConfig:
223
+ scalar_idx: int
224
+ ctor_tensor: str
225
+ ufunc_key: UfuncKey
226
+
227
+
228
+ BinaryScalarSpecializationConfigs = [
229
+ BinaryScalarSpecializationConfig(
230
+ scalar_idx=0,
231
+ ctor_tensor="self",
232
+ ufunc_key=UfuncKey.CUDAFunctorOnOther,
233
+ ),
234
+ BinaryScalarSpecializationConfig(
235
+ scalar_idx=1,
236
+ ctor_tensor="other",
237
+ ufunc_key=UfuncKey.CUDAFunctorOnSelf,
238
+ ),
239
+ ]
240
+
241
+
242
+ def compute_ufunc_cuda_dtype_body(
243
+ g: NativeFunctionsGroup,
244
+ dtype: ScalarType,
245
+ inner_loops: dict[UfuncKey, UfunctorSignature],
246
+ parent_ctx: Sequence[Binding],
247
+ ) -> str:
248
+ body = "using opmath_t = at::opmath_type<scalar_t>;"
249
+ body += "if (false) {}\n" # for ease of codegen
250
+ for config in BinaryScalarSpecializationConfigs:
251
+ if config.ufunc_key not in inner_loops:
252
+ continue
253
+ ufunctor_sig = inner_loops[config.ufunc_key]
254
+ scalar_idx = config.scalar_idx + 1
255
+ # Make a copy and at the same time widen the type (not permissible
256
+ # without copy; we don't want to mutate the input argument anyway)
257
+ ctx: list[Expr | Binding] = list(parent_ctx)
258
+ ctx.append(
259
+ Expr(
260
+ expr=f"iter.scalar_value<opmath_t>({scalar_idx})",
261
+ type=NamedCType(config.ctor_tensor, BaseCType(opmath_t)),
262
+ )
263
+ )
264
+ ufunctor_ctor_exprs_str = ", ".join(
265
+ a.expr for a in translate(ctx, ufunctor_sig.arguments().ctor)
266
+ )
267
+
268
+ # NB: ufunctor must be allocated before iter.remove_operand is called,
269
+ # as it relies on iter
270
+ body += f"""\
271
+ else if (iter.is_cpu_scalar({scalar_idx})) {{
272
+ {ufunctor_sig.name}<scalar_t> ufunctor({ufunctor_ctor_exprs_str});
273
+ iter.remove_operand({scalar_idx});
274
+ gpu_kernel(iter, ufunctor);
275
+ }}"""
276
+
277
+ ufunctor_sig = inner_loops[UfuncKey.CUDAFunctor]
278
+ ufunctor_ctor_exprs_str = ", ".join(
279
+ a.expr for a in translate(parent_ctx, ufunctor_sig.arguments().ctor)
280
+ )
281
+ body += f"""
282
+ else {{
283
+ gpu_kernel(iter, {ufunctor_sig.name}<scalar_t>({ufunctor_ctor_exprs_str}));
284
+ }}
285
+ """
286
+ return body
287
+
288
+
289
+ @with_native_function
290
+ def compute_ufunc_cuda(g: NativeFunctionsGroup) -> str:
291
+ # First, build the functors, indexing them by dtype
292
+ ufunctor_sigs, ufunctors = compute_ufunc_cuda_functors(g)
293
+
294
+ # Next, build the conditionals
295
+ sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CUDA))
296
+ dtype_cases = []
297
+ for dtype, inner_ufunc_sigs in ufunctor_sigs.items():
298
+ dtype_cases.append(
299
+ f"""
300
+ AT_DISPATCH_CASE(at::ScalarType::{dtype},
301
+ [&]() {{
302
+ {compute_ufunc_cuda_dtype_body(g, dtype, inner_ufunc_sigs, sig.arguments())}
303
+ }}
304
+ )
305
+ """
306
+ )
307
+
308
+ dtype_cases_str = "\n".join(dtype_cases)
309
+
310
+ stub_sig = StubSignature(g)
311
+
312
+ return f"""
313
+ {ufunctors}
314
+
315
+ {stub_sig.type_defn()};
316
+ {stub_sig.dispatch_decl()};
317
+
318
+ {stub_sig.kernel_defn()} {{
319
+ AT_DISPATCH_SWITCH(iter.common_dtype(), "{sig.name}",
320
+ {dtype_cases_str}
321
+ );
322
+ }}
323
+ REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name});
324
+
325
+ {sig.defn()} {{
326
+ {stub_sig.direct_call(sig.arguments())};
327
+ }}
328
+ """
329
+
330
+
331
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
332
+ #
333
+ # CPU STUFF
334
+ #
335
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
336
+
337
+
338
+ @dataclass(frozen=True)
339
+ class StubSignature:
340
+ g: NativeFunctionsGroup
341
+
342
+ @property
343
+ def name(self) -> str:
344
+ return f"{str(self.g.functional.func.name.name)}_stub"
345
+
346
+ @property
347
+ def kernel_name(self) -> str:
348
+ return f"{str(self.g.functional.func.name.name)}_kernel"
349
+
350
+ @property
351
+ def type_name(self) -> str:
352
+ return f"{str(self.g.functional.func.name.name)}_fn"
353
+
354
+ def arguments(self) -> list[Binding]:
355
+ return ufunc.stub_arguments(self.g)
356
+
357
+ def type(self) -> str:
358
+ cpp_args = self.arguments()
359
+ return f"void(*)(TensorIteratorBase&, {', '.join(a.type for a in cpp_args)})"
360
+
361
+ def dispatch_decl(self) -> str:
362
+ return f"DECLARE_DISPATCH({self.type_name}, {self.name})"
363
+
364
+ def dispatch_defn(self) -> str:
365
+ return f"DEFINE_DISPATCH({self.name})"
366
+
367
+ def kernel_defn(self) -> str:
368
+ return f"void {self.kernel_name}(TensorIteratorBase& iter, {', '.join(a.defn() for a in self.arguments())})"
369
+
370
+ def type_defn(self) -> str:
371
+ return f"using {self.type_name} = {self.type()}"
372
+
373
+ # must be called from context where this is TensorIteratorBase*
374
+ def call(self, ctx: Sequence[Binding]) -> str:
375
+ return f"{self.name}(device_type(), *this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})"
376
+
377
+ # used in CUDA to skip the unnecessary dynamic dispatch
378
+ def direct_call(self, ctx: Sequence[Binding]) -> str:
379
+ return f"{self.kernel_name}(*this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})"
380
+
381
+
382
+ @with_native_function
383
+ def compute_ufunc_cpu(g: NativeFunctionsGroup) -> str:
384
+ stub_sig = StubSignature(g)
385
+ sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CPU))
386
+
387
+ return f"""
388
+ {stub_sig.type_defn()};
389
+ {stub_sig.dispatch_decl()};
390
+ {stub_sig.dispatch_defn()};
391
+
392
+ {sig.defn()} {{
393
+ {stub_sig.call(sig.arguments())};
394
+ }}
395
+ """
396
+
397
+
398
+ def compute_ufunc_cpu_dtype_body(
399
+ g: NativeFunctionsGroup,
400
+ dtype: ScalarType,
401
+ inner_loops: dict[UfuncKey, UfuncSignature],
402
+ parent_ctx: Sequence[Binding],
403
+ ) -> str:
404
+ assert UfuncKey.CPUScalar in inner_loops, f"{dtype}, {inner_loops.keys()}"
405
+ assert inner_loops.keys() <= {UfuncKey.CPUScalar, UfuncKey.CPUVector}
406
+ scalar_loop = inner_loops[UfuncKey.CPUScalar]
407
+ vec_loop = None
408
+ if UfuncKey.CPUVector in inner_loops:
409
+ vec_loop = inner_loops[UfuncKey.CPUVector]
410
+
411
+ # NB: We DON'T use translate here, because translate is
412
+ # incapable of CSE'ing the scalar accesses in case it is also
413
+ # used by Vectorized; also, the unpacking here is very simple
414
+ # and only affects Scalar; everything else is implicitly captured
415
+ # by the lambda
416
+
417
+ # Setup scalar in scope
418
+ body = []
419
+ ctx = []
420
+ for b in parent_ctx:
421
+ if isinstance(b.argument, Argument) and b.argument.type != BaseType(
422
+ BaseTy.Scalar
423
+ ):
424
+ continue
425
+ body.append(f"auto _s_{b.name} = {b.name}.to<scalar_t>();")
426
+ ctx.append(Expr(f"_s_{b.name}", NamedCType(b.nctype.name, BaseCType(scalar_t))))
427
+ if vec_loop is not None:
428
+ for b in parent_ctx:
429
+ if isinstance(b.argument, Argument) and b.argument.type != BaseType(
430
+ BaseTy.Scalar
431
+ ):
432
+ continue
433
+ body.append(
434
+ f"auto _v_{b.name} = at::vec::Vectorized<scalar_t>(_s_{b.name});"
435
+ )
436
+ ctx.append(
437
+ Expr(
438
+ f"_v_{b.name}",
439
+ NamedCType(b.nctype.name, VectorizedCType(BaseCType(scalar_t))),
440
+ )
441
+ )
442
+
443
+ # Setup lambda signature
444
+ # NB: simplified version of ufunctor_arguments
445
+ scalar_bindings = []
446
+ vec_bindings = []
447
+ for a in g.functional.func.arguments.flat_non_out:
448
+ if not a.type.is_tensor_like():
449
+ continue
450
+ assert a.type == BaseType(BaseTy.Tensor)
451
+ scalar_bindings.append(
452
+ Binding(
453
+ name=a.name,
454
+ nctype=NamedCType(a.name, BaseCType(scalar_t)),
455
+ argument=a,
456
+ )
457
+ )
458
+ if vec_loop is not None:
459
+ vec_bindings.append(
460
+ Binding(
461
+ name=a.name,
462
+ nctype=NamedCType(a.name, VectorizedCType(BaseCType(scalar_t))),
463
+ argument=a,
464
+ )
465
+ )
466
+
467
+ def with_ctx(b: Sequence[Binding]) -> list[Expr | Binding]:
468
+ r: list[Expr | Binding] = []
469
+ r.extend(ctx)
470
+ r.extend(b)
471
+ return r
472
+
473
+ body_str = "\n".join(body)
474
+ if vec_loop is not None:
475
+ return f"""
476
+ {body_str}
477
+ cpu_kernel_vec(iter,
478
+ [=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }},
479
+ [=]({', '.join(b.decl() for b in vec_bindings)}) {{ return {vec_loop.call(with_ctx(vec_bindings))}; }}
480
+ );
481
+ """
482
+ else:
483
+ return f"""
484
+ {body_str}
485
+ cpu_kernel(iter,
486
+ [=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }}
487
+ );
488
+ """
489
+
490
+
491
+ @with_native_function
492
+ def compute_ufunc_cpu_kernel(g: NativeFunctionsGroup) -> str:
493
+ stub_sig = StubSignature(g)
494
+
495
+ # Reindex the ufunc by dtypes; processing generic/scalaronly as well
496
+ loops = g.out.ufunc_inner_loop
497
+ ufunc_sigs: dict[ScalarType, dict[UfuncKey, UfuncSignature]] = {}
498
+ for k in [UfuncKey.CPUScalar, UfuncKey.CPUVector]:
499
+ lks = []
500
+ # ORDER MATTERS: this specifies overriding precedence
501
+ if k in loops: # should happen rarely
502
+ lks.append(k)
503
+ if UfuncKey.ScalarOnly in loops and k is UfuncKey.CPUScalar:
504
+ lks.append(UfuncKey.ScalarOnly)
505
+ if UfuncKey.Generic in loops:
506
+ lks.append(UfuncKey.Generic)
507
+ # TODO: don't hardcode ufunc:: namespace here, should be centralized smh
508
+ for lk in lks:
509
+ for dtype in loops[lk].supported_dtypes:
510
+ compute_t: CType
511
+ if k is UfuncKey.CPUScalar:
512
+ compute_t = BaseCType(scalar_t)
513
+ elif k is UfuncKey.CPUVector:
514
+ compute_t = VectorizedCType(BaseCType(scalar_t))
515
+ else:
516
+ raise AssertionError
517
+ inner_ufunc_sigs = ufunc_sigs.setdefault(dtype, {})
518
+ if k not in inner_ufunc_sigs:
519
+ inner_ufunc_sigs[k] = UfuncSignature(
520
+ g, name=f"ufunc::{loops[lk].name}", compute_t=compute_t
521
+ )
522
+
523
+ # Build the conditionals
524
+ dtype_cases = []
525
+ for dtype, inner_ufunc_sigs in ufunc_sigs.items():
526
+ dtype_cases.append(
527
+ f"""
528
+ AT_DISPATCH_CASE(at::ScalarType::{dtype},
529
+ [&]() {{
530
+ {compute_ufunc_cpu_dtype_body(g, dtype, inner_ufunc_sigs, stub_sig.arguments())}
531
+ }}
532
+ )
533
+ """
534
+ )
535
+
536
+ dtype_cases_str = "\n".join(dtype_cases)
537
+ return f"""
538
+ namespace {{
539
+
540
+ {stub_sig.kernel_defn()} {{
541
+ AT_DISPATCH_SWITCH(iter.common_dtype(), "{stub_sig.name}",
542
+ {dtype_cases_str}
543
+ );
544
+ }}
545
+
546
+ }} // anonymous namespace
547
+
548
+ {stub_sig.type_defn()};
549
+ {stub_sig.dispatch_decl()};
550
+ REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name});
551
+ """
.venv/lib/python3.11/site-packages/torchgen/executorch/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/torchgen/executorch/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (192 Bytes). View file
 
.venv/lib/python3.11/site-packages/torchgen/executorch/__pycache__/model.cpython-311.pyc ADDED
Binary file (11.6 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/executorch/__pycache__/parse.cpython-311.pyc ADDED
Binary file (6.98 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/executorch/api/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/torchgen/executorch/api/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (196 Bytes). View file
 
.venv/lib/python3.11/site-packages/torchgen/executorch/api/__pycache__/custom_ops.cpython-311.pyc ADDED
Binary file (7.53 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/executorch/api/__pycache__/et_cpp.cpython-311.pyc ADDED
Binary file (15.3 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/executorch/api/__pycache__/unboxing.cpython-311.pyc ADDED
Binary file (10.6 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/executorch/api/custom_ops.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from collections import defaultdict
4
+ from dataclasses import dataclass
5
+ from typing import Sequence, TYPE_CHECKING
6
+
7
+ from torchgen import dest
8
+
9
+
10
+ # disable import sorting to avoid circular dependency.
11
+ from torchgen.api.types import DispatcherSignature # usort: skip
12
+ from torchgen.context import method_with_native_function
13
+ from torchgen.model import BaseTy, BaseType, DispatchKey, NativeFunction, Variant
14
+ from torchgen.utils import concatMap, Target
15
+
16
+
17
+ if TYPE_CHECKING:
18
+ from torchgen.executorch.model import ETKernelIndex
19
+ from torchgen.selective_build.selector import SelectiveBuilder
20
+
21
+
22
+ # Generates RegisterKernelStub.cpp, which provides placeholder kernels for custom operators. This will be used at
23
+ # model authoring side.
24
+ @dataclass(frozen=True)
25
+ class ComputeNativeFunctionStub:
26
+ @method_with_native_function
27
+ def __call__(self, f: NativeFunction) -> str | None:
28
+ if Variant.function not in f.variants:
29
+ return None
30
+
31
+ sig = DispatcherSignature.from_schema(
32
+ f.func, prefix=f"wrapper_CPU_{f.func.name.overload_name}_", symint=False
33
+ )
34
+ assert sig is not None
35
+ if len(f.func.returns) == 0:
36
+ ret_name = ""
37
+ elif len(f.func.returns) == 1:
38
+ if f.func.arguments.out:
39
+ ret_name = f.func.arguments.out[0].name
40
+ else:
41
+ ret_name = next(
42
+ (
43
+ a.name
44
+ for a in f.func.arguments.flat_non_out
45
+ if a.type == f.func.returns[0].type
46
+ ),
47
+ "",
48
+ )
49
+ if not ret_name:
50
+ # if return type is tensor
51
+ if f.func.returns[0].type == BaseType(BaseTy.Tensor):
52
+ # Returns an empty tensor
53
+ ret_name = "at::Tensor()"
54
+ else:
55
+ raise Exception( # noqa: TRY002
56
+ f"Can't handle this return type {f.func}"
57
+ ) # noqa: TRY002
58
+ elif len(f.func.arguments.out) == len(f.func.returns):
59
+ # Returns a tuple of out arguments
60
+ tensor_type = "at::Tensor &"
61
+ comma = ", "
62
+ ret_name = f"""::std::tuple<{comma.join([tensor_type] * len(f.func.returns))}>(
63
+ {comma.join([r.name for r in f.func.arguments.out])}
64
+ )"""
65
+ else:
66
+ assert all(
67
+ a.type == BaseType(BaseTy.Tensor) for a in f.func.returns
68
+ ), f"Only support tensor returns but got {f.func.returns}"
69
+ # Returns a tuple of empty tensors
70
+ tensor_type = "at::Tensor"
71
+ comma = ", "
72
+ ret_name = f"""::std::tuple<{comma.join([tensor_type] * len(f.func.returns))}>(
73
+ {comma.join(["at::Tensor()" for _ in f.func.returns])}
74
+ )"""
75
+ ret_str = f"return {ret_name};" if len(f.func.returns) > 0 else ""
76
+ return f"""
77
+ {sig.defn()} {{
78
+ {ret_str}
79
+ }}
80
+ """
81
+
82
+
83
+ def gen_custom_ops_registration(
84
+ *,
85
+ native_functions: Sequence[NativeFunction],
86
+ selector: SelectiveBuilder,
87
+ kernel_index: ETKernelIndex,
88
+ rocm: bool,
89
+ ) -> tuple[str, str]:
90
+ """
91
+ Generate custom ops registration code for dest.RegisterDispatchKey.
92
+
93
+ :param native_functions: a sequence of `NativeFunction`
94
+ :param selector: for selective build.
95
+ :param kernel_index: kernels for all the ops.
96
+ :param rocm: bool for dest.RegisterDispatchKey.
97
+ :return: generated C++ code to register custom operators into PyTorch
98
+ """
99
+
100
+ # convert kernel index to BackendIndex. This is because we can't handle ETKernelIndex yet.
101
+ # TODO larryliu: evaluate if this code is still needed. If yes let it handle ETKernelIndex.
102
+
103
+ dispatch_key = DispatchKey.CPU
104
+ backend_index = kernel_index._to_backend_index()
105
+ static_init_dispatch_registrations = ""
106
+ ns_grouped_native_functions: dict[str, list[NativeFunction]] = defaultdict(list)
107
+ for native_function in native_functions:
108
+ ns_grouped_native_functions[native_function.namespace].append(native_function)
109
+
110
+ for namespace, functions in ns_grouped_native_functions.items():
111
+ if len(functions) == 0:
112
+ continue
113
+ dispatch_registrations_body = "\n".join(
114
+ list(
115
+ concatMap(
116
+ dest.RegisterDispatchKey(
117
+ backend_index,
118
+ Target.REGISTRATION,
119
+ selector,
120
+ rocm=rocm,
121
+ symint=False,
122
+ class_method_name=None,
123
+ skip_dispatcher_op_registration=False,
124
+ ),
125
+ functions,
126
+ )
127
+ )
128
+ )
129
+ static_init_dispatch_registrations += f"""
130
+ TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
131
+ {dispatch_registrations_body}
132
+ }};"""
133
+ anonymous_definition = "\n".join(
134
+ list(
135
+ concatMap(
136
+ dest.RegisterDispatchKey(
137
+ backend_index,
138
+ Target.ANONYMOUS_DEFINITION,
139
+ selector,
140
+ rocm=rocm,
141
+ symint=False,
142
+ class_method_name=None,
143
+ skip_dispatcher_op_registration=False,
144
+ ),
145
+ native_functions,
146
+ )
147
+ )
148
+ )
149
+ return anonymous_definition, static_init_dispatch_registrations
.venv/lib/python3.11/site-packages/torchgen/executorch/api/et_cpp.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Sequence
4
+
5
+ from torchgen import local
6
+ from torchgen.api.types import (
7
+ ArgName,
8
+ BaseCType,
9
+ Binding,
10
+ ConstRefCType,
11
+ CType,
12
+ MutRefCType,
13
+ NamedCType,
14
+ SpecialArgName,
15
+ TupleCType,
16
+ VectorCType,
17
+ voidT,
18
+ )
19
+ from torchgen.executorch.api.types import (
20
+ ArrayRefCType,
21
+ BaseTypeToCppMapping,
22
+ OptionalCType,
23
+ scalarT,
24
+ tensorListT,
25
+ tensorT,
26
+ )
27
+ from torchgen.model import (
28
+ Argument,
29
+ Arguments,
30
+ BaseTy,
31
+ BaseType,
32
+ ListType,
33
+ NativeFunction,
34
+ OptionalType,
35
+ Return,
36
+ SelfArgument,
37
+ TensorOptionsArguments,
38
+ Type,
39
+ )
40
+ from torchgen.utils import assert_never
41
+
42
+
43
+ """
44
+ This file describes the translation of JIT schema to the public C++ API, which is what people use when they call
45
+ functions like at::add. It also serves as a native function API, which is the signature of kernels,
46
+ since in Executorch CppSignature is the same as NativeSignature.
47
+
48
+ Difference between this file and torchgen.api.cpp.py:
49
+
50
+ - Executorch doesn't support TensorOptions, however in this file we still keep the logic here to be compatible with
51
+ torchgen.api.cpp, so that we can do stuff like ATen mode (running ATen kernels in Executorch).
52
+
53
+ - Executorch doesn't support Dimname.
54
+
55
+ - Executorch runtime doesn't support SymInt, will treat it as int.
56
+ """
57
+
58
+
59
+ # Translation of "value types" in JIT schema to C++ API type. Value
60
+ # types look the same no matter if they are argument types or return
61
+ # types. Returns None if the type in question is not a value type.
62
+ def valuetype_type(
63
+ t: Type,
64
+ *,
65
+ binds: ArgName,
66
+ remove_non_owning_ref_types: bool = False,
67
+ ) -> NamedCType | None:
68
+ if isinstance(t, BaseType):
69
+ if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
70
+ return None
71
+ # For SymInt we simply treat it as int.
72
+ elif str(t) == "SymInt":
73
+ return NamedCType(binds, BaseCType(BaseTypeToCppMapping[BaseTy.int]))
74
+ if remove_non_owning_ref_types:
75
+ if t.name == BaseTy.str:
76
+ raise AssertionError(
77
+ "string ref->value conversion: not implemented yet"
78
+ )
79
+ # All other BaseType currently map directly to BaseCppTypes.
80
+ return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name]))
81
+ elif isinstance(t, OptionalType):
82
+ elem = valuetype_type(t.elem, binds=binds)
83
+ if elem is None:
84
+ return None
85
+ return NamedCType(binds, OptionalCType(elem.type))
86
+ elif isinstance(t, ListType):
87
+ if str(t.elem) == "bool":
88
+ assert t.size is not None
89
+ return NamedCType(
90
+ binds, ArrayRefCType(BaseCType(BaseTypeToCppMapping[BaseTy.bool]))
91
+ )
92
+ else:
93
+ return None
94
+ else:
95
+ raise AssertionError(f"unrecognized type {repr(t)}")
96
+
97
+
98
+ # Translation of types occurring in JIT arguments to a C++ argument type.
99
+ # If remove_non_owning_ref_types is set, we'll guarantee that the outputed CType is not a non-owning reference type.
100
+ # For example, we'll return std::vector<int> instead of IntArrayRef.
101
+ # See Note [translation from C++ reference to value types]
102
+ def argumenttype_type(
103
+ t: Type,
104
+ *,
105
+ mutable: bool,
106
+ binds: ArgName,
107
+ remove_non_owning_ref_types: bool = False,
108
+ ) -> NamedCType:
109
+ # If it's a value type, do the value type translation
110
+ r = valuetype_type(
111
+ t,
112
+ binds=binds,
113
+ remove_non_owning_ref_types=remove_non_owning_ref_types,
114
+ )
115
+ if r is not None:
116
+ return r
117
+ if isinstance(t, BaseType):
118
+ if t.name == BaseTy.Tensor:
119
+ if mutable and not local.use_const_ref_for_mutable_tensors():
120
+ return NamedCType(binds, MutRefCType(BaseCType(tensorT)))
121
+ else:
122
+ return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
123
+ elif t.name == BaseTy.Scalar:
124
+ return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
125
+ else:
126
+ raise AssertionError(f"base type should have been value type {t}")
127
+ elif isinstance(t, OptionalType):
128
+ if str(t.elem) == "Tensor":
129
+ if mutable and not local.use_const_ref_for_mutable_tensors():
130
+ return NamedCType(
131
+ binds, MutRefCType(BaseCType(tensorT))
132
+ ) # TODO: fix this discrepancy
133
+ else:
134
+ return NamedCType(
135
+ binds, ConstRefCType(OptionalCType(BaseCType(tensorT)))
136
+ )
137
+ elif str(t.elem) == "Scalar":
138
+ return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
139
+ elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
140
+ return NamedCType(binds, OptionalCType(elem.type))
141
+ elif isinstance(t, ListType):
142
+ # TODO: keeping these special cases for Tensor[] and Tensor?[] so that we can hookup with ATen kernels.
143
+ if str(t.elem) == "Tensor":
144
+ return NamedCType(binds, BaseCType(tensorListT))
145
+ elif str(t.elem) == "Dimname":
146
+ raise NotImplementedError("Executorch doesn't support Dimname")
147
+ elif str(t.elem) == "Tensor?":
148
+ return NamedCType(binds, ArrayRefCType(OptionalCType(BaseCType(tensorT))))
149
+ elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
150
+ return NamedCType(binds, ArrayRefCType(elem.type))
151
+ else:
152
+ raise AssertionError(f"unrecognized type {repr(t)}")
153
+
154
+
155
+ # Translate a JIT argument into its C++ type
156
+ def argument_type(a: Argument, *, binds: ArgName) -> NamedCType:
157
+ return argumenttype_type(a.type, mutable=a.is_write, binds=binds)
158
+
159
+
160
+ # Translation of a (non-multi) return type from JIT to C++
161
+ # N.B: returntype_type returns a CType, not a NamedCType.
162
+ # This is mostly because of the mismatch between return types and return names.
163
+ # e.g. a function with a return type of 'void' has 0 return names,
164
+ # and a function with a return type of 'std::tuple' has >1 return name.
165
+ def returntype_type(t: Type, *, mutable: bool) -> CType:
166
+ # placeholder is ignored
167
+ r = valuetype_type(t, binds="__placeholder__")
168
+ if r is not None:
169
+ return r.type
170
+
171
+ if isinstance(t, BaseType):
172
+ if t.name == BaseTy.Tensor:
173
+ if mutable:
174
+ if local.use_const_ref_for_mutable_tensors():
175
+ return ConstRefCType(BaseCType(tensorT))
176
+ else:
177
+ return MutRefCType(BaseCType(tensorT))
178
+ else:
179
+ # Note [Tensor Copy Returns]
180
+ # Currently, we use "Argument.is_write" to determine
181
+ # whether or not Tensor return types should be copies or references.
182
+ # If that ever changes, take a look at other locations of this note!
183
+ return BaseCType(tensorT)
184
+ elif t.name == BaseTy.Scalar:
185
+ return BaseCType(scalarT)
186
+ elif isinstance(t, ListType):
187
+ assert (
188
+ not mutable
189
+ ), "Native functions should never return a mutable tensor list. They should return void."
190
+ elem = returntype_type(t.elem, mutable=False)
191
+ assert t.size is None, f"fixed size list returns not supported: {t}"
192
+ return VectorCType(elem)
193
+
194
+ raise AssertionError(f"unrecognized return type {t}")
195
+
196
+
197
+ # Translation of a single return to its C++ type
198
+ def return_type(r: Return) -> CType:
199
+ return returntype_type(r.type, mutable=r.is_write)
200
+
201
+
202
+ # Translation of a full (possibly multi) return from JIT to its C++ type
203
+ def returns_type(rs: Sequence[Return]) -> CType:
204
+ if len(rs) == 0:
205
+ return BaseCType(voidT)
206
+ elif len(rs) == 1:
207
+ return return_type(rs[0])
208
+ else:
209
+ return TupleCType([return_type(r) for r in rs])
210
+
211
+
212
+ def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]:
213
+ returns: list[str] = []
214
+ for i, r in enumerate(f.func.returns):
215
+ # If we have an inplace function, the return argument is
216
+ # implicitly named self.
217
+ # TODO: Consider incorporating this into the data model
218
+ if f.func.name.name.inplace:
219
+ assert i == 0, "illegal inplace function with multiple returns"
220
+ name = "self"
221
+ # If we are out function, the name is the name of the
222
+ # corresponding output function (r.name will get recorded
223
+ # in field_name later.)
224
+ elif f.func.is_out_fn():
225
+ name = f.func.arguments.out[i].name
226
+ # If the return argument is explicitly named...
227
+ elif r.name:
228
+ name_conflict = any(
229
+ r.name == a.name for a in f.func.schema_order_arguments()
230
+ )
231
+ if name_conflict and not f.func.is_out_fn():
232
+ name = f"{r.name}_return"
233
+ else:
234
+ name = r.name
235
+ # If there is no explicit name and no fallback name was passed in, we just name the output result,
236
+ # unless it's a multi-return, in which case it's result0,
237
+ # result1, etc (zero-indexed)
238
+ else:
239
+ name = fallback_name if len(f.func.returns) == 1 else f"{fallback_name}{i}"
240
+ returns.append(name)
241
+ return returns
242
+
243
+
244
+ JIT_TO_CPP_DEFAULT = {
245
+ "False": "false",
246
+ "True": "true",
247
+ "None": "torch::executorch::nullopt", # UGH this one is type directed
248
+ "[]": "{}",
249
+ "contiguous_format": "torch::executorch::MemoryFormat::Contiguous",
250
+ "long": "torch::executorch::kLong",
251
+ }
252
+
253
+
254
+ # Convert a JIT default into C++ expression representing the default
255
+ def default_expr(d: str, t: Type) -> str:
256
+ if d == "None" and str(t) == "Tensor?":
257
+ return "{}"
258
+ if isinstance(t, BaseType) and t.name is BaseTy.str:
259
+ # Schema allows single quotes but C++ needs double
260
+ if len(d) >= 2 and d[0] == "'" and d[-1] == "'":
261
+ s = ""
262
+ i = 1
263
+ while i + 1 < len(d):
264
+ if d[i] != "\\":
265
+ if d[i] == '"':
266
+ s += '\\"'
267
+ else:
268
+ s += d[i]
269
+ i += 1
270
+ else:
271
+ if d[i + 1] == "'":
272
+ s += "'"
273
+ else:
274
+ s += d[i : i + 2]
275
+ i += 2
276
+
277
+ return f'"{s}"'
278
+
279
+ if isinstance(t, OptionalType):
280
+ if d == "None":
281
+ return "torch::executor::nullopt"
282
+
283
+ return default_expr(d, t.elem)
284
+
285
+ if isinstance(t, ListType):
286
+ if d.startswith("[") and d.endswith("]"):
287
+ return "{" + d[1:-1] + "}"
288
+ elif t.size is None:
289
+ # NOTE: Sized lists can have scalar defaults
290
+ raise ValueError(f"Expected a list default '[...]' but found: '{d}'")
291
+
292
+ return JIT_TO_CPP_DEFAULT.get(d, d)
293
+
294
+
295
+ # Convert an argument into its C++ API form
296
+
297
+
298
+ def argument(
299
+ a: Argument | TensorOptionsArguments | SelfArgument,
300
+ *,
301
+ cpp_no_default_args: set[str],
302
+ method: bool,
303
+ faithful: bool,
304
+ has_tensor_options: bool,
305
+ ) -> list[Binding]:
306
+ def sub_argument(
307
+ a: Argument | TensorOptionsArguments | SelfArgument,
308
+ ) -> list[Binding]:
309
+ return argument(
310
+ a,
311
+ cpp_no_default_args=cpp_no_default_args,
312
+ method=method,
313
+ faithful=faithful,
314
+ has_tensor_options=has_tensor_options,
315
+ )
316
+
317
+ if isinstance(a, Argument):
318
+ binds: ArgName
319
+ if a.name == "memory_format" and has_tensor_options:
320
+ binds = SpecialArgName.possibly_redundant_memory_format
321
+ else:
322
+ binds = a.name
323
+ default: str | None = None
324
+ if a.name not in cpp_no_default_args and a.default is not None:
325
+ default = default_expr(a.default, a.type)
326
+ return [
327
+ Binding(
328
+ nctype=argument_type(a, binds=binds),
329
+ name=a.name,
330
+ default=default,
331
+ argument=a,
332
+ )
333
+ ]
334
+ elif isinstance(a, TensorOptionsArguments):
335
+ raise NotImplementedError("Need to implement type resolution for TensorOptions")
336
+ elif isinstance(a, SelfArgument):
337
+ if method:
338
+ # Caller is responsible for installing implicit this in context!
339
+ return []
340
+ else:
341
+ return sub_argument(a.argument)
342
+ else:
343
+ assert_never(a)
344
+
345
+
346
+ def arguments(
347
+ arguments: Arguments,
348
+ *,
349
+ faithful: bool,
350
+ method: bool,
351
+ cpp_no_default_args: set[str],
352
+ ) -> list[Binding]:
353
+ args: list[Argument | TensorOptionsArguments | SelfArgument] = []
354
+ if faithful:
355
+ args.extend(arguments.non_out)
356
+ args.extend(arguments.out)
357
+ else:
358
+ args.extend(arguments.out)
359
+ args.extend(arguments.non_out)
360
+ return [
361
+ r.no_default() if faithful else r
362
+ for a in args
363
+ for r in argument(
364
+ a,
365
+ faithful=faithful,
366
+ method=method,
367
+ has_tensor_options=arguments.tensor_options is not None,
368
+ cpp_no_default_args=cpp_no_default_args,
369
+ )
370
+ ]
.venv/lib/python3.11/site-packages/torchgen/executorch/api/types/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from torchgen.executorch.api.types.types import *
2
+
3
+
4
+ from torchgen.executorch.api.types.signatures import * # usort: skip
.venv/lib/python3.11/site-packages/torchgen/executorch/api/types/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (326 Bytes). View file
 
.venv/lib/python3.11/site-packages/torchgen/executorch/api/types/__pycache__/signatures.cpython-311.pyc ADDED
Binary file (4.76 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/executorch/api/types/__pycache__/types.cpython-311.pyc ADDED
Binary file (4.29 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/executorch/api/types/signatures.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING
5
+
6
+ import torchgen.api.cpp as aten_cpp
7
+ from torchgen.executorch.api.types.types import contextArg
8
+
9
+
10
+ if TYPE_CHECKING:
11
+ from torchgen.api.types import Binding, CType
12
+ from torchgen.model import FunctionSchema, NativeFunction
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class ExecutorchCppSignature:
17
+ """
18
+ This signature is merely a CppSignature with Executorch types (optionally
19
+ contains KernelRuntimeContext as well). The inline definition of
20
+ CppSignature is generated in Functions.h and it's used by unboxing
21
+ functions.
22
+ """
23
+
24
+ # The schema this signature is derived from
25
+ func: FunctionSchema
26
+
27
+ # The set of C++ arguments which should not have defaults applied to them
28
+ cpp_no_default_args: set[str]
29
+
30
+ # Allows you to prepend an arbitrary prefix to the signature name.
31
+ # This is useful for parts of the codegen that generate wrappers around kernels,
32
+ # and need to avoid naming collisions.
33
+ prefix: str = ""
34
+
35
+ def arguments(self, *, include_context: bool = True) -> list[Binding]:
36
+ return ([contextArg] if include_context else []) + et_cpp.arguments(
37
+ self.func.arguments,
38
+ faithful=True, # always faithful, out argument at the end
39
+ method=False, # method not supported
40
+ cpp_no_default_args=self.cpp_no_default_args,
41
+ )
42
+
43
+ def name(self) -> str:
44
+ return self.prefix + aten_cpp.name(
45
+ self.func,
46
+ faithful_name_for_out_overloads=True,
47
+ )
48
+
49
+ def decl(self, name: str | None = None, *, include_context: bool = True) -> str:
50
+ args_str = ", ".join(
51
+ a.decl() for a in self.arguments(include_context=include_context)
52
+ )
53
+ if name is None:
54
+ name = self.name()
55
+ return f"{self.returns_type().cpp_type()} {name}({args_str})"
56
+
57
+ def defn(self, name: str | None = None) -> str:
58
+ args = [a.defn() for a in self.arguments()]
59
+ args_str = ", ".join(args)
60
+ if name is None:
61
+ name = self.name()
62
+ return f"{self.returns_type().cpp_type()} {name}({args_str})"
63
+
64
+ def returns_type(self) -> CType:
65
+ return et_cpp.returns_type(self.func.returns)
66
+
67
+ @staticmethod
68
+ def from_native_function(
69
+ f: NativeFunction, *, prefix: str = ""
70
+ ) -> ExecutorchCppSignature:
71
+ return ExecutorchCppSignature(
72
+ func=f.func, prefix=prefix, cpp_no_default_args=f.cpp_no_default_args
73
+ )
74
+
75
+
76
+ from torchgen.executorch.api import et_cpp
.venv/lib/python3.11/site-packages/torchgen/executorch/api/types/types.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ from torchgen.api.types import (
6
+ BaseCppType,
7
+ BaseCType,
8
+ Binding,
9
+ boolT,
10
+ CType,
11
+ doubleT,
12
+ Expr,
13
+ longT,
14
+ MutRefCType,
15
+ NamedCType,
16
+ )
17
+ from torchgen.model import BaseTy
18
+
19
+
20
+ halfT = BaseCppType("torch::executor", "Half")
21
+ bfloat16T = BaseCppType("torch::executor", "BFloat16")
22
+ stringT = BaseCppType("torch::executor", "string_view")
23
+ scalarTypeT = BaseCppType("torch::executor", "ScalarType")
24
+ tensorT = BaseCppType("torch::executor", "Tensor")
25
+ tensorListT = BaseCppType("torch::executor", "TensorList")
26
+ scalarT = BaseCppType("torch::executor", "Scalar")
27
+ memoryFormatT = BaseCppType("torch::executor", "MemoryFormat")
28
+ intArrayRefT = BaseCppType("torch::executor", "IntArrayRef")
29
+ optionalT = BaseCppType("torch::executor", "optional")
30
+ contextT = BaseCppType("torch::executor", "KernelRuntimeContext")
31
+
32
+ contextExpr = Expr(
33
+ expr="context",
34
+ type=NamedCType(name="context", type=MutRefCType(BaseCType(contextT))),
35
+ )
36
+
37
+ contextArg = Binding(
38
+ name="context",
39
+ nctype=contextExpr.type,
40
+ argument=None, # type: ignore[arg-type]
41
+ default=None,
42
+ )
43
+
44
+ BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = {
45
+ BaseTy.int: longT,
46
+ BaseTy.float: doubleT,
47
+ BaseTy.bool: boolT,
48
+ BaseTy.str: stringT,
49
+ BaseTy.ScalarType: scalarTypeT,
50
+ BaseTy.Tensor: tensorT,
51
+ BaseTy.Scalar: scalarT,
52
+ BaseTy.MemoryFormat: memoryFormatT,
53
+ }
54
+
55
+
56
+ @dataclass(frozen=True)
57
+ class OptionalCType(CType):
58
+ elem: CType
59
+
60
+ def cpp_type(self, *, strip_ref: bool = False) -> str:
61
+ # Do not pass `strip_ref` recursively.
62
+ return f"torch::executor::optional<{self.elem.cpp_type()}>"
63
+
64
+ def cpp_type_registration_declarations(self) -> str:
65
+ return f"torch::executor::optional<{self.elem.cpp_type_registration_declarations()}>"
66
+
67
+ def remove_const_ref(self) -> CType:
68
+ return OptionalCType(self.elem.remove_const_ref())
69
+
70
+
71
+ @dataclass(frozen=True)
72
+ class ArrayRefCType(CType):
73
+ elem: CType
74
+
75
+ def cpp_type(self, *, strip_ref: bool = False) -> str:
76
+ # Do not pass `strip_ref` recursively.
77
+ return f"torch::executor::ArrayRef<{self.elem.cpp_type()}>"
78
+
79
+ def cpp_type_registration_declarations(self) -> str:
80
+ return f"torch::executor::ArrayRef<{self.elem.cpp_type_registration_declarations()}>"
81
+
82
+ def remove_const_ref(self) -> CType:
83
+ return ArrayRefCType(self.elem.remove_const_ref())
.venv/lib/python3.11/site-packages/torchgen/executorch/api/unboxing.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Callable, Sequence, TYPE_CHECKING
5
+
6
+ from torchgen.model import (
7
+ Argument,
8
+ BaseTy,
9
+ BaseType,
10
+ ListType,
11
+ NativeFunction,
12
+ OptionalType,
13
+ Type,
14
+ )
15
+
16
+
17
+ if TYPE_CHECKING:
18
+ from torchgen.api.types import Binding, CType, NamedCType
19
+
20
+
21
+ connector = "\n\t"
22
+
23
+
24
+ # Return unboxing function name for a NativeFunction
25
+ def name(f: NativeFunction) -> str:
26
+ return f.func.name.unambiguous_name()
27
+
28
+
29
+ @dataclass(frozen=True)
30
+ class Unboxing:
31
+ """
32
+ Takes a sequence of Bindings and unbox EValues to these Bindings. Return generated code that performs correct unboxing.
33
+ A sample generated code:
34
+ // aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
35
+ void mul_out(EValue** stack) {
36
+ EValue& self = *stack[0];
37
+ EValue& other = *stack[1];
38
+ EValue& out = *stack[2];
39
+ const torch::executor::Tensor & self_base = self.to<torch::executor::Tensor>();
40
+ const torch::executor::Tensor & other_base = other.to<torch::executor::Tensor>();
41
+ torch::executor::Tensor & out_base = out.to<torch::executor::Tensor>();
42
+
43
+ EXECUTORCH_SCOPE_PROF("native_call_mul.out");
44
+ torch::executor::mul_outf(self_base, other_base, out_base);
45
+
46
+
47
+ }
48
+ """
49
+
50
+ # this is a callable that converts a JIT argument, into its C++ type.
51
+ # Translates (type, mutability, binds) to NamedCType. E.g., torchgen.api.cpp.argumenttype_type.
52
+ argument_type_gen: Callable[
53
+ ...,
54
+ NamedCType,
55
+ ]
56
+
57
+ # Convert all the arguments in a NativeFunction to C++ code
58
+ def convert_arguments(
59
+ self, args: Sequence[Binding]
60
+ ) -> tuple[list[Binding], list[str]]:
61
+ code_list = [f"EValue& {args[i].name} = *stack[{i}];" for i in range(len(args))]
62
+ binding_list = []
63
+ for arg in args:
64
+ # expecting only Argument
65
+ if not isinstance(arg.argument, Argument):
66
+ raise Exception( # noqa: TRY002
67
+ f"Unexpected argument type, expecting `Argument` but got {arg}"
68
+ )
69
+ argument: Argument = arg.argument
70
+ unboxed_name, _, code, decl = self.argumenttype_evalue_convert(
71
+ argument.type, argument.name, mutable=argument.is_write
72
+ )
73
+ code_list.extend(decl)
74
+ code_list.extend(code)
75
+ binding_list.append(arg.with_name(unboxed_name))
76
+ return binding_list, code_list
77
+
78
+ def argumenttype_evalue_convert(
79
+ self, t: Type, arg_name: str, *, mutable: bool = False
80
+ ) -> tuple[str, CType, list[str], list[str]]:
81
+ """
82
+ Takes in the type, name and mutability corresponding to an argument, and generates a tuple of:
83
+ (1) the C++ code necessary to unbox the argument
84
+ (2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType
85
+ :param t: a `Type` of an argument
86
+ :param arg_name: argument name
87
+ :param mutable: boolean for whether this argument type is mutable
88
+ :return: unboxed result
89
+ """
90
+ ctype = self.argument_type_gen(t, mutable=mutable, binds=arg_name).type
91
+
92
+ if isinstance(t, BaseType):
93
+ out_name = f"{arg_name}_base"
94
+ code, decl = self._gen_code_base_type(
95
+ arg_name=arg_name, out_name=out_name, ctype=ctype
96
+ )
97
+ elif isinstance(t, OptionalType):
98
+ out_name = f"{arg_name}_opt_out"
99
+ code, decl = self._gen_code_optional_type(
100
+ arg_name=arg_name, out_name=out_name, t=t, ctype=ctype
101
+ )
102
+ elif isinstance(t, ListType):
103
+ out_name = f"{arg_name}_list_out"
104
+ code, decl = self._gen_code_list_type(
105
+ arg_name=arg_name, out_name=out_name, t=t, ctype=ctype
106
+ )
107
+ else:
108
+ raise Exception( # noqa: TRY002
109
+ f"Cannot handle type {t}. arg_name: {arg_name}"
110
+ ) # noqa: TRY002
111
+ return out_name, ctype, code, decl
112
+
113
+ def _gen_code_base_type(
114
+ self, arg_name: str, out_name: str, ctype: CType
115
+ ) -> tuple[list[str], list[str]]:
116
+ return [
117
+ f"{ctype.cpp_type()} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
118
+ ], []
119
+
120
+ def _gen_code_optional_type(
121
+ self, arg_name: str, out_name: str, t: OptionalType, ctype: CType
122
+ ) -> tuple[list[str], list[str]]:
123
+ in_name = f"{arg_name}_opt_in"
124
+ res_name, base_type, res_code, decl = self.argumenttype_evalue_convert(
125
+ t.elem, in_name
126
+ )
127
+ return (
128
+ f"""
129
+ auto {out_name} = {arg_name}.toOptional<{base_type.cpp_type(strip_ref=True)}>();
130
+ """.split(
131
+ "\n"
132
+ ),
133
+ decl,
134
+ )
135
+
136
+ def _gen_code_list_type(
137
+ self, arg_name: str, out_name: str, t: ListType, ctype: CType
138
+ ) -> tuple[list[str], list[str]]:
139
+ in_name = f"{arg_name}_list_in"
140
+ elem_name = f"{arg_name}_elem"
141
+ code = []
142
+ res_name, res_ctype, res_code, decl = self.argumenttype_evalue_convert(
143
+ t.elem, elem_name
144
+ )
145
+
146
+ if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.Tensor:
147
+ code.extend(
148
+ f"""
149
+ auto {out_name} = {arg_name}.toTensorList();
150
+ """.split(
151
+ "\n"
152
+ )
153
+ )
154
+ elif isinstance(t.elem, BaseType) and (
155
+ t.elem.name == BaseTy.int or t.elem.name == BaseTy.SymInt
156
+ ):
157
+ code.extend(
158
+ f"""
159
+ auto {out_name} = {arg_name}.toIntList();
160
+ """.split(
161
+ "\n"
162
+ )
163
+ )
164
+ elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.float:
165
+ code.extend(
166
+ f"""
167
+ auto {out_name} = {arg_name}.toDoubleList();
168
+ """.split(
169
+ "\n"
170
+ )
171
+ )
172
+ elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool:
173
+ # handle list type with size, e.g., bool[4]
174
+ code.extend(
175
+ f"""
176
+ #ifdef USE_ATEN_LIB
177
+ std::array<bool, {t.size}> {out_name};
178
+ auto {in_name} = {arg_name}.toBoolList();
179
+ size_t _i = 0;
180
+ for (auto {elem_name}: {in_name}) {{
181
+ {out_name}[_i++] = {elem_name};
182
+ }}
183
+ #else
184
+ auto {out_name} = {arg_name}.toBoolList();
185
+ #endif
186
+ """.split(
187
+ "\n"
188
+ )
189
+ )
190
+ # pytorch codegen:
191
+ # we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional<at::Tensor>>
192
+ elif (
193
+ isinstance(t.elem, OptionalType)
194
+ and isinstance(t.elem.elem, BaseType)
195
+ and t.elem.elem.name == BaseTy.Tensor
196
+ ):
197
+ code.extend(
198
+ f"""
199
+ #ifdef USE_ATEN_LIB
200
+ auto {in_name} = {arg_name}.toListOptionalTensor();
201
+ c10::List<::std::optional<at::Tensor>> {out_name};
202
+ for (auto {elem_name}: {in_name}) {{
203
+ {out_name}.push_back({elem_name});
204
+ }}
205
+ #else
206
+ auto {out_name} = {arg_name}.toListOptionalTensor();
207
+ #endif
208
+ """.split(
209
+ "\n"
210
+ )
211
+ )
212
+ else:
213
+ # use ArrayRef as default.
214
+ vec_name = arg_name + "_vec"
215
+ # need to bring vector instantiation out of scope so that ArrayRef has valid data
216
+ decl.append(
217
+ f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};"
218
+ )
219
+ code.extend(
220
+ f"""
221
+ for (EValue {elem_name}: {in_name}) {{
222
+ {connector.join(res_code)}
223
+ {vec_name}.push_back({res_name});
224
+ }}
225
+ {ctype.cpp_type(strip_ref=True)} {out_name}({vec_name});
226
+ """.split(
227
+ "\n"
228
+ )
229
+ )
230
+ return code, decl
.venv/lib/python3.11/site-packages/torchgen/executorch/model.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Represents all kernels used by an Executorch model.
2
+ # It maintains a Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]] structure.
3
+
4
+ from __future__ import annotations
5
+
6
+ import itertools
7
+ from collections import defaultdict, namedtuple
8
+ from dataclasses import dataclass
9
+ from enum import IntEnum
10
+
11
+ from torchgen.model import (
12
+ BackendIndex,
13
+ BackendMetadata,
14
+ DispatchKey,
15
+ NativeFunction,
16
+ NativeFunctionsGroup,
17
+ OperatorName,
18
+ )
19
+ from torchgen.utils import assert_never
20
+
21
+
22
+ KERNEL_KEY_VERSION = 1
23
+
24
+
25
+ # TODO: Duplicated Subset from codegen.tool.gen_oplist, remove declaration in codegen
26
+ class ScalarType(IntEnum):
27
+ Byte = 0
28
+ Char = 1
29
+ Short = 2
30
+ Int = 3
31
+ Long = 4
32
+ Float = 6
33
+ Double = 7
34
+ Bool = 11
35
+
36
+
37
+ ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "kernel_index"])
38
+
39
+
40
+ @dataclass(frozen=True)
41
+ class ETKernelKeyOpArgMeta:
42
+ arg_name: str
43
+ dtype: str
44
+ # The order of the dimensions if entry is a Tensor
45
+ dim_order: tuple[int, ...]
46
+
47
+ def to_native_string(self) -> str:
48
+ dtype_str = ScalarType[self.dtype].value
49
+ dim_str = str(self.dim_order)[1:-1].replace(" ", "")
50
+ return f"{dtype_str};{dim_str}"
51
+
52
+
53
+ @dataclass(frozen=True)
54
+ class ETKernelKey:
55
+ # Field undefined is default = True
56
+ arg_meta: tuple[ETKernelKeyOpArgMeta, ...] = ()
57
+
58
+ # Indicator for this kernel being used as a catch all
59
+ default: bool = False
60
+
61
+ version: int = KERNEL_KEY_VERSION
62
+
63
+ @staticmethod
64
+ def gen_from_yaml(
65
+ args: dict[str, tuple[str, str]],
66
+ type_alias_map: dict[str, list[str]], # TODO: Support unwrapped str val
67
+ dim_order_alias_map: dict[str, list[int]],
68
+ ) -> list[ETKernelKey]:
69
+ """Generate ETKernelKeys from arg kernel specs
70
+ Multiple ETKernelKeys are returned due to dtype permutations from utilizing
71
+ type_alias_map (actualizing each potential type permutation as a KernelKey)
72
+
73
+ Args:
74
+ args: Mapping from argument name to kernel specs
75
+ Kernel specs are a tuple of (dtype, dim_order).
76
+ Currently tuple entries must be aliased via the alias map arguments
77
+ type_alias_map: Mapping from type alias to potential type enums
78
+ i.e { T0 : [Double, Int] } means T0 can be either Double or Int
79
+ Used for lookup by args
80
+ dim_order_alias_map: Mapping from alias to a list of dimension orders
81
+ Used for lookup by args
82
+ """
83
+ # Cast to dim order to int
84
+ dim_order_alias_map = {
85
+ k: [int(alias) for alias in v] for k, v in dim_order_alias_map.items()
86
+ }
87
+ kernel_keys = []
88
+
89
+ # Get all used Dtype Alias
90
+ dtype_alias_used = set()
91
+ for type_alias, dim_order in args.values():
92
+ # Enforce usage of alias initially
93
+ # TODO: Support inlined arguments
94
+ assert type_alias in type_alias_map, "Undefined type alias: " + str(
95
+ type_alias
96
+ )
97
+ assert (
98
+ dim_order in dim_order_alias_map
99
+ ), "Undefined dim_order alias: " + str(dim_order)
100
+ dtype_alias_used.add(type_alias)
101
+
102
+ # Generate all permutations of dtype alias values
103
+ alias_dtypes = [
104
+ [(alias, dtype) for dtype in type_alias_map[alias]]
105
+ for alias in dtype_alias_used
106
+ ]
107
+ alias_permutations = [
108
+ dict(permutation) for permutation in list(itertools.product(*alias_dtypes))
109
+ ]
110
+
111
+ # Using each alias value permutation, generate kernel keys
112
+ op_arg_cache = {}
113
+ for permutation in alias_permutations:
114
+ arg_list = []
115
+ for arg_name, arg_spec in args.items():
116
+ dtype = permutation[arg_spec[0]]
117
+ dim_order = dim_order_alias_map[arg_spec[1]] # type: ignore[assignment]
118
+ if (
119
+ cache_key := (arg_name, dtype, tuple(dim_order))
120
+ ) not in op_arg_cache:
121
+ op_arg_cache[cache_key] = ETKernelKeyOpArgMeta(*cache_key) # type: ignore[arg-type]
122
+
123
+ arg_list.append(op_arg_cache[cache_key])
124
+ kernel_keys.append(ETKernelKey(tuple(arg_list)))
125
+
126
+ return kernel_keys
127
+
128
+ def to_native_string(self) -> str:
129
+ if self.default:
130
+ return "default"
131
+ return (
132
+ "v"
133
+ + str(KERNEL_KEY_VERSION)
134
+ + "/"
135
+ + "|".join([arg.to_native_string() for arg in self.arg_meta])
136
+ )
137
+
138
+
139
+ @dataclass(frozen=True)
140
+ class ETKernelIndex:
141
+ index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]]
142
+
143
+ def has_kernels(self, g: NativeFunction | NativeFunctionsGroup) -> bool:
144
+ m = self.get_kernels(g)
145
+ return m is not None
146
+
147
+ def get_kernels(
148
+ self, g: NativeFunction | NativeFunctionsGroup
149
+ ) -> dict[ETKernelKey, BackendMetadata]:
150
+ if isinstance(g, NativeFunction):
151
+ f = g
152
+ elif isinstance(g, NativeFunctionsGroup):
153
+ f = g.functional
154
+ else:
155
+ assert_never(g)
156
+ if f.func.name not in self.index:
157
+ return {}
158
+ return self.index[f.func.name]
159
+
160
+ @staticmethod
161
+ def grow_from_backend_indices(
162
+ kernel_index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]],
163
+ backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
164
+ ) -> None:
165
+ for dk in backend_indices:
166
+ index = backend_indices[dk]
167
+ for op, backend_metadata in index.items():
168
+ if op in kernel_index:
169
+ kernel_index[op][ETKernelKey(default=True)] = backend_metadata
170
+ else:
171
+ kernel_index[op] = {ETKernelKey(default=True): backend_metadata}
172
+
173
+ @staticmethod
174
+ def from_backend_indices(
175
+ backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]]
176
+ ) -> ETKernelIndex:
177
+ kernel_index: dict[
178
+ OperatorName, dict[ETKernelKey, BackendMetadata]
179
+ ] = defaultdict(dict)
180
+ ETKernelIndex.grow_from_backend_indices(kernel_index, backend_indices)
181
+ return ETKernelIndex(kernel_index)
182
+
183
+ def grow(
184
+ self, backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]]
185
+ ) -> ETKernelIndex:
186
+ ETKernelIndex.grow_from_backend_indices(self.index, backend_indices)
187
+ return self
188
+
189
+ def _to_backend_index(self) -> BackendIndex:
190
+ """
191
+ WARNING: this will be deprecated once all the codegen places know how to handle ETKernelIndex.
192
+ """
193
+ index: dict[OperatorName, BackendMetadata] = {}
194
+ for op in self.index:
195
+ kernel_dict = self.index[op]
196
+ assert (
197
+ len(kernel_dict.values()) == 1
198
+ ), f"Can't convert ETKernelIndex to BackendIndex because {op} has more than one kernels. Got {kernel_dict}"
199
+ index[op] = kernel_dict.get(
200
+ ETKernelKey(default=True),
201
+ BackendMetadata(kernel="", structured=False, cpp_namespace=""),
202
+ )
203
+ return BackendIndex(
204
+ dispatch_key=DispatchKey.CPU,
205
+ use_out_as_primary=False,
206
+ device_guard=False,
207
+ external=False,
208
+ index=index,
209
+ )
210
+
211
+ # Note duplicate ETKernelKey from index_b will clobber the metadata from index_a
212
+ @staticmethod
213
+ def merge_indices(index_a: ETKernelIndex, index_b: ETKernelIndex) -> ETKernelIndex:
214
+ combined = defaultdict(dict, index_a.index.copy())
215
+
216
+ for op, entry in index_b.index.items():
217
+ for key, metadata in entry.items():
218
+ combined[op][key] = metadata
219
+
220
+ return ETKernelIndex(combined)
.venv/lib/python3.11/site-packages/torchgen/executorch/parse.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from collections import defaultdict, namedtuple
4
+ from typing import Any
5
+
6
+ import yaml
7
+
8
+ from torchgen.executorch.model import ETKernelIndex, ETKernelKey
9
+ from torchgen.gen import LineLoader, parse_native_yaml
10
+ from torchgen.model import (
11
+ BackendMetadata,
12
+ DispatchKey,
13
+ FunctionSchema,
14
+ NativeFunction,
15
+ OperatorName,
16
+ )
17
+ from torchgen.utils import NamespaceHelper
18
+
19
+
20
+ # Parse native_functions.yaml into a sequence of NativeFunctions and ET Backend Indices.
21
+ ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "et_kernel_indices"])
22
+
23
+ # Fields in native_functions.yaml used to determine which kernels should be used
24
+ ET_FIELDS = ["kernels", "type_alias", "dim_order_alias"]
25
+
26
+
27
+ def parse_from_yaml(ei: dict[str, object]) -> dict[ETKernelKey, BackendMetadata]:
28
+ """Given a loaded yaml representing kernel assignment information, extract the
29
+ mapping from `kernel keys` to `BackendMetadata` (the latter representing the kernel instance)
30
+
31
+ Args:
32
+ ei: Dict keys {kernels, type_alias, dim_order_alias}
33
+ See ETKernelKey for description of arguments
34
+ """
35
+ e = ei.copy()
36
+ if (kernels := e.pop("kernels", None)) is None:
37
+ return {}
38
+
39
+ type_alias: dict[str, list[str]] = e.pop("type_alias", {}) # type: ignore[assignment]
40
+ dim_order_alias: dict[str, list[str]] = e.pop("dim_order_alias", {}) # type: ignore[assignment]
41
+ dim_order_alias.pop("__line__", None)
42
+
43
+ kernel_mapping: dict[ETKernelKey, BackendMetadata] = {}
44
+
45
+ for entry in kernels: # type: ignore[attr-defined]
46
+ arg_meta = entry.get("arg_meta")
47
+ if arg_meta is not None:
48
+ arg_meta.pop("__line__")
49
+
50
+ kernel_name = entry.get("kernel_name")
51
+ namespace_helper = NamespaceHelper.from_namespaced_entity(
52
+ kernel_name, max_level=3
53
+ )
54
+ kernel_namespace = namespace_helper.get_cpp_namespace(default="at")
55
+ backend_metadata = BackendMetadata(
56
+ kernel=namespace_helper.entity_name,
57
+ structured=False,
58
+ cpp_namespace=(kernel_namespace + "::native"),
59
+ )
60
+
61
+ kernel_keys = (
62
+ [ETKernelKey((), default=True)]
63
+ if arg_meta is None
64
+ else ETKernelKey.gen_from_yaml(arg_meta, type_alias, dim_order_alias) # type: ignore[arg-type]
65
+ )
66
+
67
+ for kernel_key in kernel_keys:
68
+ assert kernel_key not in kernel_mapping, (
69
+ "Duplicate kernel key: " + str(kernel_key) + " " + str(e)
70
+ )
71
+ kernel_mapping[kernel_key] = backend_metadata
72
+
73
+ return kernel_mapping
74
+
75
+
76
+ def parse_et_yaml_struct(es: object) -> ETKernelIndex:
77
+ """Given a loaded yaml representing a list of operators, for each op extract the mapping
78
+ of `kernel keys` to `BackendMetadata` (the latter representing the kernel instance
79
+ that should be used by the kernel key).
80
+ """
81
+ indices: dict[OperatorName, dict[ETKernelKey, BackendMetadata]] = {}
82
+ for ei in es: # type: ignore[attr-defined]
83
+ e = ei.copy()
84
+
85
+ funcs = e.pop("func")
86
+ assert isinstance(funcs, str), f"not a str: {funcs}"
87
+ namespace_helper = NamespaceHelper.from_namespaced_entity(
88
+ namespaced_entity=funcs, max_level=1
89
+ )
90
+ opname = FunctionSchema.parse(namespace_helper.entity_name).name
91
+
92
+ assert opname not in indices, f"Duplicate func found in yaml: {opname} already"
93
+
94
+ if len(index := parse_from_yaml(e)) != 0:
95
+ indices[opname] = index
96
+
97
+ return ETKernelIndex(indices)
98
+
99
+
100
+ def extract_kernel_fields(es: object) -> dict[OperatorName, dict[str, Any]]:
101
+ """Given a loaded yaml representing a list of operators, extract the
102
+ kernel key related fields indexed by the operator name.
103
+ """
104
+ fields: dict[OperatorName, dict[str, Any]] = defaultdict(dict)
105
+ for ei in es: # type: ignore[attr-defined]
106
+ funcs = ei.get("func")
107
+ assert isinstance(funcs, str), f"not a str: {funcs}"
108
+ namespace_helper = NamespaceHelper.from_namespaced_entity(
109
+ namespaced_entity=funcs, max_level=1
110
+ )
111
+ opname = FunctionSchema.parse(namespace_helper.entity_name).name
112
+
113
+ for field in ET_FIELDS:
114
+ if (value := ei.get(field)) is not None:
115
+ fields[opname][field] = value
116
+
117
+ return fields
118
+
119
+
120
+ def parse_et_yaml(
121
+ path: str,
122
+ tags_yaml_path: str,
123
+ ignore_keys: set[DispatchKey] | None = None,
124
+ skip_native_fns_gen: bool = False,
125
+ ) -> tuple[list[NativeFunction], dict[OperatorName, dict[str, Any]]]:
126
+ """Parse native_functions.yaml into NativeFunctions and an Operator Indexed Dict
127
+ of fields to persist from native_functions.yaml to functions.yaml
128
+ """
129
+ with open(path) as f:
130
+ es = yaml.load(f, Loader=LineLoader)
131
+
132
+ et_kernel = extract_kernel_fields(es)
133
+
134
+ # Remove ET specific fields from entries for BC compatibility
135
+ strip_et_fields(es)
136
+
137
+ native_yaml = parse_native_yaml(
138
+ path,
139
+ tags_yaml_path,
140
+ ignore_keys,
141
+ skip_native_fns_gen=skip_native_fns_gen,
142
+ loaded_yaml=es,
143
+ )
144
+ return native_yaml.native_functions, et_kernel
145
+
146
+
147
+ def strip_et_fields(es: object) -> None:
148
+ """Given a loaded yaml representing a list of operators,
149
+ remove ET specific fields from every entries for BC compatibility
150
+ """
151
+ for entry in es: # type: ignore[attr-defined]
152
+ for field in ET_FIELDS:
153
+ entry.pop(field, None)
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/native/native_functions.yaml ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/native/tags.yaml ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This yaml file contains all the possible tags that can be defined in `tags` in `native_functions.yaml`
2
+
3
+ - tag: inplace_view
4
+ desc: |
5
+ This tag indicates if an operator *only* modifies the tensor metadata
6
+ - tag: pt2_compliant_tag
7
+ desc: |
8
+ This tag indicates if the operator is guaranteed to
9
+ work with the PT2 compilation APIs (torch.compile,
10
+ torch.export, etc). If you add this tag to an
11
+ operator, please use
12
+ `torch.testing._internal.optest.opcheck` to test that
13
+ the operator has been registered correctly and
14
+ works with torch.compile
15
+ - tag: view_copy
16
+ desc: |
17
+ This tag indicates operators that are *_copy* variants
18
+ of view/aliasing operators. If an operator has a view_copy tag,
19
+ then it should have the name {op}_copy, where {op} is a view operator.
20
+ - tag: dynamic_output_shape
21
+ desc: |
22
+ This tag indicates if an operator's output's shape depends on input Tensor
23
+ data.
24
+ - tag: data_dependent_output
25
+ desc: |
26
+ Operator has a non-Tensor output whose value is dependent on the data
27
+ of Tensor inputs. Among other things, this implies that this operator
28
+ cannot be run with meta tensor (since data is not available), nor
29
+ can it be symbolically traced.
30
+ - tag: generated
31
+ desc: |
32
+ This tag indicates that the operator doesn't have an explicit entry in
33
+ native_functions.yaml, and instead was generated automatically by the codegen.
34
+ - tag: nondeterministic_seeded
35
+ desc: |
36
+ This tag indicates if an operator is nondeterministically seeded
37
+ (i.e., is random) such that the operator intentionally produces
38
+ different results when run twice on the same inputs, but this randomness
39
+ is controlled by a Generator which, if reseeded would give you the
40
+ same result.
41
+ - tag: nondeterministic_bitwise
42
+ desc: |
43
+ This tag indicates if an operator doesn't guarantee bitwise equivalence
44
+ across different runs of an operator with identical inputs.
45
+ - tag: needs_fixed_stride_order
46
+ desc: |
47
+ This tag indicates that the operator should be passed Tensors following
48
+ the same stride permutation as observed in eager when compiled in inductor.
49
+ Only one of {needs_fixed_stride_order, flexible_layout} can apply; if
50
+ multiple are assigned then we assume the most restrictive one.
51
+ - tag: flexible_layout
52
+ desc: |
53
+ This tag indicates that the custom operator can accept inputs with varying
54
+ strides/storage_offset and that when compiled, Inductor is allowed to change
55
+ the strides/storage_offset of inputs to the custom operator.
56
+ Only one of {needs_fixed_stride_order, flexible_layout} can apply; if
57
+ multiple are assigned then we assume the most restrictive one.
58
+
59
+ # NOTE [Core ATen Ops]
60
+ - tag: core
61
+ desc: |
62
+ Core aten ops is a subset of aten ops that remains after aten-to-aten decomposition and
63
+ functionalization pass. Core aten ops are fully functional and adhere to single static
64
+ assignment (SSA): this implies there will be no `inplace` or `_out` variants in this opset.
65
+ This opset is designed to serve as the functional IR to interface with compiler backends.
66
+ In contrast to primTorch, core aten opset doesn't decompose ops into explicit
67
+ type promotion and broadcasting ops.
68
+ Core aten ops is also effectively the opset produced by torchdynamo.export(aten_graph=True),
69
+ and thus can be used as an opset for export purpose.
70
+ - tag: pointwise
71
+ desc: |
72
+ Pointwise operators are operators where each element of the output is computed only by accessing
73
+ the corresponding element of all the broadcasted inputs. The output shape will be the broadcasted
74
+ shape of the inputs.
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/ATenOpList.cpp ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/core/ATenOpList.h>
2
+
3
+ #include <string>
4
+ #include <cstring>
5
+ #include <utility>
6
+ #include <unordered_set>
7
+ #include <ATen/core/operator_name.h>
8
+
9
+ // ${generated_comment}
10
+
11
+ namespace at {
12
+
13
+ namespace {
14
+ struct OpNameEquals final {
15
+ bool operator()(const std::pair<const char*, const char*>& lhs, const std::pair<const char*, const char*>& rhs) const {
16
+ return 0 == strcmp(lhs.first, rhs.first) && 0 == strcmp(lhs.second, rhs.second);
17
+ }
18
+ };
19
+
20
+ struct OpNameHash final {
21
+ size_t operator()(const std::pair<const char*, const char*>& p) const {
22
+ // use std::hash<std::string> because std::hash<const char*> would hash pointers and not pointed-to strings
23
+ return std::hash<std::string>()(p.first) ^ (~ std::hash<std::string>()(p.second));
24
+ }
25
+ };
26
+ }
27
+
28
+ bool is_custom_op(const c10::OperatorName& opName) {
29
+ static std::unordered_set<std::pair<const char*, const char*>, OpNameHash, OpNameEquals> ops {
30
+ ${aten_ops}
31
+ {"", ""}
32
+ };
33
+ return ops.count(std::make_pair(
34
+ opName.name.c_str(), opName.overload_name.c_str())) == 0;
35
+ }
36
+ }
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/CompositeViewCopyKernels.cpp ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2
+ // ${generated_comment}
3
+
4
+ #include <ATen/InferSize.h>
5
+ #include <ATen/Tensor.h>
6
+ #include <ATen/native/Resize.h>
7
+
8
+ #ifndef AT_PER_OPERATOR_HEADERS
9
+ #include <ATen/Operators.h>
10
+ #else
11
+ #include <ATen/ops/clone.h>
12
+ $ops_headers
13
+ #endif
14
+
15
+ namespace at {
16
+ namespace native {
17
+
18
+ // This file contains a number of kernels for aten functions that are fully code-generated.
19
+ // TODO: rename this file to something more generic.
20
+
21
+ namespace {
22
+ at::Tensor clone_arg(const at::Tensor& t) {
23
+ return t.clone();
24
+ }
25
+
26
+ std::vector<at::Tensor> clone_arg(const at::TensorList& t_list) {
27
+ std::vector<at::Tensor> out(t_list.size());
28
+ for (const auto& i : c10::irange(t_list.size())) {
29
+ out[i] = t_list[i].clone();
30
+ }
31
+ return out;
32
+ }
33
+
34
+ // duped with gen_resize_out_helper from structured kernels
35
+ void copy_arg(const at::Tensor& dst, const at::Tensor& src) {
36
+ TORCH_CHECK(src.dtype() == dst.dtype(),
37
+ "Expected out tensor to have dtype ", src.dtype(), ", but got ", dst.dtype(), " instead");
38
+ TORCH_CHECK(src.device() == dst.device(),
39
+ "Expected out tensor to have device ", src.device(), ", but got ", dst.device(), " instead");
40
+ dst.copy_(src);
41
+ }
42
+
43
+ void copy_arg(const at::TensorList& dst, const at::TensorList& src) {
44
+ TORCH_INTERNAL_ASSERT(dst.size() == src.size());
45
+ for (const auto& i : c10::irange(dst.size())) {
46
+ copy_arg(dst[i], src[i]);
47
+ }
48
+ }
49
+
50
+ // TODO: this doesn't handle restriding empty tensors correctly; see
51
+ // gen_resize_out_helper for the correct algorithm
52
+
53
+ void resize_out_helper(const at::Tensor& dst, const at::Tensor& src) {
54
+ at::native::resize_output(dst, src.sizes());
55
+ }
56
+
57
+ void resize_out_helper(const at::TensorList& dst, const at::TensorList& src) {
58
+ TORCH_INTERNAL_ASSERT(dst.size() == src.size());
59
+ for (const auto& i : c10::irange(dst.size())) {
60
+ at::native::resize_output(dst[i], src[i].sizes());
61
+ }
62
+ }
63
+ }
64
+
65
+
66
+ ${CompositeViewCopyKernel_Definitions}
67
+
68
+ ${GeneratedCompositeFunctional_Definitions}
69
+
70
+ ${GeneratedCompositeOut_Definitions}
71
+
72
+ } // namespace native
73
+ } // namespace at
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/DispatchKeyFunction.h ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // ${generated_comment}
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace ${dispatch_namespace} {
19
+
20
+ ${dispatch_namespaced_declarations}
21
+
22
+ } // namespace ${dispatch_namespace}
23
+ } // namespace at
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/DispatchKeyFunctions.h ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/core/TensorBody.h>
2
+
3
+ // TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
4
+ // Code introduced to avoid cyclic dependency in static dispatch is no longer
5
+ // needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
6
+ // to Operators.cpp for supporting multiple backends with multiple kernels.
7
+ //
8
+ // Note [Avoiding Include Cycles In Static Dispatch]
9
+ // In order to avoid #include cycles in the static dispatch build, we've carefully split out
10
+ // the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
11
+ //
12
+ // Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
13
+ // - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
14
+ // all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
15
+ // directly inlined into TensorBody.h.
16
+ // - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
17
+ // which include functions that have defaultable std::optional<Tensor> arguments.
18
+ // That requires knowing the full Tensor class definition.
19
+ //
20
+ // We break the cycle by doing the following:
21
+ // - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
22
+ // - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
23
+ // - CPUFunctions_inl.h includes everything else
24
+ // - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
25
+ // and then it includes CPUFunctions_inl.h.
26
+ // - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
27
+ // - This also means that static dispatch build, CPUFunctions.h only needs to
28
+ // #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
29
+ ${inline_headers}
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/DispatchKeyFunctions_inl.h ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // ${generated_comment}
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ #if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
12
+ #error This change adds a dependency on all pytorch operators, meaning the \
13
+ file will need to be re-compiled every time an operator is changed or added. \
14
+ Consider including a specific operator from \
15
+ <ATen/ops/{my_operator}_${dispatch_namespace}_dispatch.h>. \
16
+ See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
17
+ #endif
18
+
19
+ ${DispatchKeyFunctions_inl_includes}
20
+
21
+
22
+ ${dispatch_namespaced_declarations}
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.cpp ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // ${generated_comment}
2
+ ${includes}
3
+ ${native_functions_include}
4
+
5
+ namespace {
6
+ ${helper_fns}
7
+ } // namespace
8
+
9
+ ${namespace_prologue}
10
+
11
+ ${native_function_definitions}
12
+
13
+ ${namespace_epilogue}
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.h ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // an external backend might generate file within its code tree
4
+ // and check all the source files within the tree with clang-format.
5
+ // so, disable it since the backend might have a different config.
6
+ // clang-format off
7
+
8
+ // ${generated_comment}
9
+
10
+ #include <ATen/Tensor.h>
11
+
12
+ ${namespace_prologue}
13
+
14
+ struct ${class_name} {
15
+
16
+ ${dispatch_declarations}
17
+
18
+ };
19
+ ${namespace_epilogue}
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/Function.h ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // ${generated_comment}
4
+
5
+ #include <ATen/Context.h>
6
+ #include <ATen/DeviceGuard.h>
7
+ #include <ATen/TensorUtils.h>
8
+ #include <ATen/TracerMode.h>
9
+ #include <ATen/core/Generator.h>
10
+ #include <ATen/core/Reduction.h>
11
+ #include <ATen/core/Tensor.h>
12
+ #include <c10/core/Scalar.h>
13
+ #include <c10/core/Storage.h>
14
+ #include <c10/core/TensorOptions.h>
15
+ #include <c10/util/Deprecated.h>
16
+ #include <optional>
17
+
18
+ ${static_dispatch_ops_headers}
19
+
20
+ ${operator_includes}
21
+
22
+ namespace at {
23
+
24
+ ${function_definitions}
25
+
26
+ }
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/FunctionalInverses.h ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // ${generated_comment}
4
+
5
+ #include <ATen/Tensor.h>
6
+
7
+ namespace at {
8
+ namespace functionalization {
9
+
10
+ enum class InverseReturnMode {
11
+ /// Specifies that functional inverses should always return a view.
12
+ AlwaysView,
13
+ /// Specifies that functional inverses should always return a non-view / copy.
14
+ NeverView,
15
+ /// Specifies that functional inverses should return a view unless a (copying) scatter
16
+ /// inverse exists, in which case that will be used instead.
17
+ /// This avoids as_strided() calls that can be difficult for subclasses to handle.
18
+ ViewOrScatterInverse,
19
+ };
20
+
21
+ struct FunctionalInverses {
22
+
23
+ ${view_inverse_declarations}
24
+
25
+ // NB: These are not generated! They're manually implemented in the template.
26
+ // TODO: Change codegen to generate these. See the following link:
27
+ // https://github.com/pytorch/pytorch/blob/main/torchgen/model.py#L2583-L2585
28
+ static at::Tensor chunk_inverse(const at::Tensor & base, const at::Tensor & mutated_view, InverseReturnMode inverse_return_mode, int64_t mutated_view_idx, int chunks, int dim);
29
+ static at::Tensor narrow_inverse(const at::Tensor & base, const at::Tensor & mutated_view, InverseReturnMode inverse_return_mode, int dim, c10::SymInt start, c10::SymInt length);
30
+
31
+ };
32
+ }
33
+ }
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/Functions.cpp ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <array>
2
+
3
+ #include <ATen/Functions.h>
4
+ #include <ATen/Utils.h>
5
+ #include <c10/core/Allocator.h>
6
+
7
+ namespace at {
8
+
9
+ Tensor TensorMaker::make_tensor() {
10
+ AutoDispatchBelowADInplaceOrView guard{}; // TODO: Remove.
11
+ tracer::impl::NoTracerDispatchMode tracer_guard{};
12
+
13
+ check_size_nonnegative(sizes_);
14
+
15
+ TORCH_CHECK_VALUE(
16
+ !deleter_ || !ctx_,
17
+ "The deleter and context arguments are mutually exclusive.");
18
+
19
+ if (device_ == std::nullopt) {
20
+ device_ = globalContext().getDeviceFromPtr(data_, opts_.device().type());
21
+ }
22
+
23
+ if (opts_.device().has_index()) {
24
+ // clang-format off
25
+ TORCH_CHECK_VALUE(
26
+ opts_.device() == *device_,
27
+ "Specified device ", opts_.device(), " does not match device of data ", *device_);
28
+ // clang-format on
29
+ }
30
+
31
+ std::size_t size_bytes = computeStorageSize();
32
+
33
+ DataPtr data_ptr{};
34
+ if (deleter_) {
35
+ data_ptr = makeDataPtrFromDeleter();
36
+ } else {
37
+ data_ptr = makeDataPtrFromContext();
38
+ }
39
+
40
+ TORCH_CHECK(!resizeable_ || allocator_ != nullptr, "Must specify an allocator with allocator() if you want to use resizeable_storage()");
41
+ Storage storage{Storage::use_byte_size_t{}, size_bytes, std::move(data_ptr), /*allocator=*/allocator_, /*resizable=*/resizeable_};
42
+
43
+ Tensor tensor = detail::make_tensor<TensorImpl>(
44
+ std::move(storage), opts_.computeDispatchKey(), opts_.dtype());
45
+
46
+ TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl();
47
+ if (strides_) {
48
+ tensor_impl->set_sizes_and_strides(sizes_, *strides_);
49
+ } else {
50
+ tensor_impl->set_sizes_contiguous(sizes_);
51
+ }
52
+ if (storage_offset_) {
53
+ tensor_impl->set_storage_offset(*storage_offset_);
54
+ }
55
+
56
+ return tensor;
57
+ }
58
+
59
+ std::size_t TensorMaker::computeStorageSize() const noexcept {
60
+ std::size_t itemsize = opts_.dtype().itemsize();
61
+
62
+ if (strides_) {
63
+ auto storage_size = detail::computeStorageNbytes(sizes_, *strides_, itemsize);
64
+ if (storage_offset_) {
65
+ storage_size += storage_offset_.value();
66
+ }
67
+ return storage_size;
68
+ }
69
+
70
+ std::size_t size = 1;
71
+ for (std::int64_t s : sizes_) {
72
+ size *= static_cast<std::size_t>(s);
73
+ }
74
+ auto storage_size = size * itemsize;
75
+ if (storage_offset_) {
76
+ storage_size += storage_offset_.value();
77
+ }
78
+ return storage_size;
79
+ }
80
+
81
+ inline DataPtr TensorMaker::makeDataPtrFromDeleter() noexcept {
82
+ return InefficientStdFunctionContext::makeDataPtr(data_, std::move(deleter_), *device_);
83
+ }
84
+
85
+ inline DataPtr TensorMaker::makeDataPtrFromContext() noexcept {
86
+ return DataPtr{data_, ctx_.release(), ctx_.get_deleter(), *device_};
87
+ }
88
+
89
+ IntArrayRef TensorMaker::makeTempSizes() const noexcept {
90
+ static std::int64_t zeros[5] = {0, 0, 0, 0, 0};
91
+ if (opts_.has_memory_format()) {
92
+ MemoryFormat format = *opts_.memory_format_opt();
93
+ if (format == MemoryFormat::ChannelsLast) {
94
+ return IntArrayRef(zeros, 4);
95
+ }
96
+ if (format == MemoryFormat::ChannelsLast3d) {
97
+ return IntArrayRef(zeros, 5);
98
+ }
99
+ }
100
+ return IntArrayRef(zeros, 1);
101
+ }
102
+
103
+ } // namespace at
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/Functions.h ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // ${generated_comment}
4
+
5
+ #ifdef TORCH_ASSERT_NO_OPERATORS
6
+ #error This change adds a dependency on native_functions.yaml, \
7
+ meaning the file will need to be re-compiled every time an operator \
8
+ is changed or added. Consider if your change would be better placed in \
9
+ another file, or if a more specific header might achieve the same goal. \
10
+ See NOTE: [Tensor vs. TensorBase]
11
+ #endif
12
+
13
+ #if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
14
+ #error This change adds a dependency on all pytorch operators, meaning the \
15
+ file will need to be re-compiled every time an operator is changed or added. \
16
+ Consider including a specific operator from <ATen/ops/{my_operator}.h> and \
17
+ see NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
18
+ #endif
19
+
20
+ // NOTE: [TORCH_ASSERT_ONLY_METHOD_OPERATORS]
21
+ //
22
+ // In ATen, certain generated headers files include the definitions of
23
+ // every single operator in PyTorch. Unfortunately this means every
24
+ // time an operator signature is updated or changed in
25
+ // native_functions.yaml, you (and every other PyTorch developer) need
26
+ // to recompile every source file that includes any of these headers.
27
+ //
28
+ // To break up these header dependencies, and improve incremental
29
+ // build times for all PyTorch developers. These headers are split
30
+ // into per-operator headers in the `ATen/ops` folder. This limits
31
+ // incremental builds to only changes to methods of `Tensor`, or files
32
+ // that use the specific operator being changed. With `at::sum` as an
33
+ // example, you should include
34
+ //
35
+ // <ATen/ops/sum.h> // instead of ATen/Functions.h
36
+ // <ATen/ops/sum_native.h> // instead of ATen/NativeFunctions.h
37
+ // <ATen/ops/sum_ops.h> // instead of ATen/Operators.h
38
+ // <ATen/ops/sum_cpu_dispatch.h> // instead of ATen/CPUFunctions.h
39
+ //
40
+ // However, even if you're careful to use this in your own code.
41
+ // `Functions.h` might be included indirectly through another header
42
+ // without you realising. To avoid this, you can add
43
+ //
44
+ // #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
45
+ //
46
+ // to the top of your source file. This way any time the non-specific
47
+ // headers are included, the compiler will error out.
48
+ //
49
+ // Also, be aware that `ops` are not available in all build
50
+ // configurations (namely fb-internal) so you must guard these
51
+ // includes with `#ifdef AT_PER_OPERATOR_HEADERS`. e.g.
52
+ //
53
+ // #ifndef AT_PER_OPERATOR_HEADERS
54
+ // #include <ATen/Functions.h>
55
+ // #else
56
+ // #include <ATen/ops/sum.h>
57
+ // #endif
58
+
59
+ #include <ATen/Context.h>
60
+ #include <ATen/DeviceGuard.h>
61
+ #include <ATen/TensorUtils.h>
62
+ #include <ATen/TracerMode.h>
63
+ #include <ATen/core/Generator.h>
64
+ #include <ATen/core/Reduction.h>
65
+ #include <c10/core/SymInt.h>
66
+ #include <ATen/core/Tensor.h>
67
+ #include <c10/core/Scalar.h>
68
+ #include <c10/core/Storage.h>
69
+ #include <c10/core/TensorOptions.h>
70
+ #include <c10/util/Deprecated.h>
71
+ #include <optional>
72
+ #include <c10/util/OptionalArrayRef.h>
73
+
74
+ #include <ATen/ops/from_blob.h>
75
+ #include <ATen/ops/tensor.h>
76
+
77
+ ${Functions_includes}
78
+
79
+ namespace at {
80
+
81
+ ${Functions_declarations}
82
+
83
+ // Special C++ only overloads for std()-like functions (See gh-40287)
84
+ // These are needed because int -> bool conversion takes precedence over int -> IntArrayRef
85
+ // So, for example std(0) would select the std(unbiased=False) overload
86
+ TORCH_API inline Tensor var(const Tensor& self, int dim) {
87
+ return at::var(self, IntArrayRef{dim});
88
+ }
89
+ TORCH_API inline std::tuple<Tensor, Tensor> var_mean(const Tensor& self, int dim) {
90
+ return at::var_mean(self, IntArrayRef{dim});
91
+ }
92
+ TORCH_API inline Tensor std(const Tensor& self, int dim) {
93
+ return at::std(self, IntArrayRef{dim});
94
+ }
95
+ TORCH_API inline std::tuple<Tensor, Tensor> std_mean(const Tensor& self, int dim) {
96
+ return at::std_mean(self, IntArrayRef{dim});
97
+ }
98
+
99
+ inline int64_t numel(const Tensor& tensor) {
100
+ return tensor.numel();
101
+ }
102
+
103
+ inline int64_t size(const Tensor& tensor, int64_t dim) {
104
+ return tensor.size(dim);
105
+ }
106
+
107
+ inline int64_t stride(const Tensor& tensor, int64_t dim) {
108
+ return tensor.stride(dim);
109
+ }
110
+
111
+ inline bool is_complex(const Tensor& tensor) {
112
+ return tensor.is_complex();
113
+ }
114
+
115
+ inline bool is_floating_point(const Tensor& tensor) {
116
+ return tensor.is_floating_point();
117
+ }
118
+
119
+ inline bool is_signed(const Tensor& tensor) {
120
+ return tensor.is_signed();
121
+ }
122
+
123
+ inline bool is_inference(const Tensor& tensor) {
124
+ return tensor.is_inference();
125
+ }
126
+
127
+ inline bool _is_zerotensor(const Tensor& tensor) {
128
+ return tensor._is_zerotensor();
129
+ }
130
+
131
+ inline bool is_conj(const Tensor& tensor) {
132
+ return tensor.is_conj();
133
+ }
134
+
135
+ inline Tensor conj(const Tensor& tensor) {
136
+ return tensor.conj();
137
+ }
138
+
139
+ inline bool is_neg(const Tensor& tensor) {
140
+ return tensor.is_neg();
141
+ }
142
+
143
+ }
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/LazyIr.h ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // This file contains autogenerated LazyTensor IR nodes
4
+ ${lazy_ir_sysinc}
5
+ ${lazy_ir_inc}
6
+
7
+ ${namespace_prologue}
8
+ using at::operator<<;
9
+
10
+ // kNullValue is used to contribute a static hash value any time
11
+ // a node has an Optional<Value> input that is nullopt. It is important
12
+ // to differentiate between HASH(std::nullopt, something) and HASH(something, std::nullopt),
13
+ // and using kNullValue in the hash function in the order of arguments
14
+ // serves this purpose.
15
+ static const torch::lazy::Value kNullValue = torch::lazy::Value();
16
+
17
+ ${ir_declarations}
18
+
19
+ ${namespace_epilogue}
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/LazyNonNativeIr.h ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ ${lazy_non_native_ir_inc}
4
+
5
+ // This file contains autogenerated LazyTensor Non Native IR nodes
6
+
7
+ ${namespace_prologue}
8
+
9
+ ${non_native_ir_nodes}
10
+
11
+ ${namespace_epilogue}
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/MethodOperators.h ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // ${generated_comment}
4
+
5
+ #ifdef TORCH_ASSERT_NO_OPERATORS
6
+ #error This change adds a dependency on native_functions.yaml, \
7
+ meaning the file will need to be re-compiled every time an operator \
8
+ is changed or added. Consider if your change would be better placed in \
9
+ another file, or if a more specific header might achieve the same goal. \
10
+ See NOTE: [Tensor vs. TensorBase]
11
+ #endif
12
+
13
+ // Forward declarations of any types needed in the operator signatures.
14
+ // We can't directly include these classes because it will cause circular include dependencies.
15
+ // This file is included by TensorBody.h, which defines the Tensor class.
16
+ #include <ATen/core/ATen_fwd.h>
17
+
18
+ ${MethodOperators_includes}
19
+
20
+ namespace at {
21
+ namespace _ops {
22
+ ${MethodOperators_declarations}
23
+ } // namespace _ops
24
+ } // namespace at
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/NativeFunction.h ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // ${generated_comment}
4
+
5
+ #include <c10/core/Scalar.h>
6
+ #include <c10/core/Storage.h>
7
+ #include <c10/core/TensorOptions.h>
8
+ #include <c10/util/Deprecated.h>
9
+ #include <optional>
10
+ #include <c10/core/QScheme.h>
11
+ #include <ATen/core/Reduction.h>
12
+ #include <ATen/core/Tensor.h>
13
+ #include <tuple>
14
+ #include <vector>
15
+ ${extra_includes}
16
+
17
+ ${native_function_declarations}
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/NativeFunctions.h ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // ${generated_comment}
4
+
5
+ #ifdef TORCH_ASSERT_NO_OPERATORS
6
+ #error This change adds a dependency on native_functions.yaml, \
7
+ meaning the file will need to be re-compiled every time an operator \
8
+ is changed or added. Consider if your change would be better placed in \
9
+ another file, or if a more specific header might achieve the same goal. \
10
+ See NOTE: [Tensor vs. TensorBase]
11
+ #endif
12
+
13
+ #if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
14
+ #error This change adds a dependency on all pytorch operators, meaning the \
15
+ file will need to be re-compiled every time an operator is changed or added. \
16
+ Consider including a specific operator from <ATen/ops/{my_operator}_native.h> \
17
+ and see NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
18
+ #endif
19
+
20
+ #include <c10/core/Scalar.h>
21
+ #include <c10/core/Storage.h>
22
+ #include <c10/core/TensorOptions.h>
23
+ #include <c10/util/Deprecated.h>
24
+ #include <optional>
25
+ #include <c10/core/QScheme.h>
26
+ #include <ATen/core/Reduction.h>
27
+ #include <ATen/core/Tensor.h>
28
+ #include <tuple>
29
+ #include <vector>
30
+
31
+ ${NativeFunctions_includes}
32
+
33
+ ${NativeFunctions_declarations}
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/NativeMetaFunction.h ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // ${generated_comment}
4
+
5
+ #include <c10/core/Scalar.h>
6
+ #include <c10/core/Storage.h>
7
+ #include <c10/core/TensorOptions.h>
8
+ #include <c10/util/Deprecated.h>
9
+ #include <optional>
10
+ #include <c10/core/QScheme.h>
11
+ #include <ATen/core/Reduction.h>
12
+ #include <ATen/TensorIterator.h>
13
+ #include <ATen/TensorMeta.h>
14
+ #include <tuple>
15
+ #include <vector>
16
+
17
+ namespace at {
18
+ namespace meta {
19
+
20
+ ${meta_function_declarations}
21
+
22
+ } // namespace native
23
+ } // namespace at