koichi12 commited on
Commit
466ab75
·
verified ·
1 Parent(s): 1c399ca

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/__init__.cpython-311.pyc +0 -0
  2. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/error.py +56 -0
  3. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_base.py +435 -0
  4. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__pycache__/proxy_value.cpython-311.pyc +0 -0
  5. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/proxy_value.py +41 -0
  6. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/__init__.cpython-311.pyc +0 -0
  7. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/union.py +69 -0
  8. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__init__.py +150 -0
  9. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/__init__.cpython-311.pyc +0 -0
  10. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/compile_fx.cpython-311.pyc +0 -0
  11. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/config.cpython-311.pyc +0 -0
  12. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/coordinate_descent_tuner.cpython-311.pyc +0 -0
  13. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/debug.cpython-311.pyc +0 -0
  14. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/dependencies.cpython-311.pyc +0 -0
  15. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/freezing.cpython-311.pyc +0 -0
  16. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/hooks.cpython-311.pyc +0 -0
  17. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/inductor_prims.cpython-311.pyc +0 -0
  18. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/metrics.cpython-311.pyc +0 -0
  19. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/quantized_lowerings.cpython-311.pyc +0 -0
  20. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/select_algorithm.cpython-311.pyc +0 -0
  21. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/sizevars.cpython-311.pyc +0 -0
  22. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/test_case.cpython-311.pyc +0 -0
  23. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/test_operators.cpython-311.pyc +0 -0
  24. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/virtualized.cpython-311.pyc +0 -0
  25. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/wrapper_benchmark.cpython-311.pyc +0 -0
  26. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_prefix.h +595 -0
  27. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu.py +1851 -0
  28. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_wrapper_cuda.py +328 -0
  29. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__init__.py +0 -0
  30. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_template.cpython-311.pyc +0 -0
  31. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_kernel.py +374 -0
  32. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/__init__.cpython-311.pyc +0 -0
  33. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/gemm_operation_extensions.cpython-311.pyc +0 -0
  34. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/gemm_template.py +706 -0
  35. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/memory_planning.py +799 -0
  36. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/multi_kernel.py +413 -0
  37. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton_foreach.py +250 -0
  38. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton_split_scan.py +180 -0
  39. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton_utils.py +130 -0
  40. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/wrapper.py +1543 -0
  41. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/comm_analysis.py +273 -0
  42. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py +2159 -0
  43. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/hooks.py +28 -0
  44. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/ops_handler.py +655 -0
  45. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/optimize_indexing.py +118 -0
  46. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/triton_heuristics.py +1527 -0
  47. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/__init__.cpython-311.pyc +0 -0
  48. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/computation.cpython-311.pyc +0 -0
  49. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/extract_compiled_graph.cpython-311.pyc +0 -0
  50. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/ts_backend.cpython-311.pyc +0 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (20.6 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/error.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+
4
+ class ExportErrorType(Enum):
5
+ # User providing invalid inputs to either tracer, or other public facing APIs
6
+ INVALID_INPUT_TYPE = 1
7
+
8
+ # User returning values from their models that we don’t support.
9
+ INVALID_OUTPUT_TYPE = 2
10
+
11
+ # Generated IR does not conform to Export IR Specification.
12
+ VIOLATION_OF_SPEC = 3
13
+
14
+ # User’s code contains types and functionalities we don’t support.
15
+ NOT_SUPPORTED = 4
16
+
17
+ # User's code didn't provide necessary details for us to successfully trace and export.
18
+ # For example, we use a lot of decorators and ask users to annotate their model.
19
+ MISSING_PROPERTY = 5
20
+
21
+ # User is using an API without proper initialization step.
22
+ UNINITIALIZED = 6
23
+
24
+
25
+ def internal_assert(pred: bool, assert_msg: str) -> None:
26
+ """
27
+ This is exir's custom assert method. It internally just throws InternalError.
28
+ Note that the sole purpose is to throw our own error while maintaining similar syntax
29
+ as python assert.
30
+ """
31
+
32
+ if not pred:
33
+ raise InternalError(assert_msg)
34
+
35
+
36
+ class InternalError(Exception):
37
+ """
38
+ Raised when an internal invariance is violated in EXIR stack.
39
+ Should hint users to report a bug to dev and expose the original
40
+ error message.
41
+ """
42
+
43
+ def __init__(self, message: str) -> None:
44
+ super().__init__(message)
45
+
46
+
47
+ class ExportError(Exception):
48
+ """
49
+ This type of exception is raised for errors that are directly caused by the user
50
+ code. In general, user errors happen during model authoring, tracing, using our public
51
+ facing APIs, and writing graph passes.
52
+ """
53
+
54
+ def __init__(self, error_code: ExportErrorType, message: str) -> None:
55
+ prefix = f"[{error_code}]: "
56
+ super().__init__(prefix + message)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_base.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import operator
2
+ import traceback
3
+ import typing
4
+ from contextlib import nullcontext
5
+ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
6
+
7
+ import torch
8
+ from functorch.experimental.control_flow import _unstack_pytree
9
+ from torch import fx
10
+ from torch._dispatch.python import enable_python_dispatcher
11
+ from torch._export.pass_infra.node_metadata import NodeMetadata
12
+ from torch._export.pass_infra.proxy_value import ProxyValue
13
+ from torch._subclasses import FakeTensor, UnsupportedFakeTensorException
14
+ from torch._subclasses.fake_tensor import FakeTensorMode
15
+ from torch.fx import traceback as fx_traceback
16
+ from torch.fx.experimental.proxy_tensor import PythonKeyTracer
17
+ from torch.fx.graph import CodeGen
18
+ from torch.fx.passes.infra.pass_base import PassBase, PassResult
19
+ from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
20
+ from torch.utils import _pytree as pytree
21
+
22
+
23
+ __all__ = ["_ExportPassBaseDeprecatedDoNotUse"]
24
+
25
+
26
+ Argument = Any
27
+ Value = Any
28
+ Fn = Callable[..., Any]
29
+ PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
30
+
31
+
32
+ _TORCH_SYM_OPS: Set[Callable] = {
33
+ torch.sym_int,
34
+ torch.sym_ite,
35
+ torch.sym_max,
36
+ torch.sym_min,
37
+ torch.sym_not,
38
+ torch.sym_sqrt,
39
+ }
40
+
41
+
42
+ class ExportPassBaseError(RuntimeError):
43
+ pass
44
+
45
+
46
+ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
47
+ """
48
+ Interpreter-based pass class to help users maintain the IR spec while writing
49
+ transformations.
50
+ """
51
+
52
+ @staticmethod
53
+ def _create_dummy_node_metadata():
54
+ return NodeMetadata({"stack_trace": "".join(traceback.format_stack(limit=1))})
55
+
56
+
57
+ class ExportTracer(PythonKeyTracer):
58
+ def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", codegen: CodeGen) -> None:
59
+ super().__init__()
60
+ self.callback = callback
61
+ self.root = torch.nn.Module()
62
+ self.graph = torch.fx.Graph()
63
+ self.graph.set_codegen(codegen)
64
+ self.tensor_attrs: Dict[str, torch.Tensor] = {} # type: ignore[assignment]
65
+ self.fake_tensor_mode: Optional[FakeTensorMode] = None
66
+ self.submodules: Dict[torch.nn.Module, str] = {}
67
+
68
+ def trace(self) -> None:
69
+ raise ExportPassBaseError("ExportTracer doesn't support trace().")
70
+
71
+ def create_arg(self, a: Argument) -> torch.fx.Node:
72
+ if isinstance(a, torch.nn.Module):
73
+ if a not in self.submodules:
74
+ name_submodule = f"submodule_{len(self.submodules)}"
75
+ self.root.add_module(name_submodule, a)
76
+ self.submodules[a] = name_submodule
77
+ elif isinstance(a, FakeTensor):
78
+ if not hasattr(a, "constant") or a.constant is None:
79
+ raise ExportPassBaseError(f"Cannot add {a} to graph.")
80
+ a = a.constant
81
+ node = super().create_arg(a)
82
+ if (
83
+ isinstance(a, torch.Tensor)
84
+ and isinstance(node, torch.fx.Node)
85
+ and node.op == "get_attr"
86
+ ):
87
+ self.set_metadata(node, a)
88
+ self.callback.on_attr(ProxyValue(a, node))
89
+ return node
90
+
91
+ def set_metadata(
92
+ self, node: torch.fx.Node, value: Argument,
93
+ ) -> None:
94
+ # propagate the fake tensor or sym nodes
95
+ def make_val(
96
+ x: Argument,
97
+ ) -> Union[FakeTensor, torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool, str, None]:
98
+ if isinstance(x, FakeTensor):
99
+ return x
100
+ elif isinstance(x, torch.Tensor):
101
+ if x.is_quantized:
102
+ # TODO (tmanlaibaatar) properly support Quantized FakeTensor
103
+ x = torch.dequantize(x)
104
+
105
+ try:
106
+ assert self.fake_tensor_mode is not None
107
+ # TODO we should allocate static shapes
108
+ # for param/buffer values
109
+ if isinstance(x, torch.nn.Parameter):
110
+ fake_tensor = self.fake_tensor_mode.from_tensor(
111
+ x, static_shapes=True
112
+ )
113
+ else:
114
+ fake_tensor = self.fake_tensor_mode.from_tensor(x)
115
+ except UnsupportedFakeTensorException:
116
+ # TODO: This is just a workaround to get over the
117
+ # x.as_subclass error
118
+ print(
119
+ "Fakeifying a Tensor subclass is not supported \
120
+ right now. Instead a TensorMetadata is used."
121
+ )
122
+ fake_tensor = None
123
+ return fake_tensor
124
+ elif isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool, str)):
125
+ return x
126
+ else:
127
+ return None
128
+
129
+ node.meta["val"] = pytree.tree_map(make_val, value)
130
+
131
+ # Set the tensor_metadata for values that do not have a corresponding FakeTensor
132
+ def make_tensor_meta(x: Argument) -> Optional[TensorMetadata]:
133
+ if not isinstance(x, FakeTensor) and isinstance(x, torch.Tensor):
134
+ if x.is_quantized:
135
+ # TODO (tmanlaibaatar) properly support Quantized FakeTensor
136
+ x = torch.dequantize(x)
137
+
138
+ try:
139
+ assert self.fake_tensor_mode is not None
140
+ _ = self.fake_tensor_mode.from_tensor(x)
141
+ tensor_meta = None
142
+ except UnsupportedFakeTensorException:
143
+ # TODO: This is just a workaround to get over the
144
+ # x.as_subclass error
145
+ tensor_meta = _extract_tensor_metadata(x)
146
+ return tensor_meta
147
+ else:
148
+ return None
149
+
150
+ node.meta["tensor_meta"] = pytree.tree_map(make_tensor_meta, value)
151
+
152
+ class ExportInterpreter(fx.Interpreter):
153
+ def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", gm: fx.GraphModule) -> None:
154
+ super().__init__(gm)
155
+ self.callback = callback
156
+ self.node: torch.fx.Node = next(iter(gm.graph.nodes))
157
+
158
+ def placeholder(
159
+ self,
160
+ target: str,
161
+ args: Tuple[Argument, ...],
162
+ kwargs: Dict[str, Argument],
163
+ ) -> ProxyValue:
164
+ arg = super().placeholder(target, args, kwargs)
165
+ return self.callback.placeholder(target, arg, NodeMetadata(self.node.meta))
166
+
167
+ def output(
168
+ self,
169
+ target: torch.fx.node.Target,
170
+ args: Tuple[Argument, ...],
171
+ kwargs: Dict[str, Argument],
172
+ ) -> ProxyValue:
173
+ return self.callback.output(args[0], NodeMetadata(self.node.meta)).data
174
+
175
+ def call_function(
176
+ self,
177
+ target: torch.fx.node.Target,
178
+ args: Tuple[Argument, ...],
179
+ kwargs: Dict[str, Argument],
180
+ ) -> ProxyValue:
181
+ meta = NodeMetadata(self.node.meta)
182
+
183
+ if target == operator.getitem:
184
+ value, key = args
185
+ return self.callback.call_getitem(value, key, meta)
186
+ elif getattr(target, "__module__", None) in {"_operator", "math"}:
187
+ assert callable(target)
188
+ return self.callback.call_sym(target, args, meta)
189
+ elif target in _TORCH_SYM_OPS:
190
+ assert callable(target)
191
+ return self.callback.call_sym(target, args, meta)
192
+ elif isinstance(target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)):
193
+ return self.callback.call_operator(
194
+ target,
195
+ args,
196
+ kwargs,
197
+ meta,
198
+ )
199
+ elif target == torch.ops.higher_order.cond:
200
+ pred, true_fn, false_fn, inputs = args
201
+ return self.callback.call_cond(pred, true_fn, false_fn, inputs, meta)
202
+ elif target == torch.ops.higher_order.map_impl:
203
+ f, mapped_args, operands = args # type: ignore[assignment]
204
+ return self.callback.call_map(f, mapped_args, operands, meta)
205
+ # For other unregistered HigherOrderOps, just interpret them blindly
206
+ elif isinstance(target, torch._ops.HigherOrderOperator):
207
+ return self.callback._fx(
208
+ "call_function",
209
+ target,
210
+ args,
211
+ kwargs,
212
+ meta,
213
+ )
214
+ else:
215
+ raise ExportPassBaseError(f"Unsupported target type: {target}")
216
+
217
+ def get_attr(
218
+ self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument]
219
+ ) -> Argument:
220
+ return super().get_attr(target, args, kwargs)
221
+
222
+ def call_module(
223
+ self,
224
+ target: torch.fx.node.Target,
225
+ args: Tuple[Argument, ...],
226
+ kwargs: Dict[str, Argument],
227
+ ) -> None:
228
+ raise ExportPassBaseError("call_module is not supported.")
229
+
230
+ def call_method(
231
+ self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument]
232
+ ) -> None:
233
+ raise ExportPassBaseError("call_method is not supported.")
234
+
235
+ def run_node(self, n: torch.fx.Node) -> Argument:
236
+ self.node = n
237
+ self.callback.node_debug_str = n.format_node()
238
+ return super().run_node(n)
239
+
240
+ def __init__(self) -> None:
241
+ self.interpreter = torch.fx.Interpreter(
242
+ torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
243
+ )
244
+ self.tracer = self.ExportTracer(self, CodeGen())
245
+ self.fake_tensor_mode: Optional[FakeTensorMode] = None
246
+ self._initialized = True
247
+ self.node_debug_str: typing.Optional[str] = None
248
+
249
+ def _fx(
250
+ self,
251
+ kind: str,
252
+ target: torch.fx.node.Target,
253
+ args: Tuple[Argument, ...],
254
+ kwargs: Dict[str, Argument],
255
+ meta: NodeMetadata,
256
+ ) -> ProxyValue:
257
+ args_data, kwargs_data = pytree.tree_map_only(
258
+ ProxyValue, lambda x: x.data, (args, kwargs)
259
+ )
260
+ res_data = getattr(self.interpreter, kind)(target, args_data, kwargs_data)
261
+ args_proxy, kwargs_proxy = pytree.tree_map_only(
262
+ ProxyValue, lambda x: x.proxy, (args, kwargs)
263
+ )
264
+
265
+ name = None
266
+ if isinstance(target, torch._ops.OpOverload):
267
+ name = self.tracer.graph._target_to_str(target.overloadpacket.__name__)
268
+
269
+ res_proxy = self.tracer.create_proxy(kind, target, args_proxy, kwargs_proxy, name=name)
270
+ res_proxy.node.meta.update(meta.data)
271
+ self.tracer.set_metadata(res_proxy.node, res_data)
272
+ return ProxyValue(res_data, res_proxy)
273
+
274
+ def inputs(self, graph_module: torch.fx.GraphModule) -> List[Argument]:
275
+ # TODO(angelayi): Update this with what we decide to do for metadata in
276
+ # the exported graph module
277
+ if (args := graph_module.meta.get("args", None)) is not None:
278
+ return list(args)
279
+
280
+ def extract_input(node: torch.fx.Node) -> Optional[FakeTensor]:
281
+ if "val" in node.meta:
282
+ fake = node.meta["val"]
283
+ if hasattr(fake, "constant") and fake.constant is not None:
284
+ return fake.constant
285
+ return fake
286
+ elif tensor_meta := node.meta.get("tensor_meta"):
287
+ assert self.fake_tensor_mode is not None
288
+ return FakeTensor(
289
+ self.fake_tensor_mode,
290
+ torch.empty(
291
+ tensor_meta.shape,
292
+ dtype=tensor_meta.dtype,
293
+ device="meta",
294
+ requires_grad=tensor_meta.requires_grad,
295
+ memory_format=tensor_meta.memory_format,
296
+ ),
297
+ torch.device("cpu"),
298
+ )
299
+ elif len(node.users) == 0:
300
+ return None
301
+ raise ExportPassBaseError(
302
+ f"Cannot construct an input for graph module: {graph_module}.",
303
+ )
304
+
305
+ return [
306
+ extract_input(node)
307
+ for node in graph_module.graph.nodes
308
+ if node.op == "placeholder"
309
+ ]
310
+
311
+ def on_attr(self, attr: ProxyValue) -> None:
312
+ pass
313
+
314
+ def placeholder(self, name: str, arg: Argument, meta: NodeMetadata) -> ProxyValue:
315
+ arg_proxy = self.tracer.create_proxy("placeholder", name, (), {})
316
+ arg_proxy.node.meta = meta.data
317
+ self.tracer.set_metadata(arg_proxy.node, arg)
318
+ return ProxyValue(arg, arg_proxy)
319
+
320
+ def call_operator(
321
+ self,
322
+ op,
323
+ args: Tuple[Argument, ...],
324
+ kwargs: Dict[str, Argument],
325
+ meta: NodeMetadata,
326
+ ) -> ProxyValue:
327
+ return self._fx("call_function", op, args, kwargs, meta)
328
+
329
+ def call_sym(
330
+ self,
331
+ target: Fn,
332
+ args: Tuple[Argument, ...],
333
+ meta: NodeMetadata,
334
+ ) -> ProxyValue:
335
+ return self._fx("call_function", target, args, {}, meta)
336
+
337
+ def call_cond(
338
+ self,
339
+ pred: ProxyValue,
340
+ true_fn: torch.fx.GraphModule,
341
+ false_fn: torch.fx.GraphModule,
342
+ inputs: List[Argument],
343
+ meta: NodeMetadata,
344
+ ) -> ProxyValue:
345
+ true_branch = self.call_submodule(true_fn, tuple(inputs))
346
+ false_branch = self.call_submodule(false_fn, tuple(inputs))
347
+ assert true_branch is not None
348
+ assert false_branch is not None
349
+ return self._fx(
350
+ "call_function",
351
+ torch.ops.higher_order.cond,
352
+ (pred, true_branch.graph_module, false_branch.graph_module, list(inputs)),
353
+ {},
354
+ meta,
355
+ )
356
+
357
+ def call_map(
358
+ self,
359
+ f: torch.fx.GraphModule,
360
+ mapped_args: List[ProxyValue],
361
+ operands: List[ProxyValue],
362
+ meta: NodeMetadata,
363
+ ) -> ProxyValue:
364
+ xs = _unstack_pytree([arg.data for arg in mapped_args])[0]
365
+ f_branch = self.call_submodule(f, tuple(xs + [arg.data for arg in operands]))
366
+ assert f_branch is not None
367
+ return self._fx(
368
+ "call_function",
369
+ torch.ops.higher_order.map_impl,
370
+ (f_branch.graph_module, mapped_args, operands),
371
+ {},
372
+ meta,
373
+ )
374
+
375
+ def call_getitem(
376
+ self, value: ProxyValue, key: int, meta: NodeMetadata
377
+ ) -> ProxyValue:
378
+ return self._fx("call_function", operator.getitem, (value, key), {}, meta)
379
+
380
+ def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue:
381
+ return self._fx("output", "output", (results,), {}, meta)
382
+
383
+ def call_submodule(
384
+ self, graph_module: fx.GraphModule, inputs: Tuple[Argument, ...]
385
+ ) -> PassResult:
386
+ prev_tracer, self.tracer = self.tracer, self.ExportTracer(
387
+ self, graph_module.graph._codegen
388
+ )
389
+ self.tracer.fake_tensor_mode = prev_tracer.fake_tensor_mode
390
+ interpreter = self.ExportInterpreter(self, graph_module)
391
+ prev_interpreter, self.interpreter = self.interpreter, torch.fx.Interpreter(
392
+ torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
393
+ )
394
+ inputs_data = pytree.tree_map_only(ProxyValue, lambda x: x.data, inputs)
395
+ with fx_traceback.preserve_node_meta():
396
+ interpreter.run(*inputs_data)
397
+
398
+ new_graph_module = torch.fx.GraphModule(self.tracer.root, self.tracer.graph)
399
+
400
+ self.tracer = prev_tracer
401
+ self.interpreter = prev_interpreter
402
+ return PassResult(
403
+ new_graph_module,
404
+ True,
405
+ )
406
+
407
+ def call(self, graph_module: fx.GraphModule) -> PassResult:
408
+ if not getattr(self, "_initialized", False):
409
+ raise ExportPassBaseError(
410
+ "ExportPass is not initialized with __init__().",
411
+ )
412
+
413
+ inputs = self.inputs(graph_module)
414
+
415
+ fake_tensor_mode = None
416
+ for i in inputs:
417
+ if isinstance(i, FakeTensor):
418
+ assert (
419
+ fake_tensor_mode is None or fake_tensor_mode is i.fake_mode
420
+ ), "Multiple fake tensor mode detected."
421
+ fake_tensor_mode = i.fake_mode
422
+ if fake_tensor_mode is None:
423
+ self.tracer.fake_tensor_mode = FakeTensorMode(allow_non_fake_inputs=True)
424
+ fake_tensor_mode = nullcontext() # type: ignore[assignment]
425
+ dispatcher_mode = nullcontext() # type: ignore[assignment]
426
+ else:
427
+ fake_tensor_mode.allow_non_fake_inputs = True
428
+ self.tracer.fake_tensor_mode = fake_tensor_mode
429
+ dispatcher_mode = enable_python_dispatcher() # type: ignore[assignment]
430
+ self.fake_tensor_mode = self.tracer.fake_tensor_mode
431
+
432
+ with fake_tensor_mode, dispatcher_mode: # type: ignore[assignment, union-attr]
433
+ result = self.call_submodule(graph_module, tuple(inputs))
434
+
435
+ return result
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__pycache__/proxy_value.cpython-311.pyc ADDED
Binary file (2.86 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/proxy_value.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pyre-strict
2
+ from typing import Union
3
+
4
+ import torch
5
+
6
+
7
+ class ProxyValue:
8
+ # pyre-ignore
9
+ def __init__(self, data, proxy: Union[torch.fx.Proxy, torch.fx.Node]):
10
+ # pyre-ignore
11
+ self.data = data
12
+ self.proxy_or_node = proxy
13
+
14
+ @property
15
+ def node(self) -> torch.fx.Node:
16
+ if isinstance(self.proxy_or_node, torch.fx.Node):
17
+ return self.proxy_or_node
18
+ assert isinstance(self.proxy_or_node, torch.fx.Proxy)
19
+ return self.proxy_or_node.node
20
+
21
+ @property
22
+ def proxy(self) -> torch.fx.Proxy:
23
+ if not isinstance(self.proxy_or_node, torch.fx.Proxy):
24
+ raise RuntimeError(
25
+ f"ProxyValue doesn't have attached Proxy object. Node: {self.proxy_or_node.format_node()}"
26
+ )
27
+ return self.proxy_or_node
28
+
29
+ def to_tensor(self) -> torch.Tensor:
30
+ assert isinstance(self.data, torch.Tensor)
31
+ return self.data
32
+
33
+ def is_tensor(self) -> bool:
34
+ return isinstance(self.data, torch.Tensor)
35
+
36
+ # pyre-ignore
37
+ def __iter__(self):
38
+ yield from self.data
39
+
40
+ def __bool__(self) -> bool:
41
+ return bool(self.data)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (220 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/union.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ from dataclasses import fields
3
+ from typing import Hashable, Set
4
+
5
+
6
+ class _UnionTag(str):
7
+ _cls: Hashable
8
+
9
+ @staticmethod
10
+ def create(t, cls):
11
+ tag = _UnionTag(t)
12
+ assert not hasattr(tag, "_cls")
13
+ tag._cls = cls
14
+ return tag
15
+
16
+ def __eq__(self, cmp) -> bool:
17
+ assert isinstance(cmp, str)
18
+ other = str(cmp)
19
+ assert other in _get_field_names(
20
+ self._cls
21
+ ), f"{other} is not a valid tag for {self._cls}. Available tags: {_get_field_names(self._cls)}"
22
+ return str(self) == other
23
+
24
+ def __hash__(self):
25
+ return hash(str(self))
26
+
27
+
28
+ @functools.lru_cache(maxsize=None)
29
+ def _get_field_names(cls) -> Set[str]:
30
+ return {f.name for f in fields(cls)}
31
+
32
+
33
+ class _Union:
34
+ _type: _UnionTag
35
+
36
+ @classmethod
37
+ def create(cls, **kwargs):
38
+ assert len(kwargs) == 1
39
+ obj = cls(**{**{f.name: None for f in fields(cls)}, **kwargs}) # type: ignore[arg-type]
40
+ obj._type = _UnionTag.create(next(iter(kwargs.keys())), cls)
41
+ return obj
42
+
43
+ def __post_init__(self):
44
+ assert not any(f.name in ("type", "_type", "create", "value") for f in fields(self)) # type: ignore[arg-type, misc]
45
+
46
+ @property
47
+ def type(self) -> str:
48
+ try:
49
+ return self._type
50
+ except AttributeError as e:
51
+ raise RuntimeError(
52
+ f"Please use {type(self).__name__}.create to instantiate the union type."
53
+ ) from e
54
+
55
+ @property
56
+ def value(self):
57
+ return getattr(self, self.type)
58
+
59
+ def __getattribute__(self, name):
60
+ attr = super().__getattribute__(name)
61
+ if attr is None and name in _get_field_names(type(self)) and name != self.type: # type: ignore[arg-type]
62
+ raise AttributeError(f"Field {name} is not set.")
63
+ return attr
64
+
65
+ def __str__(self):
66
+ return self.__repr__()
67
+
68
+ def __repr__(self):
69
+ return f"{type(self).__name__}({self.type}={getattr(self, self.type)})"
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__init__.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional
2
+
3
+ import torch.fx
4
+ import torch.utils._pytree as pytree
5
+
6
+ __all__ = ["compile", "list_mode_options", "list_options", "cudagraph_mark_step_begin"]
7
+
8
+
9
+ def compile(
10
+ gm: torch.fx.GraphModule,
11
+ example_inputs: List[torch.Tensor],
12
+ options: Optional[Dict[str, Any]] = None,
13
+ ):
14
+ """
15
+ Compile a given FX graph with TorchInductor. This allows compiling
16
+ FX graphs captured without using TorchDynamo.
17
+
18
+ Args:
19
+ gm: The FX graph to compile.
20
+ example_inputs: List of tensor inputs.
21
+ options: Optional dict of config options. See `torch._inductor.config`.
22
+
23
+ Returns:
24
+ Callable with same behavior as gm but faster.
25
+ """
26
+ from .compile_fx import compile_fx
27
+
28
+ return compile_fx(gm, example_inputs, config_patches=options)
29
+
30
+
31
+ def aot_compile(
32
+ gm: torch.fx.GraphModule,
33
+ example_inputs: List[torch.Tensor],
34
+ options: Optional[Dict[str, Any]] = None,
35
+ ) -> str:
36
+ """
37
+ Ahead-of-time compile a given FX graph with TorchInductor into a shared library.
38
+
39
+ Args:
40
+ gm: The FX graph to compile.
41
+ example_inputs: List of tensor inputs.
42
+ options: Optional dict of config options. See `torch._inductor.config`.
43
+
44
+ Returns:
45
+ Path to the generated shared library
46
+ """
47
+ from .compile_fx import compile_fx_aot
48
+
49
+ # We will serialize the pytree info into the .so as constant strings
50
+ in_spec = None
51
+ out_spec = None
52
+ if isinstance(gm.graph._codegen, torch.fx.graph._PyTreeCodeGen):
53
+ codegen = gm.graph._codegen
54
+ gm.graph._codegen = torch.fx.graph.CodeGen()
55
+ gm.recompile()
56
+
57
+ if codegen.pytree_info.in_spec is not None:
58
+ in_spec = codegen.pytree_info.in_spec
59
+ if codegen.pytree_info.out_spec is not None:
60
+ out_spec = codegen.pytree_info.out_spec
61
+
62
+ else:
63
+ if hasattr(gm, "_in_spec"):
64
+ in_spec = gm._in_spec
65
+ if hasattr(gm, "_out_spec"):
66
+ out_spec = gm._out_spec
67
+
68
+ serialized_in_spec = pytree.treespec_dumps(in_spec) if in_spec is not None else ""
69
+ serialized_out_spec = (
70
+ pytree.treespec_dumps(out_spec) if out_spec is not None else ""
71
+ )
72
+
73
+ options = (
74
+ {
75
+ "aot_inductor.serialized_in_spec": serialized_in_spec,
76
+ "aot_inductor.serialized_out_spec": serialized_out_spec,
77
+ }
78
+ if options is None
79
+ else {
80
+ **options,
81
+ "aot_inductor.serialized_in_spec": serialized_in_spec,
82
+ "aot_inductor.serialized_out_spec": serialized_out_spec,
83
+ }
84
+ )
85
+
86
+ return compile_fx_aot(
87
+ gm,
88
+ example_inputs,
89
+ config_patches=options,
90
+ )
91
+
92
+
93
+ def list_mode_options(
94
+ mode: Optional[str] = None, dynamic: Optional[bool] = None
95
+ ) -> Dict[str, Any]:
96
+ r"""Returns a dictionary describing the optimizations that each of the available
97
+ modes passed to `torch.compile()` performs.
98
+
99
+ Args:
100
+ mode (str, optional): The mode to return the optimizations for.
101
+ If None, returns optimizations for all modes
102
+ dynamic (bool, optional): Whether dynamic shape is enabled.
103
+
104
+ Example::
105
+ >>> torch._inductor.list_mode_options()
106
+ """
107
+
108
+ mode_options: Dict[str, Dict[str, bool]] = {
109
+ "default": {},
110
+ # enable cudagraphs
111
+ "reduce-overhead": {
112
+ "triton.cudagraphs": True,
113
+ },
114
+ # enable max-autotune
115
+ "max-autotune-no-cudagraphs": {
116
+ "max_autotune": True,
117
+ },
118
+ # enable max-autotune
119
+ # enable cudagraphs
120
+ "max-autotune": {
121
+ "max_autotune": True,
122
+ "triton.cudagraphs": True,
123
+ },
124
+ }
125
+ return mode_options[mode] if mode else mode_options # type: ignore[return-value]
126
+
127
+
128
+ def list_options() -> List[str]:
129
+ r"""Returns a dictionary describing the optimizations and debug configurations
130
+ that are available to `torch.compile()`.
131
+
132
+ The options are documented in `torch._inductor.config`.
133
+
134
+ Example::
135
+
136
+ >>> torch._inductor.list_options()
137
+ """
138
+
139
+ from torch._inductor import config
140
+
141
+ current_config: Dict[str, Any] = config.shallow_copy_dict()
142
+
143
+ return list(current_config.keys())
144
+
145
+
146
+ def cudagraph_mark_step_begin():
147
+ "Indicates that a new iteration of inference or training is about to begin."
148
+ from .cudagraph_trees import mark_step_begin
149
+
150
+ mark_step_begin()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (5.23 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/compile_fx.cpython-311.pyc ADDED
Binary file (68.1 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/config.cpython-311.pyc ADDED
Binary file (17 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/coordinate_descent_tuner.cpython-311.pyc ADDED
Binary file (12.1 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/debug.cpython-311.pyc ADDED
Binary file (38.4 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/dependencies.cpython-311.pyc ADDED
Binary file (33 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/freezing.cpython-311.pyc ADDED
Binary file (16.5 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/hooks.cpython-311.pyc ADDED
Binary file (1.33 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/inductor_prims.cpython-311.pyc ADDED
Binary file (4.98 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/metrics.cpython-311.pyc ADDED
Binary file (16.3 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/quantized_lowerings.cpython-311.pyc ADDED
Binary file (730 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/select_algorithm.cpython-311.pyc ADDED
Binary file (64.2 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/sizevars.cpython-311.pyc ADDED
Binary file (39.4 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/test_case.cpython-311.pyc ADDED
Binary file (3.26 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/test_operators.cpython-311.pyc ADDED
Binary file (1.99 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/virtualized.cpython-311.pyc ADDED
Binary file (21.4 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/wrapper_benchmark.cpython-311.pyc ADDED
Binary file (14.7 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_prefix.h ADDED
@@ -0,0 +1,595 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <algorithm>
4
+ #include <atomic>
5
+ #include <cmath>
6
+ #include <cstdlib>
7
+ #include <limits>
8
+ #include <omp.h>
9
+
10
+ #include <ATen/NumericUtils.h>
11
+ #include <ATen/core/PhiloxRNGEngine.h>
12
+ #include <ATen/native/Math.h>
13
+
14
+ #include <c10/util/Float8_e4m3fn.h>
15
+ #include <c10/util/Float8_e5m2.h>
16
+ #include <c10/util/BFloat16.h>
17
+ #include <c10/util/BFloat16-math.h>
18
+ #include <c10/util/generic_math.h>
19
+ #include <c10/util/Half.h>
20
+ #include <c10/util/TypeCast.h>
21
+
22
+ #if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR)
23
+ #define INDUCTOR_USE_VECTOR_TYPES() 1
24
+ #else
25
+ #define INDUCTOR_USE_VECTOR_TYPES() 0
26
+ #endif
27
+
28
+ #if INDUCTOR_USE_VECTOR_TYPES()
29
+ #include <ATen/cpu/vec/functional.h>
30
+ #include <ATen/cpu/vec/vec.h>
31
+ #include <ATen/cpu/vec/vec_n.h>
32
+ #endif
33
+
34
+ typedef at::Half half;
35
+ typedef at::BFloat16 bfloat16;
36
+
37
+ typedef at::Float8_e4m3fn float8_e4m3fn;
38
+ typedef at::Float8_e5m2 float8_e5m2;
39
+
40
+ template <typename T>
41
+ struct Welford {
42
+ T mean = T(0);
43
+ T m2 = T(0);
44
+ T weight = T(0);
45
+ };
46
+
47
+
48
+ template <typename T>
49
+ struct IsVecType: std::false_type {};
50
+
51
+ #if INDUCTOR_USE_VECTOR_TYPES()
52
+ template <typename T>
53
+ struct IsVecType<at::vec::Vectorized<T>>: std::true_type {};
54
+ #endif
55
+
56
+ template <typename T>
57
+ Welford<T> welford_combine(const Welford<T> &a, const Welford<T> &b) {
58
+ if constexpr (!IsVecType<T>::value) {
59
+ if (a.weight == 0) {
60
+ return b;
61
+ }
62
+ if (b.weight == 0) {
63
+ return a;
64
+ }
65
+ }
66
+ auto delta = b.mean - a.mean;
67
+ auto new_weight = a.weight + b.weight;
68
+ auto wb_over_w = b.weight / new_weight;
69
+ if constexpr (IsVecType<T>::value) {
70
+ // Guard against division by zero
71
+ wb_over_w = T::blendv(wb_over_w, T(0), new_weight == T(0));
72
+ }
73
+ auto result = Welford<T>{
74
+ a.mean + delta * wb_over_w,
75
+ a.m2 + b.m2 + delta * delta * a.weight * wb_over_w,
76
+ new_weight
77
+ };
78
+ return result;
79
+ }
80
+
81
+ template <typename T>
82
+ Welford<T> welford_combine(const Welford<T> &acc, T data) {
83
+ // Add a single data point
84
+ auto delta = data - acc.mean;
85
+ auto new_weight = acc.weight + T(1);
86
+ auto new_mean = acc.mean + delta / new_weight;
87
+ auto new_delta = data - new_mean;
88
+ auto result = Welford<T>{
89
+ new_mean,
90
+ acc.m2 + delta * new_delta,
91
+ new_weight
92
+ };
93
+ return result;
94
+ }
95
+
96
+ // Refer to https://github.com/pytorch/pytorch/blob/b5b36cf0c4e1958f1ff25120f5d4beeef3288187/
97
+ // aten/src/ATen/native/SharedReduceOps.h#L419-L445
98
+ template <typename scalar_t>
99
+ inline bool greater_or_nan(scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) {
100
+ // If (a == b), then choose the one with lower idx, else max(a, b)
101
+ if (at::_isnan(a)) {
102
+ if (at::_isnan(b)) {
103
+ return idx_a < idx_b;
104
+ }
105
+ return true;
106
+ }
107
+ return (a == b) ? idx_a < idx_b : (a > b);
108
+ }
109
+
110
+ template <typename scalar_t>
111
+ inline bool less_or_nan(scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) {
112
+ // If (a == b), then choose the one with lower idx, else min(a, b)
113
+ if (at::_isnan(a)) {
114
+ if (at::_isnan(b)) {
115
+ return idx_a < idx_b;
116
+ }
117
+ return true;
118
+ }
119
+ return (a == b) ? idx_a < idx_b : (a < b);
120
+ }
121
+
122
+ #if INDUCTOR_USE_VECTOR_TYPES()
123
+ template <typename scalar_t>
124
+ inline at::vec::Vectorized<scalar_t> vec_shuffle_down(at::vec::Vectorized<scalar_t> x, size_t n) {
125
+ using Vec = at::vec::Vectorized<scalar_t>;
126
+ alignas(alignof(Vec)) scalar_t array[Vec::size()];
127
+ x.store(array);
128
+ for (size_t i = 0; i + n < Vec::size(); i += 2 * n) {
129
+ array[i] = array[i + n];
130
+ }
131
+ return Vec::loadu(array);
132
+ }
133
+
134
+ #ifdef CPU_CAPABILITY_AVX2
135
+ inline at::vec::Vectorized<float> vec_shuffle_down(at::vec::Vectorized<float> x, size_t n) {
136
+ using vec_t = at::vec::Vectorized<float>;
137
+ #define SHUFFLE_MASK(z, y, x, w) ((z << 6) | (y << 4) | (x << 2) | w)
138
+ switch (n) {
139
+ case 1:
140
+ return vec_t(_mm256_permute_ps(x, SHUFFLE_MASK(1, 1, 3, 3)));
141
+ case 2:
142
+ return vec_t(_mm256_permute_ps(x, SHUFFLE_MASK(2, 2, 2, 2)));
143
+ case 4:
144
+ return vec_t(_mm256_permute2f128_ps(x, x, SHUFFLE_MASK(1, 1, 1, 1)));
145
+ }
146
+ TORCH_CHECK(false, "Unhandled vec_shuffle_down value ", n);
147
+ }
148
+ #endif
149
+
150
+ template <typename scalar_t>
151
+ Welford<scalar_t> welford_vec_reduce_all(Welford<at::vec::Vectorized<scalar_t>> acc) {
152
+ using Vec = at::vec::Vectorized<scalar_t>;
153
+ for (size_t n = 1; n < Vec::size(); n *= 2) {
154
+ auto shuffled = Welford<Vec>{
155
+ vec_shuffle_down(acc.mean, n),
156
+ vec_shuffle_down(acc.m2, n),
157
+ vec_shuffle_down(acc.weight, n)
158
+ };
159
+ acc = welford_combine(acc, shuffled);
160
+ }
161
+
162
+ Welford<scalar_t> result;
163
+ alignas(alignof(Vec)) scalar_t array[Vec::size()];
164
+ acc.mean.store(array);
165
+ result.mean = array[0];
166
+
167
+ acc.m2.store(array);
168
+ result.m2 = array[0];
169
+
170
+ acc.weight.store(array);
171
+ result.weight = array[0];
172
+
173
+ return result;
174
+ }
175
+ #endif
176
+
177
+
178
+ template <typename T, typename U> inline typename std::common_type<T, U>::type mod(T a, U b) { return a % b; }
179
+ template <> inline float mod(float a, float b) { return std::fmod(a, b); }
180
+ template <> inline double mod(double a, double b) { return std::fmod(a, b); }
181
+
182
+ template <typename scalar_t>
183
+ inline scalar_t max_propagate_nan(scalar_t a, scalar_t b) {
184
+ if (at::_isnan(a)) {
185
+ return a;
186
+ }
187
+ return a > b ? a : b;
188
+ }
189
+
190
+ template <typename scalar_t>
191
+ inline scalar_t min_propagate_nan(scalar_t a, scalar_t b) {
192
+ if (at::_isnan(a)) {
193
+ return a;
194
+ }
195
+ return a < b ? a : b;
196
+ }
197
+
198
+ constexpr float uint32_to_uniform_float(uint32_t value) {
199
+ // maximum value such that `MAX_INT * scale < 1.0` (with float rounding)
200
+ constexpr float scale = 4.6566127342e-10;
201
+ return static_cast<float>(value & 0x7FFFFFFF) * scale;
202
+ }
203
+
204
+ float normalized_rand_cpu(uint32_t seed, uint32_t offset) {
205
+ return uint32_to_uniform_float(at::Philox4_32(seed, 0, offset)());
206
+ }
207
+
208
+ float randn_cpu(uint32_t seed, uint32_t offset) {
209
+ at::Philox4_32 engine(seed, 0, offset);
210
+ return engine.randn(10);
211
+ }
212
+
213
+ int64_t randint64_cpu(uint32_t seed, uint32_t offset, int64_t low, int64_t high) {
214
+ auto gen = at::Philox4_32(seed, 0, offset);
215
+ uint64_t r0 = gen();
216
+ uint64_t r1 = gen();
217
+ uint64_t result = r0 | (r1 << 32);
218
+ return static_cast<int64_t>(result % (high - low)) + low;
219
+ }
220
+
221
+ template <typename T> struct AsIntegerType { typedef T type; };
222
+ template <> struct AsIntegerType<float> { typedef uint32_t type; };
223
+ template <> struct AsIntegerType<double> { typedef uint64_t type; };
224
+ template <> struct AsIntegerType<bfloat16> { typedef uint16_t type; };
225
+
226
+ template <typename T>
227
+ typename std::enable_if<!std::is_reduced_floating_point<T>::value, T>::type
228
+ inline fetch_value(volatile T *addr) {
229
+ return *addr;
230
+ }
231
+
232
+ template <typename T>
233
+ typename std::enable_if<std::is_reduced_floating_point<T>::value, T>::type
234
+ inline fetch_value(volatile T *addr) {
235
+ return T(addr->x, T::from_bits());
236
+ }
237
+
238
+ template <typename T>
239
+ typename std::enable_if<!std::is_integral<T>::value>::type
240
+ atomic_add(volatile T *addr, T offset) {
241
+ typedef typename AsIntegerType<T>::type alt_type;
242
+
243
+ static_assert(sizeof(std::atomic<alt_type>) == sizeof(T),
244
+ "std::atomic issue");
245
+
246
+ alt_type expected;
247
+
248
+ alt_type desired;
249
+
250
+ std::atomic<alt_type> *atomic_addr = (std::atomic<alt_type> *)addr;
251
+ do {
252
+ T val = fetch_value(addr);
253
+ reinterpret_cast<T *>(&expected)[0] = val;
254
+ reinterpret_cast<T *>(&desired)[0] = val + offset;
255
+ } while (!atomic_addr->compare_exchange_weak(expected, desired,
256
+ std::memory_order_relaxed));
257
+ }
258
+
259
+ // Since C++20 float is supported by fetch_add, but the performance may not
260
+ // better than compare_exchange_weak, which can be checked by microbenchmark
261
+ // inductor_cpu_atomic.py
262
+ template <typename T>
263
+ typename std::enable_if<std::is_integral<T>::value>::type
264
+ atomic_add(volatile T *addr, T offset) {
265
+ static_assert(sizeof(std::atomic<T>) == sizeof(T),
266
+ "std::atomic issue");
267
+ std::atomic<T> *atomic_addr = (std::atomic<T> *)addr;
268
+ atomic_addr->fetch_add(offset, std::memory_order_relaxed);
269
+ }
270
+
271
+ // This function is used to convert bool or uint8 to float mask for
272
+ // vectorization. The caller needs to make sure the src represents TRUE/FALSE
273
+ // correctly.
274
+ template <typename T>
275
+ inline float flag_to_float_scalar(T src) {
276
+ float ret;
277
+ *(uint32_t*)(&ret) = src ? 0xFFFFFFFF : 0;
278
+ return ret;
279
+ }
280
+
281
+ #if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR)
282
+
283
+ inline at::vec::Vectorized<float> masked_load(const float* src, at::vec::Vectorized<float> mask) {
284
+ # if defined(CPU_CAPABILITY_AVX512)
285
+ at::vec::Vectorized<float> zero_vec(0);
286
+ auto all_ones = _mm512_set1_epi32(0xFFFFFFFF);
287
+ auto mmask = _mm512_cmp_epi32_mask(_mm512_castps_si512(mask), all_ones, _MM_CMPINT_EQ);
288
+ return _mm512_mask_loadu_ps(zero_vec, mmask, src);
289
+ # elif defined(CPU_CAPABILITY_AVX2)
290
+ auto all_ones = _mm256_set1_epi32(0xFFFFFFFF);
291
+ auto mmask = _mm256_cmpeq_epi32(_mm256_castps_si256(mask), all_ones);
292
+ return _mm256_maskload_ps(src, mmask);
293
+ # elif defined(CPU_CAPABILITY_ZVECTOR)
294
+ auto result = at::vec::Vectorized<float>::loadu(src);
295
+ return (result & mask);
296
+ # else
297
+ # error Unsupported vectorization CPU capability
298
+ # endif
299
+ }
300
+
301
+ template <typename T>
302
+ typename std::enable_if<std::is_same<T, bfloat16>::value || std::is_same<T, half>::value, at::vec::Vectorized<T>>::type
303
+ inline masked_load(const T* src, at::vec::Vectorized<float> mask) {
304
+ # if defined(CPU_CAPABILITY_AVX512)
305
+ auto all_ones = _mm512_set1_epi32(0xFFFFFFFF);
306
+ auto mmask = _mm512_cmp_epi32_mask(_mm512_castps_si512(mask), all_ones, _MM_CMPINT_EQ);
307
+ auto zero = _mm256_set1_epi16(0);
308
+ auto temp = _mm256_mask_loadu_epi16(zero, mmask, src);
309
+ return _mm512_inserti32x8(_mm512_castsi256_si512(temp), zero, 1);
310
+ # elif defined(CPU_CAPABILITY_AVX2)
311
+ auto all_ones = _mm256_set1_epi32(0xFFFFFFFF);
312
+ auto mmask_vec = _mm256_cmpeq_epi32(_mm256_castps_si256(mask), all_ones);
313
+ __at_align__ uint32_t mmask[8];
314
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(mmask), mmask_vec);
315
+ __at_align__ uint16_t result[16];
316
+ for (auto i = 0; i < 8; i++) {
317
+ result[i] = mmask[i] == 0xFFFFFFFF ? src[i].x: uint16_t(0);
318
+ }
319
+ return at::vec::Vectorized<T>::loadu(result);
320
+ # elif defined(CPU_CAPABILITY_ZVECTOR)
321
+ auto result = at::vec::Vectorized<T>::loadu(src, 8);
322
+ uint32_t maskdata[8] = { 0 };
323
+ uint16_t maskdata_dest[16] = { 0 };
324
+ mask.store(maskdata);
325
+ for (auto i = 0; i < 8; i++) {
326
+ maskdata_dest[i] = (maskdata[i] == 0xFFFFFFFF) ? 0xFFFF: 0;
327
+ }
328
+ auto maskvector = at::vec::Vectorized<T>::loadu(maskdata_dest);
329
+ return (result & maskvector);
330
+ # else
331
+ # error Unsupported vectorization CPU capability
332
+ # endif
333
+ }
334
+
335
+ template <typename T>
336
+ typename std::enable_if<std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value, at::vec::Vectorized<T>>::type
337
+ inline masked_load(const T* src, at::vec::Vectorized<float> mask) {
338
+ # if defined(CPU_CAPABILITY_AVX512)
339
+ auto all_ones = _mm512_set1_epi32(0xFFFFFFFF);
340
+ auto mmask = _mm512_cmp_epi32_mask(_mm512_castps_si512(mask), all_ones, _MM_CMPINT_EQ);
341
+ auto zero = _mm_set1_epi8(0);
342
+ auto temp = _mm_mask_loadu_epi8(zero, mmask, src);
343
+ return _mm512_inserti64x2(_mm512_set1_epi32(0), temp, 0);
344
+ # elif defined(CPU_CAPABILITY_AVX2)
345
+ auto all_ones = _mm256_set1_epi32(0xFFFFFFFF);
346
+ auto mmask_vec = _mm256_cmpeq_epi32(_mm256_castps_si256(mask), all_ones);
347
+ __at_align__ uint32_t mmask[8];
348
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(mmask), mmask_vec);
349
+ __at_align__ T result[32];
350
+ for (auto i = 0; i < 8; i++) {
351
+ result[i] = mmask[i] == 0xFFFFFFFF ? src[i]: T(0);
352
+ }
353
+ return at::vec::Vectorized<T>::loadu(result);
354
+ # elif defined(CPU_CAPABILITY_ZVECTOR)
355
+ auto result = at::vec::Vectorized<T>::loadu(src, 8);
356
+ uint32_t maskdata[8];
357
+ T maskdata_dest[32] = { 0 };
358
+ mask.store(maskdata);
359
+ for (auto i = 0; i < 8; i++) {
360
+ maskdata_dest[i] = (maskdata[i] == 0xFFFFFFFF) ? 0xFF: 0;
361
+ }
362
+ auto maskvector = at::vec::Vectorized<T>::loadu(maskdata_dest);
363
+ return (result & maskvector);
364
+ # else
365
+ # error Unsupported vectorization CPU capability
366
+ # endif
367
+ }
368
+
369
+ template <typename T>
370
+ inline at::vec::Vectorized<float> flag_to_float_vec(const T* src) {
371
+ __at_align__ float dst_tmp[at::vec::Vectorized<float>::size()];
372
+ #pragma unroll
373
+ for (int64_t i = 0; i < at::vec::Vectorized<float>::size(); i++) {
374
+ dst_tmp[i] = flag_to_float_scalar(src[i]);
375
+ }
376
+ return at::vec::Vectorized<float>::loadu(dst_tmp);
377
+ }
378
+
379
+ template <typename scalar_t>
380
+ inline at::vec::Vectorized<float> cvt_lowp_fp_to_fp32(
381
+ at::vec::Vectorized<scalar_t> src) {
382
+ at::vec::Vectorized<float> res_vec1(0);
383
+ at::vec::Vectorized<float> res_vec2(0);
384
+ std::tie(res_vec1, res_vec2) = at::vec::convert_to_float<scalar_t>(src);
385
+ return res_vec1;
386
+ }
387
+
388
+ template <typename scalar_t>
389
+ inline at::vec::Vectorized<scalar_t> cvt_fp32_to_lowp_fp(
390
+ at::vec::Vectorized<float> src) {
391
+ return at::vec::convert_from_float<scalar_t>(src, src);
392
+ }
393
+
394
+ inline at::vec::Vectorized<float> mask_convert_to_float(at::vec::Vectorized<float> src) {
395
+ auto zeros = at::vec::Vectorized<float>(0);
396
+ auto ones = at::vec::Vectorized<float>(1);
397
+ return at::vec::Vectorized<float>::blendv(zeros, ones, src);
398
+ }
399
+
400
+ template <typename scalar_t>
401
+ inline
402
+ typename std::enable_if<std::is_same<scalar_t, bfloat16>::value || std::is_same<scalar_t, half>::value, at::vec::Vectorized<scalar_t>>::type
403
+ mask_convert_to_lowp(at::vec::Vectorized<float> src) {
404
+ auto fp_vec = mask_convert_to_float(src);
405
+ return cvt_fp32_to_lowp_fp<scalar_t>(fp_vec);
406
+ }
407
+
408
+ template <typename SRC>
409
+ inline at::vec::Vectorized<float> vec_convert_to_mask(at::vec::Vectorized<SRC> src) {
410
+ assert(
411
+ at::vec::Vectorized<float>::size() == at::vec::Vectorized<SRC>::size());
412
+ at::vec::Vectorized<float> res_vec(0);
413
+ __at_align__ float dst_tmp[at::vec::Vectorized<float>::size()];
414
+ __at_align__ SRC src_tmp[at::vec::Vectorized<SRC>::size()];
415
+ src.store(src_tmp);
416
+
417
+ #pragma unroll
418
+ for (int i = 0; i < at::vec::Vectorized<float>::size(); i++) {
419
+ *(uint32_t*)(dst_tmp + i) = src_tmp[i] ? 0xFFFFFFFF : 0;
420
+ }
421
+
422
+ return res_vec.loadu(dst_tmp);
423
+ }
424
+
425
+ template <typename SRC>
426
+ inline at::vec::Vectorized<float> to_float_mask(at::vec::Vectorized<SRC> src) {
427
+ return vec_convert_to_mask(src);
428
+ }
429
+
430
+ #if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2)
431
+ template <>
432
+ inline at::vec::Vectorized<float> to_float_mask(at::vec::Vectorized<int> src) {
433
+ #if defined(CPU_CAPABILITY_AVX2)
434
+ return at::vec::Vectorized<float>(_mm256_castsi256_ps(src));
435
+ #else
436
+ return at::vec::Vectorized<float>(_mm512_castsi512_ps(src));
437
+ #endif
438
+ }
439
+ #endif
440
+
441
+ template <>
442
+ inline at::vec::Vectorized<float> to_float_mask(at::vec::Vectorized<float> src) {
443
+ return src;
444
+ }
445
+
446
+ inline at::vec::Vectorized<float> to_float_mask(int src) {
447
+ union {
448
+ float fmask;
449
+ uint32_t imask;
450
+ } mask;
451
+ mask.imask = src ? 0xFFFFFFFF : 0;
452
+ return at::vec::Vectorized<float>(mask.fmask);
453
+ }
454
+
455
+ inline bool all_zero(at::vec::Vectorized<float> src) {
456
+ # if defined(CPU_CAPABILITY_AVX512)
457
+ auto src_int = _mm512_castps_si512(src);
458
+ __mmask16 mask = _mm512_test_epi32_mask(src_int, src_int);
459
+ return mask == 0;
460
+ # elif defined(CPU_CAPABILITY_AVX2)
461
+ return _mm256_testz_ps(src, src);
462
+ # else
463
+ __at_align__ int mask[at::vec::Vectorized<float>::size()];
464
+ src.store(mask);
465
+ for (int i = 0; i < at::vec::Vectorized<float>::size(); i++) {
466
+ if (mask[i] != 0) {
467
+ return false;
468
+ }
469
+ }
470
+ return true;
471
+ # endif
472
+ }
473
+
474
+ inline bool vector_lane_mask_check(at::vec::Vectorized<float> src, int lane) {
475
+ # if defined(CPU_CAPABILITY_AVX512)
476
+ return _mm512_movepi32_mask(_mm512_castps_si512(src)) & (1 << lane);
477
+ # elif defined(CPU_CAPABILITY_AVX2)
478
+ return _mm256_movemask_ps(src) & (1 << lane);
479
+ # else
480
+ __at_align__ int mask[at::vec::Vectorized<float>::size()];
481
+ src.store(mask);
482
+ return mask[lane] != 0;
483
+ # endif
484
+ }
485
+
486
+ inline at::vec::Vectorized<float> cvt_int64_to_fp32(at::vec::VectorizedN<int64_t,2> src) {
487
+ # if defined(CPU_CAPABILITY_AVX512)
488
+ auto low = _mm512_cvtepi64_ps(src[0]);
489
+ auto high = _mm512_cvtepi64_ps(src[1]);
490
+ return _mm512_insertf32x8(_mm512_castps256_ps512(low), high, 1);
491
+ # elif defined(CPU_CAPABILITY_AVX2)
492
+ auto low_double = at::vec::convert_to_fp_of_same_size<double>(src[0]);
493
+ auto low = _mm256_cvtpd_ps(low_double);
494
+ auto high_double = at::vec::convert_to_fp_of_same_size<double>(src[1]);
495
+ auto high = _mm256_cvtpd_ps(high_double);
496
+ return _mm256_insertf128_ps(_mm256_castps128_ps256(low), high, 1);
497
+ # else
498
+ constexpr int float_vec_size = at::vec::Vectorized<float>::size();
499
+ constexpr int int64_vec_size = at::vec::Vectorized<int64_t>::size();
500
+ __at_align__ float result[float_vec_size];
501
+ __at_align__ int64_t src_buf[int64_vec_size];
502
+ for (int i = 0; i < 2; i++) {
503
+ src[i].store(src_buf + i * int64_vec_size);
504
+ for (int j = 0; j < int64_vec_size; j++) {
505
+ result[i * int64_vec_size + j] = static_cast<float>(src_buf[i * int64_vec_size + j]);
506
+ }
507
+ }
508
+ return at::vec::Vectorized<float>::loadu(result);
509
+ # endif
510
+ }
511
+
512
+ inline at::vec::VectorizedN<int64_t,2> cvt_fp32_to_int64(at::vec::Vectorized<float> src) {
513
+ at::vec::VectorizedN<int64_t,2> result;
514
+ # if defined(CPU_CAPABILITY_AVX512)
515
+ result[0] = _mm512_cvt_roundps_epi64(_mm512_castps512_ps256(src), _MM_FROUND_TO_ZERO |_MM_FROUND_NO_EXC);
516
+ result[1] = _mm512_cvt_roundps_epi64(_mm512_extractf32x8_ps(src, 1), _MM_FROUND_TO_ZERO |_MM_FROUND_NO_EXC);
517
+ # elif defined(CPU_CAPABILITY_AVX2)
518
+ auto int32_vec = at::vec::convert_to_int_of_same_size(src);
519
+ result[0] = _mm256_cvtepi32_epi64(_mm256_castsi256_si128(int32_vec));
520
+ result[1] = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(int32_vec, 1));
521
+ # else
522
+ constexpr int float_vec_size = at::vec::Vectorized<float>::size();
523
+ constexpr int int64_vec_size = at::vec::Vectorized<int64_t>::size();
524
+ __at_align__ float src_buf[float_vec_size];
525
+ __at_align__ int64_t result_buf[int64_vec_size];
526
+ src.store(src_buf);
527
+ for (int i = 0; i < 2; i++) {
528
+ for (int j = 0; j < int64_vec_size; j++) {
529
+ result_buf[j] = static_cast<int64_t>(src_buf[i * int64_vec_size + j]);
530
+ }
531
+ result[i] = at::vec::Vectorized<int64_t>::loadu(result_buf);
532
+ }
533
+ # endif
534
+ return result;
535
+ }
536
+
537
+ inline at::vec::Vectorized<int32_t> cvt_int64_to_int32(at::vec::VectorizedN<int64_t,2> src) {
538
+ # if defined(CPU_CAPABILITY_AVX512)
539
+ auto low = _mm512_cvtepi64_epi32(src[0]);
540
+ auto high = _mm512_cvtepi64_epi32(src[1]);
541
+ return _mm512_inserti32x8(_mm512_castsi256_si512(low), high, 1);
542
+ # elif defined(CPU_CAPABILITY_AVX2)
543
+ auto low = _mm256_shuffle_epi32(src[0], _MM_SHUFFLE(2, 0, 2, 0));
544
+ auto high = _mm256_shuffle_epi32(src[1], _MM_SHUFFLE(2, 0, 2, 0));
545
+ auto low_perm = _mm256_permute4x64_epi64(low, _MM_SHUFFLE(3, 1, 2, 0));
546
+ auto high_perm = _mm256_permute4x64_epi64(high, _MM_SHUFFLE(3, 1, 2, 0));
547
+ return _mm256_blend_epi32(low_perm, high_perm, 0xF0);
548
+ # else
549
+ constexpr int int32_vec_size = at::vec::Vectorized<int32_t>::size();
550
+ constexpr int int64_vec_size = at::vec::Vectorized<int64_t>::size();
551
+ __at_align__ int32_t result[int32_vec_size];
552
+ __at_align__ int64_t src_buf[int64_vec_size];
553
+ for (int i = 0; i < 2; i++) {
554
+ src[i].store(src_buf + i * int64_vec_size);
555
+ for (int j = 0; j < int64_vec_size; j++) {
556
+ result[i * int64_vec_size + j] = static_cast<int32_t>(src_buf[i * int64_vec_size + j]);
557
+ }
558
+ }
559
+ return at::vec::Vectorized<int32_t>::loadu(result);
560
+ # endif
561
+ }
562
+
563
+ inline at::vec::VectorizedN<int64_t,2> cvt_int32_to_int64(at::vec::Vectorized<int32_t> src) {
564
+ at::vec::VectorizedN<int64_t,2> result;
565
+ # if defined(CPU_CAPABILITY_AVX512)
566
+ result[0] = _mm512_cvtepi32_epi64(_mm512_castsi512_si256(src));
567
+ result[1] = _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(src, 1));
568
+ # elif defined(CPU_CAPABILITY_AVX2)
569
+ result[0] = _mm256_cvtepi32_epi64(_mm256_castsi256_si128(src));
570
+ result[1] = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(src, 1));
571
+ #else
572
+ constexpr int int32_vec_size = at::vec::Vectorized<int32_t>::size();
573
+ constexpr int int64_vec_size = at::vec::Vectorized<int64_t>::size();
574
+ __at_align__ int32_t src_buf[int32_vec_size];
575
+ __at_align__ int64_t result_buf[int64_vec_size];
576
+ src.store(src_buf);
577
+ for (int i = 0; i < 2; i++) {
578
+ for (int j = 0; j < int64_vec_size; j++) {
579
+ result_buf[j] = static_cast<int64_t>(src_buf[i * int64_vec_size + j]);
580
+ }
581
+ result[i] = at::vec::Vectorized<int64_t>::loadu(result_buf);
582
+ }
583
+ # endif
584
+ return result;
585
+ }
586
+
587
+ inline at::vec::VectorizedN<int64_t,2> mask_convert_to_int64(at::vec::Vectorized<float> src) {
588
+ return cvt_fp32_to_int64(mask_convert_to_float(src));
589
+ }
590
+
591
+ inline at::vec::Vectorized<float> to_float_mask(at::vec::VectorizedN<int64_t,2> src) {
592
+ return to_float_mask(cvt_int64_to_int32(src));
593
+ }
594
+
595
+ #endif
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu.py ADDED
@@ -0,0 +1,1851 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import os
3
+ import sys
4
+ from itertools import count
5
+ from typing import List, Optional, Tuple
6
+
7
+ import sympy
8
+ from sympy import Expr
9
+
10
+ import torch
11
+ import torch._ops
12
+ from .. import config, ir
13
+
14
+ from ..codecache import CudaKernelParamCache
15
+ from ..utils import cache_on_self, sympy_product
16
+ from ..virtualized import V
17
+ from .common import IndentedBuffer
18
+ from .wrapper import EnterSubgraphLine, ExitSubgraphLine, pexpr, WrapperCodeGen
19
+
20
+
21
+ class CppWrapperCpu(WrapperCodeGen):
22
+ """
23
+ Generates cpp wrapper for running on CPU and calls cpp kernels
24
+ """
25
+
26
+ def __init__(self):
27
+ if not hasattr(self, "device"):
28
+ self.device = "cpu"
29
+ super().__init__()
30
+ self.declare = "auto "
31
+ self.declare_maybe_reference = "decltype(auto) "
32
+ self.ending = ";"
33
+ self.open_bracket = "{"
34
+ self.closed_bracket = "}"
35
+ self.comment = "//"
36
+ self.namespace = "at::"
37
+ self.none_str = "nullptr" if config.abi_compatible else "at::Tensor()"
38
+ self.extern_call_ops = set()
39
+ self.size = "sizes()"
40
+ self.stride = "strides()"
41
+ self.cuda = False
42
+ self.supports_intermediate_hooks = False
43
+ self.outputs_need_copy = set()
44
+ self.kernel_callsite_id = count()
45
+ self.int_array_id = count() # for int array local variable declarations
46
+ self.declared_int_array_vars = set()
47
+ self.tmp_tensor_id = count() # for tmp tensor local variable declarations
48
+ self.arg_var_id = count()
49
+ self.used_cached_devices = set()
50
+ self.used_cached_dtypes = set()
51
+ self.cached_output_id = count()
52
+ self.scalar_to_tensor_id = count()
53
+
54
+ from .cpp import cexpr, CppPrinter
55
+
56
+ self.expr_printer = cexpr
57
+
58
+ # CppPrinter sometimes calls at::native functions which causes problems in
59
+ # the ABI-compatible mode. Currently we are hitting this problem when codegen
60
+ # Grid computation expressions, but we my need to fix other size computation
61
+ # as well.
62
+ class GridExprCppPrinter(CppPrinter):
63
+ def _print_FloorDiv(self, expr):
64
+ x, div = expr.args
65
+ x = self.paren(self.doprint(x))
66
+ div = self.paren(self.doprint(div))
67
+ assert expr.is_integer, "Expect integers in GridExprPrinter"
68
+ return f"({x}/{div})"
69
+
70
+ self.grid_expr_printer = GridExprCppPrinter().doprint
71
+
72
+ def generate_kernel_call(
73
+ self,
74
+ name,
75
+ call_args,
76
+ grid=None,
77
+ device_index=None,
78
+ cuda=True,
79
+ triton=True,
80
+ arg_types=None,
81
+ grid_fn: str = "grid",
82
+ triton_meta=None,
83
+ ):
84
+ """
85
+ Generates kernel call code.
86
+
87
+ cuda: Defines whether the backend is GPU. Otherwise the backend is CPU.
88
+
89
+ triton: Defines whether the GPU backend uses Triton for codegen.
90
+ Otherwise it uses the CUDA language for codegen.
91
+ Only valid when cuda == True.
92
+ """
93
+ if cuda:
94
+ return super().generate_kernel_call(
95
+ name,
96
+ call_args,
97
+ grid,
98
+ device_index,
99
+ cuda,
100
+ triton,
101
+ arg_types,
102
+ grid_fn,
103
+ )
104
+ else:
105
+ if config.abi_compatible:
106
+ assert arg_types is not None and len(call_args) == len(
107
+ arg_types
108
+ ), "Mismatch call_args and arg_types in generate_kernel_call"
109
+ new_args = []
110
+ for idx, arg in enumerate(call_args):
111
+ if "*" in arg_types[idx]:
112
+ var_name = f"var_{next(self.arg_var_id)}"
113
+ self.writeline(
114
+ f"auto* {var_name} = get_data_ptr_wrapper({arg});"
115
+ )
116
+ new_args.append(f"({arg_types[idx]})({var_name})")
117
+ else:
118
+ # arg is a scalar
119
+ new_args.append(arg)
120
+ self.writeline(self.wrap_kernel_call(name, new_args))
121
+ else:
122
+ self.writeline(self.wrap_kernel_call(name, call_args))
123
+
124
+ def write_constant(self, name, hashed):
125
+ # include a hash so our code cache gives different constants different files
126
+ self.header.writeline(f"// {name} {hashed}")
127
+
128
+ def write_header(self):
129
+ if V.graph.is_const_graph:
130
+ # We do not write header for constant graph, it will be written by main module.
131
+ return
132
+
133
+ if V.graph.aot_mode:
134
+ for header_cpp_file in ("interface.cpp", "implementation.cpp"):
135
+ with open(
136
+ os.path.join(
137
+ os.path.dirname(__file__), "aoti_runtime", header_cpp_file
138
+ )
139
+ ) as f:
140
+ self.header.splice(f.read())
141
+ else:
142
+ self.header.splice(
143
+ """
144
+ import torch
145
+ from torch._inductor.codecache import CppWrapperCodeCache
146
+
147
+ cpp_wrapper_src = (
148
+ '''
149
+ """
150
+ )
151
+
152
+ if config.abi_compatible:
153
+ if config.c_shim_version == "1":
154
+ self.header.splice("#include <torch/csrc/inductor/aoti_torch/c/shim.h>")
155
+ else:
156
+ self.header.splice(
157
+ f"#include <torch/csrc/inductor/aoti_torch/generated/c_shim_{self.device}.h>"
158
+ )
159
+ self.header.splice(
160
+ """
161
+ #include <torch/csrc/inductor/aoti_runtime/arrayref_tensor.h>
162
+ #include <torch/csrc/inductor/aoti_runtime/thread_local.h>
163
+ #include <torch/csrc/inductor/aoti_runtime/scalar_to_tensor.h>
164
+ """
165
+ )
166
+ if V.graph.aot_mode:
167
+ self.header.splice(
168
+ """
169
+ #include <torch/csrc/inductor/aoti_runtime/model.h>
170
+ """
171
+ )
172
+ else:
173
+ self.header.splice(
174
+ """
175
+ #include <ATen/ATen.h>
176
+ #include <ATen/core/dispatch/Dispatcher.h>
177
+ #include <ATen/native/BinaryOps.h>
178
+ #include <torch/csrc/inductor/aoti_runtime/utils.h>
179
+ #include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
180
+ #include <torch/csrc/inductor/inductor_ops.h>
181
+ #include <torch/types.h>
182
+ #include <ATen/ops/bernoulli_native.h>
183
+
184
+ #define reinterpret_tensor torch::inductor::_reinterpret_tensor
185
+ #define alloc_from_pool torch::inductor::_alloc_from_pool
186
+ """
187
+ )
188
+
189
+ self.header.splice("#include <c10/util/generic_math.h>")
190
+
191
+ if not V.graph.aot_mode:
192
+ self.header.splice(
193
+ """
194
+ #include <pybind11/pybind11.h>
195
+
196
+ using namespace torch::aot_inductor;
197
+ """
198
+ )
199
+
200
+ from .memory_planning import ALIGN_BYTES
201
+
202
+ # Round up to the nearest multiple of ALIGN_BYTES
203
+ # ALIGN_BYTES must be a power of 2
204
+ self.header.splice(
205
+ f"""
206
+ [[maybe_unused]] static int64_t align(int64_t nbytes) {{
207
+ return (nbytes + {ALIGN_BYTES} - 1) & -{ALIGN_BYTES};
208
+ }}
209
+ """
210
+ )
211
+
212
+ def mark_output_type(self):
213
+ # mark output type to unwrap tensor back to python scalar
214
+ from ..ir import ShapeAsConstantBuffer
215
+
216
+ output_is_tensor = dict()
217
+ for idx, x in enumerate(V.graph.graph_outputs):
218
+ if isinstance(x, ShapeAsConstantBuffer):
219
+ output_is_tensor[idx] = False
220
+ else:
221
+ output_is_tensor[idx] = True
222
+
223
+ self.output_is_tensor = output_is_tensor
224
+
225
+ def write_prefix(self):
226
+ if V.graph.is_const_graph:
227
+ # We do not write prefix for constant graph, it will be written by main module.
228
+ return
229
+
230
+ if V.graph.aot_mode:
231
+ self.prefix.writeline("namespace torch {")
232
+ self.prefix.writeline("namespace aot_inductor {")
233
+
234
+ def write_input_output_info(
235
+ self,
236
+ info_kind: str,
237
+ idx: int,
238
+ name: str,
239
+ ):
240
+ self.prefix.writeline(f"""{info_kind}[{idx}].name = "{name}";""")
241
+
242
+ @staticmethod
243
+ def get_input_cpp_type(input):
244
+ assert config.use_minimal_arrayref_interface
245
+ from .cpp import DTYPE_TO_CPP
246
+
247
+ if isinstance(input, sympy.Expr):
248
+ from ..graph import may_get_constant_buffer_dtype
249
+
250
+ dtype = may_get_constant_buffer_dtype(input)
251
+ assert dtype is not None, f"Failed to get the dtype of sympy.Expr: {input}"
252
+ return DTYPE_TO_CPP[dtype]
253
+ return f"ArrayRefTensor<{DTYPE_TO_CPP[input.get_dtype()]}>"
254
+
255
+ def write_wrapper_decl(self):
256
+ inputs_len = len(V.graph.graph_inputs.keys())
257
+ if V.graph.aot_mode:
258
+ if config.use_minimal_arrayref_interface and not V.graph.is_const_graph:
259
+ from .cpp import DTYPE_TO_CPP
260
+
261
+ input_cpp_types = ", ".join(
262
+ f"{CppWrapperCpu.get_input_cpp_type(x)}"
263
+ for x in V.graph.graph_inputs.values()
264
+ )
265
+
266
+ output_arrayref_types = ", ".join(
267
+ f"ArrayRefTensor<{DTYPE_TO_CPP[x.get_dtype()]}>"
268
+ for x in V.graph.graph_outputs
269
+ )
270
+
271
+ self.prefix.splice(
272
+ f"""
273
+ using AOTInductorModelInputs = std::tuple<{input_cpp_types}>;
274
+ using AOTInductorModelOutputs = std::tuple<{output_arrayref_types}>;
275
+ """
276
+ )
277
+
278
+ if V.graph.const_module:
279
+ self.header.splice(V.graph.const_module.wrapper_code.header)
280
+ self.prefix.splice(V.graph.const_code)
281
+
282
+ if V.graph.is_const_graph:
283
+ self.prefix.splice(
284
+ """
285
+ void AOTInductorModel::_const_run_impl(
286
+ std::vector<AtenTensorHandle>& output_handles,
287
+ DeviceStreamType stream,
288
+ AOTIProxyExecutorHandle proxy_executor
289
+ ) {
290
+ """
291
+ )
292
+ else:
293
+ if not config.aot_inductor.use_runtime_constant_folding:
294
+ # If we do not split the constant graph, we'll just create
295
+ # an empty implementation when wrapping the main module.
296
+ self.prefix.splice(
297
+ """
298
+ void AOTInductorModel::_const_run_impl(
299
+ std::vector<AtenTensorHandle>& output_handles,
300
+ DeviceStreamType stream,
301
+ AOTIProxyExecutorHandle proxy_executor
302
+ ) {}
303
+
304
+ """
305
+ )
306
+
307
+ run_impl_proto = """
308
+ void AOTInductorModel::run_impl(
309
+ AtenTensorHandle*
310
+ input_handles, // array of input AtenTensorHandle; handles
311
+ // are stolen; the array itself is borrowed
312
+ AtenTensorHandle*
313
+ output_handles, // array for writing output AtenTensorHandle; handles
314
+ // will be stolen by the caller; the array itself is
315
+ // borrowed
316
+ DeviceStreamType stream,
317
+ AOTIProxyExecutorHandle proxy_executor
318
+ ) {
319
+ """
320
+ if config.use_minimal_arrayref_interface:
321
+ self.prefix.splice(
322
+ """
323
+ template <>
324
+ AOTInductorModelOutputs AOTInductorModel::run_impl_minimal_arrayref_interface<
325
+ AOTInductorModelInputs, AOTInductorModelOutputs>(
326
+ const AOTInductorModelInputs& inputs,
327
+ DeviceStreamType stream,
328
+ AOTIProxyExecutorHandle proxy_executor
329
+ ) {
330
+ """
331
+ )
332
+ self.suffix.splice(run_impl_proto)
333
+ self.suffix.splice(
334
+ """
335
+ AOTInductorModelInputs inputs;
336
+ convert_handles_to_inputs(input_handles, inputs);
337
+ auto outputs = run_impl_minimal_arrayref_interface<AOTInductorModelInputs, AOTInductorModelOutputs>(
338
+ inputs, stream, proxy_executor);
339
+ // NOTE: outputs is full of ArrayRef to thread_local storage. If in the future we need this
340
+ // interface to perform well for a DSO using the minimal arrayref interface, all we need
341
+ // to do is provide ThreadLocalCachedTensor for each one!
342
+ convert_outputs_to_handles(outputs, output_handles);
343
+ }
344
+ """
345
+ )
346
+
347
+ self.suffix.splice(
348
+ """
349
+ extern "C" AOTIRuntimeError AOTInductorModelRunMinimalArrayrefInterface(
350
+ AOTInductorModelHandle model_handle,
351
+ const AOTInductorModelInputs& inputs,
352
+ AOTInductorModelOutputs& outputs) {
353
+ auto model = reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle);
354
+ CONVERT_EXCEPTION_TO_ERROR_CODE({
355
+ outputs = model->run_impl_minimal_arrayref_interface<AOTInductorModelInputs, AOTInductorModelOutputs>(
356
+ inputs,
357
+ (torch::aot_inductor::DeviceStreamType)nullptr,
358
+ nullptr);
359
+ })
360
+ }
361
+ """
362
+ )
363
+ else:
364
+ self.prefix.splice(run_impl_proto)
365
+ else:
366
+ self.prefix.splice(
367
+ """
368
+ void inductor_entry_impl(
369
+ AtenTensorHandle*
370
+ input_handles, // array of input AtenTensorHandle; handles
371
+ // are stolen; the array itself is borrowed
372
+ AtenTensorHandle*
373
+ output_handles // array for writing output AtenTensorHandle; handles
374
+ // will be stolen by the caller; the array itself is
375
+ // borrowed)
376
+ ) {
377
+ """
378
+ )
379
+ with self.prefix.indent():
380
+ # assign inputs and outputs in both cases so the later codegen can be simplified
381
+ if not config.use_minimal_arrayref_interface:
382
+ if not V.graph.is_const_graph:
383
+ if V.graph.aot_mode:
384
+ num_args = len(V.graph.graph_inputs)
385
+ else:
386
+ # Weights are promoted in the JIT mode
387
+ num_args = len(V.graph.graph_inputs) + len(V.graph.constants)
388
+ self.prefix.splice(
389
+ """
390
+ pybind11::gil_scoped_release release;
391
+ """
392
+ )
393
+
394
+ if config.abi_compatible:
395
+ self.prefix.splice(
396
+ f"""
397
+ auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, {num_args});
398
+ """
399
+ )
400
+ else:
401
+ # This looks dumb, but can avoid creating two versions of code in the AOTInductor runtime.
402
+ self.prefix.splice(
403
+ f"""
404
+ auto inputs = alloc_tensors_by_stealing_from_handles(input_handles, {num_args});
405
+ """
406
+ )
407
+
408
+ if inputs_len != 0:
409
+ for idx, input_key in enumerate(V.graph.graph_inputs.keys()):
410
+ if config.use_minimal_arrayref_interface:
411
+ self.prefix.writeline(
412
+ f"auto {input_key} = std::get<{idx}>(inputs);"
413
+ )
414
+ continue
415
+ # unwrap input tensor back to scalar
416
+ if isinstance(V.graph.graph_inputs[input_key], sympy.Expr):
417
+ from ..graph import may_get_constant_buffer_dtype
418
+ from .cpp import DTYPE_TO_CPP
419
+
420
+ dtype = may_get_constant_buffer_dtype(
421
+ V.graph.graph_inputs[input_key]
422
+ )
423
+ assert (
424
+ dtype is not None
425
+ ), "Fails to get the dtype of the sympy.Expr"
426
+ cpp_dtype = DTYPE_TO_CPP[dtype]
427
+ if config.abi_compatible:
428
+ self.prefix.writeline(f"{cpp_dtype} {input_key};")
429
+ dtype_str = str(dtype).split(".")[-1]
430
+ self.prefix.writeline(
431
+ f"aoti_torch_item_{dtype_str}(inputs[{idx}], &{input_key});"
432
+ )
433
+ else:
434
+ self.prefix.writeline(
435
+ f"{cpp_dtype} {input_key} = inputs[{idx}].item<{cpp_dtype}>();"
436
+ )
437
+ else:
438
+ self.prefix.writeline(
439
+ f"auto {input_key} = std::move(inputs[{idx}]);"
440
+ )
441
+
442
+ assert all(
443
+ isinstance(v, torch.Tensor) for v in list(V.graph.constants.values())
444
+ ), "Expect all constants to be Tensor"
445
+ for idx, constants_key in enumerate(V.graph.constants.keys()):
446
+ if V.graph.aot_mode:
447
+ # Weights are stored in constants_ and owned by RAIIAtenTensorHandle there.
448
+ # Don't call std::move here because it will cause constants_ to lose the ownership.
449
+ if config.abi_compatible:
450
+ self.prefix.writeline(
451
+ f"""auto {constants_key} = constants_->at({idx});"""
452
+ )
453
+ else:
454
+ self.prefix.writeline(
455
+ f"auto {constants_key} = *tensor_handle_to_tensor_pointer("
456
+ + f"""constants_->at({idx}));"""
457
+ )
458
+ else:
459
+ # Append constants as inputs to the graph
460
+ constants_idx = inputs_len + idx
461
+ self.prefix.writeline(
462
+ f"auto {constants_key} = inputs[{constants_idx}];"
463
+ )
464
+
465
+ self.codegen_inputs(self.prefix, V.graph.graph_inputs)
466
+
467
+ if V.graph.aot_mode:
468
+ if not V.graph.is_const_graph:
469
+ if config.use_minimal_arrayref_interface:
470
+ # TODO: input shape checking for regular tensor interface as well?
471
+ self.codegen_input_numel_asserts()
472
+ else:
473
+ self.prefix.writeline("inputs.clear();")
474
+ self.prefix.writeline(
475
+ "auto& kernels = static_cast<AOTInductorModelKernels&>(*this->kernels_.get());"
476
+ )
477
+
478
+ def codegen_input_numel_asserts(self):
479
+ for name, buf in V.graph.graph_inputs.items():
480
+ if isinstance(buf, sympy.Expr):
481
+ continue
482
+
483
+ # comparing strides for 0 size tensor is tricky. Ignore them for now.
484
+ if sympy_product(buf.get_size()) == 0:
485
+ continue
486
+ numel = buf.get_numel()
487
+ self.prefix.writeline(f"assert_numel({name}, {numel});")
488
+
489
+ def codegen_input_size_var_decl(self, code: IndentedBuffer, name):
490
+ if config.abi_compatible:
491
+ code.writeline(f"int64_t* {name}_size;")
492
+ code.writeline(
493
+ f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes({name}, &{name}_size));"
494
+ )
495
+ else:
496
+ super().codegen_input_size_var_decl(code, name)
497
+
498
+ def codegen_input_stride_var_decl(self, code: IndentedBuffer, name):
499
+ if config.abi_compatible:
500
+ code.writeline(f"int64_t* {name}_stride;")
501
+ code.writeline(
502
+ f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides({name}, &{name}_stride));"
503
+ )
504
+ else:
505
+ super().codegen_input_stride_var_decl(code, name)
506
+
507
+ def codegen_model_kernels(self):
508
+ self.prefix.writeline("namespace {")
509
+ self.prefix.writeline(
510
+ "class AOTInductorModelKernels : public AOTInductorModelKernelsBase {"
511
+ )
512
+ self.prefix.writeline(" public:")
513
+ declare_kernel = set(self.src_to_kernel.values())
514
+ declare_kernel.update(
515
+ entry[0] for entry in self.user_defined_kernel_cache.values()
516
+ )
517
+ if V.graph.const_module:
518
+ declare_kernel.update(
519
+ V.graph.const_module.wrapper_code.src_to_kernel.values()
520
+ )
521
+ for kernel in declare_kernel:
522
+ self.prefix.writeline(f" CUfunction {kernel}{{nullptr}};")
523
+ self.prefix.writeline("};")
524
+ self.prefix.writeline("} // namespace")
525
+
526
+ def codegen_model_constructor(self):
527
+ """
528
+ // Generated code example
529
+ AOTInductorModel::AOTInductorModel()
530
+ : AOTInductorModelBase(4, 1) {
531
+ inputs_info_[0].name = "input0";
532
+ inputs_info_[0].dtype = "torch.float16";
533
+ ...
534
+ constants_info_[0].name = "L__self___weight";
535
+ constants_info_[0].dtype = at::kFloat;
536
+ constants_info_[0].offset = 0;
537
+ constants_info_[0].data_size = 8192;
538
+ constants_info_[0].shape = {64, 32};
539
+ constants_info_[0].stride = {32, 1};
540
+ ...
541
+ outputs_info_[0].name = "output0";
542
+ outputs_info_[0].dtype = "torch.float16";
543
+ }
544
+ """
545
+
546
+ num_inputs = len(V.graph.graph_inputs)
547
+ num_outputs = len(V.graph.graph_outputs)
548
+ num_constants = len(V.graph.constants)
549
+ self.prefix.splice(
550
+ f"""
551
+ AOTInductorModel::AOTInductorModel(std::shared_ptr<ConstantMap> constants_map,
552
+ std::shared_ptr<std::vector<ConstantHandle>> constants_array,
553
+ const std::string& device_str,
554
+ std::optional<std::string> cubin_dir)
555
+ : AOTInductorModelBase({num_inputs}, {num_outputs}, {num_constants}, device_str, cubin_dir) {{
556
+ """
557
+ )
558
+
559
+ with self.prefix.indent():
560
+ for idx, (name, inp) in enumerate(V.graph.graph_inputs.items()):
561
+ assert not isinstance(
562
+ inp, sympy.Expr
563
+ ), f"input {name=} cannot be symbolic"
564
+ self.write_input_output_info("inputs_info_", idx, name)
565
+
566
+ for idx, (name, tensor) in enumerate(V.graph.constants.items()):
567
+ assert isinstance(tensor, torch.Tensor)
568
+ self.prefix.writeline(f"""constants_info_[{idx}].name = "{name}";""")
569
+ self.prefix.writeline(
570
+ f"constants_info_[{idx}].dtype = static_cast<int32_t>({self.codegen_dtype(tensor.dtype)});"
571
+ )
572
+ self.prefix.writeline(
573
+ f"constants_info_[{idx}].offset = {tensor.storage_offset()};"
574
+ )
575
+ self.prefix.writeline(
576
+ f"constants_info_[{idx}].data_size = {tensor.untyped_storage().nbytes()};"
577
+ )
578
+ from_folded = "true" if name in V.graph.folded_constants else "false"
579
+ self.prefix.writeline(
580
+ f"constants_info_[{idx}].from_folded = {from_folded};"
581
+ )
582
+
583
+ size_str = ", ".join([str(s) for s in tensor.size()])
584
+ self.prefix.writeline(f"constants_info_[{idx}].shape = {{{size_str}}};")
585
+
586
+ stride_str = ", ".join([str(s) for s in tensor.stride()])
587
+ self.prefix.writeline(
588
+ f"constants_info_[{idx}].stride = {{{stride_str}}};"
589
+ )
590
+ if name in V.graph.dynamo_flat_name_to_original_fqn:
591
+ original_fqn = V.graph.dynamo_flat_name_to_original_fqn.get(
592
+ name, name
593
+ )
594
+ elif name in V.graph.allocated_constant_name:
595
+ original_fqn = V.graph.allocated_constant_name[name]
596
+ else:
597
+ raise AssertionError("original_fqn must be set for constant")
598
+ self.prefix.writeline(
599
+ f"""constants_info_[{idx}].original_fqn = "{original_fqn}";"""
600
+ )
601
+ self.prefix.writeline("update_constants_map(std::move(constants_map));")
602
+ self.prefix.writeline("update_constants_array(std::move(constants_array));")
603
+
604
+ def escape_string(x):
605
+ return (
606
+ x.replace("\\", "\\\\")
607
+ .replace('"', '\\"')
608
+ .replace("\n", "\\n")
609
+ .replace("\t", "\\t")
610
+ )
611
+
612
+ self.prefix.writeline(
613
+ f'in_spec_ = "{escape_string(config.aot_inductor.serialized_in_spec)}";'
614
+ )
615
+ self.prefix.writeline(
616
+ f'out_spec_ = "{escape_string(config.aot_inductor.serialized_out_spec)}";'
617
+ )
618
+
619
+ for idx, output in enumerate(V.graph.graph_outputs):
620
+ assert not isinstance(
621
+ output, sympy.Expr
622
+ ), f"output {name=} cannot be symbolic"
623
+ name = f"output{idx}"
624
+ self.write_input_output_info("outputs_info_", idx, name)
625
+
626
+ self.prefix.writeline(
627
+ "this->kernels_ = std::make_unique<AOTInductorModelKernels>();"
628
+ )
629
+
630
+ self.prefix.writeline("}")
631
+
632
+ def codegen_const_run_driver(self):
633
+ """
634
+ // Generated code example
635
+ std::unordered_map<std::string, AtenTensorHandle> AOTInductorModel::const_run_impl(
636
+ DeviceStreamType stream,
637
+ AOTIProxyExecutorHandle proxy_executor,
638
+ bool initialization
639
+ ) {
640
+ std::unordered_map<std::string, AtenTensorHandle> folded_constants_map;
641
+ std::vector<AtenTensorHandle> output_handles;
642
+ // build up output_handles over here.
643
+ _const_run_impl(output_handles, stream, proxy_executor);
644
+ // build up folded_constants_map
645
+ return folded_constants_map;
646
+ }
647
+ """
648
+
649
+ self.prefix.splice(
650
+ """
651
+ std::unordered_map<std::string, AtenTensorHandle> AOTInductorModel::const_run_impl(
652
+ DeviceStreamType stream,
653
+ AOTIProxyExecutorHandle proxy_executor,
654
+ bool initialization
655
+ ) {
656
+ """
657
+ )
658
+ if not config.aot_inductor.use_runtime_constant_folding:
659
+ self.prefix.splice(
660
+ """
661
+ if (!initialization) {
662
+ std::cerr << "[WARNING] Calling constant_folding in model, but compiled with config: "
663
+ << "aot_inductor.use_runtime_constant_folding=False\\n";
664
+ }
665
+ return {};
666
+ }
667
+ """
668
+ )
669
+ return
670
+
671
+ with self.prefix.indent():
672
+ # This is a mapping to the index of constant folding graph's output
673
+ const_index_mapping: List[Optional[Tuple[int, str]]] = [None] * len(
674
+ V.graph.const_output_index
675
+ )
676
+ for idx, (name, _) in enumerate(V.graph.constants.items()):
677
+ if name in V.graph.const_output_index:
678
+ const_index_mapping[V.graph.const_output_index[name]] = (idx, name) # type: ignore[call-overload]
679
+ assert (
680
+ None not in const_index_mapping
681
+ ), "Not all constant gets mapped for constant folding graph."
682
+
683
+ self.prefix.writeline(
684
+ f"""
685
+ std::unordered_map<std::string, AtenTensorHandle> folded_constants_map;
686
+ folded_constants_map.reserve({len(const_index_mapping)});
687
+ std::vector<AtenTensorHandle> output_handles({len(const_index_mapping)});
688
+ """
689
+ )
690
+
691
+ self.prefix.splice(
692
+ """
693
+ // The below assignment of output_handles to constants is not used directly.
694
+ // It's only used to memo the correspondence of handle and constants.
695
+ """
696
+ )
697
+
698
+ for output_idx, (const_idx, _) in enumerate(const_index_mapping): # type: ignore[misc]
699
+ self.prefix.writeline(
700
+ f"output_handles[{output_idx}] = constants_->at({const_idx});"
701
+ )
702
+
703
+ self.prefix.writeline(
704
+ "_const_run_impl(output_handles, stream, proxy_executor);"
705
+ )
706
+
707
+ for output_idx, (_, const_name) in enumerate(const_index_mapping): # type: ignore[misc]
708
+ self.prefix.writeline(
709
+ f'folded_constants_map["{const_name}"] = output_handles[{output_idx}];'
710
+ )
711
+ self.prefix.writeline("return folded_constants_map;")
712
+
713
+ self.prefix.writeline("}")
714
+
715
+ def generate(self, is_inference):
716
+ if V.graph.aot_mode and not V.graph.is_const_graph:
717
+ self.codegen_model_kernels()
718
+ self.codegen_model_constructor()
719
+ self.codegen_const_run_driver()
720
+ self.write_wrapper_decl()
721
+ return super().generate(is_inference)
722
+
723
+ def finalize_prefix(self):
724
+ cached_dtypes_buffer = IndentedBuffer()
725
+ if config.abi_compatible:
726
+ for dtype in self.used_cached_dtypes:
727
+ cached_dtypes_buffer.writeline(f"CACHE_TORCH_DTYPE({dtype});")
728
+ for device in self.used_cached_devices:
729
+ cached_dtypes_buffer.writeline(f"CACHE_TORCH_DEVICE({device});")
730
+ cached_dtypes_buffer.splice(self.prefix)
731
+ self.prefix = cached_dtypes_buffer
732
+
733
+ def define_kernel(
734
+ self, name: str, kernel: str, metadata: Optional[str] = None, cuda=False
735
+ ):
736
+ self.header.splice(f"\n{kernel}\n")
737
+
738
+ def codegen_scalar_to_tensor(self, output: str):
739
+ name = f"scalar_to_tensor_{next(self.scalar_to_tensor_id)}"
740
+ self.wrapper_call.writeline(
741
+ f"RAIIAtenTensorHandle {name} = scalar_to_tensor_handle({output});"
742
+ )
743
+ return name
744
+
745
+ @cache_on_self
746
+ def get_output_refs(self):
747
+ return [
748
+ f"torch::tensor({x.codegen_reference(self.wrapper_call)})"
749
+ if isinstance(x, ir.ShapeAsConstantBuffer) and not config.abi_compatible
750
+ else x.codegen_reference(self.wrapper_call)
751
+ for x in V.graph.graph_outputs
752
+ ]
753
+
754
+ def generate_return(self, output_refs):
755
+ cst_names = V.graph.constants.keys()
756
+ arr_iface = (
757
+ not V.graph.is_const_graph and config.use_minimal_arrayref_interface
758
+ ) # For brevity.
759
+
760
+ def use_thread_local_cached_output_tensor(idx, output):
761
+ cached_output_name = f"cached_output_{next(self.cached_output_id)}"
762
+ cache_type = "Array" if arr_iface else "Tensor"
763
+ self.wrapper_call.writeline(
764
+ f"thread_local ThreadLocalCachedOutput{cache_type}<std::decay_t<decltype({output})>> "
765
+ f"{cached_output_name}({output});"
766
+ )
767
+ if arr_iface:
768
+ self.wrapper_call.writeline(
769
+ f"{cached_output_name}.copy_data_from({output});"
770
+ )
771
+ output_entry = f"std::get<{idx}>(output_arrayref_tensors)"
772
+ element_type = f"std::decay_t<decltype({output_entry}.data()[0])>"
773
+ self.wrapper_call.writeline(
774
+ f"{output_entry} = {cached_output_name}.arrayref_tensor<{element_type}>();"
775
+ )
776
+ else:
777
+ self.wrapper_call.writeline(
778
+ f"{cached_output_name}.copy_data_from({output});"
779
+ )
780
+ self.wrapper_call.writeline(
781
+ f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&output_handles[{idx}]));"
782
+ )
783
+ self.wrapper_call.writeline(
784
+ f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors({cached_output_name}.tensor(), "
785
+ f"output_handles[{idx}]));"
786
+ )
787
+
788
+ if arr_iface:
789
+ self.wrapper_call.writeline(
790
+ "AOTInductorModelOutputs output_arrayref_tensors;"
791
+ )
792
+ for idx, output in enumerate(output_refs):
793
+ if config.abi_compatible:
794
+ output_buffer = V.graph.graph_outputs[idx]
795
+ if isinstance(output_buffer, ir.ShapeAsConstantBuffer):
796
+ # Need to wrap scalar into tensor as the main function returns a vector of tensors
797
+ output_tensor = self.codegen_scalar_to_tensor(output)
798
+ self.wrapper_call.writeline(
799
+ f"output_handles[{idx}] = {output_tensor}.release();"
800
+ )
801
+ continue
802
+
803
+ output_is_tensor_handle_expr = (
804
+ f"std::is_same_v<std::decay_t<decltype({output})>,"
805
+ "RAIIAtenTensorHandle> || "
806
+ f"std::is_same_v<std::decay_t<decltype({output})>,"
807
+ "AtenTensorHandle> || "
808
+ f"std::is_same_v<std::decay_t<decltype({output})>,"
809
+ "ConstantHandle>"
810
+ )
811
+ self.wrapper_call.writeline(
812
+ f"if constexpr ({output_is_tensor_handle_expr}) {{"
813
+ )
814
+ with self.wrapper_call.indent():
815
+ if arr_iface:
816
+ cached_output_name = (
817
+ f"cached_output_{next(self.cached_output_id)}"
818
+ )
819
+ output_value_type = f"std::decay_t<decltype(std::get<{idx}>(output_arrayref_tensors).data()[0])>"
820
+ self.wrapper_call.writeline(
821
+ f"thread_local RAIIAtenTensorHandle {cached_output_name};"
822
+ )
823
+ if output in cst_names:
824
+ # NOTE(return_constant): In some rare cases where we return
825
+ # a constant, we have to return a copy of this constant,
826
+ # because (1) constants are not owned by the Model instance
827
+ # (2) constants remain the same cross inference runs,
828
+ # assuming they are not updated at runtime Basically, we
829
+ # cannot release or transfer the ownership of any original
830
+ # constant to the user.
831
+ self.wrapper_call.writeline(
832
+ f"AtenTensorHandle {cached_output_name}_tmp;"
833
+ )
834
+ self.wrapper_call.writeline(
835
+ f"aoti_torch_clone({output}, &{cached_output_name}_tmp);"
836
+ )
837
+ self.wrapper_call.writeline(
838
+ f"{cached_output_name} = {cached_output_name}_tmp;"
839
+ )
840
+ else:
841
+ self.wrapper_call.writeline(
842
+ f"{cached_output_name} = {output}.release();"
843
+ )
844
+ self.wrapper_call.writeline(
845
+ f"convert_handle_to_arrayref_tensor({cached_output_name}, "
846
+ f"std::get<{idx}>(output_arrayref_tensors));"
847
+ )
848
+ else:
849
+ if output in cst_names:
850
+ # See NOTE(return_constant) above.
851
+ self.wrapper_call.writeline(
852
+ f"aoti_torch_clone({output}, &output_handles[{idx}]);"
853
+ )
854
+ else:
855
+ self.wrapper_call.writeline(
856
+ f"output_handles[{idx}] = {output}.release();"
857
+ )
858
+ self.wrapper_call.writeline("} else {")
859
+ with self.wrapper_call.indent():
860
+ use_thread_local_cached_output_tensor(idx, output)
861
+ self.wrapper_call.writeline("}")
862
+
863
+ else:
864
+ assert (
865
+ not arr_iface
866
+ ), "minimal ArrayRef interface is only supported in ABI-compatible mode"
867
+ if output in cst_names:
868
+ output_expr = f"{output}.clone()"
869
+ # See NOTE(return_constant) above.
870
+ else:
871
+ output_expr = output
872
+ self.wrapper_call.writeline(
873
+ f"output_handles[{idx}] = reinterpret_cast<AtenTensorHandle>("
874
+ + f"new at::Tensor({output_expr}));"
875
+ )
876
+ if arr_iface:
877
+ self.wrapper_call.writeline("return output_arrayref_tensors;")
878
+
879
+ def generate_before_suffix(self, result):
880
+ if not V.graph.is_const_graph:
881
+ if V.graph.aot_mode:
882
+ result.writeline("} // AOTInductorModel::run_impl")
883
+ else:
884
+ result.writeline("} // inductor_entry_impl")
885
+
886
+ def generate_end(self, result):
887
+ if V.graph.aot_mode:
888
+ if V.graph.is_const_graph:
889
+ result.writeline("} // AOTInductorModel::_const_run_impl")
890
+ else:
891
+ result.writeline("} // namespace aot_inductor")
892
+ result.writeline("} // namespace torch")
893
+ return
894
+
895
+ result.writeline("'''\n)")
896
+ result.splice(
897
+ f"""
898
+ inductor_entry = CppWrapperCodeCache.load_pybinding(
899
+ ["std::vector<at::Tensor>"], cpp_wrapper_src, {self.cuda}, {len(V.graph.graph_outputs)})
900
+ """
901
+ )
902
+
903
+ # unwrap output tensor back to python scalar
904
+ if all(x for x in self.output_is_tensor.values()):
905
+ # If no ShapeAsConstantBuffer in the output, directly return the output as tensors
906
+ return_str = "return f(args_tensor)"
907
+ else:
908
+ outputs = [
909
+ f"outputs[{i}]" if self.output_is_tensor[i] else f"outputs[{i}].item()"
910
+ for i in range(len(V.graph.graph_outputs))
911
+ ]
912
+ outputs_str = f"[{', '.join(outputs)}]"
913
+ return_str = f"""
914
+ outputs = f(args_tensor)
915
+ return {outputs_str}
916
+ """
917
+
918
+ args_str = "args_tensor = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg) for arg in args]"
919
+ if V.graph.constants:
920
+ # Append constants to the input args for cpp wrapper.
921
+ # Python wrapper directly gets the value inside the wrapper call
922
+ # as a global variable passed when calling exec(code, mod.__dict__, mod.__dict__).
923
+ # For cpp wrapper, we need to pass this python value to the inductor_entry_impl function explicitly.
924
+ assert all(
925
+ isinstance(v, torch.Tensor) for v in list(V.graph.constants.values())
926
+ ), "Expect all constants to be Tensor"
927
+ constants_str = f"[{', '.join(V.graph.constants.keys())}]"
928
+ args_str += f"""
929
+ constants_tensor = {constants_str}
930
+ args_tensor.extend(constants_tensor)
931
+ """
932
+
933
+ # Wrap the func to support setting result._boxed_call = True
934
+ result.splice(
935
+ f"""
936
+ def _wrap_func(f):
937
+ def g(args):
938
+ {args_str}
939
+ {return_str}
940
+ return g
941
+ call = _wrap_func(inductor_entry)
942
+ """
943
+ )
944
+
945
+ def generate_c_shim_extern_kernel_call(self, kernel, args):
946
+ # In the abi_compatible mode, we call fallback aten ops through a C shim layer
947
+ self.allow_stack_allocation = False
948
+ kernel_tokens = kernel.split("::")
949
+ kernel_suffix = kernel_tokens[-1]
950
+ if kernel_suffix == "call":
951
+ kernel_suffix = kernel_tokens[-2]
952
+ if config.c_shim_version == "1":
953
+ shim_fn = f"aoti_torch_{kernel_suffix}"
954
+ else:
955
+ shim_fn = f"aoti_torch_{self.device}_{kernel_suffix}"
956
+
957
+ # HACK: val_to_arg_str jams multiple arguments together using a comma. If that
958
+ # ever breaks, it needs to be reworked to be able to return multiple arguments,
959
+ # and the split-on-comma code here needs to be removed.
960
+ wrapped_args = []
961
+ for x in args:
962
+ pieces = x.split(", ")
963
+ for piece in pieces:
964
+ # We only really *need* convert_arrayref_tensor_to_tensor for
965
+ # ArrayRefTensors. The code flowing into here uses `0` for nullptr,
966
+ # which convert_arrayref_tensor_to_tensor would blindly coerce to int,
967
+ # so just avoid wrapping integers.
968
+ if not piece.isdigit():
969
+ piece = f"convert_arrayref_tensor_to_tensor({piece})"
970
+ wrapped_args.append(piece)
971
+ self.writeline(
972
+ f"AOTI_TORCH_ERROR_CODE_CHECK({shim_fn}({', '.join(wrapped_args)}));"
973
+ )
974
+
975
+ def generate_c_shim_extern_kernel_alloc(self, extern_kernel, args):
976
+ # registered output buffer name
977
+ name = extern_kernel.name
978
+ output_handle_name = f"{name}_handle"
979
+ self.writeline(f"AtenTensorHandle {output_handle_name};")
980
+ output_arg = f"&{output_handle_name}"
981
+ self.generate_c_shim_extern_kernel_call(
982
+ extern_kernel.get_kernel_name(), args + [output_arg]
983
+ )
984
+ self.writeline(f"RAIIAtenTensorHandle {name}({output_handle_name});")
985
+
986
+ def generate_extern_kernel_alloc(self, extern_kernel, args):
987
+ if config.abi_compatible:
988
+ self.generate_c_shim_extern_kernel_alloc(extern_kernel, args)
989
+ else:
990
+ super().generate_extern_kernel_alloc(extern_kernel, args)
991
+
992
+ def generate_c_shim_fallback_kernel(self, fallback_kernel, args):
993
+ output_args = []
994
+ output_raii_handles = []
995
+ output_name_base = fallback_kernel.get_name()
996
+ for idx, output in enumerate(fallback_kernel.outputs):
997
+ if isinstance(output, ir.MultiOutput):
998
+ name = f"{output.get_name()}"
999
+ output_handle_name = f"{name}_handle"
1000
+ if output.indices:
1001
+ assert (
1002
+ output.indices[0][1] == idx
1003
+ ), f"expected {output.indices[0][1]=} == {idx=} for {output_name_base=}"
1004
+ self.writeline(f"AtenTensorHandle {output_handle_name};")
1005
+ output_args.append(f"&{output_handle_name}")
1006
+ output_raii_handles.append(
1007
+ f"RAIIAtenTensorHandle {name}({output_handle_name});"
1008
+ )
1009
+ elif isinstance(output, int):
1010
+ output_name = f"{output_name_base}_{idx}"
1011
+ self.writeline(f"int64_t {output_name} = {output};")
1012
+ output_args.append(f"&{output_name}")
1013
+ elif output is None:
1014
+ output_args.append("nullptr")
1015
+ else:
1016
+ raise NotImplementedError("unsupported type of {output=}")
1017
+ args = args + output_args
1018
+ assert (
1019
+ fallback_kernel.abi_compatible_kernel is not None
1020
+ ), f"abi_compatible_kernel is None for {fallback_kernel.python_kernel_name=}"
1021
+ self.generate_c_shim_extern_kernel_call(
1022
+ fallback_kernel.abi_compatible_kernel, args
1023
+ )
1024
+ for raii_handle in output_raii_handles:
1025
+ self.writeline(raii_handle)
1026
+
1027
+ def generate_fallback_kernel(self, fallback_kernel, args):
1028
+ if config.abi_compatible:
1029
+ self.generate_c_shim_fallback_kernel(fallback_kernel, args)
1030
+ else:
1031
+ super().generate_fallback_kernel(fallback_kernel, args)
1032
+
1033
+ def generate_extern_kernel_out(self, output_view, codegen_reference, args, kernel):
1034
+ if output_view:
1035
+ output_as_strided = f"{output_view.codegen_reference()}"
1036
+ output_name = f"{output_view.get_name()}_as_strided"
1037
+ self.writeline(f"auto {output_name} = {output_as_strided};")
1038
+
1039
+ args.insert(0, output_name)
1040
+ else:
1041
+ args.insert(0, f"{codegen_reference}")
1042
+
1043
+ if config.abi_compatible:
1044
+ self.generate_c_shim_extern_kernel_call(kernel, args)
1045
+ else:
1046
+ self.writeline(self.wrap_kernel_call(kernel, args))
1047
+
1048
+ def generate_user_defined_triton_kernel(
1049
+ self, kernel_name, grid, configs, args, triton_meta
1050
+ ):
1051
+ assert len(grid) != 0
1052
+ if len(grid) == 1:
1053
+ grid_decision = grid[0]
1054
+ else:
1055
+ meta = CudaKernelParamCache.get(kernel_name)
1056
+ assert meta is not None
1057
+ grid_decision = None
1058
+ for i, c in enumerate(configs):
1059
+ if all(arg == meta["meta"][key] for key, arg in c.kwargs.items()):
1060
+ grid_decision = grid[i]
1061
+ break
1062
+ assert grid_decision is not None
1063
+
1064
+ self.generate_kernel_call(
1065
+ kernel_name,
1066
+ args,
1067
+ grid=grid_decision,
1068
+ device_index=V.graph.scheduler.current_device.index,
1069
+ cuda=True,
1070
+ triton=True,
1071
+ triton_meta=triton_meta,
1072
+ )
1073
+
1074
+ def generate_scatter_fallback(
1075
+ self, output, inputs, kernel, python_kernel_name, src_is_tensor, reduce, kwargs
1076
+ ):
1077
+ # TODO: support other overload for cpp wrapper and remove the below assertions
1078
+ if config.abi_compatible:
1079
+ # call the ABI shim function instead of the ATen one
1080
+ kernel = kernel.replace("at::", "aoti_torch_")
1081
+ line = f"{kernel}({output}, {','.join(map(str, inputs))}"
1082
+ if python_kernel_name == "aten.scatter_":
1083
+ if src_is_tensor:
1084
+ if reduce:
1085
+ line += f", {V.graph.wrapper_code.val_to_arg_str(reduce)}"
1086
+ else:
1087
+ assert (
1088
+ reduce is None
1089
+ ), "Expect reduce to be None for aten.scatter_ with scalar src"
1090
+ else:
1091
+ line += f", {','.join(kwargs)}"
1092
+ line += f"){self.ending}"
1093
+ self.writeline(line)
1094
+
1095
+ def generate_index_put_fallback(self, kernel, x, indices, values, accumulate):
1096
+ if V.graph.aot_mode and V.graph.cpp_wrapper and config.abi_compatible:
1097
+ # See the comment in codegen_reinterpret_view about why having something like
1098
+ # RAIIAtenTensorHandle(tmp_tensor_handle_2) in a tmp array can cause the correponding
1099
+ # tensor prematurely deallocated, thus this std::vector().data() trick here.
1100
+ indices_str = (
1101
+ f"std::vector<AtenTensorHandle>{{{', '.join(indices)}}}.data()"
1102
+ )
1103
+ args = [x, indices_str, str(len(indices)), values, accumulate]
1104
+ else:
1105
+ indices_str = (
1106
+ f"{self.open_bracket}{', '.join(indices)}{self.closed_bracket}"
1107
+ )
1108
+ args = [x, indices_str, values, accumulate]
1109
+
1110
+ args.insert(0, x) # set x as the output tensor, this fallback mutates x.
1111
+ self.writeline(self.wrap_kernel_call(kernel, args))
1112
+
1113
+ def add_benchmark_harness(self, output):
1114
+ if V.graph.aot_mode:
1115
+ return
1116
+ super().add_benchmark_harness(output)
1117
+
1118
+ def codegen_sizevar(self, x: Expr) -> str:
1119
+ return self.expr_printer(V.graph.sizevars.simplify(x))
1120
+
1121
+ def codegen_tuple_access(self, basename: str, name: str, index: str) -> str:
1122
+ if config.abi_compatible:
1123
+ # in the abi_compatible mode, outputs are returned via arguments
1124
+ return name
1125
+ else:
1126
+ return f"std::get<{index}>({basename})"
1127
+
1128
+ def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str:
1129
+ parts = list(map(self.codegen_sizevar, shape))
1130
+ if len(parts) == 0:
1131
+ return "{}"
1132
+ if len(parts) == 1:
1133
+ return f"{{{parts[0]}, }}"
1134
+ return f"{{{', '.join(parts)}}}"
1135
+
1136
+ def codegen_dynamic_scalar(self, node):
1137
+ from .cpp import DTYPE_TO_ATEN, DTYPE_TO_CPP
1138
+
1139
+ (data,) = (t.codegen_reference() for t in node.inputs)
1140
+ if config.abi_compatible:
1141
+ dtype = node.inputs[0].get_dtype()
1142
+ dtype_str = str(dtype).split(".")[-1]
1143
+ self.writeline(f"{DTYPE_TO_CPP[dtype]} {node.sym};")
1144
+ self.writeline(f"aoti_torch_item_{dtype_str}({data}, &{node.sym});")
1145
+ # record in unbacked_symbol_decls so we won't generate a declaration of the symbol again
1146
+ self.unbacked_symbol_decls.add(str(node.sym))
1147
+ else:
1148
+ if node.is_bool:
1149
+ self.writeline(f"bool {node.sym} = {data}.item() ? 1 : 0;")
1150
+ else:
1151
+ convert_type = DTYPE_TO_ATEN[node.inputs[0].get_dtype()].replace(
1152
+ "at::k", "to"
1153
+ )
1154
+ self.writeline(f"auto {node.sym} = {data}.item().{convert_type}();")
1155
+
1156
+ def can_stack_allocate_buffer(self, buffer):
1157
+ return (
1158
+ self.allow_stack_allocation
1159
+ and buffer.get_device().type == "cpu"
1160
+ and self.can_prove_buffer_has_static_shape(buffer)
1161
+ and ir.is_contiguous_strides_for_shape(
1162
+ buffer.get_stride(), buffer.get_size()
1163
+ )
1164
+ )
1165
+
1166
+ def make_buffer_free(self, buffer):
1167
+ return (
1168
+ ""
1169
+ if isinstance(buffer.get_layout(), ir.MultiOutputLayout)
1170
+ or (V.graph.aot_mode and buffer.get_name() in self.stack_allocated_buffers)
1171
+ or (
1172
+ config.use_minimal_arrayref_interface
1173
+ and V.graph.aot_mode
1174
+ and buffer.get_name() in V.graph.graph_inputs
1175
+ )
1176
+ else f"{buffer.get_name()}.reset();"
1177
+ )
1178
+
1179
+ def make_free_by_names(self, names_to_del: List[str]):
1180
+ return " ".join(f"{name}.reset();" for name in names_to_del)
1181
+
1182
+ def codegen_exact_buffer_reuse(self, old_name: str, new_name: str, del_line: str):
1183
+ if config.abi_compatible:
1184
+ return f"auto {new_name} = std::move({old_name}); // reuse"
1185
+ else:
1186
+ return super().codegen_exact_buffer_reuse(old_name, new_name, del_line)
1187
+
1188
+ def generate_profiler_mark_wrapper_call(self, stack):
1189
+ self.wrapper_call.writeline(
1190
+ 'RECORD_FUNCTION("inductor_wrapper_call", c10::ArrayRef<c10::IValue>());'
1191
+ )
1192
+
1193
+ def write_triton_header_once(self):
1194
+ pass
1195
+
1196
+ def generate_start_graph(self):
1197
+ pass
1198
+
1199
+ def generate_end_graph(self):
1200
+ pass
1201
+
1202
+ def generate_inf_and_nan_checker(self, nodes):
1203
+ for buf in nodes.get_names():
1204
+ # TODO: Add buf name directly into check_inf_and_nan.
1205
+ self.writeline(
1206
+ f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_check_inf_and_nan({buf}));"
1207
+ )
1208
+
1209
+ def codegen_device(self, device):
1210
+ if config.abi_compatible:
1211
+ self.used_cached_devices.add(device.type)
1212
+ return f"cached_torch_device_type_{device.type},{device.index if device.index else 0}"
1213
+ else:
1214
+ from .cpp import DEVICE_TO_ATEN
1215
+
1216
+ return (
1217
+ f"c10::Device({DEVICE_TO_ATEN[device.type]}, {device.index})"
1218
+ if device.index is not None
1219
+ else f"{DEVICE_TO_ATEN[device.type]}"
1220
+ )
1221
+
1222
+ def codegen_dtype(self, dtype):
1223
+ if config.abi_compatible:
1224
+ dtype_str = str(dtype).split(".")[-1]
1225
+ self.used_cached_dtypes.add(dtype_str)
1226
+ return f"cached_torch_dtype_{dtype_str}"
1227
+ else:
1228
+ from .cpp import DTYPE_TO_ATEN
1229
+
1230
+ return DTYPE_TO_ATEN[dtype]
1231
+
1232
+ @functools.lru_cache(None)
1233
+ def codegen_int_array_var(
1234
+ self,
1235
+ int_array: str,
1236
+ writer=None,
1237
+ known_statically=False,
1238
+ graph=None, # for per-graph caching
1239
+ ):
1240
+ # Because the memory planning is done in two passes (see the implementation
1241
+ # of self.generate), the writeline behavior is different in the two passes.
1242
+ # As a result, the emitted int array declarations may appear in a later
1243
+ # position of the generated code, so the second pass codegen should not
1244
+ # reuse int array declarations generated in the first pass
1245
+ if writer is None:
1246
+ # The first pass codegen uses `self` as the writer
1247
+ writer = self
1248
+
1249
+ var = f"int_array_{next(self.int_array_id)}"
1250
+ if var not in self.declared_int_array_vars:
1251
+ self.declared_int_array_vars.add(var)
1252
+ if known_statically:
1253
+ writer.writeline(f"static constexpr int64_t {var}[] = {int_array};")
1254
+ else:
1255
+ writer.writeline(f"int64_t {var}[] = {int_array};")
1256
+ return var
1257
+
1258
+ def make_buffer_allocation(self, buffer):
1259
+ return self.make_allocation(
1260
+ buffer.get_name(),
1261
+ buffer.get_device(),
1262
+ buffer.get_dtype(),
1263
+ buffer.get_size(),
1264
+ buffer.get_stride(),
1265
+ buffer if self.can_stack_allocate_buffer(buffer) else None,
1266
+ )
1267
+
1268
+ def make_allocation(
1269
+ self, name, device, dtype, shape, stride, buffer_if_can_stack_allocate=None
1270
+ ):
1271
+ orig_stride = stride
1272
+ device_str = self.codegen_device(device)
1273
+ dtype_code = self.codegen_dtype(dtype)
1274
+ size = self.codegen_shape_tuple(shape)
1275
+ stride = self.codegen_shape_tuple(orig_stride)
1276
+ if config.abi_compatible:
1277
+ size_array_var = self.codegen_int_array_var(
1278
+ size,
1279
+ self.wrapper_call,
1280
+ known_statically=self.is_statically_known_list_of_ints(shape),
1281
+ graph=self.get_codegened_graph(),
1282
+ )
1283
+ stride_array_var = self.codegen_int_array_var(
1284
+ stride,
1285
+ self.wrapper_call,
1286
+ known_statically=self.is_statically_known_list_of_ints(orig_stride),
1287
+ graph=self.get_codegened_graph(),
1288
+ )
1289
+ device_type, device_id = device_str.split(",")
1290
+ device_idx = "this->device_idx_" if V.graph.aot_mode else device_id
1291
+ if buffer_if_can_stack_allocate is not None:
1292
+ from .cpp import DTYPE_TO_CPP
1293
+
1294
+ self.stack_allocated_buffers[name] = buffer_if_can_stack_allocate
1295
+ cpp_type = DTYPE_TO_CPP[dtype]
1296
+ numel = buffer_if_can_stack_allocate.get_numel()
1297
+ # Note: we don't zero storage because empty_strided doesn't zero either.
1298
+ self.wrapper_call.writeline(f"{cpp_type} {name}_storage[{numel}];")
1299
+ args = [
1300
+ f"{name}_storage",
1301
+ size_array_var,
1302
+ stride_array_var,
1303
+ device_type,
1304
+ device_idx,
1305
+ ]
1306
+ return f"ArrayRefTensor<{cpp_type}> {name}({', '.join(args)});"
1307
+
1308
+ args = [
1309
+ str(len(shape)),
1310
+ size_array_var,
1311
+ stride_array_var,
1312
+ dtype_code,
1313
+ device_type,
1314
+ device_idx,
1315
+ f"&{name}_handle",
1316
+ ]
1317
+
1318
+ self.wrapper_call.writeline(f"AtenTensorHandle {name}_handle;")
1319
+ self.wrapper_call.writeline(
1320
+ f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided({', '.join(args)}));"
1321
+ )
1322
+
1323
+ return f"RAIIAtenTensorHandle {name}({name}_handle);"
1324
+
1325
+ if V.graph.aot_mode and device_str.startswith("c10::Device("):
1326
+ tensor_device = f"{device_str.split(',')[0]}, this->device_idx_)"
1327
+ else:
1328
+ tensor_device = device_str
1329
+
1330
+ if device.type == "cpu":
1331
+ return f"at::Tensor {name} = at::detail::empty_strided_cpu({size}, {stride}, {dtype_code});"
1332
+ if device.type == "cuda":
1333
+ return (
1334
+ f"at::Tensor {name} = at::detail::empty_strided_cuda("
1335
+ f"{size}, {stride}, {dtype_code}, c10::DeviceType::CUDA);"
1336
+ )
1337
+ return (
1338
+ f"{self.declare}{name} = {self.namespace}empty_strided("
1339
+ f"{size}, {stride}, at::TensorOptions({tensor_device}).dtype({dtype_code})){self.ending}"
1340
+ )
1341
+
1342
+ def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str:
1343
+ if config.abi_compatible:
1344
+ size = self.codegen_shape_tuple(shape)
1345
+ stride = self.codegen_shape_tuple(stride)
1346
+ tmp_name = f"tmp_tensor_handle_{next(self.tmp_tensor_id)}"
1347
+ args = [
1348
+ name,
1349
+ pexpr(offset), # bytes not numel
1350
+ self.codegen_dtype(dtype),
1351
+ str(len(shape)),
1352
+ self.codegen_int_array_var(
1353
+ size, self.wrapper_call, graph=self.get_codegened_graph()
1354
+ ),
1355
+ self.codegen_int_array_var(
1356
+ stride, self.wrapper_call, graph=self.get_codegened_graph()
1357
+ ),
1358
+ f"&{tmp_name}",
1359
+ ]
1360
+ self.wrapper_call.writeline(f"AtenTensorHandle {tmp_name};")
1361
+ self.wrapper_call.writeline(
1362
+ f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool({', '.join(args)}));"
1363
+ )
1364
+ return f"RAIIAtenTensorHandle({tmp_name})"
1365
+
1366
+ return "alloc_from_pool({})".format(
1367
+ ", ".join(
1368
+ [
1369
+ name,
1370
+ pexpr(offset), # bytes not numel
1371
+ self.codegen_dtype(dtype),
1372
+ self.codegen_shape_tuple(shape),
1373
+ self.codegen_shape_tuple(stride),
1374
+ ]
1375
+ )
1376
+ )
1377
+
1378
+ def codegen_reinterpret_view(
1379
+ self, data, size_list, stride_list, offset, writer
1380
+ ) -> str:
1381
+ dim = str(len(size_list))
1382
+ size = self.codegen_shape_tuple(size_list)
1383
+ stride = self.codegen_shape_tuple(stride_list)
1384
+ offset = self.codegen_sizevar(offset)
1385
+
1386
+ if config.abi_compatible:
1387
+ tmp_name = f"tmp_tensor_handle_{next(self.tmp_tensor_id)}"
1388
+ # Because the memory planning is done in two passes (see the implementation
1389
+ # of self.generate), the writeline behavior is different in the two passes.
1390
+ if writer is None:
1391
+ writer = self
1392
+
1393
+ args = [
1394
+ f"{data.get_name()}",
1395
+ dim,
1396
+ self.codegen_int_array_var(
1397
+ size,
1398
+ writer,
1399
+ known_statically=self.is_statically_known_list_of_ints(size_list),
1400
+ graph=self.get_codegened_graph(),
1401
+ ),
1402
+ self.codegen_int_array_var(
1403
+ stride,
1404
+ writer,
1405
+ known_statically=self.is_statically_known_list_of_ints(stride_list),
1406
+ graph=self.get_codegened_graph(),
1407
+ ),
1408
+ offset,
1409
+ ]
1410
+
1411
+ def gen_reinterpret_call(writer, args):
1412
+ writer.writeline(
1413
+ f"auto {tmp_name} = reinterpret_tensor_wrapper({', '.join(args)});"
1414
+ )
1415
+
1416
+ if (
1417
+ self.can_stack_allocate_buffer(data)
1418
+ and self.is_statically_known_list_of_ints(size_list)
1419
+ and self.is_statically_known_list_of_ints(stride_list)
1420
+ and ir.is_contiguous_strides_for_shape(stride_list, size_list)
1421
+ ):
1422
+ gen_reinterpret_call(writer, args)
1423
+ return tmp_name
1424
+
1425
+ gen_reinterpret_call(writer, args)
1426
+
1427
+ # NB, the return handle here represents a temporary tensor, which will be automatically
1428
+ # released.
1429
+ # Here's a sample usage in the cpp wrapper code:
1430
+ # ```
1431
+ # aoti_torch_addmm_out(
1432
+ # buf1,
1433
+ # arg1_1,
1434
+ # RAIIAtenTensorHandle(tmp_tensor_handle_0),
1435
+ # buf0,
1436
+ # 1L,
1437
+ # 1L));
1438
+ # ```
1439
+ # RAIIAtenTensorHandle(tmp_tensor_handle_0) will be released after the call to addmm_out.
1440
+ # This could be problematic when it's used in a different pattern, for example:
1441
+ # ````
1442
+ # AtenTensorHandle tensor_args[] = {RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6};
1443
+ # aoti_torch_proxy_executor_call_function(..., tensor_args);
1444
+ # ````
1445
+ # RAIIAtenTensorHandle(tmp_tensor_handle_2) will be invalid when it's used in the latter
1446
+ # kernel call.
1447
+ #
1448
+ # This is solved by updating the proxy_executor invocation to
1449
+ # ```
1450
+ # aoti_torch_proxy_executor_call_function(...,
1451
+ # std::vector<AtenTensorHandle>{
1452
+ # RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6
1453
+ # }.data()
1454
+ # );
1455
+ # ```
1456
+ return f"wrap_with_raii_handle_if_needed({tmp_name})"
1457
+ else:
1458
+ args = [data.get_name(), size, stride, offset]
1459
+ return f"reinterpret_tensor({', '.join(args)})"
1460
+
1461
+ def codegen_device_copy(self, src, dst):
1462
+ if config.abi_compatible:
1463
+ self.writeline(
1464
+ f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_tensor_copy_(expensive_copy_to_tensor_if_needed({src}), {dst}));"
1465
+ )
1466
+ else:
1467
+ self.writeline(f"{dst}.copy_({src});")
1468
+
1469
+ def codegen_multi_output(self, name, value):
1470
+ # in the abi_compatible mode, outputs are retrieved by passing
1471
+ # output pointers, so we skip its codegen here.
1472
+ if not config.abi_compatible:
1473
+ super().codegen_multi_output(name, value)
1474
+
1475
+ def codegen_subgraph_prefix(self, subgraph, outer_inputs, outer_outputs):
1476
+ for inner_input, outer_input in zip(subgraph.graph.graph_inputs, outer_inputs):
1477
+ if config.abi_compatible:
1478
+ # in ABI-compatible mode, we copy the underlying at::Tensor of the conditional
1479
+ # input (outer_input) into another at::Tensor to be used as a subgraph input
1480
+ # (inner_input) in the nested scope. we can't std::move here, as the codegened
1481
+ # outer input may be an expression / rvalue (e.g., reinterpret_view(x)), so we
1482
+ # can't necessarily std::move it back to the origin (x).
1483
+ self.writeline(f"AtenTensorHandle {inner_input}_handle;")
1484
+ self.writeline(
1485
+ f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors_out({outer_input}, &{inner_input}_handle));"
1486
+ )
1487
+ self.writeline(
1488
+ f"RAIIAtenTensorHandle {inner_input}({inner_input}_handle);"
1489
+ )
1490
+ else:
1491
+ self.writeline(
1492
+ f"{self.declare}{inner_input} = {outer_input}{self.ending}"
1493
+ )
1494
+
1495
+ def codegen_subgraph_suffix(self, subgraph, outer_inputs, outer_outputs):
1496
+ for inner_output, outer_output in zip(
1497
+ subgraph.graph.graph_outputs, outer_outputs
1498
+ ):
1499
+ src = inner_output.codegen_reference()
1500
+ if config.abi_compatible:
1501
+ # in ABI-compatible mode, we need to std::move subgraph output (inner_output)
1502
+ # to the conditional output (outer_output), as RAIIAtenTensorHandle's copy
1503
+ # constructor is deleted.
1504
+ src = f"std::move({src})"
1505
+ self.writeline(f"{outer_output} = {src}{self.ending}")
1506
+
1507
+ def codegen_conditional(self, conditional):
1508
+ name = conditional.get_name()
1509
+ outer_inputs = [f"{buf.codegen_reference()}" for buf in conditional.operands]
1510
+ if config.abi_compatible:
1511
+ outer_outputs = []
1512
+ for out in conditional.outputs:
1513
+ # in ABI-compatible mode, ir.MultiOutput is not codegened,
1514
+ # hence pre-declare output variables directly and separately
1515
+ self.writeline(f"RAIIAtenTensorHandle {out.get_name()};")
1516
+ outer_outputs.append(out.get_name())
1517
+ predicate = f"{conditional.predicate.get_name()}_scalar"
1518
+ self.writeline(f"bool {predicate};")
1519
+ # in ABI-compatible mode, we need to use the ABI shim function
1520
+ # to extract a C++ bool from the unrelying scalar bool Tensor
1521
+ self.writeline(
1522
+ f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_bool({conditional.predicate.codegen_reference()}, &{predicate}));"
1523
+ )
1524
+ else:
1525
+ # in non-ABI-compatible mode, we can codegen the conditional outputs
1526
+ # as array of at::Tensor instances, as the ir.MultiOutput is codegened
1527
+ outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))]
1528
+ self.writeline(f"at::Tensor {name}[{len(conditional.outputs)}];")
1529
+ predicate = f"{conditional.predicate.codegen_reference()}.item<bool>()"
1530
+
1531
+ self.writeline(f"if ({predicate}) {{")
1532
+ self.writeline(EnterSubgraphLine(self, conditional.true_subgraph.graph))
1533
+ self.codegen_subgraph(conditional.true_subgraph, outer_inputs, outer_outputs)
1534
+ self.writeline(ExitSubgraphLine(self))
1535
+ self.writeline("} else {")
1536
+ self.writeline(EnterSubgraphLine(self, conditional.false_subgraph.graph))
1537
+ self.codegen_subgraph(conditional.false_subgraph, outer_inputs, outer_outputs)
1538
+ self.writeline(ExitSubgraphLine(self))
1539
+ self.writeline("}")
1540
+
1541
+ def generate_extern_kernel_args_decl_if_needed(
1542
+ self, op_overload, raw_args, output_args
1543
+ ):
1544
+ arg_types = [x.real_type for x in op_overload._schema.arguments]
1545
+ return_types = [x.type for x in op_overload._schema.returns]
1546
+
1547
+ new_tensor_args = []
1548
+ new_int_args = []
1549
+
1550
+ def fill_args(arg, arg_type):
1551
+ static_arg_types = (
1552
+ torch.FloatType,
1553
+ torch.BoolType,
1554
+ torch.StringType,
1555
+ torch.Type,
1556
+ torch.DeviceObjType,
1557
+ )
1558
+ inductor_tensor_buffers = (
1559
+ ir.Buffer,
1560
+ ir.ReinterpretView,
1561
+ )
1562
+
1563
+ if isinstance(arg_type, torch.TensorType):
1564
+ assert isinstance(arg, inductor_tensor_buffers), f"got {type(arg)}"
1565
+ new_tensor_args.append(f"{arg.codegen_reference()}")
1566
+ elif isinstance(arg_type, torch.IntType):
1567
+ # int
1568
+ new_int_args.append(str(arg))
1569
+ elif isinstance(arg_type, torch.SymIntType):
1570
+ # SymInt
1571
+ expr = arg.node.expr if isinstance(arg, torch.SymInt) else arg
1572
+ new_int_args.append(self.expr_printer(expr))
1573
+ elif isinstance(arg_type, torch.NumberType):
1574
+ # Scalar of type int
1575
+ assert isinstance(arg, (int, float, bool))
1576
+ # Only treat int Scalar as dynamic
1577
+ if isinstance(arg, int):
1578
+ new_int_args.append(str(arg))
1579
+ elif isinstance(arg_type, torch.ListType):
1580
+ assert isinstance(arg, (list, tuple))
1581
+
1582
+ # List[Tensor]
1583
+ if isinstance(arg_type.getElementType(), torch.TensorType):
1584
+ new_tensor_args.extend([f"{a.codegen_reference()}" for a in arg])
1585
+ # List[Optional[Tensor]]
1586
+ elif isinstance(
1587
+ arg_type.getElementType(), torch.OptionalType
1588
+ ) and isinstance(
1589
+ arg_type.getElementType().getElementType(), torch.TensorType
1590
+ ):
1591
+ new_tensor_args.extend(
1592
+ [f"{a.codegen_reference()}" for a in arg if a is not None]
1593
+ )
1594
+ # List[int]
1595
+ elif isinstance(arg_type.getElementType(), torch.IntType):
1596
+ new_int_args.extend([str(a) for a in arg])
1597
+ # List[SymInt]
1598
+ elif isinstance(arg_type.getElementType(), torch.SymIntType):
1599
+ expressions = [
1600
+ a.node.expr if isinstance(a, torch.SymInt) else a for a in arg
1601
+ ]
1602
+ new_int_args.extend(
1603
+ [self.expr_printer(expr) for expr in expressions]
1604
+ )
1605
+ # List[Scalar]
1606
+ elif isinstance(arg_type.getElementType(), torch.NumberType):
1607
+ # Only treat int Scalar as dynamic
1608
+ is_int_type = [isinstance(a, int) for a in arg]
1609
+ if any(is_int_type):
1610
+ assert all(
1611
+ is_int_type
1612
+ ), "AOTInductor only supports int scalars of the same type"
1613
+ new_int_args.extend([str(a) for a in arg])
1614
+ else:
1615
+ assert isinstance(
1616
+ arg_type.getElementType(), static_arg_types # type: ignore[arg-type]
1617
+ ), f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}"
1618
+ else:
1619
+ assert isinstance(
1620
+ arg_type, static_arg_types # type: ignore[arg-type]
1621
+ ), f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}"
1622
+
1623
+ for arg, arg_type in zip(raw_args, arg_types):
1624
+ if arg is not None:
1625
+ if isinstance(arg_type, torch.OptionalType):
1626
+ fill_args(arg, arg_type.getElementType())
1627
+ else:
1628
+ fill_args(arg, arg_type)
1629
+
1630
+ def fill_output_arg(arg, return_type):
1631
+ if isinstance(return_type, torch.TensorType):
1632
+ self.writeline(f"AtenTensorHandle {arg}_handle; // output buffer")
1633
+ self.writeline(
1634
+ f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{arg}_handle));"
1635
+ )
1636
+ self.writeline(f"RAIIAtenTensorHandle {arg}({arg}_handle);")
1637
+ new_tensor_args.append(f"{arg}")
1638
+ elif isinstance(return_type, torch.SymIntType):
1639
+ raise NotImplementedError("NYI support for return type: SymInt")
1640
+ elif isinstance(return_type, torch.ListType) and isinstance(
1641
+ return_type.getElementType(), torch.SymIntType
1642
+ ):
1643
+ raise NotImplementedError("NYI support for return type: List[SymInt]")
1644
+ else:
1645
+ raise AssertionError(f"Unsupported return type found: {return_type}")
1646
+
1647
+ # TODO: Only support tensor(s) returns for now, SymInt is not implemented yet
1648
+ for return_type in return_types:
1649
+ if isinstance(return_type, (torch.TensorType)):
1650
+ pass
1651
+ elif isinstance(return_type, torch.OptionalType):
1652
+ assert isinstance(return_type.getElementType(), torch.TensorType)
1653
+ elif isinstance(return_type, torch.ListType):
1654
+ assert isinstance(return_type.getElementType(), torch.TensorType)
1655
+ else:
1656
+ raise NotImplementedError(
1657
+ f"return type {return_type} is not yet supported."
1658
+ )
1659
+
1660
+ for output_arg in output_args:
1661
+ assert output_arg is not None, "Optional return types are not yet supported"
1662
+ if isinstance(output_arg, (list, tuple)):
1663
+ for out in output_arg:
1664
+ fill_output_arg(out, torch.TensorType.get())
1665
+ else:
1666
+ fill_output_arg(output_arg, torch.TensorType.get())
1667
+
1668
+ return new_tensor_args, new_int_args
1669
+
1670
+ def generate_extern_kernel_alloc_and_find_schema_if_needed(
1671
+ self,
1672
+ name,
1673
+ kernel,
1674
+ codegen_args,
1675
+ cpp_op_schema,
1676
+ cpp_kernel_key,
1677
+ cpp_kernel_overload_name="",
1678
+ op_overload=None,
1679
+ raw_args=None,
1680
+ outputs=None,
1681
+ ):
1682
+ if config.is_fbcode():
1683
+ assert op_overload is not None
1684
+ assert raw_args is not None
1685
+ assert outputs is not None
1686
+
1687
+ return self.generate_extern_kernel_alloc_and_find_schema_if_needed_fbcode(
1688
+ name,
1689
+ cpp_kernel_key,
1690
+ op_overload,
1691
+ raw_args,
1692
+ outputs,
1693
+ )
1694
+ else:
1695
+ return self.generate_extern_kernel_alloc_and_find_schema_if_needed_oss(
1696
+ name,
1697
+ kernel,
1698
+ codegen_args,
1699
+ cpp_op_schema,
1700
+ cpp_kernel_key,
1701
+ cpp_kernel_overload_name,
1702
+ )
1703
+
1704
+ def generate_extern_kernel_alloc_and_find_schema_if_needed_oss(
1705
+ self,
1706
+ name,
1707
+ kernel,
1708
+ codegen_args,
1709
+ cpp_op_schema,
1710
+ cpp_kernel_key,
1711
+ cpp_kernel_overload_name="",
1712
+ ):
1713
+ if cpp_kernel_key not in self.extern_call_ops:
1714
+ self.writeline(
1715
+ f"static auto op_{cpp_kernel_key} = c10::Dispatcher::singleton()"
1716
+ )
1717
+ self.writeline(
1718
+ f'\t.findSchemaOrThrow("{kernel}", "{cpp_kernel_overload_name}")'
1719
+ )
1720
+ self.writeline(f"\t.typed<{cpp_op_schema}>();")
1721
+ self.extern_call_ops.add(cpp_kernel_key)
1722
+
1723
+ self.writeline(
1724
+ f"auto {name} = op_{cpp_kernel_key}.call({', '.join(codegen_args)});"
1725
+ )
1726
+
1727
+ def generate_extern_kernel_alloc_and_find_schema_if_needed_fbcode(
1728
+ self,
1729
+ name,
1730
+ cpp_kernel_key,
1731
+ op_overload,
1732
+ raw_args, # contains both args and flatten kwargs
1733
+ outputs,
1734
+ ):
1735
+ def extract_output_name(out):
1736
+ assert out is not None, "None, i.e. optional output is not supported"
1737
+ if isinstance(out, ir.MultiOutput):
1738
+ return out.get_name()
1739
+ elif isinstance(out, (list, tuple)):
1740
+ return type(out)(extract_output_name(o) for o in out)
1741
+ else:
1742
+ raise AssertionError(f"Unexpected output: {type(out)}")
1743
+
1744
+ # output_args has the same pytree structure as outputs
1745
+ output_args = extract_output_name(outputs)
1746
+ if isinstance(output_args, str):
1747
+ output_args = [output_args]
1748
+
1749
+ (
1750
+ tensor_call_args,
1751
+ int_call_args,
1752
+ ) = self.generate_extern_kernel_args_decl_if_needed(
1753
+ op_overload, raw_args, output_args
1754
+ )
1755
+
1756
+ tensor_call_args_str = ", ".join(tensor_call_args)
1757
+ int_call_args_str = ", ".join(int_call_args)
1758
+
1759
+ extern_kernel_node_index = len(V.graph.extern_kernel_nodes) - 1
1760
+
1761
+ self.writeline(
1762
+ f"aoti_torch_proxy_executor_call_function(proxy_executor, "
1763
+ f"{extern_kernel_node_index}, "
1764
+ f"{len(int_call_args)}, "
1765
+ f"std::vector<int64_t>{{{int_call_args_str}}}.data(), "
1766
+ f"{len(tensor_call_args)}, "
1767
+ f"std::vector<AtenTensorHandle>{{{tensor_call_args_str}}}.data());"
1768
+ )
1769
+
1770
+ self.extern_call_ops.add(cpp_kernel_key)
1771
+
1772
+ def generate_reset_kernel_saved_flags(self):
1773
+ pass
1774
+
1775
+ def generate_save_uncompiled_kernels(self):
1776
+ pass
1777
+
1778
+ def val_to_cpp_arg_str(self, type_, val, is_legacy_abi) -> str:
1779
+ if (
1780
+ config.abi_compatible
1781
+ and not is_legacy_abi
1782
+ and isinstance(type_, torch.OptionalType)
1783
+ ):
1784
+ if val is None:
1785
+ return "0" # nullptr is not available in C
1786
+ if not isinstance(type_.getElementType(), torch.TensorType):
1787
+ var_name = f"var_{next(self.arg_var_id)}"
1788
+ self.writeline(f"auto {var_name} = {self.val_to_arg_str(val)};")
1789
+ return f"&{var_name}"
1790
+ elif config.c_shim_version == "2":
1791
+ # Similar to other data type, use pointer to denote optional tensor arg in v2 C shim
1792
+ base_handle = self.val_to_arg_str(val)
1793
+ if "wrap_with_raii_handle_if_needed" in base_handle:
1794
+ # wrap_with_raii_handle_if_needed creates a temp RAIIAtenTensorHandle, so we need to
1795
+ # explicitly store it. Otherwise, it will be destroyed before the fallback kernel call.
1796
+ tmp_var_name = f"var_{next(self.arg_var_id)}"
1797
+ self.writeline(
1798
+ f"RAIIAtenTensorHandle {tmp_var_name} = {base_handle};"
1799
+ )
1800
+ base_handle = tmp_var_name
1801
+ var_name = f"var_{next(self.arg_var_id)}"
1802
+ self.writeline(f"AtenTensorHandle {var_name} = {base_handle}.get();")
1803
+ return f"&{var_name}"
1804
+
1805
+ return self.val_to_arg_str(val)
1806
+
1807
+ def val_to_arg_str(self, val) -> str:
1808
+ if val is None:
1809
+ # When None is passed as an argument, it represents an optional that does not contain a value.
1810
+ if config.abi_compatible:
1811
+ return "0" # nullptr is not available in C
1812
+ return "c10::nullopt"
1813
+ elif isinstance(val, bool):
1814
+ if config.abi_compatible:
1815
+ return "1" if val else "0"
1816
+ else:
1817
+ return "true" if val else "false"
1818
+ elif isinstance(val, int):
1819
+ # uint64_t is long on Linux, but long long on MacOS
1820
+ return f"{val}LL" if sys.platform == "darwin" else f"{val}L"
1821
+ elif isinstance(val, str):
1822
+ return f'"{val}"'
1823
+ elif isinstance(
1824
+ val, (ir.Buffer, ir.ReinterpretView, ir.StorageBox, ir.TensorBox)
1825
+ ):
1826
+ return val.codegen_reference()
1827
+ elif isinstance(val, torch.device):
1828
+ return self.codegen_device(val)
1829
+ elif isinstance(val, torch.dtype):
1830
+ return self.codegen_dtype(val)
1831
+ elif isinstance(val, float) and val in [float("inf"), float("-inf")]:
1832
+ if val == float("inf"):
1833
+ return "std::numeric_limits<float>::infinity()"
1834
+ else:
1835
+ return "-std::numeric_limits<float>::infinity()"
1836
+ elif isinstance(val, (list, tuple)):
1837
+ # FIXME handle embedded optional types?
1838
+ result = f"{{{', '.join(self.val_to_arg_str(x) for x in val)}}}"
1839
+ if config.abi_compatible:
1840
+ static = self.is_statically_known_list_of_ints(val)
1841
+ # Need to pass the array length because we can't use std::vector
1842
+ int_var_array = self.codegen_int_array_var(
1843
+ result,
1844
+ known_statically=static,
1845
+ graph=self.get_codegened_graph(),
1846
+ )
1847
+ return f"{int_var_array}, {len(val)}"
1848
+ else:
1849
+ return result
1850
+ else:
1851
+ return repr(val)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_wrapper_cuda.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import os
3
+ from itertools import chain, count
4
+ from typing import Any, List, Optional, TYPE_CHECKING
5
+
6
+ import sympy
7
+
8
+ from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name
9
+
10
+ from .. import config
11
+ from ..codecache import CudaKernelParamCache
12
+ from ..triton_heuristics import grid as default_grid
13
+ from ..virtualized import V
14
+ from .cpp_wrapper_cpu import CppWrapperCpu
15
+ from .wrapper import SymbolicCallArg
16
+
17
+ if TYPE_CHECKING:
18
+ from ..graph import GraphLowering
19
+
20
+
21
+ def is_int(s: str) -> bool:
22
+ # Cpp code gen adds L at the end of ints
23
+ # Lets remove it for checking whether we have an int or not
24
+ if s and s[-1] == "L":
25
+ s = s[:-1]
26
+ try:
27
+ int(s)
28
+ except ValueError:
29
+ return False
30
+ except TypeError:
31
+ return False
32
+ return True
33
+
34
+
35
+ def is_float(s: str) -> bool:
36
+ try:
37
+ float(s)
38
+ except ValueError:
39
+ return False
40
+ return True
41
+
42
+
43
+ class CppWrapperCuda(CppWrapperCpu):
44
+ """
45
+ Generates cpp wrapper for running on GPU and calls CUDA kernels
46
+ """
47
+
48
+ def __init__(self):
49
+ self.device = "cuda"
50
+ super().__init__()
51
+ self.grid_id = count()
52
+ self.cuda = True
53
+
54
+ def write_header(self):
55
+ if V.graph.is_const_graph:
56
+ # We do not write header for constant graph, it will be written by main module.
57
+ return
58
+
59
+ super().write_header()
60
+
61
+ self.header.splice("#include <filesystem>")
62
+ if config.abi_compatible:
63
+ self.header.splice(
64
+ "#include <torch/csrc/inductor/aoti_runtime/utils_cuda.h>"
65
+ )
66
+ else:
67
+ self.header.splice(
68
+ """
69
+ #include <c10/cuda/CUDAGuard.h>
70
+ #include <c10/cuda/CUDAStream.h>
71
+ #include <ATen/cuda/EmptyTensor.h>
72
+ """
73
+ )
74
+
75
+ self.header.splice(
76
+ """
77
+ #define CUDA_DRIVER_CHECK(EXPR) \\
78
+ do { \\
79
+ CUresult code = EXPR; \\
80
+ const char *msg; \\
81
+ cuGetErrorString(code, &msg); \\
82
+ if (code != CUDA_SUCCESS) { \\
83
+ throw std::runtime_error( \\
84
+ std::string("CUDA driver error: ") + \\
85
+ std::string(msg)); \\
86
+ } \\
87
+ } while (0);
88
+
89
+ namespace {
90
+
91
+ struct Grid {
92
+ Grid(uint32_t x, uint32_t y, uint32_t z)
93
+ : grid_x(x), grid_y(y), grid_z(z) {}
94
+ uint32_t grid_x;
95
+ uint32_t grid_y;
96
+ uint32_t grid_z;
97
+
98
+ bool is_non_zero() {
99
+ return grid_x > 0 && grid_y > 0 && grid_z > 0;
100
+ }
101
+ };
102
+
103
+ } // anonymous namespace
104
+
105
+ static inline CUfunction loadKernel(
106
+ std::string filePath,
107
+ const std::string &funcName,
108
+ uint32_t sharedMemBytes,
109
+ const std::optional<std::string> &cubinDir = std::nullopt) {
110
+ if (cubinDir) {
111
+ std::filesystem::path p1{*cubinDir};
112
+ std::filesystem::path p2{filePath};
113
+ filePath = (p1 / p2.filename()).string();
114
+ }
115
+
116
+ CUmodule mod;
117
+ CUfunction func;
118
+ CUDA_DRIVER_CHECK(cuModuleLoad(&mod, filePath.c_str()));
119
+ CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str()));
120
+ if (sharedMemBytes > 0) {
121
+ CUDA_DRIVER_CHECK(cuFuncSetAttribute(
122
+ func,
123
+ CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
124
+ sharedMemBytes
125
+ ))
126
+ }
127
+ return func;
128
+ }
129
+
130
+ static inline void launchKernel(
131
+ CUfunction func,
132
+ uint32_t gridX,
133
+ uint32_t gridY,
134
+ uint32_t gridZ,
135
+ uint32_t numWarps,
136
+ uint32_t sharedMemBytes,
137
+ void* args[],
138
+ cudaStream_t stream) {
139
+ CUDA_DRIVER_CHECK(cuLaunchKernel(
140
+ func, gridX, gridY, gridZ, 32*numWarps, 1, 1, sharedMemBytes, stream, args, nullptr
141
+ ));
142
+ }
143
+ """
144
+ )
145
+
146
+ def write_get_raw_stream(self, index, graph=None):
147
+ name = f"stream{index}"
148
+ self.writeline(f"cudaStream_t {name};")
149
+ self.writeline(
150
+ f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream({index}, (void**)&{name}));"
151
+ )
152
+ return name
153
+
154
+ def define_kernel(
155
+ self, name: str, kernel: str, metadata: Optional[str] = None, cuda=True
156
+ ):
157
+ if not cuda:
158
+ return super().define_kernel(name, kernel, metadata, cuda)
159
+
160
+ def generate(self, is_inference):
161
+ self.prefix.writeline("\n")
162
+ if not V.graph.aot_mode:
163
+ for kernel in chain(
164
+ self.src_to_kernel.values(),
165
+ [entry[0] for entry in self.user_defined_kernel_cache.values()],
166
+ ):
167
+ self.prefix.writeline(f"static CUfunction {kernel} = nullptr;")
168
+ self.prefix.writeline("\n")
169
+ return super().generate(is_inference)
170
+
171
+ @functools.lru_cache(None)
172
+ def generate_load_kernel_once(
173
+ self,
174
+ name: str,
175
+ mangled_name: str,
176
+ cubin_path: str,
177
+ shared_mem: int,
178
+ graph: "GraphLowering", # for per-graph caching
179
+ ):
180
+ if V.graph.aot_mode:
181
+ self.writeline(f"if (kernels.{name} == nullptr) {{")
182
+ self.writeline(
183
+ f""" kernels.{name} = loadKernel("{cubin_path}", "{mangled_name}", {shared_mem}, this->cubin_dir_);"""
184
+ )
185
+ self.writeline("}")
186
+ else:
187
+ self.writeline(f"if ({name} == nullptr) {{")
188
+ self.writeline(
189
+ f""" {name} = loadKernel("{cubin_path}", "{mangled_name}", {shared_mem});"""
190
+ )
191
+ self.writeline("}")
192
+
193
+ def generate_args_decl(self, call_args):
194
+ dynamic_symbols = V.graph.sizevars.free_symbols()
195
+ # TODO: only works for constant now, need type info
196
+ new_args = []
197
+ for arg in call_args:
198
+ var_name = f"var_{next(self.arg_var_id)}"
199
+ if isinstance(arg, (sympy.Integer, sympy.Symbol, SymbolicCallArg)):
200
+ self.writeline(f"auto {var_name} = {arg};")
201
+ elif isinstance(arg, sympy.Expr):
202
+ self.writeline(f"auto {var_name} = {self.expr_printer(arg)};")
203
+ elif is_int(arg):
204
+ self.writeline(f"int {var_name} = {arg};")
205
+ elif is_float(arg):
206
+ self.writeline(f"float {var_name} = {arg};")
207
+ elif any(str(arg) == s.name for s in dynamic_symbols):
208
+ self.writeline(f"auto {var_name} = {arg};")
209
+ elif arg == "nullptr":
210
+ self.writeline(f"auto {var_name} = nullptr;")
211
+ elif arg == "c10::nullopt":
212
+ self.writeline(f"auto {var_name} = c10::nullopt;")
213
+ else:
214
+ if config.abi_compatible:
215
+ self.writeline(f"CUdeviceptr {var_name};")
216
+ self.writeline(
217
+ f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr({arg}, reinterpret_cast<void**>(&{var_name})));"
218
+ )
219
+ else:
220
+ self.writeline(
221
+ f"CUdeviceptr {var_name} = reinterpret_cast<CUdeviceptr>({arg}.data_ptr());"
222
+ )
223
+ new_args.append(f"&{var_name}")
224
+
225
+ return ", ".join(new_args)
226
+
227
+ def generate_default_grid(self, name: str, grid: List[Any], cuda: bool = True):
228
+ """
229
+ Generate grid configs for launching a CUDA kernel using the grid
230
+ function from triton_heuristics.
231
+ """
232
+ if not cuda:
233
+ return grid
234
+ assert isinstance(grid, list), f"expected {grid=} to be a list"
235
+ grid = [e.inner_expr if isinstance(e, SymbolicCallArg) else e for e in grid]
236
+ grid_fn = default_grid(*grid)
237
+ params = CudaKernelParamCache.get(name)
238
+ assert (
239
+ params is not None
240
+ ), f"cuda kernel parameters for {name} should already exist at this moment, only found {CudaKernelParamCache.get_keys()}"
241
+ block_cfg = {
242
+ "XBLOCK": params["x_block"],
243
+ "YBLOCK": params["y_block"],
244
+ "ZBLOCK": params["z_block"],
245
+ }
246
+ return grid_fn(block_cfg)
247
+
248
+ def generate_kernel_call(
249
+ self,
250
+ name,
251
+ call_args,
252
+ grid=None,
253
+ device_index=None,
254
+ cuda=True,
255
+ triton=True,
256
+ arg_types=None,
257
+ grid_fn: str = "grid",
258
+ triton_meta=None,
259
+ ):
260
+ if not cuda:
261
+ # Even in CppWrapperCuda, we may see cpp kernels
262
+ return super().generate_kernel_call(
263
+ name, call_args, grid, device_index, cuda, triton, arg_types
264
+ )
265
+
266
+ params = CudaKernelParamCache.get(name)
267
+ assert (
268
+ params is not None
269
+ ), f"cuda kernel parameters for {name} should already exist at this moment"
270
+ mangled_name = params.get("mangled_name", None)
271
+ assert mangled_name is not None, "missing mangled_name"
272
+ cubin_path = params.get(get_cpp_wrapper_cubin_path_name(), None)
273
+ assert cubin_path is not None and os.path.exists(
274
+ cubin_path
275
+ ), f"cubin file should already exist at this moment: {cubin_path}"
276
+ shared_mem = params.get("shared_mem", 0)
277
+
278
+ self.generate_load_kernel_once(
279
+ name, mangled_name, cubin_path, shared_mem, V.graph
280
+ )
281
+
282
+ # args with value 1 are added into equal_to_1 and constants
283
+ # in triton_meta (in the Python codegen) which makes them
284
+ # inlined in the PTX and compiled CUBIN
285
+ if (
286
+ triton_meta is not None
287
+ and "configs" in triton_meta
288
+ and triton_meta["configs"]
289
+ ):
290
+ equal_to_1 = triton_meta["configs"][0].equal_to_1
291
+ call_args = [arg for i, arg in enumerate(call_args) if i not in equal_to_1]
292
+
293
+ call_args = self.generate_args_decl(call_args)
294
+ kernel_args_var = f"kernel_args_var_{next(self.kernel_callsite_id)}"
295
+ self.writeline(f"void* {kernel_args_var}[] = {{{call_args}}};")
296
+ stream = (
297
+ "stream"
298
+ if V.graph.aot_mode
299
+ else self.write_get_raw_stream(device_index, V.graph)
300
+ )
301
+ grid_name = f"{name}_grid_{next(self.grid_id)}"
302
+ assert isinstance(
303
+ grid, (list, tuple)
304
+ ), f"expected grid to be a list or tuple but got: {grid=}"
305
+
306
+ grid = [V.graph.sizevars.simplify(item) for item in grid]
307
+ grid_uses_symbolic_shapes = any(item.free_symbols for item in grid)
308
+ grid_args = [self.grid_expr_printer(item) for item in grid]
309
+ grid_args_str = ", ".join(grid_args)
310
+ self.writeline(f"Grid {grid_name} = Grid({grid_args_str});")
311
+
312
+ if grid_uses_symbolic_shapes:
313
+ self.writeline(f"if ({grid_name}.is_non_zero()) {{")
314
+ kernel_var_name = f"kernels.{name}" if V.graph.aot_mode else name
315
+ self.writeline(
316
+ "launchKernel({}, {}, {}, {}, {}, {}, {}, {});".format(
317
+ kernel_var_name,
318
+ f"{grid_name}.grid_x",
319
+ f"{grid_name}.grid_y",
320
+ f"{grid_name}.grid_z",
321
+ params["num_warps"],
322
+ params["shared_mem"],
323
+ kernel_args_var,
324
+ stream,
325
+ )
326
+ )
327
+ if grid_uses_symbolic_shapes:
328
+ self.writeline("}")
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__init__.py ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_template.cpython-311.pyc ADDED
Binary file (12.5 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_kernel.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union
3
+
4
+ from ... import ir
5
+ from ...autotune_process import CUDABenchmarkRequest
6
+ from ...ir import Buffer, CUDATemplateBuffer, IRNode, Layout, TensorBox
7
+ from ...select_algorithm import ChoiceCaller
8
+ from ...utils import sympy_product
9
+ from ...virtualized import V
10
+
11
+ from ..common import IndentedBuffer, Kernel, OpOverrides, PrimitiveInfoType
12
+ from ..cpp import CppPrinter, DTYPE_TO_CPP
13
+
14
+ if TYPE_CHECKING:
15
+ from torch._inductor.codegen.cuda.cuda_template import CUDATemplate
16
+
17
+ log = logging.getLogger(__name__)
18
+
19
+ cexpr = CppPrinter().doprint
20
+
21
+
22
+ def _normalize_idx(index: int, total_length: int) -> int:
23
+ return index if index >= 0 else index + total_length
24
+
25
+
26
+ class CUDAKernel(Kernel):
27
+ """
28
+ Baseclass for CUDA / Cutlass based Kernels
29
+ """
30
+
31
+ overrides = OpOverrides # type: ignore[assignment]
32
+
33
+
34
+ class CUDATemplateKernel(CUDAKernel):
35
+ """
36
+ Template kernels defined by CUDA / Cutlass in C++.
37
+ """
38
+
39
+ _EXTRA_CPP_ARGS = "size_t* workspace_size, uint8_t* workspace, cudaStream_t stream"
40
+
41
+ def __init__(self, kernel_name):
42
+ """
43
+ Initializes a new instance of the CUDATemplateKernel class.
44
+
45
+ Args:
46
+ kernel_name (str): The name of the kernel.
47
+ """
48
+ super().__init__()
49
+ self.kernel_name = kernel_name
50
+ # Mapping from arg name to IRNode.
51
+ self.named_nodes: Dict[str, IRNode] = {}
52
+
53
+ def arg_name(self, node: IRNode) -> Optional[str]:
54
+ """
55
+ Returns arg name of a given input or output node.
56
+ """
57
+ if node is None:
58
+ return None
59
+ return {**self.args.input_buffers, **self.args.output_buffers}.get(
60
+ node.get_name(), None
61
+ )
62
+
63
+ def check_not_null(self, node: IRNode) -> str:
64
+ """
65
+ Generates code to check that a node is not null.
66
+ """
67
+
68
+ if node is None:
69
+ return ""
70
+
71
+ size_str = self.size(node, 0, -1)
72
+ name_str = self.arg_name(node)
73
+ if name_str is None:
74
+ return ""
75
+
76
+ res = IndentedBuffer(initial_indent=2)
77
+ res.tabwidth = 1
78
+ res.splice(
79
+ f"""
80
+ {{
81
+ if (!{name_str}) {{
82
+ int64_t {name_str}_size = {size_str};
83
+ if ({name_str}_size > 0) {{
84
+ throw std::runtime_error("input {name_str} is null but size is not 0!");
85
+ }}
86
+ }}
87
+ }}
88
+ """
89
+ )
90
+ return res.getvalue()
91
+
92
+ def def_kernel(
93
+ self,
94
+ inputs: List[IRNode],
95
+ outputs: List[IRNode],
96
+ names_str: str = "",
97
+ input_reorder: Optional[List[int]] = None,
98
+ ) -> str:
99
+ """
100
+ Hook called from template code to generate function definition and
101
+ needed args.
102
+
103
+ Args:
104
+ inputs: List of input IRNodes
105
+ outputs: List of output IRNodes
106
+ names_str: Comma separated list of input + output argument names.
107
+ input_reorder: The actual order of input nodes.
108
+ e.g. The template might have input argument defined as [X, W, Bias],
109
+ and the actual input passed into this template could be [Bias, X, W].
110
+ In this case, the `input_reorder` would be [2, 0, 1].
111
+ """
112
+
113
+ names = [x.strip() for x in names_str.strip().split(",")]
114
+ if len(inputs) + len(outputs) != len(names):
115
+ raise RuntimeError(
116
+ f"{len(inputs) + len(outputs)=} != {len(names)=}, {inputs=}, {outputs=}, {names=}"
117
+ )
118
+
119
+ if input_reorder is not None:
120
+ assert len(inputs) == len(input_reorder)
121
+ else:
122
+ input_reorder = list(range(len(inputs)))
123
+
124
+ for idx in input_reorder:
125
+ name = names[idx]
126
+ node = inputs[idx]
127
+ if node is not None:
128
+ self.named_nodes[name] = node
129
+ self.args.input_buffers[node.get_name()] = name
130
+
131
+ for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs):
132
+ if node is not None:
133
+ self.named_nodes[name] = node
134
+ self.args.output_buffers[node.get_name()] = name
135
+
136
+ arg_defs, *_ = self.args.cpp_argdefs()
137
+ return f"PT_EXPORT int {self.kernel_name}({', '.join(arg_defs)}, {self._EXTRA_CPP_ARGS})"
138
+
139
+ def call_kernel(
140
+ self, name: str, node: "CUDATemplateBuffer", epilogue_nodes: List[ir.Buffer] # type: ignore[name-defined]
141
+ ) -> None:
142
+ """
143
+ Generates code to call the kernel through V.graph.wrapper_code.
144
+ used from within torch._inductor.wrapper.WrapperCodeGen
145
+
146
+ name: Name of kernel function.
147
+ node: The CUDATemplateBuffer node which contains information about the kernel, it's fused epilogue nodes
148
+ as well as all required inputs and outputs.
149
+ """
150
+ wrapper = V.graph.wrapper_code
151
+ _, call_args, _ = self.args.python_argdefs()
152
+ # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
153
+ for i in range(len(call_args)):
154
+ if V.graph.is_unspec_arg(call_args[i]):
155
+ call_args[i] = call_args[i] + ".item()"
156
+ else:
157
+ call_args[i] = f"c_void_p({call_args[i]}.data_ptr())"
158
+
159
+ # workspace_size ptr is NULL to mark this call is not intended for retrieving workspace_size.
160
+ # workspace_size should have already been retrieved prior to this call.
161
+ call_args.append("None")
162
+
163
+ if node.get_workspace_size() > 0:
164
+ call_args.append(f"c_void_p({node.get_name()}_workspace.data_ptr())")
165
+ else:
166
+ call_args.append("None")
167
+
168
+ wrapper.generate_kernel_call(
169
+ name,
170
+ call_args,
171
+ device_index=V.graph.scheduler.current_device.index,
172
+ cuda=True,
173
+ triton=False,
174
+ )
175
+
176
+ def dtype(self, node: IRNode) -> Optional[str]:
177
+ """
178
+ Generates code which represents dtype of a given node.
179
+ """
180
+
181
+ if node is None:
182
+ return "void"
183
+ return DTYPE_TO_CPP.get(node.get_layout().dtype)
184
+
185
+ def offset(self, node: IRNode) -> str:
186
+ """
187
+ Generates code which represents offset of a given node.
188
+ """
189
+
190
+ if node is None:
191
+ return "0"
192
+ return str(node.get_layout().offset)
193
+
194
+ def ptr(self, node: IRNode) -> str:
195
+ """
196
+ Generates code which represents pointer of a given node.
197
+ """
198
+
199
+ if node is None:
200
+ return "nullptr"
201
+ arg_name = self.arg_name(node)
202
+ if arg_name is None:
203
+ return "nullptr"
204
+ offset = self.offset(node)
205
+ return arg_name if offset == "0" else f"{arg_name} + {offset}"
206
+
207
+ def size(
208
+ self,
209
+ node: IRNode,
210
+ start_index: int,
211
+ end_index: Optional[int] = None,
212
+ default_value: int = 0,
213
+ ) -> str:
214
+ """
215
+ Hook called from template code to get the size of an arg.
216
+ Generates code which represents size of a given node in [start_index, end_index).
217
+ If node is None, returns default_value.
218
+
219
+ TODO: Will add needed args to pass it in if it is dynamic.
220
+ """
221
+
222
+ if node is None:
223
+ return str(default_value)
224
+
225
+ start_index = _normalize_idx(start_index, len(node.get_size()))
226
+ if end_index is None:
227
+ end_index = start_index
228
+ end_index = _normalize_idx(end_index, len(node.get_size()))
229
+
230
+ sizes = node.get_size()[start_index : end_index + 1]
231
+ if len(sizes) == 0:
232
+ return str(default_value)
233
+
234
+ val = sympy_product(sizes)
235
+ return cexpr(self.rename_indexing(val))
236
+
237
+ def stride(self, node: IRNode, index: int, default_value: int = 0) -> str:
238
+ """
239
+ Hook called from template code to get the stride of an arg.
240
+ Generates code which represents stride of a given node at index.
241
+ If node is None, returns default_value.
242
+
243
+ TODO: Will add needed args to pass it in if it is dynamic.
244
+ """
245
+
246
+ if node is None:
247
+ return str(default_value)
248
+
249
+ index = _normalize_idx(index, len(node.get_size()))
250
+ if index < 0:
251
+ return str(default_value)
252
+
253
+ stride = node.get_stride()[index]
254
+ return cexpr(self.rename_indexing(stride))
255
+
256
+ def row_or_column_stride(self, node: IRNode, default_value: int = 0) -> str:
257
+ """
258
+ Hook called from template code to get the row or column stride of an arg.
259
+ This is required by some CUTLASS 2.X APIs.
260
+ If the node is in row_major, it returns stride[-2].
261
+ If the node is in column_major, it returns stride[-1].
262
+
263
+ TODO: Will add needed args to pass it in if it is dynamic.
264
+ """
265
+
266
+ if node is None or len(node.get_stride()) < 2:
267
+ return str(default_value)
268
+
269
+ stride0 = node.get_stride()[-1]
270
+ stride1 = node.get_stride()[-2]
271
+ if stride0 == 1:
272
+ return cexpr(self.rename_indexing(stride1))
273
+ elif stride1 == 1:
274
+ return cexpr(self.rename_indexing(stride0))
275
+ else:
276
+ raise RuntimeError(
277
+ f"At least 1 stride should be 1. Strides: {node.get_stride()=}"
278
+ )
279
+
280
+
281
+ class CUDATemplateCaller(ChoiceCaller):
282
+ """
283
+ CUDATemplateCaller
284
+
285
+ This class represents a caller for CUDA template kernels. It is a subclass of ChoiceCaller.
286
+ Attributes:
287
+ name (str): The name of the caller.
288
+ category (str): The category of the caller.
289
+ bmreq (CUDABenchmarkRequest): The benchmark request for the caller.
290
+ template_buffer (CUDATemplateBuffer): The template buffer for the caller.
291
+ """
292
+
293
+ def __init__(
294
+ self,
295
+ name: str,
296
+ category: str,
297
+ input_nodes: List[Buffer],
298
+ layout: Layout,
299
+ make_kernel_render: Callable[[CUDATemplateBuffer, Optional[List[IRNode]]], str],
300
+ bmreq: CUDABenchmarkRequest,
301
+ template: "CUDATemplate", # type: ignore[name-defined]
302
+ info_kwargs: Optional[Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]], # type: ignore[type-arg]
303
+ ):
304
+ super().__init__(name, input_nodes, layout)
305
+ self.category = category
306
+ self.make_kernel_render = make_kernel_render
307
+ self.bmreq = bmreq
308
+ self.template = template
309
+ self.info_kwargs = info_kwargs
310
+
311
+ def precompile(self) -> None:
312
+ assert self.bmreq is not None
313
+ self.bmreq.precompile()
314
+
315
+ def benchmark(self, *args, out) -> float:
316
+ assert self.bmreq is not None
317
+ return self.bmreq.benchmark(
318
+ *args, output_tensor=out
319
+ ) # @TODO: Hack for ensuring that Cutlass Kernel is preferred
320
+
321
+ def __str__(self):
322
+ return f"CUDATemplateCaller(source_file={self.bmreq.source_file})"
323
+
324
+ def call_name(self) -> str:
325
+ return f"cuda_template_kernels.{self.name}"
326
+
327
+ def hash_key(self) -> str:
328
+ return "-".join(
329
+ [
330
+ self.category,
331
+ self.bmreq.hash_key,
332
+ ]
333
+ )
334
+
335
+ def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]:
336
+ """Information returned here is logged to the autotune log file when that is enabled."""
337
+ if self.info_kwargs is not None and "op" in self.info_kwargs:
338
+ op: Any = self.info_kwargs["op"]
339
+ epilogue_node_names: List[str] = [
340
+ getattr(en, "name", "no_name")
341
+ for en in self.info_kwargs.get("epilogue_nodes", []) # type: ignore[union-attr]
342
+ ]
343
+ epilogue_node_strs: List[str] = [
344
+ str(en) for en in self.info_kwargs.get("epilogue_nodes", []) # type: ignore[union-attr]
345
+ ]
346
+ return {
347
+ "backend": "CUDA",
348
+ "op_type": type(op).__name__,
349
+ "op_conf_name": str(op.configuration_name()),
350
+ "op_arch": str(op.arch),
351
+ "tile_shape": str(op.tile_description.tile_shape),
352
+ "epilogue_schedule": str(op.epilogue_schedule),
353
+ "kernel_schedule": str(op.kernel_schedule),
354
+ "element_accumulator": str(op.accumulator_type()),
355
+ "op_name": str(op.procedural_name()),
356
+ "epilogue_node_names": epilogue_node_names, # type: ignore[dict-item]
357
+ "epilogue_node_strs": epilogue_node_strs, # type: ignore[dict-item]
358
+ "instruction_shape": str(
359
+ op.tile_description.math_instruction.instruction_shape
360
+ ),
361
+ }
362
+ else:
363
+ return {"backend": "CUDA", "op_type": "unknown"}
364
+
365
+ def output_node(self) -> TensorBox:
366
+ return TensorBox.create(
367
+ CUDATemplateBuffer(
368
+ layout=self.layout,
369
+ inputs=self.input_nodes,
370
+ make_kernel_render=self.make_kernel_render,
371
+ workspace_size=self.bmreq.workspace_size,
372
+ template=self.template,
373
+ )
374
+ )
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (252 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/gemm_operation_extensions.cpython-311.pyc ADDED
Binary file (10.2 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/gemm_template.py ADDED
@@ -0,0 +1,706 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ import re
4
+ from typing import cast, Dict, List, Optional, Tuple
5
+
6
+ from ...config import cuda as inductor_cuda_config
7
+ from ...ir import Buffer, CUDATemplateBuffer, FixedLayout, IRNode, Layout
8
+ from ..common import IndentedBuffer
9
+
10
+ from . import cutlass_utils
11
+ from .cuda_kernel import CUDATemplateKernel
12
+ from .cuda_template import CUTLASSTemplate
13
+ from .cutlass_epilogue_gen import (
14
+ CutlassEVTEpilogueArgumentFormatter,
15
+ CutlassEVTEpilogueTypeFormatter,
16
+ )
17
+
18
+ log = logging.getLogger(__name__)
19
+
20
+ GEMM_TEMPLATE = r"""
21
+ {{template.header().getvalue()}}
22
+ {{template.globals().getvalue()}}
23
+ {{instance_definition}}
24
+ // When workspace_size is not a nullptr, populates requested workspace_size and returns.
25
+ // Otherwise, computes the Gemm kernel using the given workspace ptr.
26
+ extern "C" {
27
+ {{kernel.def_kernel(inputs=[X, W, Bias], outputs=[Y], names_str="X, W, Bias, Y", input_reorder=input_reorder)}} {
28
+ try {
29
+ {{kernel.check_not_null(X)}}
30
+ {{kernel.check_not_null(W)}}
31
+ {{kernel.check_not_null(Bias)}}
32
+ {{kernel.check_not_null(Y)}}
33
+ int64_t B = {{kernel.size(Y, 0, -3, default_value=1)}};
34
+ int64_t M = {{kernel.size(X, -2)}};
35
+ int64_t K = {{kernel.size(X, -1)}};
36
+ int64_t N = {{kernel.size(W, -1)}};
37
+ using ElementComputeEpilogue = {{instance_type}}::ElementAccumulator;
38
+ using coord_t = cutlass::gemm::GemmCoord::Index;
39
+ {{instance_type}}::Arguments arguments;
40
+ {{template.render_gemm_arguments(argument_template, epilogue_template, should_swap_xw,
41
+ X, W, Bias, Y, alpha, beta, kernel, epilogue_args)}}
42
+ {{instance_type}} gemm_op;
43
+ if (workspace_size) {
44
+ *workspace_size = gemm_op.get_workspace_size(arguments);
45
+ return 0;
46
+ }
47
+ {
48
+ auto status = gemm_op.can_implement(arguments);
49
+ CUTLASS_CHECK(status);
50
+ }
51
+ {
52
+ auto status = gemm_op.initialize(arguments, workspace, stream);
53
+ CUTLASS_CHECK(status);
54
+ }
55
+ {
56
+ auto status = gemm_op(stream);
57
+ CUTLASS_CHECK(status);
58
+ }
59
+ }
60
+ catch (std::exception& e) {
61
+ std::cerr << "Runtime error: " << e.what() << std::endl;
62
+ return -1;
63
+ }
64
+ catch (...) {
65
+ return -1;
66
+ }
67
+ return 0;
68
+ }
69
+ }
70
+ """
71
+
72
+
73
+ GEMM_ARGS_CUTLASS_2X = r"""
74
+ int64_t batch_stride_x = {{kernel.stride(X, -3)}};
75
+ int64_t row_stride_x = {{kernel.row_or_column_stride(X)}};
76
+ int64_t batch_stride_w = {{kernel.stride(W, -3)}};
77
+ int64_t row_stride_w = {{kernel.row_or_column_stride(W)}};
78
+ int64_t batch_stride_bias = {{kernel.stride(Bias, -3)}};
79
+ int64_t row_stride_bias = {{kernel.row_or_column_stride(Bias)}};
80
+ int64_t batch_stride_y = {{kernel.stride(Y, -3)}};
81
+ int64_t row_stride_y = {{kernel.row_or_column_stride(Y)}};
82
+ // Initialize GemmUniversalInstance arguments.
83
+ arguments = {
84
+ {{template.gemm_mode()}}, // GemmUniversalMode mode
85
+ {
86
+ static_cast<coord_t>(M),
87
+ static_cast<coord_t>(N),
88
+ static_cast<coord_t>(K)
89
+ }, // GemmCoord problem_size
90
+ {{split_k if split_k > 1 else 'B'}}, // int batch_count
91
+ {ElementComputeEpilogue({{alpha}}), ElementComputeEpilogue({{beta}})}, // typename EpilogueOutputOp::Params epilogue
92
+ {{template.cutlass_type_cast(X, kernel.ptr(X))}}, // void const * ptr_A
93
+ {{template.cutlass_type_cast(W, kernel.ptr(W))}}, // void const * ptr_B
94
+ {{template.cutlass_type_cast(Bias, kernel.ptr(Bias))}}, // void const * ptr_C
95
+ {{template.cutlass_type_cast(Y, kernel.ptr(Y))}}, // void * ptr_D
96
+ batch_stride_x, // int64_t batch_stride_A
97
+ batch_stride_w, // int64_t batch_stride_B
98
+ batch_stride_bias, // int64_t batch_stride_C
99
+ batch_stride_y, // int64_t batch_stride_D
100
+ row_stride_x, // typename LayoutA::Stride::LongIndex lda
101
+ row_stride_w, // typename LayoutB::Stride::LongIndex ldb
102
+ row_stride_bias, // typename LayoutC::Stride::LongIndex ldc
103
+ row_stride_y, // typename LayoutC::Stride::LongIndex ldd
104
+ };
105
+ """
106
+
107
+
108
+ GEMM_ARGS_CUTLASS_3X = r"""
109
+ // Initialize GemmUniversal3xInstance arguments.
110
+ arguments = {
111
+ {{template.gemm_mode()}}, // GemmUniversalMode mode
112
+ {
113
+ static_cast<coord_t>({{M}}),
114
+ static_cast<coord_t>({{N}}),
115
+ static_cast<coord_t>(K),
116
+ static_cast<coord_t>(B)
117
+ }, // ProblemShape problem_shape
118
+ {
119
+ {{template.cutlass_type_cast(X, kernel.ptr(X))}}, // ElementA const* ptr_A
120
+ {
121
+ {{template.cute_int(kernel.stride(X, -2), "stride_x0")}},
122
+ {{template.cute_int(kernel.stride(X, -1), "stride_x1")}},
123
+ {{template.cute_int(kernel.stride(X, -3), "batch_stride_x")}}
124
+ }, // StrideA dA
125
+ {{template.cutlass_type_cast(W, kernel.ptr(W))}}, // ElementB const* ptr_B
126
+ {
127
+ {{template.cute_int(kernel.stride(W, -1), "stride_w1")}},
128
+ {{template.cute_int(kernel.stride(W, -2), "stride_w0")}},
129
+ {{template.cute_int(kernel.stride(W, -3), "batch_stride_w")}}
130
+ }, // StrideB dB
131
+ }, // MainloopArguments mainloop
132
+ {{epilogue_arguments}}
133
+ };
134
+ """
135
+
136
+ GEMM_ARGS_CUTLASS_3X_EPILOGUE = r"""
137
+ // see https://tinyurl.com/4rk89z48
138
+ {
139
+ {{epilogue_args}}, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT )
140
+ {{template.cutlass_type_cast(Bias, kernel.ptr(Bias))}}, // ElementC const* ptr_C
141
+ {
142
+ {{template.cute_int(kernel.stride(Bias, -2, 1), "stride_bias0")}},
143
+ {{template.cute_int(kernel.stride(Bias, -1, 1), "stride_bias1")}},
144
+ {{template.cute_int(kernel.stride(Bias, -3), "batch_stride_bias")}}
145
+ }, // StrideC dC
146
+ {{template.cutlass_type_cast(Y, kernel.ptr(Y))}}, // ElementD const* ptr_D
147
+ {
148
+ {{template.cute_int(kernel.stride(Y, -2), "stride_y0")}},
149
+ {{template.cute_int(kernel.stride(Y, -1), "stride_y1")}},
150
+ {{template.cute_int(kernel.stride(Y, -3), "batch_stride_y")}}
151
+ }, // StrideD dD
152
+ }, // EpilogueArguments epilogue
153
+ """
154
+
155
+
156
+ class CUTLASSGemmTemplate(CUTLASSTemplate):
157
+ """
158
+ CUTLASS GEMM template, which is used to generate CUTLASS GEMM kernels
159
+ including those which allow flexible fusions with epilogues.
160
+ """
161
+
162
+ def __init__(
163
+ self,
164
+ input_nodes: List[Buffer],
165
+ layout: Layout,
166
+ alpha: float,
167
+ beta: float,
168
+ input_reorder: Optional[List[int]] = None,
169
+ can_fuse_epilogue: Optional[bool] = None,
170
+ ):
171
+ """
172
+ Args:
173
+ input_nodes: input nodes of the kernel
174
+ layout: layout of the output node
175
+ alpha: alpha value of the GEMM operation
176
+ beta: beta value of the GEMM operation
177
+ input_reorder: reorder of the input nodes
178
+ can_fuse_epilogue: If set to True, will only list and use operators capable of flexible epilogue fusions.
179
+ If False, it will not use those. If None, both may be listed, but it will not allow fusions.
180
+ Defaults to None
181
+ """
182
+ super().__init__("cutlass_gemm", input_nodes, layout, input_reorder)
183
+ self.alpha = alpha
184
+ self.beta = beta
185
+ self.can_fuse_epilogue = can_fuse_epilogue
186
+
187
+ @staticmethod
188
+ def add_cutlass_gemm_choices(
189
+ choices,
190
+ layout,
191
+ input_nodes,
192
+ alpha=1,
193
+ beta=0,
194
+ input_reorder=None,
195
+ fuseable=True,
196
+ non_fuseable=True,
197
+ ):
198
+ if non_fuseable:
199
+ if fuseable:
200
+ # list both fuseable and non-fuseable ops, and treat them all as non-fuseable
201
+ can_fuse_epilogue = False
202
+ else:
203
+ can_fuse_epilogue = None
204
+
205
+ cutlass_template = CUTLASSGemmTemplate(
206
+ input_nodes,
207
+ layout,
208
+ alpha=alpha,
209
+ beta=beta,
210
+ input_reorder=input_reorder,
211
+ can_fuse_epilogue=can_fuse_epilogue,
212
+ )
213
+ ops = cutlass_template.gen_ops()
214
+ for op in ops:
215
+ cutlass_template.maybe_append_choice(
216
+ choices,
217
+ op=op,
218
+ )
219
+ else:
220
+ ops = []
221
+ if fuseable:
222
+ cutlass_template_evt = CUTLASSGemmTemplate(
223
+ input_nodes,
224
+ layout,
225
+ alpha=alpha,
226
+ beta=beta,
227
+ input_reorder=input_reorder,
228
+ can_fuse_epilogue=True,
229
+ )
230
+ # This will list only ops capable of EVT fusion
231
+ ops_evt = cutlass_template_evt.gen_ops()
232
+ for op in ops_evt:
233
+ cutlass_template_evt.maybe_append_choice(
234
+ choices,
235
+ op=op,
236
+ )
237
+ else:
238
+ ops_evt = []
239
+ log.debug(
240
+ "Added %d cutlass gemm configs and %d fuseable gemm configs.",
241
+ len(ops),
242
+ len(ops_evt),
243
+ )
244
+
245
+ def header(self) -> IndentedBuffer:
246
+ res = super().header()
247
+ res.splice(
248
+ """
249
+ #include "cutlass/gemm/gemm.h"
250
+ #include "cutlass/gemm/device/gemm_universal.h"
251
+ #include "cutlass/gemm/device/gemm_universal_adapter.h"
252
+ #include "cutlass/gemm/kernel/gemm_universal.hpp"
253
+ #include "cutlass/gemm/collective/collective_builder.hpp"
254
+ #include "cutlass/epilogue/collective/collective_builder.hpp"
255
+ #include "cutlass/epilogue/collective/default_epilogue.hpp"
256
+ #include "cutlass/epilogue/thread/linear_combination.h"
257
+ #include "cutlass/gemm/dispatch_policy.hpp"
258
+ #include "cutlass/gemm/kernel/tile_scheduler.hpp"
259
+ #include "cutlass/util/distribution.h"
260
+ #include "cutlass/util/packed_stride.hpp"
261
+ #include "cutlass/util/tensor_view_io.h"
262
+ """
263
+ )
264
+ return res
265
+
266
+ @staticmethod
267
+ def cutlass_layout(torch_layout) -> "Optional[cutlass_lib.LayoutType]": # type: ignore[name-defined] # noqa: F821
268
+ assert cutlass_utils.try_import_cutlass()
269
+ import cutlass_library.library as cutlass_lib
270
+
271
+ if torch_layout.stride[-1] == 1:
272
+ return cutlass_lib.LayoutType.RowMajor
273
+ elif torch_layout.stride[-2] == 1:
274
+ return cutlass_lib.LayoutType.ColumnMajor
275
+ else:
276
+ return None
277
+
278
+ @staticmethod
279
+ def flip_cutlass_layout(
280
+ cutlass_layout: "cutlass_lib.LayoutType", # type: ignore[name-defined] # noqa: F821
281
+ ) -> "cutlass_lib.LayoutType": # type: ignore[name-defined] # noqa: F821
282
+ assert cutlass_utils.try_import_cutlass()
283
+ import cutlass_library.library as cutlass_lib
284
+
285
+ if cutlass_layout == cutlass_lib.LayoutType.RowMajor:
286
+ return cutlass_lib.LayoutType.ColumnMajor
287
+ else:
288
+ return cutlass_lib.LayoutType.RowMajor
289
+
290
+ @staticmethod
291
+ def layout_match(torch_layout, cutlass_layout) -> bool:
292
+ return CUTLASSGemmTemplate.cutlass_layout(torch_layout) == cutlass_layout
293
+
294
+ @staticmethod
295
+ def set_alignment(torch_layout, op_element) -> bool:
296
+ alignment = cutlass_utils.get_max_alignment(torch_layout)
297
+ if alignment < op_element.alignment:
298
+ return False
299
+ else:
300
+ op_element.alignment = alignment
301
+ return True
302
+
303
+ @staticmethod
304
+ def has_tma_epilogue(op) -> bool:
305
+ assert cutlass_utils.try_import_cutlass()
306
+ import cutlass_library.library as cutlass_lib
307
+
308
+ result = False
309
+ if op.gemm_kind == cutlass_lib.GemmKind.Universal3x:
310
+ epilogue_schedule_str = str(op.epilogue_schedule).split(".")[-1]
311
+ result = epilogue_schedule_str.lower().startswith("tma")
312
+ return result
313
+
314
+ @staticmethod
315
+ def supports_evt(op: "cutlass_library.gemm_op.GemmOperation") -> bool: # type: ignore[name-defined] # noqa: F821
316
+ """
317
+ returns True if the op is capable of flexible epilogue fusions
318
+ using epilogue visitor trees.
319
+
320
+ See https://github.com/NVIDIA/cutlass/blob/e01b9b5029b7caca5a43c29f7d2714d7cf1dcae8/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu#L283-L285 # noqa: B950
321
+ """
322
+ assert cutlass_utils.try_import_cutlass()
323
+ import cutlass_library.library as cutlass_lib
324
+
325
+ if op.gemm_kind != cutlass_lib.GemmKind.Universal3x:
326
+ return False
327
+ if op.epilogue_schedule not in (
328
+ cutlass_lib.EpilogueScheduleType.TmaWarpSpecialized,
329
+ cutlass_lib.EpilogueScheduleType.TmaWarpSpecializedCooperative,
330
+ ):
331
+ return False
332
+
333
+ return True
334
+
335
+ def render_evt_epilogue_declaration(
336
+ self,
337
+ template_output_node_name: str,
338
+ evt_type_name: str,
339
+ epilogue_nodes: List[IRNode],
340
+ ) -> str:
341
+ """Generates the epilogue for the EVT epilogue fusion"""
342
+ return CutlassEVTEpilogueTypeFormatter.ir_to_evt_string(
343
+ template_output_node_name, evt_type_name, epilogue_nodes
344
+ )
345
+
346
+ def define_gemm_instance(
347
+ self,
348
+ op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
349
+ output_buffer_name: str,
350
+ epilogue_nodes: Optional[List[IRNode]] = None,
351
+ ) -> Tuple[str, str]:
352
+ assert cutlass_utils.try_import_cutlass()
353
+ import cutlass_library.gemm_operation as cutlass_gemm_op
354
+ import cutlass_library.library as cutlass_lib
355
+
356
+ from torch._inductor.codegen.cuda.cutlass_lib_extensions.gemm_operation_extensions import (
357
+ EmitGemmUniversal3xInstanceWithEVT,
358
+ )
359
+
360
+ if op.gemm_kind == cutlass_lib.GemmKind.Universal3x:
361
+ if epilogue_nodes is not None and len(epilogue_nodes) > 0:
362
+ emitter = EmitGemmUniversal3xInstanceWithEVT()
363
+ op.epilogue_functor = lambda epilogue_functor_type_name: self.render_evt_epilogue_declaration(
364
+ output_buffer_name, epilogue_functor_type_name, epilogue_nodes
365
+ )
366
+ else:
367
+ emitter = cutlass_gemm_op.EmitGemmUniversal3xInstance()
368
+ op_def = emitter.emit(op)
369
+ pattern = re.compile(r"\s*struct\s(.*?)\s:")
370
+ decl = [line for line in op_def.split("\n") if "struct " in line][-1]
371
+ else:
372
+ if epilogue_nodes is not None and len(epilogue_nodes) > 0:
373
+ raise RuntimeError(
374
+ "EVT epilogue fusion is not supported for Cutlass 2.x ops."
375
+ )
376
+ emitter = cutlass_gemm_op.EmitGemmInstance()
377
+ op_def = emitter.emit(op)
378
+ op_def = op_def.replace(
379
+ "cutlass::gemm::device::Gemm", "cutlass::gemm::device::GemmUniversal"
380
+ )
381
+ op_def = op_def.replace("false,", "")
382
+ pattern = re.compile(r"\s*using\s(.*?)\s=")
383
+ decl = op_def.split("\n")[2]
384
+ match = pattern.match(decl)
385
+ if match is None:
386
+ raise RuntimeError("Invalid Gemm config: \n" + op_def)
387
+ op_type = match.groups()[0]
388
+ if op.gemm_kind == cutlass_lib.GemmKind.Universal3x:
389
+ op_def += f"\n using {op_type}_device_type = cutlass::gemm::device::GemmUniversalAdapter<{op_type}>;\n"
390
+ op_type = f"{op_type}_device_type"
391
+ return op_def, op_type
392
+
393
+ @staticmethod
394
+ def should_swap_XW(
395
+ bias: IRNode,
396
+ beta: float,
397
+ ) -> bool:
398
+ return True
399
+
400
+ # TODO(ipiszy): Check whether it's necessary to swap X/W.
401
+ # strides = bias.get_stride()
402
+ # if strides[-1] != 1:
403
+ # return True
404
+ # for stride in strides[:-1]:
405
+ # if stride != 0:
406
+ # return True
407
+ # return False
408
+
409
+ @staticmethod
410
+ def swap_XW(
411
+ op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
412
+ ) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined] # noqa: F821
413
+ # Swap X and W in GemmOperation.
414
+ new_op = copy.deepcopy(op)
415
+ new_op.A.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.A.layout)
416
+ new_op.B.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.B.layout)
417
+ new_op.A, new_op.B = new_op.B, new_op.A
418
+ new_op.C.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.C.layout)
419
+ new_op.D.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.D.layout)
420
+ return new_op
421
+
422
+ def filter_op(
423
+ self,
424
+ op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
425
+ ) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined] # noqa: F821
426
+ assert cutlass_utils.try_import_cutlass()
427
+ import cutlass_library.library as cutlass_lib
428
+
429
+ # Skip simt kernels
430
+ if (
431
+ op.tile_description.math_instruction.opcode_class
432
+ == cutlass_lib.OpcodeClass.Simt
433
+ ):
434
+ return None
435
+
436
+ # Only keep GemmUniversal kernels
437
+ if op.gemm_kind not in {
438
+ cutlass_lib.GemmKind.Universal,
439
+ cutlass_lib.GemmKind.Universal3x,
440
+ }:
441
+ return None
442
+ # Filter ops by dtypes.
443
+ X = self.input_nodes[0]
444
+ W = self.input_nodes[1]
445
+ accumulator_torch_dtype = cutlass_utils.get_accumulator_dtype(
446
+ [X.get_dtype(), W.get_dtype()],
447
+ )
448
+ if not (
449
+ cutlass_utils.dtype_match(X.get_dtype(), op.A.element)
450
+ and cutlass_utils.dtype_match(W.get_dtype(), op.B.element)
451
+ and cutlass_utils.dtype_match(
452
+ self.output_node.get_layout().dtype, op.C.element
453
+ )
454
+ and cutlass_utils.dtype_match(
455
+ accumulator_torch_dtype, op.accumulator_type()
456
+ )
457
+ ):
458
+ return None
459
+
460
+ # Filter ops by input layouts.
461
+ if not (
462
+ self.layout_match(X.get_layout(), op.A.layout)
463
+ and self.layout_match(W.get_layout(), op.B.layout)
464
+ ):
465
+ return None
466
+
467
+ # Update op.
468
+ op = copy.deepcopy(op)
469
+
470
+ # Set output layout.
471
+ op.D.layout = CUTLASSGemmTemplate.cutlass_layout(self.output_node.get_layout())
472
+
473
+ # Filter ops by alignments and set alignments.
474
+ if not (
475
+ self.set_alignment(X.get_layout(), op.A)
476
+ and self.set_alignment(W.get_layout(), op.B)
477
+ and self.set_alignment(self.output_node.get_layout(), op.D)
478
+ ):
479
+ return None
480
+
481
+ # Set epilogue.
482
+ # TODO: update epilogue functor according to epilogues.
483
+ op.element_epilogue = op.accumulator_type()
484
+
485
+ # Set bias layout and alignment.
486
+ if len(self.input_nodes) >= 3 and self.input_nodes[2] is not None:
487
+ Bias = self.input_nodes[2]
488
+ bias_layout = CUTLASSGemmTemplate.cutlass_layout(Bias.get_layout())
489
+ if op.gemm_kind != cutlass_lib.GemmKind.Universal3x:
490
+ if bias_layout != op.D.layout:
491
+ # For cutlass2, bias and output layout must match
492
+ return None
493
+ else:
494
+ op.C.layout = bias_layout
495
+ if not self.set_alignment(Bias.get_layout(), op.C):
496
+ return None
497
+ else:
498
+ if op.gemm_kind == cutlass_lib.GemmKind.Universal3x:
499
+ op.C.element = cutlass_lib.DataType.void
500
+ else:
501
+ op.C.layout = op.D.layout
502
+ supports_evt: bool = self.supports_evt(op)
503
+ if (self.can_fuse_epilogue is not None) and (
504
+ self.can_fuse_epilogue != supports_evt
505
+ ):
506
+ return None
507
+ if inductor_cuda_config.cutlass_only_evt_capable_ops and not supports_evt:
508
+ return None
509
+ return op
510
+
511
+ def gen_ops(self) -> "List[cutlass_gemm_op.GemmOperation]": # type: ignore[name-defined] # noqa: F821
512
+ assert cutlass_utils.try_import_cutlass()
513
+ import cutlass_library.gemm_operation as cutlass_gemm_op
514
+ import cutlass_library.library as cutlass_lib
515
+
516
+ ops = cutlass_utils.gen_ops()[cutlass_lib.OperationKind.Gemm]
517
+ res: Dict[str, cutlass_gemm_op.GemmOperation] = dict()
518
+ num_3x_ops = 0
519
+ num_2x_ops = 0
520
+ for op_dict in ops.values():
521
+ for op_list in op_dict.values():
522
+ for op in op_list:
523
+ assert isinstance(op, cutlass_gemm_op.GemmOperation)
524
+ filter_res = self.filter_op(op)
525
+ if (
526
+ filter_res is not None
527
+ and res.get(filter_res.configuration_name(), None) is None
528
+ ):
529
+ res[filter_res.configuration_name()] = filter_res
530
+ for op in res.values():
531
+ if op.gemm_kind == cutlass_lib.GemmKind.Universal3x:
532
+ num_3x_ops += 1
533
+ else:
534
+ num_2x_ops += 1
535
+ log.debug(
536
+ "Got cutlass configs: total number of ops: %d, "
537
+ "total number of 3x ops: %d, total number of 2x ops: %d",
538
+ len(res),
539
+ num_3x_ops,
540
+ num_2x_ops,
541
+ )
542
+ return list(res.values())[: inductor_cuda_config.cutlass_max_profiling_configs]
543
+
544
+ def gemm_mode(self) -> str:
545
+ sizes = self.output_node.get_size()
546
+ if len(sizes) > 2:
547
+ return "cutlass::gemm::GemmUniversalMode::kBatched"
548
+ else:
549
+ return "cutlass::gemm::GemmUniversalMode::kGemm"
550
+
551
+ def render_gemm_arguments(
552
+ self,
553
+ argument_template: str,
554
+ epilogue_template: str,
555
+ should_swap_xw: bool,
556
+ X: IRNode,
557
+ W: IRNode,
558
+ Bias: IRNode,
559
+ Y: IRNode,
560
+ alpha: float,
561
+ beta: float,
562
+ kernel: CUDATemplateKernel,
563
+ epilogue_args,
564
+ ) -> str:
565
+ options = dict(
566
+ alpha=self.alpha,
567
+ beta=self.beta,
568
+ X=X,
569
+ W=W,
570
+ Y=Y,
571
+ Bias=Bias,
572
+ template=self,
573
+ kernel=kernel,
574
+ M="M",
575
+ N="N",
576
+ epilogue_args=epilogue_args,
577
+ )
578
+
579
+ if epilogue_template is not None:
580
+ if should_swap_xw:
581
+ # Swap
582
+ def clone_with_transposed_stride(node: IRNode) -> IRNode:
583
+ old_layout = node.get_layout()
584
+ new_stride = list(old_layout.stride)
585
+ new_stride[-2], new_stride[-1] = new_stride[-1], new_stride[-2]
586
+ new_layout = FixedLayout(
587
+ old_layout.device,
588
+ old_layout.dtype,
589
+ list(old_layout.size),
590
+ new_stride,
591
+ old_layout.offset,
592
+ )
593
+ return Buffer(node.get_name(), new_layout)
594
+
595
+ new_X = clone_with_transposed_stride(X)
596
+ new_W = clone_with_transposed_stride(W)
597
+ new_Bias = clone_with_transposed_stride(Bias)
598
+ new_Y = clone_with_transposed_stride(Y)
599
+ options["X"], options["W"], options["Bias"], options["Y"] = (
600
+ new_W,
601
+ new_X,
602
+ new_Bias,
603
+ new_Y,
604
+ )
605
+ options["M"], options["N"] = "N", "M"
606
+
607
+ epilogue_arguments = self._template_from_string(epilogue_template).render(
608
+ **options
609
+ )
610
+ arguments = self._template_from_string(argument_template).render(
611
+ epilogue_arguments=epilogue_arguments, **options
612
+ )
613
+ else:
614
+ arguments = self._template_from_string(GEMM_ARGS_CUTLASS_2X).render(
615
+ split_k=1, **options
616
+ )
617
+ return arguments
618
+
619
+ def render( # type: ignore[override]
620
+ self,
621
+ kernel: CUDATemplateKernel,
622
+ op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821
623
+ template_buffer_node: Optional[CUDATemplateBuffer] = None,
624
+ epilogue_nodes: Optional[List[IRNode]] = None,
625
+ **kwargs,
626
+ ) -> str:
627
+ if epilogue_nodes is not None and len(epilogue_nodes) > 0:
628
+ assert self.can_fuse_epilogue and CUTLASSGemmTemplate.supports_evt(
629
+ op
630
+ ), "op does not support EVT epilogue fusion"
631
+ assert (
632
+ template_buffer_node is not None
633
+ ), "Template node is required for epilogue fusion"
634
+ assert isinstance(
635
+ template_buffer_node, CUDATemplateBuffer
636
+ ), f"Template node has to be a CUDATemplateBuffer, is type {type(template_buffer_node)}"
637
+ assert (
638
+ template_buffer_node.name is not None
639
+ ), "Output node has to be a Buffer with a name"
640
+ # This is the name of the output of the Matmul, before epilogues are applied.
641
+ # it is not necessarily materialized in global memory if we have an epilogue
642
+
643
+ template_output_node_name = (
644
+ template_buffer_node.name if template_buffer_node is not None else None
645
+ )
646
+
647
+ assert cutlass_utils.try_import_cutlass()
648
+ import cutlass_library.gemm_operation as cutlass_gemm_op
649
+ import cutlass_library.library as cutlass_lib
650
+
651
+ assert isinstance(
652
+ op, cutlass_gemm_op.GemmOperation
653
+ ), "op argument is required and has to be an instance of GemmOperation"
654
+ if template_buffer_node is not None:
655
+ self.output_node = template_buffer_node
656
+ if epilogue_nodes is not None and len(epilogue_nodes) > 0:
657
+ self.output_node = cast(Buffer, epilogue_nodes[-1])
658
+
659
+ assert len(self.input_nodes) >= 2 and self.output_node is not None
660
+ X, W = self.input_nodes[0], self.input_nodes[1]
661
+ Y = self.output_node
662
+ Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2]
663
+
664
+ epilogue_template: Optional[str] = None
665
+ should_swap_xw: bool = False
666
+ epilogue_args = f"{{ElementComputeEpilogue({self.alpha}), ElementComputeEpilogue({self.beta})}}"
667
+ if op.gemm_kind == cutlass_lib.GemmKind.Universal3x:
668
+ if Bias is not None and self.has_tma_epilogue(op):
669
+ if self.should_swap_XW(Bias, self.beta):
670
+ # TMA epilogue requires bias vector in column major to get best perf.
671
+ op = self.swap_XW(op)
672
+ should_swap_xw = True
673
+ if epilogue_nodes is not None and len(epilogue_nodes) > 0:
674
+ epilogue_args = (
675
+ CutlassEVTEpilogueArgumentFormatter.ir_to_evt_argument_string(
676
+ cast(str, template_output_node_name), epilogue_nodes
677
+ )
678
+ )
679
+ epilogue_template = GEMM_ARGS_CUTLASS_3X_EPILOGUE
680
+ argument_template = GEMM_ARGS_CUTLASS_3X
681
+ else:
682
+ # TODO: Support split_k.
683
+ argument_template = GEMM_ARGS_CUTLASS_2X
684
+
685
+ instance_definition, instance_type = self.define_gemm_instance(
686
+ op, cast(str, template_output_node_name), epilogue_nodes
687
+ )
688
+ options = dict(
689
+ alpha=self.alpha,
690
+ beta=self.beta,
691
+ X=X,
692
+ W=W,
693
+ Y=Y,
694
+ Bias=Bias,
695
+ epilogue_template=epilogue_template,
696
+ argument_template=argument_template,
697
+ should_swap_xw=should_swap_xw,
698
+ template=self,
699
+ kernel=kernel,
700
+ instance_definition=instance_definition,
701
+ instance_type=instance_type,
702
+ input_reorder=self.input_reorder,
703
+ epilogue_args=epilogue_args,
704
+ )
705
+ res = self._template_from_string(GEMM_TEMPLATE).render(**options)
706
+ return res
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/memory_planning.py ADDED
@@ -0,0 +1,799 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import collections
4
+ import dataclasses
5
+ import itertools
6
+ import pprint
7
+ from typing import Any, Dict, Iterable, List, Optional, Protocol
8
+
9
+ import sympy
10
+
11
+ import torch
12
+ from .. import config, ir
13
+ from ..utils import cache_on_self, CachedMethod, IndentedBuffer
14
+ from ..virtualized import V
15
+
16
+ from .wrapper import (
17
+ AllocateLine,
18
+ FreeIfNotReusedLine,
19
+ MemoryPlanningLine,
20
+ NullLine,
21
+ ReuseLine,
22
+ )
23
+
24
+
25
+ ALIGN_BYTES = 64
26
+ assert (ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0 and ALIGN_BYTES >= 8, "must be power of 2"
27
+
28
+
29
+ def _align(nbytes):
30
+ """Round up to the nearest multiple of ALIGN_BYTES"""
31
+ return (nbytes + ALIGN_BYTES - 1) & -ALIGN_BYTES
32
+
33
+
34
+ def _is_aligned(v: sympy.Expr):
35
+ """v can be statically proven to be a multiple of ALIGN_BYTES"""
36
+ if isinstance(v, (sympy.Add, sympy.Max)):
37
+ return all(map(_is_aligned, v.args))
38
+ return isinstance(v, align) or sympy.gcd(v, ALIGN_BYTES) == ALIGN_BYTES
39
+
40
+
41
+ class align(sympy.Function):
42
+ """Symbolically round up to the nearest multiple of ALIGN_BYTES"""
43
+
44
+ nargs = (1,)
45
+ is_integer = True
46
+
47
+ @classmethod
48
+ def eval(cls, value):
49
+ if isinstance(value, (int, sympy.Integer)):
50
+ return _align(int(value))
51
+ if _is_aligned(value):
52
+ return value
53
+
54
+
55
+ @dataclasses.dataclass
56
+ class LiveRange:
57
+ """
58
+ A range where a given tensor is live. Begin and end are both counters
59
+ representing points in the program of grouped memory operations.
60
+ Begin is inclusive, end is exclusive.
61
+
62
+ Invariant: begin <= end
63
+ """
64
+
65
+ begin: float # int | ±inf
66
+ end: float # int | ±inf
67
+
68
+ def contains(self, other: LiveRange):
69
+ """Is other entirely within self"""
70
+ return self.begin <= other.begin and other.end <= self.end
71
+
72
+ def join(self, other: LiveRange):
73
+ """Combine two ranges using a union operation"""
74
+ return LiveRange(min(self.begin, other.begin), max(self.end, other.end))
75
+
76
+ def __len__(self):
77
+ return self.end - self.begin
78
+
79
+
80
+ class LiveRanges:
81
+ """
82
+ A collection of LiveRange regions, allowing for non-contiguous
83
+ live regions.
84
+
85
+ Invariant: LiveRanges.ranges is in sorted order and non-overlapping
86
+ """
87
+
88
+ def __init__(self, ranges: Iterable[LiveRange]):
89
+ ranges = [*sorted(ranges, key=lambda x: x.begin)]
90
+ self.ranges = ranges[:1]
91
+ for r in ranges[1:]:
92
+ assert self.ranges[-1].begin <= r.begin
93
+ if self.ranges[-1].end >= r.begin:
94
+ self.ranges[-1] = LiveRange.join(self.ranges[-1], r)
95
+ else:
96
+ self.ranges.append(r)
97
+
98
+ def overlaps(self, other: LiveRanges):
99
+ """Check if any pair of ranges in self and other overlap"""
100
+ left = collections.deque(self.ranges)
101
+ right = collections.deque(other.ranges)
102
+ while left and right:
103
+ if left[0].begin > right[0].begin:
104
+ left, right = right, left
105
+ assert left[0].begin <= right[0].begin
106
+ if left[0].end > right[0].begin:
107
+ return True
108
+ left.popleft()
109
+ return False
110
+
111
+ @property
112
+ def begin(self):
113
+ return self.ranges[0].begin
114
+
115
+ @property
116
+ def end(self):
117
+ return self.ranges[-1].end
118
+
119
+ def __repr__(self):
120
+ return f"{self.__class__.__name__}([{', '.join(map(repr, self.ranges))}])"
121
+
122
+
123
+ class AllocationTreeNode:
124
+ """
125
+ Abstract base class for nodes in allocation pool.
126
+ """
127
+
128
+ def allocate(self, block: Allocation, is_last: bool) -> bool:
129
+ """
130
+ Try to assign block to a memory location in this bool. Return True if
131
+ an assignment was made.
132
+ """
133
+ return False
134
+
135
+ def get_live_ranges(self) -> LiveRanges:
136
+ """Aggregate LiveRanges for all objects below this in tree"""
137
+ raise NotImplementedError()
138
+
139
+ def get_size_hint(self) -> int:
140
+ """Number of bytes used for example inputs"""
141
+ raise NotImplementedError()
142
+
143
+ def get_symbolic_size(self) -> sympy.Expr:
144
+ """Number of bytes needed at runtime"""
145
+ raise NotImplementedError()
146
+
147
+ def finalize(self, pool, offset) -> AllocationTreeNode:
148
+ """Called after all allocations have been made"""
149
+ return self
150
+
151
+ def is_empty(self):
152
+ return False
153
+
154
+
155
+ @dataclasses.dataclass
156
+ class Allocation(AllocationTreeNode):
157
+ """
158
+ Represents memory allocated to a given node in the allocation pool.
159
+ """
160
+
161
+ node: ir.Buffer
162
+ live_range: LiveRange
163
+ size_hint: int
164
+ symbolic_size: sympy.Expr
165
+ allocated: bool = False
166
+ pool: Optional[AllocationPool] = None
167
+ offset: Optional[sympy.Expr] = None
168
+
169
+ @property
170
+ def device(self):
171
+ return self.node.get_device()
172
+
173
+ def get_live_ranges(self):
174
+ return LiveRanges([self.live_range])
175
+
176
+ def get_size_hint(self):
177
+ return self.size_hint
178
+
179
+ def get_symbolic_size(self):
180
+ return self.symbolic_size
181
+
182
+ def mark_allocated(self):
183
+ assert not self.allocated
184
+ self.allocated = True
185
+
186
+ def finalize(self, pool, offset):
187
+ assert self.pool is None and self.offset is None
188
+ self.pool = pool
189
+ self.offset = offset
190
+ return self
191
+
192
+ def codegen_alloc_from_pool(self, wrapper):
193
+ assert self.pool
194
+ node = self.node
195
+ shape = tuple(node.get_size())
196
+ stride = tuple(node.get_stride())
197
+ return wrapper.codegen_alloc_from_pool(
198
+ self.pool.name, self.offset, node.get_dtype(), shape, stride
199
+ )
200
+
201
+ def __repr__(self):
202
+ return (
203
+ f"{self.__class__.__name__}("
204
+ f"node={self.node.get_name()}, "
205
+ f"live_range={self.live_range}, "
206
+ f"size_hint={self.size_hint}, "
207
+ f"symbolic_size={self.symbolic_size}, "
208
+ f"pool={self.pool.name if self.pool else None}, "
209
+ f"offset={self.offset})"
210
+ )
211
+
212
+
213
+ @dataclasses.dataclass
214
+ class Empty(AllocationTreeNode):
215
+ """
216
+ Placeholder to represent empty space in the allocation pool.
217
+ Only exists to get the size_hint correct in parent nodes.
218
+ """
219
+
220
+ size_hint: int
221
+
222
+ def get_live_ranges(self):
223
+ return LiveRanges([])
224
+
225
+ def get_size_hint(self):
226
+ return self.size_hint
227
+
228
+ def get_symbolic_size(self):
229
+ return 0
230
+
231
+ def is_empty(self):
232
+ return True
233
+
234
+
235
+ class MemorySplitProtocol(Protocol):
236
+ get_live_ranges: CachedMethod[[], LiveRanges]
237
+ get_size_hint: CachedMethod[[], int]
238
+ get_symbolic_size: CachedMethod[[], sympy.Expr]
239
+
240
+ def _allocate(self, block: Allocation, is_last: bool) -> bool:
241
+ ...
242
+
243
+
244
+ class ClearCacheOnAllocateMixin(MemorySplitProtocol):
245
+ """
246
+ Helper to assist in caching get_live_ranges, get_size_hint, and
247
+ get_symbolic_size.
248
+ """
249
+
250
+ def allocate(self, block: Allocation, is_last: bool):
251
+ is_allocated = self._allocate(block, is_last)
252
+ if is_allocated:
253
+ self.clear_cache()
254
+ return is_allocated
255
+
256
+ def clear_cache(self):
257
+ self.get_live_ranges.clear_cache(self)
258
+ self.get_size_hint.clear_cache(self)
259
+ self.get_symbolic_size.clear_cache(self)
260
+
261
+
262
+ @dataclasses.dataclass
263
+ class TemporalSplit(ClearCacheOnAllocateMixin, AllocationTreeNode):
264
+ """
265
+ Contains a list of allocations not overlapping in LiveRanges.
266
+
267
+ Invariant: no pair (a,b) in self.allocations will have:
268
+ a.get_live_ranges().overlaps(b.get_live_ranges())
269
+ """
270
+
271
+ allocations: List[AllocationTreeNode]
272
+
273
+ def _allocate(self, block: Allocation, is_last: bool):
274
+ slot_size = self.get_size_hint()
275
+ block_size = block.get_size_hint()
276
+ if not is_last and block_size > slot_size:
277
+ return False # doesn't fit
278
+
279
+ block_live = block.get_live_ranges()
280
+ overlapping = [
281
+ s for s in self.allocations if s.get_live_ranges().overlaps(block_live)
282
+ ]
283
+ if len(overlapping) > 1:
284
+ # TODO(jansel): we could try harder here by merging overlapping in space
285
+ return False
286
+ elif len(overlapping) == 1:
287
+ return overlapping[0].allocate(block, is_last)
288
+ else:
289
+ block.mark_allocated()
290
+
291
+ if len(self.allocations) == 1 and isinstance(self.allocations[-1], Empty):
292
+ self.allocations.pop()
293
+
294
+ if slot_size == block_size:
295
+ # perfect fit
296
+ self.allocations.append(block)
297
+ elif slot_size > block_size:
298
+ self.allocations.append(
299
+ SpatialSplit.create(block, slot_size - block_size)
300
+ )
301
+ else: # grow this allocation
302
+ assert is_last
303
+ self.allocations = [
304
+ *(
305
+ SpatialSplit.create(a, block_size - slot_size)
306
+ for a in self.allocations
307
+ ),
308
+ block,
309
+ ]
310
+ return True
311
+
312
+ @cache_on_self
313
+ def get_live_ranges(self) -> LiveRanges:
314
+ return LiveRanges(
315
+ itertools.chain.from_iterable(
316
+ x.get_live_ranges().ranges for x in self.allocations
317
+ )
318
+ )
319
+
320
+ @cache_on_self
321
+ def get_size_hint(self) -> int:
322
+ if not self.allocations:
323
+ return 0
324
+ return max(x.get_size_hint() for x in self.allocations)
325
+
326
+ @cache_on_self
327
+ def get_symbolic_size(self) -> sympy.Expr:
328
+ if not self.allocations:
329
+ return 0 # type: ignore[return-value]
330
+ return sympy.Max(*[x.get_symbolic_size() for x in self.allocations])
331
+
332
+ def is_empty(self):
333
+ return len(self.allocations) == 1 and self.allocations[0].is_empty()
334
+
335
+ def finalize(self, pool, offset):
336
+ self.allocations = [block.finalize(pool, offset) for block in self.allocations]
337
+ self.clear_cache()
338
+ if len(self.allocations) == 1:
339
+ return self.allocations[0]
340
+ return self
341
+
342
+
343
+ @dataclasses.dataclass
344
+ class SpatialSplit(ClearCacheOnAllocateMixin, AllocationTreeNode):
345
+ """
346
+ Contains two allocations, left and right, that do not overlap in space.
347
+ Right will be allocated immediately after left in memory.
348
+ """
349
+
350
+ left: TemporalSplit
351
+ right: TemporalSplit
352
+
353
+ @staticmethod
354
+ def create(left, extra_space):
355
+ assert isinstance(left, AllocationTreeNode)
356
+ assert isinstance(extra_space, int) and extra_space >= 1
357
+ return SpatialSplit(TemporalSplit([left]), TemporalSplit([Empty(extra_space)]))
358
+
359
+ def _allocate(self, block: Allocation, is_last: bool):
360
+ return self.left.allocate(block, False) or self.right.allocate(block, is_last)
361
+
362
+ @cache_on_self
363
+ def get_live_ranges(self):
364
+ return LiveRanges(
365
+ itertools.chain(
366
+ self.left.get_live_ranges().ranges, self.right.get_live_ranges().ranges
367
+ )
368
+ )
369
+
370
+ @cache_on_self
371
+ def get_size_hint(self) -> int:
372
+ return _align(self.left.get_size_hint()) + self.right.get_size_hint()
373
+
374
+ @cache_on_self
375
+ def get_symbolic_size(self) -> sympy.Expr:
376
+ return align(self.left.get_symbolic_size()) + self.right.get_symbolic_size()
377
+
378
+ def finalize(self, pool, offset):
379
+ self.left = self.left.finalize(pool, offset)
380
+ self.right = self.right.finalize(
381
+ pool, offset + align(self.left.get_symbolic_size())
382
+ )
383
+ self.clear_cache()
384
+ if self.right.is_empty():
385
+ return self.left
386
+ return self
387
+
388
+
389
+ @dataclasses.dataclass
390
+ class AllocationPool:
391
+ """
392
+ Represents a pool of allocations that will be generated by a single
393
+ call to torch.empty.
394
+ """
395
+
396
+ device: torch.device
397
+ root: TemporalSplit
398
+ can_expand: bool = True
399
+ restrict_live_range: Optional[LiveRange] = None
400
+ name: Optional[str] = None
401
+ names_to_del: List[str] = dataclasses.field(default_factory=list)
402
+ creation_cache: Dict[str, str] = dataclasses.field(default_factory=dict)
403
+
404
+ def allocate(self, block: Allocation, is_last: bool):
405
+ if self.restrict_live_range and not self.restrict_live_range.contains(
406
+ block.live_range
407
+ ):
408
+ return False
409
+
410
+ is_last = self.can_expand and is_last
411
+ if self.root.allocate(block, is_last):
412
+ return True
413
+
414
+ if is_last:
415
+ return self.allocate_at_end(block)
416
+
417
+ return False
418
+
419
+ def allocate_at_end(self, block):
420
+ block.mark_allocated()
421
+ self.root = TemporalSplit([SpatialSplit(self.root, TemporalSplit([block]))])
422
+ return True
423
+
424
+ def finalize(self, name):
425
+ assert not self.name
426
+ self.name = name
427
+ self.names_to_del.append(name)
428
+ self.root.finalize(self, 0)
429
+
430
+ def codegen_create(self, wrapper, code: IndentedBuffer):
431
+ assert self.name
432
+ nbytes = self.root.get_symbolic_size()
433
+ for block in self.root.allocations:
434
+ if isinstance(block, Allocation) and nbytes == block.get_symbolic_size():
435
+ # optimization: fuse first allocation and pool creation
436
+ node = block.node
437
+ code.writeline(
438
+ wrapper.make_allocation(
439
+ self.name,
440
+ device=self.device,
441
+ dtype=node.get_dtype(),
442
+ shape=tuple(node.get_size()),
443
+ stride=tuple(node.get_stride()),
444
+ )
445
+ )
446
+ self.creation_cache[block.codegen_alloc_from_pool(wrapper)] = self.name
447
+ return
448
+ else:
449
+ code.writeline(
450
+ wrapper.make_allocation(
451
+ self.name,
452
+ device=self.device,
453
+ dtype=torch.uint8,
454
+ shape=(nbytes,),
455
+ stride=(1,),
456
+ )
457
+ )
458
+
459
+ def codegen_destroy(self, wrapper, code: IndentedBuffer):
460
+ code.writeline(wrapper.make_free_by_names(self.names_to_del))
461
+
462
+ def __eq__(self, other):
463
+ return self is other
464
+
465
+ def __hash__(self):
466
+ return id(self)
467
+
468
+
469
+ @dataclasses.dataclass
470
+ class AllocationPools:
471
+ """
472
+ Collection of many AllocationPool objects grouped by device.
473
+ """
474
+
475
+ device_to_pools: Dict[torch.device, List[AllocationPool]] = dataclasses.field(
476
+ default_factory=dict
477
+ )
478
+
479
+ def get_pools(self, block):
480
+ if block.device not in self.device_to_pools:
481
+ self.device_to_pools[block.device] = []
482
+ return self.device_to_pools[block.device]
483
+
484
+ def allocate(self, block: Allocation):
485
+ pools = self.get_pools(block)
486
+
487
+ for pool in pools:
488
+ if pool.allocate(block, is_last=pool is pools[-1]):
489
+ return
490
+
491
+ # everything is full, make a new pool
492
+ pools.append(
493
+ AllocationPool(
494
+ block.device,
495
+ TemporalSplit([block]),
496
+ can_expand=config.memory_pool != "none",
497
+ )
498
+ )
499
+ block.mark_allocated()
500
+
501
+ def allocate_output(self, block: Allocation):
502
+ """Outputs get different pools so memory gets freed properly"""
503
+ pools = self.get_pools(block)
504
+ if pools and config.memory_pool in ("outputs", "combined"):
505
+ pools[-1].allocate_at_end(block)
506
+ else:
507
+ # create a new pool
508
+ block.mark_allocated()
509
+ pools.append(
510
+ AllocationPool(
511
+ block.device,
512
+ TemporalSplit([block]),
513
+ can_expand=config.memory_pool == "combined",
514
+ )
515
+ )
516
+
517
+ def finalize(self):
518
+ """Called at the end of allocation process"""
519
+ for i, pool in enumerate(
520
+ itertools.chain.from_iterable(self.device_to_pools.values())
521
+ ):
522
+ pool.finalize(f"pool{i}")
523
+
524
+ def pprint(self):
525
+ for pool in itertools.chain.from_iterable(self.device_to_pools.values()):
526
+ print()
527
+ print(pool.name)
528
+ print(pool.root.get_live_ranges())
529
+ pprint.pprint(pool.root)
530
+
531
+
532
+ class BufferGroup:
533
+ """
534
+ Due to inplace reuse an allocated buffer can have many names.
535
+ This tracks these collections of buffers sharing underlying memory.
536
+ """
537
+
538
+ def __init__(self, node: ir.Buffer):
539
+ self.node = node
540
+ self.names = [node.get_name()]
541
+ self.is_output = False
542
+ self.allocation: Optional[Allocation] = None
543
+ self.live_range = LiveRange(float("inf"), -float("inf"))
544
+
545
+ def update_usage(self, timestep: int):
546
+ """Expand self.live_range to include timestep"""
547
+ self.live_range = LiveRange(
548
+ min(timestep, self.live_range.begin),
549
+ max(timestep, self.live_range.end),
550
+ )
551
+
552
+ def sym_nbytes(self):
553
+ return self.node.get_layout().storage_size() * self.node.get_dtype().itemsize
554
+
555
+ def make_allocation(self):
556
+ assert not self.allocation, "multiple allocations"
557
+ assert isinstance(self.live_range.begin, int), "live ranges not computed"
558
+ nbytes = self.sym_nbytes()
559
+ # For now, fallback value will be used if we encounter an unbacked SymInt. The longer-term plan is to have
560
+ # size_hint() use better heuristics for unbackeds, at which point the fallback value will be ignored.
561
+ size_hint = V.graph.sizevars.size_hint(nbytes, fallback=64)
562
+ self.allocation = Allocation(
563
+ self.node,
564
+ self.live_range,
565
+ size_hint=size_hint,
566
+ symbolic_size=nbytes,
567
+ )
568
+
569
+ def __repr__(self):
570
+ return (
571
+ f"{self.__class__.__name__}({self.names!r}, is_output={self.is_output}, "
572
+ f"live_range={self.live_range}"
573
+ )
574
+
575
+
576
+ @dataclasses.dataclass
577
+ class PoolMemoryPlanningLine(MemoryPlanningLine):
578
+ """Abstract base class for {Alloc,Dealloc}FromPoolLine"""
579
+
580
+ group: BufferGroup
581
+ timestep: Optional[int] = None
582
+
583
+ @property
584
+ def node(self):
585
+ return self.group.node
586
+
587
+
588
+ @dataclasses.dataclass
589
+ class AllocFromPoolLine(PoolMemoryPlanningLine):
590
+ """Similar to AllocationLine, but takes memory from a pool"""
591
+
592
+ is_first_pool_usage: bool = False
593
+
594
+ def codegen(self, code: IndentedBuffer):
595
+ allocation = self.group.allocation
596
+ assert allocation and allocation.pool
597
+ pool = allocation.pool
598
+ name = self.node.get_name()
599
+
600
+ if self.is_first_pool_usage:
601
+ pool.codegen_create(self.wrapper, code)
602
+
603
+ pool.names_to_del.extend(self.group.names)
604
+ alloc_from_pool = allocation.codegen_alloc_from_pool(self.wrapper)
605
+ if alloc_from_pool in pool.creation_cache:
606
+ code.writeline(
607
+ self.wrapper.make_tensor_alias(
608
+ name, pool.creation_cache[alloc_from_pool], "alloc"
609
+ )
610
+ )
611
+ else:
612
+ pool.creation_cache[alloc_from_pool] = name
613
+ code.writeline(
614
+ f"{self.wrapper.declare}{name} = {alloc_from_pool}{self.wrapper.ending}"
615
+ )
616
+
617
+
618
+ @dataclasses.dataclass
619
+ class DeallocFromPoolLine(PoolMemoryPlanningLine):
620
+ """Similar to FreeIfNotReusedLine, but takes memory from a pool"""
621
+
622
+ is_last_pool_usage: bool = False
623
+
624
+ def codegen(self, code: IndentedBuffer):
625
+ if self.is_last_pool_usage:
626
+ assert self.group.allocation and self.group.allocation.pool
627
+ self.group.allocation.pool.codegen_destroy(self.wrapper, code)
628
+
629
+
630
+ @dataclasses.dataclass
631
+ class MemoryPlanner:
632
+ """
633
+ Coordination object to run memory planning passes during wrapper
634
+ codegen.
635
+ """
636
+
637
+ wrapper: Any
638
+ pools: AllocationPools = dataclasses.field(default_factory=AllocationPools)
639
+ buffer_groups: Optional[List[BufferGroup]] = None
640
+
641
+ def plan(self, lines: List[Any]) -> List[Any]:
642
+ """Call all the memory planning passes in sequence"""
643
+ lines = [*lines]
644
+ self.drop_removed_buffers(lines)
645
+ self.convert_to_pool_lines(lines)
646
+ self.compute_live_ranges(lines)
647
+ self.allocate_groups()
648
+ self.mark_first_last_usage(lines)
649
+ return lines
650
+
651
+ def drop_removed_buffers(self, lines):
652
+ """
653
+ Replace any memory planning lines in V.graph.removed_buffers with NullLine
654
+ """
655
+ # drop any removed buffers
656
+ for i, line in enumerate(lines):
657
+ if isinstance(line, (AllocateLine, FreeIfNotReusedLine, ReuseLine)):
658
+ if line.node.get_name() in V.graph.removed_buffers:
659
+ lines[i] = NullLine(self.wrapper)
660
+
661
+ def compute_buffer_groups(self, lines):
662
+ """
663
+ Populates self.buffer_groups with BufferGroup objects that join
664
+ allocations with common storage (due to inplace reuse) into a
665
+ single object.
666
+ """
667
+ name_to_group = {}
668
+ for line in lines:
669
+ if isinstance(line, AllocateLine):
670
+ name = line.node.get_name()
671
+ assert name not in name_to_group
672
+ name_to_group[name] = BufferGroup(line.node)
673
+ elif isinstance(line, ReuseLine):
674
+ old_name = line.node.get_name()
675
+ new_name = line.reused_as.get_name()
676
+ assert new_name not in name_to_group
677
+ # TODO(jansel): we should support reusing buffers created via ExternKernelAlloc
678
+ if old_name in name_to_group:
679
+ name_to_group[old_name].names.append(new_name)
680
+ name_to_group[new_name] = name_to_group[old_name]
681
+
682
+ outputs = set(V.graph.get_output_names())
683
+ unique_groups = [*{id(g): g for g in name_to_group.values()}.values()]
684
+ for group in unique_groups:
685
+ group.is_output = any(x in outputs for x in group.names)
686
+
687
+ assert self.buffer_groups is None
688
+ self.buffer_groups = unique_groups
689
+ return name_to_group
690
+
691
+ def convert_to_pool_lines(self, lines):
692
+ """
693
+ Convert AllocateLine/FreeIfNotReusedLine/ReuseLine into their
694
+ pool-based counterparts.
695
+ """
696
+ name_to_group = self.compute_buffer_groups(lines)
697
+ for i, line in enumerate(lines):
698
+ if isinstance(line, AllocateLine):
699
+ if line.node.get_name() in name_to_group:
700
+ lines[i] = AllocFromPoolLine(
701
+ self.wrapper, name_to_group[line.node.get_name()]
702
+ )
703
+ elif isinstance(line, FreeIfNotReusedLine):
704
+ assert not line.is_reused
705
+ if line.node.get_name() in name_to_group:
706
+ lines[i] = DeallocFromPoolLine(
707
+ self.wrapper, name_to_group[line.node.get_name()]
708
+ )
709
+ elif isinstance(line, ReuseLine):
710
+ if line.node.get_name() in name_to_group:
711
+ line.delete_old = False
712
+
713
+ def compute_live_ranges(self, lines):
714
+ """Populate every BufferGroup.live_ranges field based on first/last usage"""
715
+ timestep = 0
716
+ worklist = collections.deque(lines)
717
+ while worklist:
718
+ if isinstance(worklist[0], MemoryPlanningLine):
719
+ timestep += 1
720
+ while worklist and isinstance(worklist[0], MemoryPlanningLine):
721
+ line = worklist.popleft()
722
+ if isinstance(line, PoolMemoryPlanningLine):
723
+ line.group.update_usage(timestep)
724
+ line.timestep = timestep
725
+ else:
726
+ worklist.popleft()
727
+
728
+ timestep += 1
729
+ assert self.buffer_groups is not None
730
+ for group in self.buffer_groups:
731
+ if group.is_output:
732
+ group.update_usage(timestep)
733
+
734
+ def allocate_groups(self):
735
+ """
736
+ Assign every allocation to a specific location in a specific AllocationPool.
737
+ """
738
+ assert config.memory_pool in ("none", "intermediates", "outputs", "combined")
739
+ assert self.buffer_groups is not None
740
+
741
+ for group in self.buffer_groups:
742
+ group.make_allocation()
743
+
744
+ outputs: List[Allocation] = []
745
+ intermediates: List[Allocation] = []
746
+ for group in self.buffer_groups:
747
+ assert group.allocation
748
+ if group.is_output and config.memory_pool != "combined":
749
+ outputs.append(group.allocation)
750
+ else:
751
+ intermediates.append(group.allocation)
752
+
753
+ for block in sorted(
754
+ outputs,
755
+ key=lambda x: (
756
+ x.size_hint,
757
+ -len(x.live_range),
758
+ ),
759
+ ):
760
+ self.pools.allocate_output(block)
761
+
762
+ for block in sorted(
763
+ intermediates,
764
+ key=lambda x: (
765
+ -x.size_hint,
766
+ -len(x.live_range),
767
+ ),
768
+ ):
769
+ self.pools.allocate(block)
770
+
771
+ self.pools.finalize()
772
+
773
+ def mark_first_last_usage(self, lines):
774
+ """
775
+ Populate the AllocFromPoolLine.is_first_pool_usage and
776
+ DeallocFromPoolLine.is_last_pool_usage fields so that pools
777
+ are created/destroyed.
778
+ """
779
+ seen = set()
780
+ for line in lines:
781
+ if isinstance(line, AllocFromPoolLine):
782
+ assert line.group.allocation
783
+ pool = line.group.allocation.pool
784
+ assert pool is not None
785
+ if pool not in seen:
786
+ line.is_first_pool_usage = True
787
+ seen.add(pool)
788
+
789
+ seen = set()
790
+ for line in reversed(lines):
791
+ if isinstance(line, DeallocFromPoolLine):
792
+ assert line.group.allocation
793
+ pool = line.group.allocation.pool
794
+ assert pool is not None
795
+ if pool not in seen:
796
+ line.is_last_pool_usage = (
797
+ pool.root.get_live_ranges().end <= line.timestep
798
+ )
799
+ seen.add(pool)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/multi_kernel.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import Any, List
4
+
5
+ from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
6
+
7
+ from .. import config
8
+ from ..codecache import PyCodeCache, TritonFuture
9
+ from ..utils import cache_on_self, do_bench
10
+ from ..virtualized import V
11
+ from .common import TensorArg
12
+
13
+ log = logging.getLogger(__name__)
14
+
15
+
16
+ def get_kernel_argdefs(kernel):
17
+ arg_defs, _, _ = kernel.args.python_argdefs()
18
+ return arg_defs
19
+
20
+
21
+ def _get_all_args(args_list):
22
+ all_args = max(args_list, key=len)[:]
23
+ for args in args_list:
24
+ assert set(args).issubset(set(all_args)), f"{args} v.s. {all_args}"
25
+
26
+ return all_args
27
+
28
+
29
+ def get_all_kernel_argdefs(kernels):
30
+ """
31
+ The logic here must match with `get_all_call_args`.
32
+ """
33
+ argdefs_list = [get_kernel_argdefs(kernel) for kernel in kernels]
34
+
35
+ return _get_all_args(argdefs_list)
36
+
37
+
38
+ def get_all_call_args(call_args_list):
39
+ """
40
+ Passed in the call_args for each subkernel and return the call_args for the
41
+ combined multi-kernel.
42
+
43
+ Note an algorithm as follows does not always work:
44
+ ```
45
+ all_call_args: Dict[
46
+ Any, None
47
+ ] = {} # use a dict rather than set to maintain insertion order
48
+ for call_args in call_args_list:
49
+ all_call_args.update({arg: None for arg in call_args})
50
+
51
+ all_call_args = list(all_call_args.keys())
52
+ ```
53
+ It will fail if any kernel has the same argument passed in multiple times.
54
+ Check test_pass_same_arg_multi_times in test_multi_kernel.py
55
+
56
+ Instead, we pick the longest call args and assert that otehr call args are
57
+ a subset of it.
58
+ """
59
+ return _get_all_args(call_args_list)
60
+
61
+
62
+ def get_numel_argdefs(kernel):
63
+ numel_argdefs = []
64
+ for tree in kernel.range_trees:
65
+ if tree.prefix != "r" or kernel.inside_reduction:
66
+ numel_argdefs.append(f"{tree.prefix}numel")
67
+
68
+ return numel_argdefs
69
+
70
+
71
+ class MultiKernelState:
72
+ """
73
+ Maintain state of multi-kernel compilation so we don't define duplicated
74
+ multi-kernel for the same set of sub-kernels.
75
+
76
+ V.graph.wrapper_code has a reference to MultiKernelState instance.
77
+ """
78
+
79
+ def __init__(self):
80
+ self.subkernel_to_kernel_name = {}
81
+
82
+ def define_kernel(self, kernels):
83
+ """
84
+ Previously we name the multi kernel as "multi_kernel_{kernel_names[0]}".
85
+ This has some minor issue.
86
+
87
+ E.g. for persistent reduction https://gist.github.com/shunting314/39e7c00ff8bb2055942ed5a3255d61ca ,
88
+ there are 2 flavors of non-persistent reduction:
89
+ https://gist.github.com/shunting314/056d43d35907e87efb883970b35c17d4
90
+ and
91
+ https://gist.github.com/shunting314/02ee753b65c513c54e695626afe682bd
92
+
93
+ The only different is cache eviction policy.
94
+
95
+ We should name the multi-kernel differently in these 2 cases.
96
+ """
97
+ kernel_names = tuple(k.kernel_name for k in kernels)
98
+ if kernel_names in self.subkernel_to_kernel_name:
99
+ return self.subkernel_to_kernel_name[kernel_names]
100
+
101
+ # name the multi kernel based on the first kernel
102
+ multi_kernel_name = f"multi_kernel_{len(self.subkernel_to_kernel_name)}"
103
+ self.subkernel_to_kernel_name[kernel_names] = multi_kernel_name
104
+
105
+ if V.graph.cpp_wrapper:
106
+ # we should not generate any python code for multi-kernel during
107
+ # the second pass of cpp-wrapper.
108
+ return multi_kernel_name
109
+
110
+ wrapper = V.graph.wrapper_code
111
+
112
+ kernel_call_def_code = "\n".join(
113
+ [
114
+ f"""
115
+ def call{idx}(need_clone_args=False):
116
+ args = [{', '.join(get_kernel_argdefs(kernels[idx]))}]
117
+ if need_clone_args:
118
+ args, _ = multi_kernel_call.kernels[{idx}].clone_args(*args)
119
+ multi_kernel_call.kernels[{idx}].run(*args, {', '.join(get_numel_argdefs(kernels[idx]))}, grid=grid, stream=stream)
120
+ """.format(
121
+ idx
122
+ ).strip(
123
+ "\n"
124
+ )
125
+ for idx in range(len(kernels))
126
+ ]
127
+ )
128
+
129
+ # add subkernel src code hashes to the multi-kernel source code so changing a
130
+ # subkernel implementation will result in a differnt py file for
131
+ # multi-kernel. This makes cache implementation straightforward since
132
+ # we can decide cache file name based on multi-kernel py file name
133
+ # directly.
134
+ #
135
+ # Without the hash added for subkernels, the cache file may be shared by
136
+ # different subkernels which is incorrect.
137
+ subkernel_hashes = "\n".join(
138
+ f"# subkernel{i} code hash: {kernel.code_hash}"
139
+ for i, kernel in enumerate(kernels)
140
+ )
141
+
142
+ src_code = f"""
143
+ {subkernel_hashes}
144
+ def run(multi_kernel_call, {', '.join(get_all_kernel_argdefs(kernels))}, {', '.join(get_numel_argdefs(kernels[0]))}, grid, stream):
145
+ {kernel_call_def_code}
146
+ multi_kernel_call.run_with_argless_kernels([call0, call1])
147
+ """ # noqa: B950 line too long
148
+ wrapper.header.splice(
149
+ f"""
150
+ {multi_kernel_name} = async_compile.multi_kernel({multi_kernel_name!r}, [
151
+ {", ".join(kernel_names)},
152
+ ],
153
+ '''
154
+ """
155
+ )
156
+ wrapper.header.splice(src_code)
157
+ wrapper.header.splice(
158
+ """
159
+ '''
160
+ )
161
+ """
162
+ )
163
+
164
+ return multi_kernel_name
165
+
166
+
167
+ class MultiKernel:
168
+ """
169
+ This class maintains the compile time state for multi kernels.
170
+
171
+ Assume we do codegen for a MultiKernel encapsulating kernel1 and kernel2.
172
+ The generated definition for the multi-kernel will looks like:
173
+ ```
174
+ multi_kernel_kernel1 = MultiKernelCall([kernel1, kernel2], multi_kernel_definition_code)
175
+ ```
176
+
177
+ Here is an concrete example: https://gist.github.com/shunting314/d9f3fb6bc6cee3dbae005825ca196d39
178
+ """
179
+
180
+ def __init__(self, kernels):
181
+ assert len(kernels) >= 2
182
+
183
+ self.kernels = kernels
184
+ self.kernel_name = V.graph.wrapper_code.multi_kernel_state.define_kernel(
185
+ kernels
186
+ )
187
+
188
+ # need this since some code in inductor check if the kernel object has an args
189
+ # attribute to decide if it's a non-null kernel.
190
+ self.args = object()
191
+
192
+ def call_kernel(self, kernel_name):
193
+ """
194
+ Collect the union of arguments from all subkernels as the arguments
195
+ for the multi-kernel.
196
+ """
197
+ assert kernel_name == self.kernel_name
198
+ call_args_list = [kernel.get_call_args() for kernel in self.kernels]
199
+
200
+ all_call_args = get_all_call_args(call_args_list)
201
+ grid: List[Any] = []
202
+
203
+ if V.graph.cpp_wrapper:
204
+ # for the second pass of cpp-wrapper codegen, we should call
205
+ # the fast kernel directly
206
+ picked_kernel = MultiKernelCall.lookup_choice(kernel_name)
207
+ kernel_name = self.kernels[picked_kernel].kernel_name
208
+ final_call_args = call_args_list[picked_kernel]
209
+ else:
210
+ final_call_args = all_call_args
211
+
212
+ # numels for all subkernels should be the same. Use kernels[0] here
213
+ self.kernels[0].add_numel_to_call_args_and_grid(
214
+ kernel_name, final_call_args, grid
215
+ )
216
+
217
+ grid = V.graph.wrapper_code.generate_default_grid(kernel_name, grid)
218
+
219
+ V.graph.wrapper_code.generate_kernel_call(
220
+ kernel_name,
221
+ final_call_args,
222
+ grid,
223
+ V.graph.scheduler.current_device.index,
224
+ )
225
+
226
+ def codegen_nan_check(self):
227
+ wrapper = V.graph.wrapper_code
228
+ seen = set()
229
+ for k in self.kernels:
230
+ _, call_args, arg_types = k.args.python_argdefs()
231
+ for arg, arg_type in zip(call_args, arg_types):
232
+ if arg in seen:
233
+ continue
234
+ seen.add(arg)
235
+ if isinstance(arg_type, TensorArg):
236
+ line = f"assert not {arg}.isnan().any().item()"
237
+ wrapper.writeline(line)
238
+ line = f"assert not {arg}.isinf().any().item()"
239
+ wrapper.writeline(line)
240
+
241
+ @property
242
+ def removed_buffers(self):
243
+ return set.intersection(*[k.removed_buffers for k in self.kernels])
244
+
245
+ @property
246
+ def inplaced_to_remove(self):
247
+ return set.intersection(*[k.inplaced_to_remove for k in self.kernels])
248
+
249
+ @property
250
+ @cache_on_self
251
+ def inplace_update_buffers(self):
252
+ """
253
+ Make sure all kernels have the same inplace update mappings.
254
+ """
255
+ for k in self.kernels[1:]:
256
+ assert k.inplace_update_buffers == self.kernels[0].inplace_update_buffers
257
+ return self.kernels[0].inplace_update_buffers
258
+
259
+ def warn_mix_layout(self, kernel_name: str):
260
+ pass
261
+
262
+
263
+ class MultiKernelCall:
264
+ """
265
+ This class is called at run time to actually run the kernel
266
+ """
267
+
268
+ def __init__(self, multi_kernel_name, kernels, src_code):
269
+ assert len(kernels) >= 2
270
+ self._kernels = kernels
271
+ self.multi_kernel_name = multi_kernel_name
272
+
273
+ self._run = PyCodeCache.load(src_code).run
274
+ self.disable_cache = os.environ.get(
275
+ "TORCHINDUCTOR_DISABLE_MULTI_KERNEL_CACHE"
276
+ ) == "1" or is_metric_table_enabled("persistent_red_perf")
277
+
278
+ self.picked_kernel = None
279
+ if config.triton.multi_kernel > 1:
280
+ # manually force a subkernel to ease perf testing
281
+ picked_by_config = config.triton.multi_kernel - 2
282
+ assert picked_by_config < len(self._kernels)
283
+ self.picked_kernel = picked_by_config
284
+ elif not self.disable_cache:
285
+ self.load_cache()
286
+
287
+ self._recorded = False
288
+
289
+ def cache_file_path(self):
290
+ py_file_path = self._run.__globals__["__file__"]
291
+ return os.path.splitext(py_file_path)[0] + ".picked_kernel"
292
+
293
+ def load_cache(self):
294
+ assert self.picked_kernel is None
295
+ path = self.cache_file_path()
296
+ if os.path.exists(path):
297
+ with open(path) as fd:
298
+ self.picked_kernel = int(fd.read())
299
+ assert self.picked_kernel >= 0 and self.picked_kernel < len(
300
+ self._kernels
301
+ )
302
+ log.debug(
303
+ "Load picked kernel %d from cache file %s", self.picked_kernel, path
304
+ )
305
+
306
+ def store_cache(self):
307
+ assert self.picked_kernel is not None
308
+ path = self.cache_file_path()
309
+ with open(path, "w") as fd:
310
+ fd.write(str(self.picked_kernel))
311
+ log.debug("Store picked kernel %d to cache file %s", self.picked_kernel, path)
312
+
313
+ @property
314
+ def kernels(self):
315
+ """
316
+ Read results from future.
317
+
318
+ This should be called after parallel compilation is done.
319
+ In case you call this before compilation is done,
320
+ it may slow down the parallel compilation.
321
+ """
322
+ for i, kernel in enumerate(self._kernels):
323
+ if isinstance(kernel, TritonFuture):
324
+ self._kernels[i] = kernel.result()
325
+
326
+ return self._kernels
327
+
328
+ def run(self, *args, **kwargs):
329
+ self._run(self, *args, **kwargs)
330
+
331
+ @staticmethod
332
+ def benchmark_sub_kernels(kernel_calls):
333
+ """
334
+ Benchmark all the sub kernels and return the execution time
335
+ (in milliseconds) for each of time.
336
+
337
+ Unit test may mock this method to force a specific kernel to
338
+ be picked.
339
+ """
340
+ return [
341
+ do_bench(lambda: kernel_call(True), rep=40, fast_flush=True)
342
+ for kernel_call in kernel_calls
343
+ ]
344
+
345
+ # record_choice and lookup_choice are helper functions for cpp-wrapper
346
+ # codegen. The first pass use record_choice to keep the choice and
347
+ # the second pass do lookup by calling lookup_choice.
348
+ #
349
+ # An alternative that reused the multi-kernel cache does not work well
350
+ # since during codegen of the second pass, it's very hard to know the
351
+ # path for the cache file. Also reading the cache file need do some IO
352
+ # which can be slower.
353
+ @staticmethod
354
+ def record_choice(multi_kernel_name, choice):
355
+ """
356
+ Record the multi-kernel choice for cpp-wrapper first pass codegen
357
+ for the second pass.
358
+
359
+ We should do nothing if this function is not called during codegen.
360
+ """
361
+ from torch._inductor.graph import GraphLowering
362
+
363
+ if not isinstance(V.graph, GraphLowering):
364
+ return
365
+
366
+ if not V.graph.record_multi_kernel_choice:
367
+ return
368
+
369
+ V.graph.multi_kernel_to_choice[multi_kernel_name] = choice
370
+
371
+ @staticmethod
372
+ def lookup_choice(multi_kernel_name):
373
+ # this should always been done during cpp-wrapper codegen
374
+ assert V.graph.record_multi_kernel_choice
375
+ # there should be no miss
376
+ return V.graph.multi_kernel_to_choice[multi_kernel_name]
377
+
378
+ def run_with_argless_kernels(self, kernel_calls):
379
+ if self.picked_kernel is None:
380
+ timings = self.benchmark_sub_kernels(kernel_calls)
381
+ self.picked_kernel = timings.index(min(timings))
382
+ k0 = self.kernels[0]
383
+ log.debug(
384
+ "pick %dth sub-kernel in %s. Size hints %s. Reduction hint %s. Timings %s",
385
+ self.picked_kernel,
386
+ [k.inductor_meta.get("kernel_name") for k in self.kernels],
387
+ k0.size_hints,
388
+ k0.inductor_meta.get("reduction_hint"),
389
+ timings,
390
+ )
391
+
392
+ def get_kernel_path(k):
393
+ return k.fn.fn.__code__.co_filename
394
+
395
+ get_metric_table("persistent_red_perf").add_row(
396
+ lambda: {
397
+ "kernel1_name": get_kernel_path(self.kernels[0]),
398
+ "kernel2_name": get_kernel_path(self.kernels[1]),
399
+ "kernel1_latency": timings[0],
400
+ "kernel2_latency": timings[1],
401
+ "size_hints": k0.size_hints,
402
+ "reduction_hint": k0.inductor_meta.get("reduction_hint"),
403
+ "speedup": timings[1] / timings[0],
404
+ }
405
+ )
406
+
407
+ if not self.disable_cache:
408
+ self.store_cache()
409
+
410
+ if not self._recorded:
411
+ self._recorded = True
412
+ self.record_choice(self.multi_kernel_name, self.picked_kernel)
413
+ kernel_calls[self.picked_kernel]()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton_foreach.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ from collections import defaultdict
3
+ from dataclasses import dataclass
4
+ from typing import Dict, List, Tuple
5
+
6
+ from sympy import Integer
7
+
8
+ import torch
9
+
10
+ from .. import metrics
11
+ from ..scheduler import SchedulerNode
12
+ from ..utils import ceildiv, Placeholder
13
+ from ..virtualized import V
14
+ from .common import IndentedBuffer, Kernel
15
+ from .triton import gen_common_triton_imports, TritonKernel
16
+ from .triton_utils import config_of, signature_to_meta
17
+
18
+
19
+ @dataclass
20
+ class PartitionState:
21
+ partitions: List[
22
+ List[Tuple[List[SchedulerNode], Tuple[Integer, ...], Integer, Integer]]
23
+ ]
24
+ cur_partition: List[
25
+ Tuple[List[SchedulerNode], Tuple[Integer, ...], Integer, Integer]
26
+ ]
27
+ cur_count: int
28
+
29
+ def finalize(self):
30
+ if self.cur_partition:
31
+ self.partitions.append(self.cur_partition)
32
+
33
+
34
+ class ForeachKernel(Kernel):
35
+ MAX_NUM_ARGS = 250 # number where I would no longer get triton errors
36
+
37
+ @staticmethod
38
+ def _update_partition(partition_state, node_rw_count, node_info):
39
+ if partition_state.cur_count + node_rw_count > ForeachKernel.MAX_NUM_ARGS:
40
+ partition_state.partitions.append(partition_state.cur_partition)
41
+ partition_state.cur_partition = [node_info]
42
+ partition_state.cur_count = node_rw_count
43
+ else:
44
+ partition_state.cur_count += node_rw_count
45
+ partition_state.cur_partition.append(node_info)
46
+
47
+ @staticmethod
48
+ def horizontal_partition(subkernel_nodes, triton_scheduling):
49
+ """Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnumel)
50
+ for each subkernel node where each sublist is guaranteed to not exceed CUDA limits for number of args
51
+ (read/writes) and to have the same 2D or 1D blocking strategy."""
52
+ assert len(subkernel_nodes) >= 1
53
+
54
+ partition_state_1d = PartitionState([], [], 0)
55
+ yelem_to_partition_state_2d: Dict[Integer, PartitionState] = defaultdict(
56
+ lambda: PartitionState([], [], 0)
57
+ )
58
+
59
+ for node in subkernel_nodes:
60
+ fused_nodes = node.get_nodes()
61
+ _, (numel, rnumel) = max(
62
+ fused_nodes, key=lambda x: int(x.is_reduction())
63
+ ).group
64
+ tiled_groups = triton_scheduling.select_tiling(fused_nodes, numel, rnumel)
65
+ node_info = fused_nodes, tiled_groups, numel, rnumel
66
+
67
+ read_writes = node.read_writes
68
+ read_write_count = len(read_writes.reads) + len(read_writes.writes)
69
+
70
+ if tiled_groups[1] == 1:
71
+ ForeachKernel._update_partition(
72
+ partition_state_1d, read_write_count, node_info
73
+ )
74
+ else:
75
+ y_elem = tiled_groups[0]
76
+ partition_state_2d = yelem_to_partition_state_2d[y_elem]
77
+ ForeachKernel._update_partition(
78
+ partition_state_2d, read_write_count, node_info
79
+ )
80
+
81
+ partition_state_1d.finalize()
82
+ all_partitions = partition_state_1d.partitions
83
+ for partition_state_2d in yelem_to_partition_state_2d.values():
84
+ partition_state_2d.finalize()
85
+ all_partitions.extend(partition_state_2d.partitions)
86
+
87
+ return all_partitions
88
+
89
+ def __init__(self):
90
+ super().__init__()
91
+ self.blocking_2d = False
92
+ self.block_size_1d = 1024 # Try tuning this value
93
+ self.block_size_2d = 32
94
+ self.num_warps = 8
95
+ self.sub_kernels = []
96
+ self.iter_vars_count = itertools.count()
97
+ self.x_block_count = 0
98
+ self.y_block_count = 0
99
+
100
+ def get_block_size(self):
101
+ if self.blocking_2d:
102
+ return self.block_size_2d
103
+ else:
104
+ return self.block_size_1d
105
+
106
+ @staticmethod
107
+ def codegen_pid_offsets(code, block_count, lower_bound, prefix):
108
+ if block_count == 0:
109
+ code.splice(f"{prefix}pid_offset = {prefix}pid")
110
+ else:
111
+ code.splice(f"{prefix}pid_offset = {prefix}pid - {lower_bound}")
112
+
113
+ def codegen_pid_range(self, code, x_elems):
114
+ num_x_blocks = ceildiv(x_elems, self.get_block_size())
115
+ upper_bound_x_pid = self.x_block_count + num_x_blocks
116
+ lower_bound_x_pid = self.x_block_count
117
+
118
+ if self.x_block_count == 0:
119
+ cond = "if"
120
+ else:
121
+ cond = "elif"
122
+
123
+ x_pid_bounds_check = (
124
+ f"xpid >= {lower_bound_x_pid} and xpid < {upper_bound_x_pid}"
125
+ )
126
+ code.splice(f"{cond} {x_pid_bounds_check}:")
127
+
128
+ with code.indent():
129
+ ForeachKernel.codegen_pid_offsets(
130
+ code, num_x_blocks, lower_bound_x_pid, "x"
131
+ )
132
+ self.x_block_count += num_x_blocks
133
+
134
+ def create_sub_kernel(self, *groups, index_dtype, mutations, reduction_hint):
135
+ sub_kernel = TritonKernel(
136
+ *groups,
137
+ index_dtype=index_dtype,
138
+ mutations=mutations,
139
+ pid_cache={
140
+ "tl.program_id(0)": "xpid_offset",
141
+ "tl.program_id(1)": "ypid",
142
+ },
143
+ reduction_hint=reduction_hint,
144
+ )
145
+ if self.blocking_2d:
146
+ assert len(groups) == 3
147
+
148
+ self.blocking_2d |= groups[1] != 1 and len(groups) == 3
149
+ metrics.generated_kernel_count -= 1
150
+ sub_kernel.args = self.args
151
+ sub_kernel.iter_vars_count = self.iter_vars_count
152
+ sub_kernel.cse.iter_buffer_ids = self.cse.iter_buffer_ids
153
+ self.sub_kernels.append(sub_kernel)
154
+ return sub_kernel
155
+
156
+ def jit_lines(self):
157
+ can_use_32bit = all(k.index_dtype == "tl.int32" for k in self.sub_kernels)
158
+ size_dtype = "tl.int32" if can_use_32bit else "tl.int64"
159
+ _, _, signature = self.args.python_argdefs()
160
+ triton_meta = {
161
+ "signature": signature_to_meta(signature, size_dtype=size_dtype),
162
+ "device": V.graph.scheduler.current_device.index,
163
+ "device_type": V.graph.scheduler.current_device.type,
164
+ "constants": {},
165
+ }
166
+ triton_meta["configs"] = [config_of(signature)]
167
+ inductor_meta = {
168
+ "kernel_name": str(Placeholder.DESCRIPTIVE_NAME),
169
+ "backend_hash": torch.utils._triton.triton_hash_with_backend(),
170
+ }
171
+ return f"""
172
+ @triton_heuristics.foreach(
173
+ num_warps={self.num_warps},
174
+ triton_meta={triton_meta!r},
175
+ inductor_meta={inductor_meta!r},
176
+ )
177
+ @triton.jit
178
+ """
179
+
180
+ def grid(self):
181
+ return (
182
+ self.x_block_count,
183
+ ceildiv(int(self.sub_kernels[0].numels[0]), self.block_size_2d)
184
+ if self.blocking_2d
185
+ else 1,
186
+ 1,
187
+ )
188
+
189
+ def codegen_kernel(self, name=None):
190
+ code = IndentedBuffer()
191
+
192
+ code.splice(gen_common_triton_imports())
193
+ argdefs, _, _ = self.args.python_argdefs()
194
+ code.splice(self.jit_lines())
195
+ code.writeline(
196
+ f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(argdefs)}):"
197
+ )
198
+
199
+ with code.indent():
200
+ code.splice("xpid = tl.program_id(0)")
201
+ if self.blocking_2d:
202
+ code.splice("ypid = tl.program_id(1)")
203
+ code.splice(f"XBLOCK: tl.constexpr = {self.block_size_2d}")
204
+ code.splice(f"YBLOCK: tl.constexpr = {self.block_size_2d}")
205
+ else:
206
+ code.splice(f"XBLOCK: tl.constexpr = {self.block_size_1d}")
207
+
208
+ for sub_kernel in self.sub_kernels:
209
+ assert len(sub_kernel.numels) <= 3
210
+ # TODO mlazos: support dynamic shapes
211
+ numel_ind = 0 if not self.blocking_2d else 1
212
+ self.codegen_pid_range(code, int(sub_kernel.numels[numel_ind]))
213
+ with code.indent():
214
+ if self.blocking_2d:
215
+ code.splice(f"ynumel = {sub_kernel.numels[0]}")
216
+ code.splice(f"xnumel = {sub_kernel.numels[1]}")
217
+ else:
218
+ code.splice(f"xnumel = {sub_kernel.numels[0]}")
219
+
220
+ sub_kernel.codegen_body()
221
+ code.splice(sub_kernel.body)
222
+
223
+ code.splice("else:")
224
+ with code.indent():
225
+ code.splice("pass")
226
+
227
+ return code.getvalue()
228
+
229
+ def call_kernel(self, code, name: str):
230
+ _, call_args, _ = self.args.python_argdefs()
231
+ # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
232
+ for i in range(len(call_args)):
233
+ if V.graph.is_unspec_arg(call_args[i]):
234
+ call_args[i] = call_args[i] + ".item()"
235
+ if V.graph.cpp_wrapper:
236
+ V.graph.wrapper_code.generate_kernel_call(
237
+ name,
238
+ call_args,
239
+ device_index=V.graph.scheduler.current_device.index,
240
+ grid=self.grid(),
241
+ )
242
+ else:
243
+ # TODO: refactor generate_kernel_call
244
+ call_args_str = ", ".join(call_args)
245
+ stream_name = code.write_get_raw_stream(
246
+ V.graph.scheduler.current_device.index
247
+ )
248
+ code.writeline(
249
+ f"{name}.run({call_args_str}, grid=({self.grid()}), stream={stream_name})"
250
+ )
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton_split_scan.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ from typing import Optional, Set
4
+
5
+ from torch._inductor import config, ir
6
+
7
+ from torch._inductor.codegen.triton import (
8
+ IterationRangesRoot,
9
+ triton_compute_type,
10
+ TritonKernel,
11
+ TritonKernelOverrides,
12
+ )
13
+
14
+ from torch._prims_common import prod
15
+
16
+ from torch.utils._sympy.functions import CeilDiv
17
+
18
+
19
+ class TritonSplitScanKernel(TritonKernel):
20
+ """Generates a triton kernel that supports ops.scan calls while also splitting
21
+ the reduction dimension over multiple triton programs.
22
+
23
+ For this kernel, loop numels will always take the form ``(xdim, rdim)``
24
+ and the grid has the shape ``(CeilDiv(rdim, RBLOCK), xdim)``. Communication
25
+ between blocks occurs within a global memory workspace buffer, which
26
+ must be zero-filled before launching the kernel.
27
+
28
+ Note that generation for ``ops.reduction`` is not supported.
29
+
30
+ For details of the communication strategy, see
31
+ https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back
32
+
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ *groups,
38
+ index_dtype: str,
39
+ mutations: Optional[Set[str]] = None,
40
+ reduction_hint=ir.ReductionHint.DEFAULT,
41
+ min_elem_per_thread=0,
42
+ ):
43
+ super().__init__(
44
+ *groups,
45
+ index_dtype=index_dtype,
46
+ mutations=mutations,
47
+ pid_cache=None,
48
+ reduction_hint=reduction_hint,
49
+ min_elem_per_thread=min_elem_per_thread,
50
+ )
51
+ self.no_x_dim = True
52
+
53
+ def initialize_range_tree(self, pid_cache):
54
+ prefixes = "yxr"
55
+ assert len(self.numels) <= len(
56
+ prefixes
57
+ ), "z dimension not supported for split scan"
58
+ active_prefixes = prefixes[len(prefixes) - len(self.numels) :]
59
+
60
+ grid_dims = "rxy"
61
+ for numel, prefix in zip(self.numels, active_prefixes):
62
+ is_reduction = prefix == "r"
63
+ tensor_dim = 0 if is_reduction else None
64
+ grid_dim = grid_dims.find(prefix)
65
+ self.range_trees.append(
66
+ IterationRangesRoot(
67
+ f"{prefix}index",
68
+ numel,
69
+ prefix,
70
+ grid_dim,
71
+ self,
72
+ pid_cache=pid_cache,
73
+ is_loop=False,
74
+ tensor_dim=tensor_dim,
75
+ grid_dim=grid_dim,
76
+ )
77
+ )
78
+ for tree in self.range_trees:
79
+ tree.codegen_header(self.body)
80
+
81
+ def reduction(self, dtype, src_dtype, reduction_type, value):
82
+ raise NotImplementedError("NYI TritonSplitDimKernel reductions")
83
+
84
+ def scan(self, dtype, combine_fn, value, init):
85
+ import triton.language as tl
86
+
87
+ compute_type = triton_compute_type(dtype)
88
+ compute_type_triton = getattr(tl, compute_type[3:])
89
+
90
+ element_nbits = compute_type_triton.primitive_bitwidth
91
+
92
+ scratch_type = "tl.uint32" if element_nbits <= 16 else "tl.uint64"
93
+ scratch_type_triton = getattr(tl, scratch_type[3:])
94
+ scratch_elems_per_block = 3 if element_nbits == 64 else 1
95
+ scratch_nbytes_per_block = scratch_elems_per_block * (
96
+ scratch_type_triton.primitive_bitwidth // 8
97
+ )
98
+
99
+ cse_load = functools.partial(self.cse.generate, self.loads)
100
+ cse_compute = functools.partial(self.cse.generate, self.compute)
101
+
102
+ assert len(self.numels) == 2, "Unexpected tiling"
103
+ min_rblock = config.triton.min_split_scan_rblock
104
+ max_blocks = prod(self.numels[:-1]) * CeilDiv(self.numels[-1], min_rblock)
105
+ nbytes = scratch_nbytes_per_block * max_blocks
106
+ scratch_base, offset = self.args.workspace(nbytes=nbytes, zero_fill=True)
107
+ if offset != 0:
108
+ scratch_base = cse_load(f"{scratch_base} + {self.index_to_str(offset)}")
109
+ runtime_rblocks = cse_load(f"tl.num_programs({self.range_trees[-1].index})")
110
+ scratch_base = cse_load(
111
+ f"{scratch_base}.to(tl.pointer_type({scratch_type})) + xoffset * "
112
+ f"{scratch_elems_per_block} * {runtime_rblocks}"
113
+ )
114
+
115
+ masks = {f"{tree.prefix}mask" for tree in self.range_trees}
116
+ self.filter_masks(masks)
117
+ masks = sorted(masks)
118
+ if self._load_mask:
119
+ masks.append(self._load_mask)
120
+
121
+ value = cse_compute(f"{value}.to({compute_type})")
122
+ value = cse_compute(f"tl.broadcast_to({value}, {self.dense_size_str()})")
123
+ init = cse_compute(f"tl.full([], {init}, {compute_type})")
124
+ if masks:
125
+ cond = " & ".join(masks)
126
+ masked_value = cse_compute(TritonKernelOverrides.where(cond, value, init))
127
+ else:
128
+ masked_value = value
129
+
130
+ combine_helper_fn = self._lift_helper(combine_fn, 2)
131
+ dim = self.triton_tensor_ndim() - 1
132
+ assert dim == 0, ""
133
+
134
+ block_sum = cse_compute(
135
+ f"tl.reduce({masked_value}, {dim}, {combine_helper_fn})"
136
+ )
137
+ exclusive_prefix = self.cse.newvar()
138
+ if element_nbits == 64:
139
+ self.compute.splice(
140
+ f"""
141
+ {exclusive_prefix} = triton_helpers.exclusive_scan_decoupled_lookback_64(
142
+ {scratch_base},
143
+ {block_sum},
144
+ {self.range_trees[-1].get_pid()},
145
+ {combine_helper_fn},
146
+ {init},
147
+ )
148
+ """,
149
+ strip=True,
150
+ )
151
+
152
+ else:
153
+ assert element_nbits <= 32
154
+ value_as_uint_dtype = f"tl.uint{element_nbits}"
155
+
156
+ self.compute.splice(
157
+ f"""
158
+ {exclusive_prefix} = triton_helpers.exclusive_scan_decoupled_lookback(
159
+ {scratch_base},
160
+ {block_sum},
161
+ {self.range_trees[-1].get_pid()},
162
+ {combine_helper_fn},
163
+ {init},
164
+ DTYPE_VALUE_AS_UINT={value_as_uint_dtype},
165
+ DTYPE_PACK={scratch_type},
166
+ )
167
+ """,
168
+ strip=True,
169
+ )
170
+ # Compute final cumsum
171
+ block_scan = cse_compute(
172
+ f"tl.associative_scan({masked_value}, {dim}, {combine_helper_fn})"
173
+ )
174
+ return cse_compute(f"{combine_helper_fn}({exclusive_prefix}, {block_scan})")
175
+
176
+ def _get_heuristic(self):
177
+ return "split_scan"
178
+
179
+ def _get_grid_fn(self):
180
+ return "split_scan_grid"
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton_utils.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional
2
+
3
+ import torch
4
+
5
+ from .. import config
6
+ from ..utils import _type_of, instance_descriptor
7
+ from ..virtualized import V
8
+ from .common import KernelArgType, SizeArg, TensorArg, WorkspaceArg
9
+
10
+
11
+ def signature_of(arg: KernelArgType, *, size_dtype: str) -> str:
12
+ if isinstance(arg, TensorArg):
13
+ # TODO: Remove fp8 special handling when Triton supports PyTorch fp8 dtypes.
14
+ # Related PR: https://github.com/openai/triton/pull/2279/
15
+ if arg.dtype == torch.float8_e4m3fn:
16
+ tye = "*fp8e4nv"
17
+ elif arg.dtype == torch.float8_e5m2:
18
+ tye = "*fp8e5"
19
+ elif arg.dtype == torch.float8_e4m3fnuz:
20
+ tye = "*fp8e4b8"
21
+ elif arg.dtype == torch.float8_e5m2fnuz:
22
+ tye = "*fp8e5b16"
23
+ else:
24
+ tye = _type_of(arg.dtype)
25
+ if V.graph.is_unspec_arg(arg.buffer):
26
+ # had unwrapped 0d tensor as scalar
27
+ new_tye = tye.lstrip("*")
28
+ if new_tye in ["fp16", "bf16"]:
29
+ return "fp32"
30
+ else:
31
+ return new_tye
32
+ else:
33
+ return tye
34
+ if isinstance(arg, SizeArg):
35
+ if arg.expr is None:
36
+ # From triton/runtime/jit.py
37
+ # `None` is nullptr. Implicitly convert to *i8.
38
+ return "*i8"
39
+ elif isinstance(arg.expr, float):
40
+ return "fp32"
41
+ if size_dtype == "tl.int32":
42
+ return "i32"
43
+ elif size_dtype == "tl.int64":
44
+ return "i64"
45
+ else:
46
+ raise NotImplementedError(f"unhandled size_dtype {size_dtype}")
47
+ if isinstance(arg, WorkspaceArg):
48
+ return "*i8"
49
+ raise NotImplementedError(f"unhandled {type(arg)}: {arg}")
50
+
51
+
52
+ def signature_to_meta(
53
+ signature: List[KernelArgType],
54
+ *,
55
+ size_dtype: str,
56
+ indices: Optional[List[int]] = None,
57
+ ) -> Dict[int, str]:
58
+ if indices is None:
59
+ indices = list(range(len(signature)))
60
+ return {
61
+ i: signature_of(arg, size_dtype=size_dtype)
62
+ for i, arg in zip(indices, signature)
63
+ }
64
+
65
+
66
+ def config_of(
67
+ args: List[KernelArgType],
68
+ *,
69
+ indices: Optional[List[int]] = None,
70
+ ) -> Any:
71
+ if indices is None:
72
+ indices = list(range(len(args)))
73
+
74
+ def is_aligned(x: KernelArgType, alignment: int, include_tensor: bool) -> bool:
75
+ """
76
+ Roughly follow triton code here:
77
+ https://github.com/openai/triton/blob/5282ed890d453e10b9ee30076ef89115dd197761/python/triton/runtime/jit.py#L208-L222
78
+ """
79
+ if isinstance(x, TensorArg):
80
+ if include_tensor:
81
+ offset_aligned = V.graph.sizevars.statically_known_multiple_of(
82
+ x.offset * x.dtype.itemsize, alignment # type: ignore[arg-type]
83
+ )
84
+ return offset_aligned and not V.graph.scheduler.is_unaligned_buffer(
85
+ x.buffer
86
+ )
87
+ else:
88
+ return False
89
+ if isinstance(x, SizeArg):
90
+ # TODO(voz): These are kinda redundant, if we can solve out statically_known_multiple_of with
91
+ # _maybe_evaluate_static...
92
+ if x.name.startswith("load_seed_offset"):
93
+ return False
94
+ if x.expr is None:
95
+ return False
96
+ if isinstance(x.expr, float):
97
+ return False
98
+ return V.graph.sizevars.statically_known_multiple_of(x.expr, alignment) # type: ignore[arg-type]
99
+ if isinstance(x, WorkspaceArg):
100
+ return V.graph.sizevars.statically_known_multiple_of(x.nbytes, alignment) # type: ignore[arg-type]
101
+ raise NotImplementedError(f"unhandled {type(x)}: {x}")
102
+
103
+ if config.triton.divisible_by_16:
104
+ divisible_by_16 = tuple(
105
+ i
106
+ for i, arg in zip(indices, args)
107
+ if is_aligned(arg, alignment=16, include_tensor=True)
108
+ )
109
+ else:
110
+ divisible_by_16 = ()
111
+ divisible_by_8 = tuple(
112
+ i
113
+ for i, arg in zip(indices, args)
114
+ if is_aligned(arg, alignment=8, include_tensor=False)
115
+ )
116
+
117
+ equal_to_1 = tuple(
118
+ i
119
+ for i, arg in zip(indices, args)
120
+ if isinstance(arg, SizeArg)
121
+ and arg.expr is not None
122
+ and V.graph.sizevars.statically_known_equals(arg.expr, 1) # type: ignore[arg-type]
123
+ )
124
+ # ids_of_folded_args is set from equal_to_1
125
+ # and None args by the Triton compiler
126
+ ids_of_folded_args = tuple(equal_to_1)
127
+
128
+ return instance_descriptor(
129
+ divisible_by_16, equal_to_1, ids_of_folded_args, divisible_by_8
130
+ )
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/wrapper.py ADDED
@@ -0,0 +1,1543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import contextlib
3
+ import dataclasses
4
+ import functools
5
+ import inspect
6
+ import operator
7
+ import re
8
+ from itertools import count
9
+ from typing import (
10
+ Any,
11
+ Callable,
12
+ Dict,
13
+ Iterator,
14
+ List,
15
+ Optional,
16
+ Set,
17
+ Tuple,
18
+ TYPE_CHECKING,
19
+ Union,
20
+ )
21
+
22
+ import sympy
23
+ from sympy import Expr
24
+
25
+ import torch
26
+ import torch._ops
27
+ from torch._dynamo.utils import counters, dynamo_timed
28
+
29
+ from torch._inductor.codegen.multi_kernel import MultiKernelState
30
+ from torch.fx.experimental.symbolic_shapes import SymTypes
31
+ from torch.fx.node import _get_qualified_name
32
+ from torch.utils._sympy.singleton_int import SingletonInt
33
+
34
+ from .. import codecache, config, ir
35
+ from ..ir import ReinterpretView
36
+ from ..utils import (
37
+ cache_on_self,
38
+ get_benchmark_name,
39
+ LineContext,
40
+ sympy_product,
41
+ sympy_str,
42
+ )
43
+ from ..virtualized import V
44
+ from .common import CodeGen, DeferredLine, IndentedBuffer, PythonPrinter
45
+ from .triton_utils import config_of, signature_to_meta
46
+
47
+ if TYPE_CHECKING:
48
+ import triton
49
+
50
+ from ..graph import GraphLowering
51
+
52
+
53
+ pexpr = PythonPrinter().doprint
54
+
55
+
56
+ ReuseKey = Tuple[torch.device, torch.dtype, str]
57
+
58
+
59
+ def buffer_reuse_key(node: ir.Buffer) -> ReuseKey:
60
+ return (
61
+ node.get_device(),
62
+ node.get_dtype(),
63
+ # NB: this is symbolic so that we don't try to reuse a buffer
64
+ # for s0 for s1, just because they happen to share the same
65
+ # size hint
66
+ sympy_str(V.graph.sizevars.simplify(node.layout.storage_size())),
67
+ )
68
+
69
+
70
+ def convert_arg_type(arg: torch.Argument) -> str:
71
+ from .cpp import CONTAINER_PYTHON_TO_CPP, PYTHON_TO_CPP
72
+
73
+ # use x.real_type instead of x.type so that we get ScalarType instead of int
74
+ python_type = repr(arg.real_type) # type: ignore[attr-defined]
75
+
76
+ if python_type == "Tensor":
77
+ # Conversions rules follow https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/native#func
78
+ if arg.alias_info is not None and arg.alias_info.is_write:
79
+ return f"at::{python_type}&"
80
+ else:
81
+ return f"at::{python_type} const&"
82
+
83
+ if python_type in PYTHON_TO_CPP:
84
+ cpp_type = PYTHON_TO_CPP[python_type]
85
+ return cpp_type
86
+
87
+ # Convert args of container types e.g. Optional[*]
88
+ for py_container, cpp_container in CONTAINER_PYTHON_TO_CPP.items():
89
+ container_match = re.findall(py_container + r"\[([a-zA-Z_]+)]", python_type)
90
+ if len(container_match) == 1:
91
+ contained_type = container_match[0]
92
+ assert (
93
+ contained_type in PYTHON_TO_CPP
94
+ ), f"unsupported {py_container} type in convert_arg_type: {contained_type}"
95
+ cpp_contained_type = PYTHON_TO_CPP[contained_type]
96
+ return f"{cpp_container}<{cpp_contained_type}>"
97
+
98
+ raise AssertionError(f"unsupport python_type: {python_type}")
99
+
100
+
101
+ def convert_return_type(ret: torch.Argument) -> str:
102
+ # use x.real_type instead of x.type so that we get ScalarType instead of int
103
+ python_type = repr(ret.real_type) # type: ignore[attr-defined]
104
+ python_to_cpp = {
105
+ "Tensor": "at::Tensor",
106
+ "List[Tensor]": "std::vector<at::Tensor>",
107
+ }
108
+
109
+ cpp_type = python_to_cpp.get(python_type, None)
110
+ assert cpp_type is not None, f"NYI return type: {python_type}"
111
+ # An output aliasing an input is returned by reference only when it's a
112
+ # Tensor, not when it's a Tensor[]. For example, aten.split.Tensor's output
113
+ # aliases the input tensor, but the op returns a vector by value.
114
+ if python_type == "Tensor" and ret.alias_info is not None:
115
+ cpp_type += "&"
116
+ return cpp_type
117
+
118
+
119
+ def get_cpp_op_schema(kernel: torch._ops.OpOverload) -> str:
120
+ args = kernel._schema.arguments
121
+ returns = kernel._schema.returns
122
+
123
+ num_returns = len(returns)
124
+ assert num_returns > 0, "must have at least one return value"
125
+
126
+ if num_returns == 1:
127
+ cpp_return_value = convert_return_type(returns[0])
128
+ elif num_returns > 1:
129
+ tuple_returns = ", ".join([convert_return_type(r) for r in returns])
130
+ cpp_return_value = f"std::tuple<{tuple_returns}>"
131
+
132
+ cpp_arg_type = [f"{convert_arg_type(arg)} {arg.name}" for arg in args]
133
+ return f"{cpp_return_value}({', '.join(cpp_arg_type)})" # type: ignore[possibly-undefined]
134
+
135
+
136
+ # TODO: Move to a well known place
137
+ TritonMetaParams = Dict[str, int]
138
+ TritonGrid = Union[
139
+ Tuple[Union[int, sympy.Expr], ...], Callable[[TritonMetaParams], Tuple[int, ...]]
140
+ ]
141
+
142
+
143
+ def user_defined_kernel_grid_fn_code(
144
+ name: str,
145
+ configs: List["triton.Config"],
146
+ grids: List[TritonGrid],
147
+ wrapper: Optional["WrapperCodeGen"] = None,
148
+ ) -> Tuple[str, str]:
149
+ output = IndentedBuffer()
150
+
151
+ def _convert_to_sympy_expr(item: Union[int, sympy.Expr]) -> sympy.Expr:
152
+ return item if isinstance(item, sympy.Expr) else sympy.Integer(item)
153
+
154
+ def determine_grid(grid: TritonGrid):
155
+ if wrapper is None or callable(grid):
156
+ # return as-is when used in eager mode or when grid is callable
157
+ return grid
158
+ # Grid contains ints/Expr, so utilize wrapper's expr printer for codegen
159
+ sympy_grid = tuple(_convert_to_sympy_expr(g) for g in grid)
160
+ return wrapper.codegen_shape_tuple(sympy_grid)
161
+
162
+ fn_name = f"grid_wrapper_for_{name}"
163
+ output.writeline(f"def {fn_name}(meta):")
164
+ with output.indent():
165
+ if len(grids) == 1:
166
+ grid = determine_grid(grids[0])
167
+ output.writeline(f"return {grid}")
168
+ else:
169
+ assert len(grids) > 1
170
+ assert len(grids) == len(configs)
171
+ seen = set()
172
+ for grid, c in zip(grids, configs):
173
+ guards = [f"meta['{name}'] == {val}" for name, val in c.kwargs.items()]
174
+ guards = " and ".join(guards)
175
+ grid = determine_grid(grid)
176
+ statement = f"if {guards}: return {grid}"
177
+ if statement in seen:
178
+ continue
179
+ seen.add(statement)
180
+ output.writeline(statement)
181
+
182
+ return fn_name, output.getvalue()
183
+
184
+
185
+ @dataclasses.dataclass
186
+ class SymbolicCallArg:
187
+ inner: str
188
+ # the original symbolic expression represented by inner
189
+ inner_expr: sympy.Expr
190
+
191
+ def __str__(self):
192
+ return str(self.inner)
193
+
194
+
195
+ # Default thread stack sizes vary by platform:
196
+ # - Linux: 8 MB
197
+ # - macOS: 512 KB
198
+ # - Windows: 1 MB
199
+ # Just pick something comfortably smaller than the smallest for now.
200
+ MAX_STACK_ALLOCATION_SIZE = 1024 * 100
201
+
202
+
203
+ class MemoryPlanningState:
204
+ def __init__(self):
205
+ super().__init__()
206
+ self.reuse_pool: Dict[
207
+ ReuseKey, List[FreeIfNotReusedLine]
208
+ ] = collections.defaultdict(list)
209
+ self.total_allocated_buffer_size: int = 0
210
+
211
+ def __contains__(self, key: ReuseKey) -> bool:
212
+ return bool(self.reuse_pool.get(key, None))
213
+
214
+ def pop(self, key: ReuseKey) -> "FreeIfNotReusedLine":
215
+ item = self.reuse_pool[key].pop()
216
+ assert not item.is_reused
217
+ return item
218
+
219
+ def push(self, key: ReuseKey, item: "FreeIfNotReusedLine") -> None:
220
+ assert not item.is_reused
221
+ self.reuse_pool[key].append(item)
222
+
223
+
224
+ class WrapperLine:
225
+ pass
226
+
227
+
228
+ @dataclasses.dataclass
229
+ class EnterSubgraphLine(WrapperLine):
230
+ wrapper: "WrapperCodeGen"
231
+ graph: "GraphLowering"
232
+
233
+ def codegen(self, code: IndentedBuffer) -> None:
234
+ self.wrapper.push_codegened_graph(self.graph)
235
+ code.do_indent()
236
+
237
+
238
+ @dataclasses.dataclass
239
+ class ExitSubgraphLine(WrapperLine):
240
+ wrapper: "WrapperCodeGen"
241
+
242
+ def codegen(self, code: IndentedBuffer) -> None:
243
+ self.wrapper.pop_codegened_graph()
244
+ code.do_unindent()
245
+
246
+
247
+ @dataclasses.dataclass
248
+ class EnterDeviceContextManagerLine(WrapperLine):
249
+ device_idx: int
250
+ last_seen_device_guard_index: Optional[int]
251
+
252
+ def codegen(self, code: IndentedBuffer) -> None:
253
+ if V.graph.cpp_wrapper:
254
+ code.writeline("\n")
255
+ if V.graph.aot_mode:
256
+ # In AOT mode, we have a stream provided as a param. A stream is
257
+ # associated with a device, so we never expect the device to change.
258
+ # CUDAStreamGuard sets the stream and the device.
259
+ if self.last_seen_device_guard_index is None:
260
+ if config.abi_compatible:
261
+ code.writeline(
262
+ "AOTICudaStreamGuard stream_guard(stream, this->device_idx_);"
263
+ )
264
+ else:
265
+ code.writeline(
266
+ "at::cuda::CUDAStreamGuard stream_guard("
267
+ + "at::cuda::getStreamFromExternal(stream, this->device_idx_));"
268
+ )
269
+ else:
270
+ assert (
271
+ self.last_seen_device_guard_index == self.device_idx
272
+ ), "AOTInductor only supports running on one CUDA device"
273
+ else:
274
+ if self.last_seen_device_guard_index is None:
275
+ code.writeline(
276
+ f"AOTICudaGuard device_guard({self.device_idx});"
277
+ if config.abi_compatible
278
+ else f"at::cuda::CUDAGuard device_guard({self.device_idx});"
279
+ )
280
+ else:
281
+ code.writeline(f"device_guard.set_index({self.device_idx});")
282
+ else:
283
+ # Note _DeviceGuard has less overhead than device, but only accepts
284
+ # integers
285
+ code.writeline(f"with {V.graph.device_ops.device_guard(self.device_idx)}:")
286
+ code.do_indent()
287
+ code.writeline(V.graph.device_ops.set_device(self.device_idx))
288
+
289
+
290
+ class ExitDeviceContextManagerLine(WrapperLine):
291
+ def codegen(self, code: IndentedBuffer) -> None:
292
+ if not V.graph.cpp_wrapper:
293
+ code.do_unindent()
294
+
295
+
296
+ @dataclasses.dataclass
297
+ class MemoryPlanningLine(WrapperLine):
298
+ wrapper: "WrapperCodeGen"
299
+
300
+ def plan(self, state: MemoryPlanningState) -> "MemoryPlanningLine":
301
+ """First pass to find reuse"""
302
+ return self
303
+
304
+ def codegen(self, code: IndentedBuffer) -> None:
305
+ """Second pass to output code"""
306
+ pass
307
+
308
+ def __str__(self) -> str:
309
+ """
310
+ Emits a string representation that fits on one line.
311
+ """
312
+ args: List[str] = []
313
+ for field in dataclasses.fields(self):
314
+ if field.name == "wrapper":
315
+ continue
316
+ val = getattr(self, field.name)
317
+ args.append(
318
+ f"{field.name}={val.get_name() if field.type is ir.Buffer else val}"
319
+ )
320
+ return f"{type(self).__name__}({', '.join(args)})"
321
+
322
+
323
+ @dataclasses.dataclass
324
+ class AllocateLine(MemoryPlanningLine):
325
+ node: ir.Buffer
326
+
327
+ def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine:
328
+ if self.node.get_name() in V.graph.removed_buffers:
329
+ return NullLine(self.wrapper)
330
+
331
+ # try to reuse a recently freed buffer
332
+ key = buffer_reuse_key(self.node)
333
+ if config.allow_buffer_reuse and key in state:
334
+ free_line = state.pop(key)
335
+ free_line.is_reused = True
336
+ return ReuseLine(self.wrapper, free_line.node, self.node)
337
+
338
+ if self.node.get_device().type == "cpu":
339
+ static_shape = self.wrapper.static_shape_for_buffer_or_none(self.node)
340
+ if static_shape is not None:
341
+ state.total_allocated_buffer_size += int(
342
+ functools.reduce(operator.mul, static_shape, 1)
343
+ )
344
+
345
+ return self
346
+
347
+ def codegen(self, code: IndentedBuffer) -> None:
348
+ assert self.node.get_name() not in V.graph.removed_buffers
349
+ line = self.wrapper.make_buffer_allocation(self.node)
350
+ code.writeline(line)
351
+
352
+
353
+ @dataclasses.dataclass
354
+ class FreeIfNotReusedLine(MemoryPlanningLine):
355
+ node: ir.Buffer
356
+ is_reused: bool = False
357
+
358
+ def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine:
359
+ if isinstance(self.node.layout, (ir.AliasedLayout, ir.MultiOutputLayout)):
360
+ return self
361
+ assert not self.is_reused
362
+ if self.node.get_name() in V.graph.removed_buffers:
363
+ return NullLine(self.wrapper)
364
+ if config.allow_buffer_reuse:
365
+ state.push(buffer_reuse_key(self.node), self)
366
+ return self
367
+
368
+ def codegen(self, code: IndentedBuffer) -> None:
369
+ assert self.node.get_name() not in V.graph.removed_buffers
370
+ if not self.is_reused:
371
+ code.writeline(self.wrapper.make_buffer_free(self.node))
372
+
373
+
374
+ @dataclasses.dataclass
375
+ class ReuseLine(MemoryPlanningLine):
376
+ node: ir.Buffer
377
+ reused_as: ir.Buffer
378
+ delete_old: bool = True
379
+
380
+ def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine:
381
+ if self.node.get_name() in V.graph.removed_buffers:
382
+ assert self.reused_as.get_name() in V.graph.removed_buffers
383
+ return NullLine(self.wrapper)
384
+ assert self.reused_as.get_name() not in V.graph.removed_buffers
385
+ return self
386
+
387
+ def codegen(self, code: IndentedBuffer) -> None:
388
+ assert self.node.get_name() not in V.graph.removed_buffers
389
+ assert self.reused_as.get_name() not in V.graph.removed_buffers
390
+ code.writeline(
391
+ self.wrapper.make_buffer_reuse(self.node, self.reused_as, self.delete_old)
392
+ )
393
+
394
+
395
+ class NullLine(MemoryPlanningLine):
396
+ pass
397
+
398
+
399
+ BufferName = str
400
+
401
+
402
+ class WrapperCodeGen(CodeGen):
403
+ """
404
+ Generate outer wrapper in Python that calls the kernels.
405
+ """
406
+
407
+ def __init__(self):
408
+ super().__init__()
409
+ self._names_iter: Iterator[int] = count()
410
+ self.header = IndentedBuffer()
411
+ self.prefix = IndentedBuffer()
412
+ self.suffix = IndentedBuffer()
413
+ self.wrapper_call = IndentedBuffer()
414
+ # If the generated source code is exactly the same, reuse the
415
+ # pre-existing kernel for it
416
+ self.src_to_kernel: Dict[str, str] = {}
417
+ self.kernel_numel_expr: Set[Tuple[str, "GraphLowering"]] = set()
418
+ self.lines: List[Union[MemoryPlanningLine, LineContext]] = []
419
+ self.declare = ""
420
+ self.declare_maybe_reference = ""
421
+ self.ending = ""
422
+ self.open_bracket = "["
423
+ self.closed_bracket = "]"
424
+ self.comment = "#"
425
+ self.namespace = ""
426
+ self.none_str = "None"
427
+ self.size = "size()"
428
+ self.stride = "stride()"
429
+ self.last_seen_device_guard_index: Optional[int] = None
430
+ self.supports_intermediate_hooks = True
431
+ self.expr_printer = pexpr
432
+ self.user_defined_kernel_cache: Dict[Tuple[Any, ...], Tuple[str, Any]] = {}
433
+ self.unbacked_symbol_decls: Set[str] = set() # str of sympy.Symbol
434
+ self.allow_stack_allocation: Optional[bool] = None
435
+ self.stack_allocated_buffers: Dict[BufferName, ir.Buffer] = {}
436
+ self.computed_sizes: Set[sympy.Symbol] = set()
437
+
438
+ # this is used for tracking which GraphLowering instance---parent graph
439
+ # or (nested) subgraph---is currently codegened; the primary use case is
440
+ # including the graph instance into a cache key to avoid cross-graph
441
+ # caching during lowering of nested subgraphs
442
+ self.codegened_graph_stack = [V.graph]
443
+
444
+ self.write_header()
445
+ self.write_prefix()
446
+
447
+ if not V.graph.aot_mode:
448
+ for name, hashed in V.graph.constant_reprs.items():
449
+ # include a hash so our code cache puts different constants into different files
450
+ self.write_constant(name, hashed)
451
+
452
+ self.allocated: Set[BufferName] = set()
453
+ self.freed: Set[BufferName] = set()
454
+
455
+ # maps from reusing buffer to reused buffer
456
+ self.reuses: Dict[BufferName, BufferName] = dict()
457
+
458
+ self.write_get_raw_stream = functools.lru_cache(None)( # type: ignore[assignment]
459
+ self.write_get_raw_stream
460
+ )
461
+
462
+ @functools.lru_cache(None)
463
+ def add_import_once(line: str) -> None:
464
+ self.header.writeline(line)
465
+
466
+ self.add_import_once = add_import_once
467
+ self._metas: Dict[str, str] = {}
468
+ self.multi_kernel_state = MultiKernelState()
469
+
470
+ def write_constant(self, name: str, hashed: str) -> None:
471
+ self.header.writeline(f"{name} = None # {hashed}")
472
+
473
+ def write_header(self) -> None:
474
+ self.header.splice(
475
+ f"""
476
+ from ctypes import c_void_p, c_long
477
+ import torch
478
+ import math
479
+ import random
480
+ import os
481
+ import tempfile
482
+ from math import inf, nan
483
+ from torch._inductor.hooks import run_intermediate_hooks
484
+ from torch._inductor.utils import maybe_profile
485
+ from torch._inductor.codegen.memory_planning import _align as align
486
+
487
+ from torch import device, empty_strided
488
+ from {codecache.__name__} import AsyncCompile
489
+ from torch._inductor.select_algorithm import extern_kernels
490
+ from torch._inductor.codegen.multi_kernel import MultiKernelCall
491
+
492
+ aten = torch.ops.aten
493
+ inductor_ops = torch.ops.inductor
494
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
495
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
496
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
497
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
498
+ reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
499
+ async_compile = AsyncCompile()
500
+
501
+ """
502
+ )
503
+
504
+ @cache_on_self
505
+ def write_triton_header_once(self) -> None:
506
+ self.header.splice(
507
+ """
508
+ import triton
509
+ import triton.language as tl
510
+ from torch._inductor.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
511
+ {}
512
+ """.format(
513
+ V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
514
+ )
515
+ )
516
+
517
+ def add_meta_once(self, meta: TritonMetaParams) -> str:
518
+ meta = repr(meta)
519
+ if meta not in self._metas:
520
+ var = f"meta{len(self._metas)}"
521
+ self._metas[meta] = var
522
+ self.header.writeline(f"{var} = {meta}")
523
+ return self._metas[meta]
524
+
525
+ @cache_on_self
526
+ def get_output_refs(self) -> List[str]:
527
+ return [x.codegen_reference(self.wrapper_call) for x in V.graph.graph_outputs]
528
+
529
+ def mark_output_type(self) -> None:
530
+ return
531
+
532
+ def codegen_input_size_asserts(self) -> None:
533
+ for name, buf in V.graph.graph_inputs.items():
534
+ if isinstance(buf, sympy.Expr):
535
+ continue
536
+
537
+ # comparing strides for 0 size tensor is tricky. Ignore them for now.
538
+ if sympy_product(buf.get_size()) == 0:
539
+ continue
540
+ size = self.codegen_shape_tuple(buf.get_size())
541
+ stride = self.codegen_shape_tuple(buf.get_stride())
542
+ self.prefix.writeline(f"assert_size_stride({name}, {size}, {stride})")
543
+
544
+ def codegen_input_nan_asserts(self) -> None:
545
+ self.prefix.writeline("# make sure graph inputs are not nan/inf")
546
+ for name, buf in V.graph.graph_inputs.items():
547
+ if isinstance(buf, sympy.Expr):
548
+ continue
549
+
550
+ line = f"assert not {name}.isnan().any().item()"
551
+ self.prefix.writeline(line)
552
+ line = f"assert not {name}.isinf().any().item()"
553
+ self.prefix.writeline(line)
554
+
555
+ def write_prefix(self) -> None:
556
+ self.prefix.splice(
557
+ """
558
+
559
+ async_compile.wait(globals())
560
+ del async_compile
561
+
562
+ def call(args):
563
+ """
564
+ )
565
+ with self.prefix.indent():
566
+ if config.triton.debug_sync_graph:
567
+ self.prefix.writeline(V.graph.device_ops.synchronize())
568
+ if V.graph.graph_inputs:
569
+ lhs = ", ".join(V.graph.graph_input_names)
570
+ if len(V.graph.graph_input_names) == 1:
571
+ lhs += ","
572
+ self.prefix.writeline(f"{lhs} = args")
573
+ self.prefix.writeline("args.clear()")
574
+
575
+ self.codegen_inputs(self.prefix, V.graph.graph_inputs)
576
+ if config.size_asserts:
577
+ self.codegen_input_size_asserts()
578
+ if config.nan_asserts:
579
+ self.codegen_input_nan_asserts()
580
+
581
+ # this function (and below) takes a graph as input so
582
+ # that stream caching happens per graph instance. this
583
+ # is important for nested subgraph codegening.
584
+ def write_get_raw_stream(self, device_idx: int, graph=None) -> str:
585
+ self.write_triton_header_once()
586
+ name = f"stream{device_idx}"
587
+ self.writeline(f"{name} = get_raw_stream({device_idx})")
588
+ return name
589
+
590
+ def get_codegened_graph(self):
591
+ return self.codegened_graph_stack[-1]
592
+
593
+ def push_codegened_graph(self, graph):
594
+ self.codegened_graph_stack.append(graph)
595
+
596
+ def pop_codegened_graph(self):
597
+ return self.codegened_graph_stack.pop()
598
+
599
+ def next_kernel_suffix(self) -> str:
600
+ return f"{next(self._names_iter)}"
601
+
602
+ def codegen_device_guard_enter(self, device_idx: int) -> None:
603
+ self.writeline(
604
+ EnterDeviceContextManagerLine(device_idx, self.last_seen_device_guard_index)
605
+ )
606
+ self.last_seen_device_guard_index = device_idx
607
+
608
+ def codegen_device_guard_exit(self) -> None:
609
+ self.writeline(ExitDeviceContextManagerLine())
610
+
611
+ def generate_return(self, output_refs: List[str]) -> None:
612
+ if output_refs:
613
+ self.wrapper_call.writeline("return (" + ", ".join(output_refs) + ", )")
614
+ else:
615
+ self.wrapper_call.writeline("return ()")
616
+
617
+ def generate_before_suffix(self, result: IndentedBuffer) -> None:
618
+ return
619
+
620
+ def generate_end(self, result: IndentedBuffer) -> None:
621
+ return
622
+
623
+ def generate_fallback_kernel(self, fallback_kernel, args):
624
+ self.generate_extern_kernel_alloc(fallback_kernel, args)
625
+
626
+ def generate_extern_kernel_alloc(self, extern_kernel, args):
627
+ output_name = extern_kernel.get_name()
628
+ origin_node = extern_kernel.get_origin_node()
629
+ kernel_name = extern_kernel.get_kernel_name()
630
+ ending = self.ending
631
+ if config.memory_planning and "view_as_complex" in kernel_name:
632
+ # view operation fallbacks cause issues since inductor
633
+ # doesn't know the memory is still needed and might reuse it.
634
+ ending = f".clone(){ending}"
635
+ self.writeline(
636
+ f"{self.declare}{output_name} = {kernel_name}({', '.join(args)}){ending}"
637
+ )
638
+ if (
639
+ self.supports_intermediate_hooks
640
+ and config.generate_intermediate_hooks
641
+ and origin_node is not None
642
+ ):
643
+ counters["inductor"]["intermediate_hooks"] += 1
644
+ self.writeline(
645
+ f"run_intermediate_hooks({origin_node.name!r}, {output_name})"
646
+ )
647
+
648
+ def generate_extern_kernel_out(self, output_view, codegen_reference, args, kernel):
649
+ if output_view:
650
+ args.append(f"out={output_view.codegen_reference()}")
651
+ else:
652
+ args.append(f"out={codegen_reference}")
653
+ self.writeline(f"{kernel}({', '.join(args)})")
654
+
655
+ def generate_user_defined_triton_kernel(
656
+ self, kernel_name, grid, configs, args, triton_meta
657
+ ):
658
+ grid, code = user_defined_kernel_grid_fn_code(
659
+ kernel_name, configs, grid, wrapper=self
660
+ )
661
+ # Must happen after free symbols are already codegened
662
+ # Emit the grid wrapper function right before the call
663
+ for line in code.split("\n"):
664
+ self.writeline(line)
665
+
666
+ stream_name = self.write_get_raw_stream(
667
+ V.graph.scheduler.current_device.index, V.graph
668
+ )
669
+ self.writeline(
670
+ f"{kernel_name}.run({', '.join(args)}, grid={grid}, stream={stream_name})"
671
+ )
672
+
673
+ def generate_scatter_fallback(
674
+ self, output, inputs, kernel, python_kernel_name, src_is_tensor, reduce, kwargs
675
+ ):
676
+ line = f"{kernel}({','.join(map(str, inputs))}"
677
+ if kernel == "aten.scatter_":
678
+ if reduce:
679
+ line += f", reduce={repr(reduce)}"
680
+ else:
681
+ line += ", ".join([""] + kwargs)
682
+ line += f"){self.ending}"
683
+ self.writeline(line)
684
+
685
+ def generate_index_put_fallback(self, kernel, x, indices, values, accumulate):
686
+ indices_str = f"{self.open_bracket}{', '.join(indices)}{self.closed_bracket}"
687
+ args = [x, indices_str, values, accumulate]
688
+ self.writeline(self.wrap_kernel_call(kernel, args))
689
+
690
+ def generate_extern_kernel_alloc_and_find_schema_if_needed(
691
+ self,
692
+ name,
693
+ kernel,
694
+ codegen_args,
695
+ cpp_op_schema,
696
+ cpp_kernel_key,
697
+ cpp_kernel_overload_name="",
698
+ op_overload=None,
699
+ raw_args=None,
700
+ outputs=None,
701
+ ):
702
+ self.writeline(f"{name} = {kernel}({', '.join(codegen_args)})")
703
+
704
+ def generate_inf_and_nan_checker(self, node):
705
+ # TODO: Add check for python too.
706
+ pass
707
+
708
+ @dynamo_timed
709
+ def generate(self, is_inference):
710
+ if config.profile_bandwidth:
711
+ self.write_triton_header_once()
712
+ result = IndentedBuffer()
713
+ result.splice(self.header)
714
+
715
+ with contextlib.ExitStack() as stack:
716
+ stack.enter_context(self.wrapper_call.indent())
717
+ if config.profiler_mark_wrapper_call:
718
+ self.generate_profiler_mark_wrapper_call(stack)
719
+ if config.profile_bandwidth:
720
+ self.generate_start_graph()
721
+
722
+ # We disable planning during training because it presently increases peak memory consumption.
723
+ if is_inference and config.memory_planning:
724
+ self.memory_plan()
725
+ # TODO: integrate memory planning & stack allocation?
726
+ self.allow_stack_allocation = False
727
+ else:
728
+ self.memory_plan_reuse()
729
+
730
+ if config.triton.store_cubin:
731
+ self.generate_reset_kernel_saved_flags()
732
+
733
+ for line in self.lines:
734
+ if isinstance(line, WrapperLine):
735
+ line.codegen(self.wrapper_call)
736
+ else:
737
+ self.wrapper_call.writeline(line)
738
+
739
+ output_refs = self.get_output_refs()
740
+ self.mark_output_type()
741
+ if config.triton.debug_sync_graph:
742
+ self.wrapper_call.writeline(V.graph.device_ops.synchronize())
743
+
744
+ if config.profile_bandwidth:
745
+ self.generate_end_graph()
746
+
747
+ if config.triton.store_cubin:
748
+ self.generate_save_uncompiled_kernels()
749
+
750
+ self.generate_return(output_refs)
751
+
752
+ self.finalize_prefix()
753
+ result.splice(self.prefix)
754
+
755
+ with result.indent():
756
+ result.splice(self.wrapper_call)
757
+
758
+ self.generate_before_suffix(result)
759
+ result.splice(self.suffix)
760
+
761
+ self.generate_end(result)
762
+
763
+ self.add_benchmark_harness(result)
764
+
765
+ return result.getvaluewithlinemap()
766
+
767
+ def memory_plan(self):
768
+ from .memory_planning import MemoryPlanner
769
+
770
+ self.lines = MemoryPlanner(self).plan(self.lines)
771
+
772
+ def memory_plan_reuse(self):
773
+ out_names = V.graph.get_output_names()
774
+
775
+ while (
776
+ self.lines
777
+ and isinstance(self.lines[-1], MemoryPlanningLine)
778
+ # TODO: this seems legit, NullLine has no node
779
+ and self.lines[-1].node.name not in out_names # type: ignore[attr-defined]
780
+ ):
781
+ # these lines will be pointless
782
+ self.lines.pop()
783
+
784
+ # codegen allocations in two passes
785
+ planning_states = [MemoryPlanningState()]
786
+ past_planning_states = []
787
+ for i in range(len(self.lines)):
788
+ line = self.lines[i]
789
+ if isinstance(line, MemoryPlanningLine):
790
+ self.lines[i] = line.plan(planning_states[-1])
791
+ elif isinstance(line, EnterSubgraphLine):
792
+ planning_states.append(MemoryPlanningState())
793
+ elif isinstance(line, ExitSubgraphLine):
794
+ past_planning_states.append(planning_states.pop())
795
+ past_planning_states.append(planning_states.pop())
796
+ assert len(planning_states) == 0
797
+
798
+ # conservatively use the sum of all allocated buffer sizes
799
+ # in potentially nested scopes as the total allocated size
800
+ total_allocated_buffer_size = sum(
801
+ s.total_allocated_buffer_size for s in past_planning_states
802
+ )
803
+
804
+ self.allow_stack_allocation = (
805
+ self.allow_stack_allocation is not False
806
+ and config.allow_stack_allocation
807
+ and total_allocated_buffer_size <= MAX_STACK_ALLOCATION_SIZE
808
+ )
809
+
810
+ def codegen_input_size_var_decl(self, code: IndentedBuffer, name):
811
+ code.writeline(f"{self.declare}{name}_size = {name}.{self.size}{self.ending}")
812
+
813
+ def codegen_input_stride_var_decl(self, code: IndentedBuffer, name):
814
+ code.writeline(
815
+ f"{self.declare}{name}_stride = {name}.{self.stride}{self.ending}"
816
+ )
817
+
818
+ def codegen_inputs(
819
+ self, code: IndentedBuffer, graph_inputs: Dict[str, ir.TensorBox]
820
+ ):
821
+ """Assign all symbolic shapes to locals"""
822
+
823
+ @functools.lru_cache(None)
824
+ def sizeof(name):
825
+ self.codegen_input_size_var_decl(code, name)
826
+ return f"{name}_size"
827
+
828
+ @functools.lru_cache(None)
829
+ def strideof(name):
830
+ self.codegen_input_stride_var_decl(code, name)
831
+ return f"{name}_stride"
832
+
833
+ # Assign all symbolic shapes needed to local variables
834
+ needed = V.graph.sizevars.free_symbols()
835
+
836
+ def is_expr(x):
837
+ return isinstance(x[1], sympy.Expr)
838
+
839
+ graph_inputs_expr = list(filter(is_expr, graph_inputs.items()))
840
+ graph_inputs_tensors = list(
841
+ filter(lambda x: not is_expr(x), graph_inputs.items())
842
+ )
843
+
844
+ for name, shape in graph_inputs_expr:
845
+ shape = V.graph.sizevars.simplify(shape) # type: ignore[arg-type]
846
+ if shape in needed:
847
+ needed.remove(shape) # type: ignore[arg-type]
848
+ code.writeline(f"{self.declare}{shape} = {name}{self.ending}")
849
+
850
+ for name, value in graph_inputs_tensors:
851
+ shapes = value.get_size()
852
+ for dim, shape in enumerate(shapes):
853
+ shape = V.graph.sizevars.simplify(shape) # type: ignore[arg-type]
854
+ if shape in needed:
855
+ needed.remove(shape) # type: ignore[arg-type]
856
+ code.writeline(
857
+ f"{self.declare}{shape} = {sizeof(name)}[{dim}]{self.ending}"
858
+ )
859
+
860
+ for name, value in graph_inputs_tensors:
861
+ shapes = value.get_stride()
862
+ for dim, shape in enumerate(shapes):
863
+ shape = V.graph.sizevars.simplify(shape) # type: ignore[arg-type]
864
+ if shape in needed:
865
+ needed.remove(shape) # type: ignore[arg-type]
866
+ code.writeline(
867
+ f"{self.declare}{shape} = {strideof(name)}[{dim}]{self.ending}"
868
+ )
869
+
870
+ def ensure_size_computed(self, sym: sympy.Symbol):
871
+ if isinstance(sym, sympy.Symbol) and sym.name.startswith("ps"):
872
+ if sym in self.computed_sizes:
873
+ return
874
+ self.computed_sizes.add(sym)
875
+ expr = V.graph.sizevars.inv_precomputed_replacements[sym]
876
+ self.writeline(
877
+ f"{self.declare}{sym} = {self.expr_printer(expr)}{self.ending}"
878
+ )
879
+
880
+ def finalize_prefix(self):
881
+ pass
882
+
883
+ def codegen_python_sizevar(self, x: Expr) -> str:
884
+ return pexpr(V.graph.sizevars.simplify(x))
885
+
886
+ def codegen_sizevar(self, x: Expr) -> str:
887
+ return self.codegen_python_sizevar(x)
888
+
889
+ def codegen_tuple_access(self, basename: str, name: str, index: str) -> str:
890
+ return f"{basename}[{index}]"
891
+
892
+ def codegen_python_shape_tuple(self, shape: Tuple[Expr, ...]) -> str:
893
+ parts = list(map(self.codegen_python_sizevar, shape))
894
+ if len(parts) == 0:
895
+ return "()"
896
+ if len(parts) == 1:
897
+ return f"({parts[0]}, )"
898
+ return f"({', '.join(parts)})"
899
+
900
+ def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str:
901
+ return self.codegen_python_shape_tuple(shape)
902
+
903
+ def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str:
904
+ return "alloc_from_pool({})".format(
905
+ ", ".join(
906
+ [
907
+ name,
908
+ pexpr(offset), # bytes not numel
909
+ str(dtype),
910
+ self.codegen_shape_tuple(shape),
911
+ self.codegen_shape_tuple(stride),
912
+ ]
913
+ )
914
+ )
915
+
916
+ def codegen_reinterpret_view(self, data, size, stride, offset, writer) -> str:
917
+ size = self.codegen_shape_tuple(size)
918
+ stride = self.codegen_shape_tuple(stride)
919
+ offset = self.codegen_sizevar(offset)
920
+ return f"reinterpret_tensor({data.get_name()}, {size}, {stride}, {offset})"
921
+
922
+ def codegen_device_copy(self, src, dst):
923
+ self.writeline(f"{dst}.copy_({src})")
924
+
925
+ def codegen_multi_output(self, name, value):
926
+ self.writeline(f"{self.declare}{name} = {value}{self.ending}")
927
+
928
+ def codegen_dynamic_scalar(self, node):
929
+ (data,) = (t.codegen_reference() for t in node.inputs)
930
+ if node.is_bool:
931
+ self.writeline(f"{node.sym} = 1 if {data}.item() else 0")
932
+ else:
933
+ self.writeline(f"{node.sym} = {data}.item()")
934
+ # No one should ever use this buffer, but for uniformity
935
+ # define the variable and assign it None
936
+ self.writeline(f"{node.get_name()} = None")
937
+
938
+ def benchmark_compiled_module(self, output):
939
+ def add_fake_input(name, shape, stride, device, dtype):
940
+ output.writeline(
941
+ f"{name} = rand_strided("
942
+ f"{self.codegen_python_shape_tuple(shape)}, "
943
+ f"{self.codegen_python_shape_tuple(stride)}, "
944
+ f"device='{device}', dtype={dtype})"
945
+ )
946
+
947
+ def add_expr_input(name, val):
948
+ output.writeline(f"{name} = {val}")
949
+
950
+ output.writelines(
951
+ ["", "", "def benchmark_compiled_module(times=10, repeat=10):"]
952
+ )
953
+ with output.indent():
954
+ output.splice(
955
+ """
956
+ from torch._dynamo.testing import rand_strided
957
+ from torch._inductor.utils import print_performance
958
+ """,
959
+ strip=True,
960
+ )
961
+
962
+ for name, value in V.graph.constants.items():
963
+ # all the constants are global variables, that's why we need
964
+ # these 'global var_name' lines
965
+ output.writeline(f"global {name}")
966
+ add_fake_input(
967
+ name, value.size(), value.stride(), value.device, value.dtype
968
+ )
969
+
970
+ for name, value in V.graph.graph_inputs.items():
971
+ if isinstance(value, sympy.Symbol) and isinstance(
972
+ V.graph.sizevars.var_to_val.get(value, None), SingletonInt
973
+ ):
974
+ # Inductor should only work with dense -> dense graph, and
975
+ # SingletonInts belong to metadata that should only live on
976
+ # the subclass.
977
+ continue
978
+ if isinstance(value, sympy.Expr): # Don't need to add symbolic
979
+ add_expr_input(name, V.graph.sizevars.size_hint(value))
980
+ else:
981
+ shape = [V.graph.sizevars.size_hint(x) for x in value.get_size()]
982
+ stride = [V.graph.sizevars.size_hint(x) for x in value.get_stride()]
983
+ add_fake_input(
984
+ name, shape, stride, value.get_device(), value.get_dtype()
985
+ )
986
+
987
+ call_str = f"call([{', '.join(V.graph.graph_inputs.keys())}])"
988
+ output.writeline(f"fn = lambda: {call_str}")
989
+ output.writeline("return print_performance(fn, times=times, repeat=repeat)")
990
+
991
+ def add_benchmark_harness(self, output):
992
+ """
993
+ Append a benchmark harness to generated code for debugging
994
+ """
995
+ if not config.benchmark_harness:
996
+ return
997
+
998
+ self.benchmark_compiled_module(output)
999
+
1000
+ output.writelines(["", "", 'if __name__ == "__main__":'])
1001
+ with output.indent():
1002
+ output.writelines(
1003
+ [
1004
+ "from torch._inductor.wrapper_benchmark import compiled_module_main",
1005
+ f"compiled_module_main('{get_benchmark_name()}', benchmark_compiled_module)",
1006
+ ]
1007
+ )
1008
+
1009
+ def define_kernel(
1010
+ self, name: str, kernel: str, metadata: Optional[str] = None, cuda=True
1011
+ ):
1012
+ metadata_comment = f"{metadata}\n" if metadata else ""
1013
+ self.header.splice(f"\n\n{metadata_comment}{name} = {kernel}")
1014
+
1015
+ def define_user_defined_triton_kernel(self, kernel, configs, kwargs):
1016
+ original_name = kernel.__name__
1017
+
1018
+ from .common import KernelArgType, SizeArg, TensorArg
1019
+
1020
+ signature: List[KernelArgType] = []
1021
+ constants: Dict[int, Any] = {}
1022
+ non_constant_indices = []
1023
+ equal_to_1_arg_idx: List[int] = []
1024
+ for idx, key in enumerate(kernel.arg_names):
1025
+ if key not in kwargs:
1026
+ continue
1027
+ arg = kwargs[key]
1028
+ if idx in kernel.constexprs:
1029
+ constants[idx] = arg
1030
+ else:
1031
+ non_constant_indices.append(idx)
1032
+ if isinstance(arg, ir.Buffer):
1033
+ signature.append(
1034
+ TensorArg(
1035
+ name=key,
1036
+ buffer=arg.get_name(),
1037
+ dtype=arg.get_dtype(),
1038
+ )
1039
+ )
1040
+ elif isinstance(arg, ir.ReinterpretView):
1041
+ # for ReinterpretView we use the underlying
1042
+ # buffer name and note the (possibly non-zero)
1043
+ # offset relative to the underlying buffer
1044
+ signature.append(
1045
+ TensorArg(
1046
+ name=key,
1047
+ buffer=arg.data.get_name(),
1048
+ dtype=arg.get_dtype(),
1049
+ offset=arg.layout.offset,
1050
+ )
1051
+ )
1052
+ else:
1053
+ signature.append(SizeArg(key, arg))
1054
+ if arg is not None and V.graph.sizevars.statically_known_equals(arg, 1): # type: ignore[arg-type]
1055
+ equal_to_1_arg_idx.append(idx)
1056
+ index_dtype = "tl.int32"
1057
+ triton_meta = {
1058
+ "signature": signature_to_meta(
1059
+ signature,
1060
+ size_dtype=index_dtype,
1061
+ indices=non_constant_indices,
1062
+ ),
1063
+ "device": V.graph.scheduler.current_device.index,
1064
+ "device_type": V.graph.scheduler.current_device.type,
1065
+ # Triton compiler includes equal_to_1 args into constants even
1066
+ # when they are not constexpr. otherwise there may be a segfault
1067
+ # during launching the Inductor-compiled Triton kernel.
1068
+ # TODO(aakhundov): add None args to constants, too. currently, this
1069
+ # causes CUDA errors in test_aot_inductor.test_triton_kernel_with_none_input.
1070
+ # https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307
1071
+ # https://github.com/openai/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384
1072
+ "constants": {
1073
+ **constants,
1074
+ **{idx: 1 for idx in equal_to_1_arg_idx},
1075
+ },
1076
+ "configs": [
1077
+ config_of(
1078
+ signature,
1079
+ indices=non_constant_indices,
1080
+ )
1081
+ ],
1082
+ }
1083
+
1084
+ # Distinguish between different functions using function id
1085
+ cache_key: List[Any] = [id(kernel.fn)]
1086
+ if len(configs) > 0:
1087
+ for arg in kwargs.values():
1088
+ # We need to key on non tensor arg only in autotune mode
1089
+ if not isinstance(arg, (ir.Buffer, ir.ReinterpretView)):
1090
+ cache_key.append(arg)
1091
+ cache_key.append(str(triton_meta))
1092
+ cache_key = tuple(cache_key)
1093
+
1094
+ if cache_key in self.user_defined_kernel_cache:
1095
+ return self.user_defined_kernel_cache[cache_key]
1096
+
1097
+ name = f"{original_name}_{len(self.user_defined_kernel_cache)}"
1098
+ # Add to the cache for the next use
1099
+ self.user_defined_kernel_cache[cache_key] = (name, triton_meta)
1100
+
1101
+ compile_wrapper = IndentedBuffer()
1102
+ compile_wrapper.writeline(f"async_compile.triton({original_name!r}, '''")
1103
+
1104
+ from .triton import gen_common_triton_imports
1105
+
1106
+ compile_wrapper.splice(gen_common_triton_imports())
1107
+
1108
+ inductor_meta = {
1109
+ "kernel_name": name,
1110
+ "backend_hash": torch.utils._triton.triton_hash_with_backend(),
1111
+ }
1112
+
1113
+ configs = [
1114
+ {
1115
+ "kwargs": config.kwargs,
1116
+ "num_warps": config.num_warps,
1117
+ "num_stages": config.num_stages,
1118
+ }
1119
+ for config in configs
1120
+ ]
1121
+
1122
+ compile_wrapper.splice(
1123
+ f"""
1124
+ @triton_heuristics.user_autotune(
1125
+ configs={configs!r},
1126
+ inductor_meta={inductor_meta!r},
1127
+ triton_meta={triton_meta!r},
1128
+ filename=__file__,
1129
+ custom_kernel=True,
1130
+ )
1131
+ @triton.jit
1132
+ """
1133
+ )
1134
+ compile_wrapper.splice(kernel.src, strip=True)
1135
+
1136
+ # Also include any possible kernel being called indirectly
1137
+ from triton import JITFunction
1138
+
1139
+ symbols_included = {original_name}
1140
+
1141
+ def traverse(cur_kernel):
1142
+ for symbol_name in cur_kernel.fn.__code__.co_names:
1143
+ if symbol_name in symbols_included:
1144
+ continue
1145
+ if symbol_name in cur_kernel.fn.__globals__:
1146
+ symbol = cur_kernel.fn.__globals__[symbol_name]
1147
+ if isinstance(symbol, JITFunction):
1148
+ compile_wrapper.newline()
1149
+ compile_wrapper.writeline("@triton.jit")
1150
+ compile_wrapper.splice(symbol.src, strip=True)
1151
+ symbols_included.add(symbol_name)
1152
+ traverse(symbol)
1153
+ elif isinstance(symbol, (int, str, bool)):
1154
+ compile_wrapper.newline()
1155
+ compile_wrapper.writeline(f"{symbol_name} = {symbol!r}")
1156
+ symbols_included.add(symbol_name)
1157
+
1158
+ traverse(kernel)
1159
+
1160
+ compile_wrapper.writeline(
1161
+ f"''', device_str='{V.graph.scheduler.current_device.type}')"
1162
+ )
1163
+ _, lineno = inspect.getsourcelines(kernel.fn)
1164
+ srcfile = inspect.getsourcefile(kernel.fn)
1165
+ metadata = f"# Original path: {srcfile}:{lineno}"
1166
+ self.define_kernel(
1167
+ name,
1168
+ compile_wrapper.getvalue(),
1169
+ metadata,
1170
+ )
1171
+ return name, triton_meta
1172
+
1173
+ def generate_numel_expr(self, kernel_name: str, tree):
1174
+ expr = f"{kernel_name}_{tree.prefix}numel"
1175
+ if (expr, V.graph) not in self.kernel_numel_expr:
1176
+ # declare expr once in each graph (scope)
1177
+ self.kernel_numel_expr.add((expr, V.graph))
1178
+ self.writeline(
1179
+ f"{self.declare}{expr} = {self.expr_printer(tree.numel)}{self.ending}"
1180
+ )
1181
+ else:
1182
+ self.writeline(f"{expr} = {self.expr_printer(tree.numel)}{self.ending}")
1183
+ # We can get symbolic expressions here, like s0*64
1184
+ # It is fine to have them here, but we need to handle them correctly as their own type
1185
+ # This is tricky to do, so we wrap in a custom type, distinct from scalars, but also from sympy*
1186
+ # scalars as well.
1187
+ # This is handled in `generate_args_decl` which has a correct comment of: TODO: only works for
1188
+ # constant now, need type info. I agree, this needs type info, and while this is not true type info
1189
+ # it suffices as a type hint for the purposes of producing the correct code for this type.
1190
+ return SymbolicCallArg(expr, tree.numel)
1191
+
1192
+ def generate_workspace_allocation(self, nbytes, device, zero_fill):
1193
+ line = self.make_allocation(
1194
+ "workspace", device, torch.uint8, shape=(nbytes,), stride=(1,)
1195
+ )
1196
+ self.writeline(line)
1197
+ if zero_fill:
1198
+ self.writeline(f"workspace.zero_(){self.ending}")
1199
+
1200
+ def wrap_kernel_call(self, name, call_args):
1201
+ return f"{name}({', '.join(call_args)}){self.ending}"
1202
+
1203
+ def generate_profiler_mark_wrapper_call(self, stack):
1204
+ self.wrapper_call.writeline("from torch.profiler import record_function")
1205
+ self.wrapper_call.writeline(
1206
+ f"with record_function('graph_{V.graph.graph_id}_inductor_wrapper_call'):"
1207
+ )
1208
+ stack.enter_context(self.wrapper_call.indent())
1209
+
1210
+ def generate_start_graph(self):
1211
+ self.wrapper_call.writeline("start_graph()")
1212
+
1213
+ def generate_end_graph(self):
1214
+ self.wrapper_call.writeline("end_graph()")
1215
+
1216
+ def generate_reset_kernel_saved_flags(self):
1217
+ self.wrapper_call.splice(
1218
+ """
1219
+ for kernel in globals().values():
1220
+ if isinstance(kernel, torch._inductor.triton_heuristics.CachingAutotuner):
1221
+ kernel.cuda_kernel_saved = False
1222
+ """
1223
+ )
1224
+
1225
+ def generate_save_uncompiled_kernels(self):
1226
+ """
1227
+ Precompile and save the CUBINs of the Triton kernels that haven't
1228
+ been precompiled and saved as a side effect of running the generated
1229
+ JIT model (Python wrapper). This can happen when the model contains
1230
+ control flow: only one pass through the control flow operators covers
1231
+ the kernels that are saved, the remaining kernels are not launched,
1232
+ hence not saved. The main purpose of this codegen is to compile and
1233
+ save the Triton kernels outside the active control flow path for
1234
+ subsequent AOTInductor code generation and compilation.
1235
+ """
1236
+ self.wrapper_call.splice(
1237
+ """
1238
+ for kernel in globals().values():
1239
+ if isinstance(kernel, torch._inductor.triton_heuristics.CachingAutotuner):
1240
+ if not kernel.cuda_kernel_saved:
1241
+ if len(kernel.launchers) == 0:
1242
+ kernel.precompile()
1243
+ kernel.save_cuda_kernel(
1244
+ grid=(0, 0, 0), # use dummy grid
1245
+ stream="stream", # use dummy stream
1246
+ launcher=kernel.launchers[0],
1247
+ )
1248
+ """
1249
+ )
1250
+
1251
+ def generate_default_grid(self, name: str, grid_args: List[Any]):
1252
+ return grid_args
1253
+
1254
+ def generate_kernel_call(
1255
+ self,
1256
+ name,
1257
+ call_args,
1258
+ grid=None,
1259
+ device_index=None,
1260
+ cuda=True,
1261
+ triton=True,
1262
+ arg_types=None,
1263
+ grid_fn: str = "grid",
1264
+ triton_meta=None,
1265
+ ):
1266
+ """
1267
+ Generates kernel call code.
1268
+
1269
+ cuda: Defines whether the backend is GPU. Otherwise the backend is CPU.
1270
+
1271
+ triton: Defines whether the GPU backend uses Triton for codegen.
1272
+ Otherwise it uses the CUDA language for codegen.
1273
+ Only valid when cuda == True.
1274
+ """
1275
+ if cuda:
1276
+ call_args_str = ", ".join(pexpr(item) for item in call_args)
1277
+ stream_name = self.write_get_raw_stream(
1278
+ V.graph.scheduler.current_device.index, V.graph
1279
+ )
1280
+ if triton:
1281
+ grid_str = ", ".join(pexpr(item) for item in grid)
1282
+ grid_str = f"{grid_fn}({grid_str})"
1283
+ self.writeline(
1284
+ f"{name}.run({call_args_str}, grid={grid_str}, stream={stream_name})"
1285
+ )
1286
+ else:
1287
+ stream_ptr = f"c_void_p({stream_name})"
1288
+ self.writeline(f"{name}.{name}({call_args_str}, {stream_ptr})")
1289
+ else:
1290
+ self.writeline(self.wrap_kernel_call(name, call_args))
1291
+
1292
+ def writeline(self, line):
1293
+ self.lines.append(line)
1294
+
1295
+ def enter_context(self, ctx):
1296
+ self.lines.append(LineContext(ctx))
1297
+
1298
+ def val_to_cpp_arg_str(self, type_, val, is_legacy_abi) -> str:
1299
+ raise NotImplementedError()
1300
+
1301
+ def val_to_arg_str(self, s):
1302
+ if isinstance(s, SymTypes):
1303
+ return pexpr(sympy.expand(repr(s)))
1304
+ elif isinstance(s, sympy.Expr):
1305
+ return pexpr(s)
1306
+ elif isinstance(s, (tuple, list)):
1307
+
1308
+ @dataclasses.dataclass
1309
+ class Shim:
1310
+ ref: Any
1311
+
1312
+ def __repr__(self):
1313
+ return self.ref
1314
+
1315
+ return repr(type(s)(Shim(self.val_to_arg_str(a)) for a in s))
1316
+ elif isinstance(s, torch._ops.OpOverload):
1317
+ return _get_qualified_name(s)
1318
+ elif isinstance(s, (ir.Buffer, ReinterpretView)):
1319
+ return s.codegen_reference()
1320
+ else:
1321
+ return repr(s)
1322
+
1323
+ # The following methods are for memory management
1324
+ def make_buffer_allocation(self, buffer):
1325
+ device = buffer.get_device()
1326
+ dtype = buffer.get_dtype()
1327
+ shape = tuple(buffer.get_size())
1328
+ stride = tuple(buffer.get_stride())
1329
+ return self.make_allocation(buffer.get_name(), device, dtype, shape, stride)
1330
+
1331
+ def make_allocation(self, name, device, dtype, shape, stride):
1332
+ if device.type in ("cpu", "cuda"):
1333
+ # optimized path for faster allocations, saving ~2us versus the stuff below
1334
+ return (
1335
+ f"{name} = empty_strided_{device.type}("
1336
+ f"{self.codegen_shape_tuple(shape)}, "
1337
+ f"{self.codegen_shape_tuple(stride)}, "
1338
+ f"{dtype})"
1339
+ )
1340
+ # all other devices:
1341
+ return (
1342
+ f"{name} = empty_strided("
1343
+ f"{self.codegen_shape_tuple(shape)}, "
1344
+ f"{self.codegen_shape_tuple(stride)}, "
1345
+ f"device='{device.type}', dtype={dtype})"
1346
+ )
1347
+
1348
+ def make_tensor_alias(self, new_name, old_name, comment=""):
1349
+ return f"{self.declare}{new_name} = {old_name}{self.ending} {self.comment} {comment}"
1350
+
1351
+ def make_buffer_free(self, buffer):
1352
+ return f"del {buffer.get_name()}"
1353
+
1354
+ def make_free_by_names(self, names_to_del: List[str]):
1355
+ return f"del {', '.join(name for name in names_to_del)}"
1356
+
1357
+ def codegen_exact_buffer_reuse(self, old_name: str, new_name: str, del_line: str):
1358
+ return f"{self.declare_maybe_reference}{new_name} = {old_name}{del_line}{self.ending} {self.comment} reuse"
1359
+
1360
+ def make_buffer_reuse(self, old, new, delete_old: bool):
1361
+ assert old.get_dtype() == new.get_dtype()
1362
+ old_name = old.get_name()
1363
+ new_name = new.get_name()
1364
+ del_line = ";"
1365
+ if old_name not in V.graph.get_output_names() and delete_old:
1366
+ del_line = f"; {self.make_buffer_free(old)}"
1367
+
1368
+ if old.get_size() == new.get_size() and old.get_stride() == new.get_stride():
1369
+ if old_name in self.stack_allocated_buffers:
1370
+ self.stack_allocated_buffers[new_name] = new
1371
+ return self.codegen_exact_buffer_reuse(old_name, new_name, del_line)
1372
+
1373
+ reinterpret_view = self.codegen_reinterpret_view(
1374
+ old, new.get_size(), new.get_stride(), 0, self.wrapper_call
1375
+ )
1376
+ if reinterpret_view in self.stack_allocated_buffers:
1377
+ self.stack_allocated_buffers[new_name] = new
1378
+ return f"{self.declare_maybe_reference}{new_name} = {reinterpret_view}{del_line} {self.comment} reuse"
1379
+
1380
+ def codegen_deferred_allocation(self, name, layout):
1381
+ self.writeline(
1382
+ DeferredLine(
1383
+ name,
1384
+ f"{self.declare_maybe_reference}{name} = {layout.view.codegen_reference()}{self.ending} "
1385
+ f"{self.comment} alias",
1386
+ )
1387
+ )
1388
+
1389
+ def codegen_allocation(self, buffer):
1390
+ assert (
1391
+ buffer.get_workspace_size() == 0
1392
+ ), "Only support zero workspace size for now!"
1393
+
1394
+ name = buffer.get_name()
1395
+
1396
+ if name in V.graph.removed_buffers or name in self.allocated:
1397
+ return
1398
+ self.allocated.add(name)
1399
+ if isinstance(
1400
+ buffer,
1401
+ (ir.ExternKernelAlloc, ir.MultiOutput),
1402
+ ):
1403
+ return
1404
+
1405
+ layout = buffer.get_layout()
1406
+ if isinstance(layout, ir.MutationLayout):
1407
+ return
1408
+ if isinstance(layout, ir.AliasedLayout):
1409
+ assert isinstance(
1410
+ layout.view, ir.ReinterpretView
1411
+ ), f"unexpected {type(layout.view)}: {layout.view}"
1412
+ self.codegen_allocation(layout.view.data)
1413
+ self.codegen_deferred_allocation(name, layout)
1414
+ return
1415
+
1416
+ self.writeline(AllocateLine(self, buffer))
1417
+
1418
+ def codegen_free(self, buffer):
1419
+ assert (
1420
+ buffer.get_workspace_size() == 0
1421
+ ), "Only support zero workspace size for now!"
1422
+
1423
+ name = buffer.get_name()
1424
+
1425
+ # can be freed but not reused
1426
+ if isinstance(buffer, ir.InputBuffer):
1427
+ self.writeline(self.make_buffer_free(buffer))
1428
+ return
1429
+
1430
+ if not self.can_reuse(buffer):
1431
+ return
1432
+ self.freed.add(name)
1433
+
1434
+ self.writeline(FreeIfNotReusedLine(self, buffer))
1435
+
1436
+ def can_reuse(self, input_buffer, output_buffer=None):
1437
+ name = input_buffer.get_name()
1438
+ if (
1439
+ name in V.graph.removed_buffers
1440
+ or name in V.graph.graph_inputs
1441
+ or name in V.graph.constants
1442
+ or name in V.graph.never_reuse_buffers
1443
+ or name in self.freed
1444
+ ):
1445
+ return False
1446
+
1447
+ return True
1448
+
1449
+ def did_reuse(self, buffer, reused_buffer):
1450
+ # Check whether a given buffer was reused by a possible reuser in the wrapper codegen
1451
+ # Can be consulted from inside ir codegen, e.g. to determine whether a copy is needed
1452
+ return (
1453
+ buffer.get_name() in self.reuses
1454
+ and self.reuses[buffer.get_name()] == reused_buffer.get_name()
1455
+ )
1456
+
1457
+ def codegen_inplace_reuse(self, input_buffer, output_buffer):
1458
+ assert buffer_reuse_key(input_buffer) == buffer_reuse_key(output_buffer)
1459
+ self.codegen_allocation(input_buffer)
1460
+ self.freed.add(input_buffer.get_name())
1461
+ self.allocated.add(output_buffer.get_name())
1462
+ self.reuses[output_buffer.get_name()] = input_buffer.get_name()
1463
+ self.writeline(ReuseLine(self, input_buffer, output_buffer))
1464
+
1465
+ def codegen_unbacked_symbol_decl(self, symbol):
1466
+ name = str(symbol)
1467
+ if name in self.unbacked_symbol_decls:
1468
+ return name
1469
+ else:
1470
+ # When in CppWrapperCpu, we should only generate the declaration once
1471
+ self.unbacked_symbol_decls.add(name)
1472
+ return self.declare + name
1473
+
1474
+ def codegen_subgraph_prefix(self, subgraph, outer_inputs, outer_outputs):
1475
+ for inner_input, outer_input in zip(subgraph.graph.graph_inputs, outer_inputs):
1476
+ self.writeline(f"{self.declare}{inner_input} = {outer_input}{self.ending}")
1477
+
1478
+ def codegen_subgraph_suffix(self, subgraph, outer_inputs, outer_outputs):
1479
+ for inner_output, outer_output in zip(
1480
+ subgraph.graph.graph_outputs, outer_outputs
1481
+ ):
1482
+ self.writeline(
1483
+ f"{outer_output} = {inner_output.codegen_reference()}{self.ending}"
1484
+ )
1485
+
1486
+ def codegen_subgraph(self, subgraph, outer_inputs, outer_outputs):
1487
+ try:
1488
+ self.push_codegened_graph(subgraph.graph)
1489
+ self.writeline(f"{self.comment} subgraph: {subgraph.name}")
1490
+ self.codegen_subgraph_prefix(subgraph, outer_inputs, outer_outputs)
1491
+ parent_graph = V.graph
1492
+ with V.set_graph_handler(subgraph.graph):
1493
+ subgraph.graph.codegen_subgraph(
1494
+ parent_graph=parent_graph,
1495
+ )
1496
+ self.codegen_subgraph_suffix(subgraph, outer_inputs, outer_outputs)
1497
+ finally:
1498
+ self.pop_codegened_graph()
1499
+
1500
+ def codegen_conditional(self, conditional):
1501
+ name = conditional.get_name()
1502
+ outer_inputs = [buf.codegen_reference() for buf in conditional.operands]
1503
+ outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))]
1504
+
1505
+ self.writeline(f"{name} = [None] * {len(conditional.outputs)}")
1506
+ self.writeline(f"if {conditional.predicate.codegen_reference()}.item():")
1507
+ self.writeline(EnterSubgraphLine(self, conditional.true_subgraph.graph))
1508
+ self.codegen_subgraph(conditional.true_subgraph, outer_inputs, outer_outputs)
1509
+ self.writeline(ExitSubgraphLine(self))
1510
+ self.writeline("else:")
1511
+ self.writeline(EnterSubgraphLine(self, conditional.false_subgraph.graph))
1512
+ self.codegen_subgraph(conditional.false_subgraph, outer_inputs, outer_outputs)
1513
+ self.writeline(ExitSubgraphLine(self))
1514
+
1515
+ @staticmethod
1516
+ def statically_known_int_or_none(x):
1517
+ try:
1518
+ val = V.graph._shape_env._maybe_evaluate_static(x)
1519
+ return int(x)
1520
+ except Exception:
1521
+ return None
1522
+
1523
+ @staticmethod
1524
+ def statically_known_list_of_ints_or_none(lst):
1525
+ result = []
1526
+ for x in lst:
1527
+ num = WrapperCodeGen.statically_known_int_or_none(x)
1528
+ if num is None:
1529
+ return None
1530
+ result.append(num)
1531
+ return result
1532
+
1533
+ @staticmethod
1534
+ def is_statically_known_list_of_ints(lst):
1535
+ return WrapperCodeGen.statically_known_list_of_ints_or_none(lst) is not None
1536
+
1537
+ @staticmethod
1538
+ def static_shape_for_buffer_or_none(buffer):
1539
+ return WrapperCodeGen.statically_known_list_of_ints_or_none(buffer.get_size())
1540
+
1541
+ @staticmethod
1542
+ def can_prove_buffer_has_static_shape(buffer):
1543
+ return WrapperCodeGen.static_shape_for_buffer_or_none(buffer) is not None
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/comm_analysis.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from enum import IntEnum
3
+
4
+ import sympy
5
+
6
+ import torch
7
+ from . import ir
8
+
9
+ from .utils import get_dtype_size, sympy_product
10
+ from .virtualized import V
11
+
12
+
13
+ class NCCL_COLL(IntEnum):
14
+ ALL_REDUCE = 0
15
+ ALL_GATHER = 1
16
+ REDUCE_SCATTER = 2
17
+
18
+
19
+ class NVIDIA_GPU_TYPE(IntEnum):
20
+ VOLTA = 0
21
+ AMPERE = 1
22
+ HOPPER = 2
23
+
24
+
25
+ def get_gpu_type() -> NVIDIA_GPU_TYPE:
26
+ gpu_info = torch.utils.collect_env.get_gpu_info(torch.utils.collect_env.run) or ""
27
+ if "V100" in gpu_info:
28
+ return NVIDIA_GPU_TYPE.VOLTA
29
+ elif "A100" in gpu_info:
30
+ return NVIDIA_GPU_TYPE.AMPERE
31
+ elif "H100" in gpu_info:
32
+ return NVIDIA_GPU_TYPE.HOPPER
33
+ else:
34
+ # for other gpu types, assume Ampere
35
+ return NVIDIA_GPU_TYPE.AMPERE
36
+
37
+
38
+ def get_collective_type(node: ir.IRNode) -> NCCL_COLL:
39
+ if isinstance(node, ir._CollectiveKernel):
40
+ kernel_name = node.python_kernel_name
41
+ assert kernel_name is not None
42
+ if "all_reduce" in kernel_name:
43
+ return NCCL_COLL.ALL_REDUCE
44
+ elif "all_gather" in kernel_name:
45
+ return NCCL_COLL.ALL_GATHER
46
+ elif "reduce_scatter" in kernel_name:
47
+ return NCCL_COLL.REDUCE_SCATTER
48
+ else:
49
+ raise Exception(f"Unsupported collective kernel: {kernel_name}")
50
+
51
+ if isinstance(node, (ir.AllReduce, ir.AllReduceCoalesced)):
52
+ return NCCL_COLL.ALL_REDUCE
53
+ elif isinstance(node, (ir.AllGatherIntoTensor, ir.AllGatherIntoTensorCoalesced)):
54
+ return NCCL_COLL.ALL_GATHER
55
+ elif isinstance(node, (ir.ReduceScatterTensor, ir.ReduceScatterTensorCoalesced)):
56
+ return NCCL_COLL.REDUCE_SCATTER
57
+ else:
58
+ raise Exception(f"Unsupported collective type: {node}")
59
+
60
+
61
+ def get_collective_input_size_bytes(node: ir.IRNode) -> int:
62
+ sz_bytes = 0
63
+ for inp in node.inputs: # type: ignore[attr-defined]
64
+ shape = inp.layout.size
65
+ numel = sympy_product(inp.layout.size)
66
+ if isinstance(numel, sympy.Integer):
67
+ # For ease of testing
68
+ numel = int(numel)
69
+ else:
70
+ numel = V.graph.sizevars.size_hint(numel)
71
+ sz_bytes += numel * get_dtype_size(inp.layout.dtype)
72
+ return sz_bytes
73
+
74
+
75
+ def get_collective_group_size(node: ir.IRNode) -> int:
76
+ if type(node) == ir._CollectiveKernel:
77
+ from torch.distributed.distributed_c10d import _get_group_size_by_name
78
+
79
+ return _get_group_size_by_name(node.constant_args[-1])
80
+ elif isinstance(node, ir.CollectiveKernel):
81
+ return node.constant_args[2] # type: ignore[attr-defined]
82
+ else:
83
+ raise TypeError(f"Unsupported collective type: {node}")
84
+
85
+
86
+ ####################################################################################################################
87
+ # The following code and constants are adapted from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc #
88
+ ####################################################################################################################
89
+
90
+
91
+ class NCCL_HW(IntEnum):
92
+ NVLINK = 0
93
+ PCI = 1
94
+ NET = 2
95
+
96
+
97
+ class NCCL_ALGO(IntEnum):
98
+ TREE = 0
99
+ RING = 1
100
+
101
+
102
+ class NCCL_PROTO(IntEnum):
103
+ # The ordering and enum values here matches original in
104
+ # https://github.com/NVIDIA/nccl/blob/0b083e52096c387bad7a5c5c65b26a9dca54de8c/src/include/devcomm.h#L28
105
+ # For difference between these protocols, see https://github.com/NVIDIA/nccl/issues/281#issuecomment-571816990
106
+ LL = 0 # Low-latency
107
+ # LL128 = 1 # Low-latency 128-byte
108
+ # SIMPLE = 2
109
+
110
+
111
+ # Latencies in us
112
+ # len(NCCL_ALGO) x len(NCCL_PROTO)
113
+ # NOTE: use array instead of tensor to prevent incompatibility with fake mode
114
+ baseLat = [
115
+ # Tree
116
+ [
117
+ 6.8, # LL
118
+ ],
119
+ # Ring
120
+ [
121
+ 6.6, # LL
122
+ ],
123
+ ]
124
+
125
+ # Latencies in us
126
+ # len(NCCL_HW) x len(NCCL_ALGO) x len(NCCL_PROTO)
127
+ hwLat = [
128
+ # NVLINK
129
+ [
130
+ [0.6], # Tree (LL)
131
+ [0.6], # Ring (LL)
132
+ ],
133
+ # PCI
134
+ [
135
+ [1.0], # Tree (LL)
136
+ [1.0], # Ring (LL)
137
+ ],
138
+ # NET
139
+ [
140
+ [5.0], # Tree (LL)
141
+ [2.7], # Ring (LL)
142
+ ],
143
+ ]
144
+
145
+
146
+ # LL128 max BW per channel
147
+ llMaxBws = [
148
+ # Volta-N1/Intel-N2/Intel-N4
149
+ [
150
+ 39.0,
151
+ 39.0,
152
+ 20.4,
153
+ ],
154
+ # Ampere-N1/AMD-N2/AMD-N4
155
+ [
156
+ 87.7,
157
+ 22.5, # avg of ring & tree
158
+ 19.0,
159
+ ],
160
+ # Hopper-N1/AMD-N2/AMD-N4
161
+ [
162
+ 87.7,
163
+ 22.5, # avg of ring & tree
164
+ 19.0,
165
+ ],
166
+ ]
167
+
168
+
169
+ def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
170
+ """
171
+ Returns estimated NCCL collective runtime in nanoseconds (ns).
172
+
173
+ The following heuristics are copied from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc.
174
+ We aim to estimate the runtime as accurately as possible.
175
+
176
+ Assumptions:
177
+ - only ring algorithm (NCCL_ALGO_RING) is used
178
+ - only Low-Latency protocol (NCCL_PROTO_LL) is used, i.e. Simple or LL128 is not used
179
+ - 8 gpus per node # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
180
+ - collective is one of: allreduce, reducescatter, allgather
181
+ """
182
+ tensor_storage_size_bytes = get_collective_input_size_bytes(node)
183
+ # Convert bytes to GB
184
+ tensor_storage_size_GB = tensor_storage_size_bytes / 1024 / 1024 / 1024
185
+
186
+ # Currently assumes each node has 8 gpus. And when >1 node is used, assumes each node uses all 8 gpus.
187
+ # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
188
+ num_gpus_per_node = 8
189
+ group_size = get_collective_group_size(node)
190
+ nNodes = math.ceil(group_size / num_gpus_per_node)
191
+ nRanks = group_size # this is total # of gpus globally that participate in this collective op
192
+
193
+ if nRanks <= 1:
194
+ return 0
195
+
196
+ # Assumes ring algorithm
197
+ nccl_algo = NCCL_ALGO.RING
198
+ nccl_proto = NCCL_PROTO.LL
199
+ coll = get_collective_type(node)
200
+
201
+ # =============== bandwidth computation ===============
202
+ # First compute bandwidth in GB/s; then at the end, convert it to GB/ns
203
+
204
+ bwIntra = torch._inductor.config.intra_node_bw
205
+ bwInter = torch._inductor.config.inter_node_bw
206
+
207
+ compCapIndex = get_gpu_type()
208
+ index2 = nNodes - 1 if nNodes <= 2 else 2
209
+ # LL: for single node, we look at GPU type; for multi-node, we look at CPU type
210
+ index1 = compCapIndex if nNodes == 1 else 0
211
+ llMaxBw = llMaxBws[index1][index2]
212
+
213
+ # NOTE: each step of ring algorithm is synchronized,
214
+ # and is bottlenecked by the slowest link which is the inter-node interconnect.
215
+ # hence when nNodes >= 2, bw is inter-node bandwidth.
216
+ # NOTE: the original code in https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc
217
+ # have this as `if nNodes <= 2` which seems wrong. Corrected it here.
218
+ bw = bwIntra if nNodes == 1 else bwInter
219
+ nChannels = 2 # Assume # channels is 2
220
+ busBw = nChannels * bw
221
+
222
+ # Various model refinements
223
+ busBw = min(
224
+ llMaxBw,
225
+ busBw
226
+ * (1.0 / 4.0 if (nNodes > 1 or coll == NCCL_COLL.ALL_REDUCE) else 1.0 / 3.0),
227
+ )
228
+
229
+ if coll == NCCL_COLL.ALL_REDUCE:
230
+ nsteps = 2 * (nRanks - 1)
231
+ elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER):
232
+ nsteps = nRanks - 1
233
+
234
+ # Convert bus BW to algorithm BW (tensor bytes / algoBW = actual execution time)
235
+ ratio = (1.0 * nRanks) / nsteps # type: ignore[possibly-undefined]
236
+ bandwidth = busBw * ratio
237
+ # Convert GB/s to GB/ns
238
+ bandwidth_GB_per_ns = bandwidth / 1e9
239
+
240
+ # =============== latency computation ===============
241
+ intraHw = NCCL_HW.NVLINK
242
+ hw = intraHw if nNodes == 1 else NCCL_HW.NET
243
+
244
+ if coll == NCCL_COLL.ALL_REDUCE:
245
+ if nNodes > 1:
246
+ nInterSteps = 2 * nNodes
247
+ else:
248
+ nInterSteps = 0
249
+ elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER):
250
+ nInterSteps = nNodes - 1
251
+
252
+ # First compute latency in us; then at the end, convert it to ns
253
+ latency = baseLat[nccl_algo][nccl_proto]
254
+ intraLat = hwLat[intraHw][nccl_algo][nccl_proto]
255
+ interLat = hwLat[NCCL_HW.NET][nccl_algo][nccl_proto]
256
+
257
+ # Inter-node rings still have to launch nsteps * net overhead.
258
+ netOverhead = 0.0
259
+ if nNodes > 1:
260
+ netOverhead = 1.0 # getNetOverhead(comm);
261
+ intraLat = max(intraLat, netOverhead)
262
+ latency += (nsteps - nInterSteps) * intraLat + nInterSteps * interLat # type: ignore[possibly-undefined]
263
+ # Convert us to ns
264
+ latency_ns = latency * 1e3
265
+
266
+ # =============== final result ===============
267
+ transport_ns = tensor_storage_size_GB / bandwidth_GB_per_ns
268
+ return transport_ns + latency_ns
269
+
270
+
271
+ ################################################################################################################
272
+ # The above code and constants are adapted from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc #
273
+ ################################################################################################################
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py ADDED
@@ -0,0 +1,2159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CUDA graph trees are a safety abstraction over CUDAGraphs, similar to make_graph_callables,
3
+ which share the same memory pool. Sharing a memory pool is an extremely
4
+ important optimization when chaining multiple CUDA graphs together, as it
5
+ prevents you from needing to copy intermediate tensors from one graph to the
6
+ next, and reduces overall memory usage by allowing dead memory from the first
7
+ pool to be reused in the second.
8
+
9
+ The standard graph/make_graph_callables support sharing memory pool, but
10
+ with a lot of caveats. CUDA graph trees remove these restrictions:
11
+
12
+ * Previously, if you recorded graphs A, B, you had to replay A, B in that
13
+ order. With CUDA graph trees, after replaying A, you can change your
14
+ mind and record/replay a different graph B'; we will support efficient
15
+ execution of both A, B and A, B', using only max(mem(A, B), mem(A, B')). In
16
+ other words: we support arbitrary trees of CUDA graph operations, not just
17
+ sequences (this is why this feature is called CUDA graph trees.)
18
+
19
+ * Previously, if you executed graph A, some non-CUDA graph code, and then
20
+ graph B, after executing graph B, it was not safe to retain any references
21
+ to intermediates produced by A. With CUDA graph trees, we track if any
22
+ outputs of graph A are still live by the time graph B is run, and make
23
+ sure graph B doesn't clobber there memory when reusing the CUDA graphs
24
+ pool. You'll get a separate recording of B depending on what tensors
25
+ stay live or dead.
26
+
27
+ CUDA graph trees are flexible enough to be used in Dynamo across graph breaks,
28
+ which is their primary use case.
29
+
30
+ The ability to switch from replay to record is fairly nontrivial: remember that
31
+ when you replay a CUDA graph, you only replay CUDA operations; no CPU side state
32
+ is updated. In particular, the CPU-side book-keeping for the allocator is not
33
+ reconstructed. However, to record a new child CUDA graph, we must restore this
34
+ book-keeping. This is what checkpoint pool state is used for.
35
+ """
36
+
37
+ from __future__ import annotations
38
+
39
+ import contextlib
40
+ import dataclasses
41
+ import functools
42
+ import gc
43
+ import itertools
44
+ import operator
45
+ import sys
46
+ import threading
47
+ import traceback
48
+ import warnings
49
+ import weakref
50
+ from collections import defaultdict
51
+
52
+ from enum import auto, Enum
53
+ from typing import (
54
+ Any,
55
+ Callable,
56
+ cast,
57
+ Dict,
58
+ Iterator,
59
+ List,
60
+ Optional,
61
+ Sequence,
62
+ Set,
63
+ Tuple,
64
+ Union,
65
+ )
66
+
67
+ import torch.fx
68
+ from torch import Tensor
69
+ from torch._dynamo.mutation_guard import GenerationTracker
70
+ from torch._dynamo.utils import preserve_rng_state
71
+ from torch._inductor.compile_fx import (
72
+ align_inputs_from_check_idxs,
73
+ copy_misaligned_inputs,
74
+ get_expanded_dims,
75
+ get_input_idxs_to_check,
76
+ index_expanded_dims,
77
+ remove_unaligned_input_idxs,
78
+ static_input,
79
+ )
80
+ from torch.multiprocessing.reductions import StorageWeakRef
81
+ from torch.storage import UntypedStorage
82
+ from torch.types import _bool
83
+ from torch.utils import _pytree as pytree
84
+ from torch.utils.weak import TensorWeakRef
85
+
86
+ StorageWeakRefPointer = int
87
+ StorageDataPtr = int
88
+ NBytes = int
89
+
90
+ if torch.backends.cuda.is_built():
91
+ from torch._C import (
92
+ _cuda_CUDAAllocator_AllocatorState as AllocatorState,
93
+ _set_cached_tensors_enabled as _set_cached_tensors_enabled,
94
+ )
95
+ else:
96
+
97
+ class AllocatorState: # type: ignore[no-redef]
98
+ pass
99
+
100
+ def _set_cached_tensors_enabled(enabled: _bool) -> None:
101
+ pass
102
+
103
+
104
+ log = torch._logging.getArtifactLogger(__name__, "cudagraphs")
105
+
106
+
107
+ from . import config
108
+
109
+
110
+ @dataclasses.dataclass(frozen=True)
111
+ class GraphID:
112
+ "Unique counter of a cuda graph recording"
113
+ id: int
114
+
115
+
116
+ @dataclasses.dataclass(frozen=True)
117
+ class FunctionID:
118
+ "Unique counter of a function wrapped in cudagraphify_impl"
119
+ id: int
120
+
121
+
122
+ @dataclasses.dataclass(frozen=True)
123
+ class WrappedFunction:
124
+ """
125
+ Represents a function that you want to record for CUDA graph replay,
126
+ with a little more metadata so we can identify if we have an applicable
127
+ CUDA graph in our CUDA graph tree for it.
128
+ """
129
+
130
+ model: Callable[..., Any]
131
+ static_input_idxs: Sequence[int]
132
+ id: FunctionID
133
+ constants: Tuple[torch.Tensor, ...]
134
+
135
+
136
+ def clear_cublass_cache():
137
+ """
138
+ Cublas keeps a persistent workspace allocation for running matmuls. This poses a problem for
139
+ doing warmup within a CUDAGraph private pool because we do not want persistent allocations from
140
+ one one run to the next. When we begin a new run of a cudagraphs path (generation), all tensors
141
+ from the previous generation are freed. This frees them the memory pool, but not elsewhere.
142
+ A tensor in the cublas workspace would continue to be in use the workspace but would also get allocated
143
+ in the next run. The memory would be in use in two places.
144
+
145
+ To solve this, we clear cublas caches before and after warming up or recording. If a workspace is required
146
+ it will be allocated to the cudagraph private pool and accounted for in the allocator for the duration of the
147
+ program. There is no overhead to this on replay since cudagraphs removes allocation overhead.
148
+ """
149
+ torch._C._cuda_clearCublasWorkspaces()
150
+
151
+
152
+ @contextlib.contextmanager
153
+ def clear_cublas_manager():
154
+ "Context manager around clearing cublas caches that will clear on enter and exit"
155
+ clear_cublass_cache()
156
+ try:
157
+ yield
158
+ finally:
159
+ clear_cublass_cache()
160
+
161
+
162
+ @contextlib.contextmanager
163
+ def disable_conv_cache_emptying():
164
+ prev = torch._C._cuda_get_conv_benchmark_empty_cache()
165
+ torch._C._cudnn_set_conv_benchmark_empty_cache(False)
166
+ try:
167
+ yield
168
+ finally:
169
+ torch._C._cudnn_set_conv_benchmark_empty_cache(prev)
170
+
171
+
172
+ @contextlib.contextmanager
173
+ def enable_history_recording():
174
+ "Turns on history recording in the CUDA Caching Allocator"
175
+ enabled = torch._C._cuda_isHistoryEnabled()
176
+ try:
177
+ if not enabled:
178
+ torch.cuda.memory._record_memory_history()
179
+ yield
180
+ finally:
181
+ if not enabled:
182
+ torch.cuda.memory._record_memory_history(None)
183
+
184
+
185
+ def get_history_recording():
186
+ # TODO - remove, prevents cleanup
187
+ if not config.triton.cudagraph_trees_history_recording:
188
+ return contextlib.nullcontext()
189
+ return enable_history_recording()
190
+
191
+
192
+ class TreeManagerContainer:
193
+ """
194
+ Manages the lifetime of the tree manager. Like `PrivatePool` in cuda caching allocator,
195
+ the tree and its corresponding memory pool should be kept alive as long as any outstanding
196
+ graph or tensor which is an output of a graph remains alive.
197
+
198
+ There is a single tree manager container per device.
199
+
200
+ The lifecycle of a tree_manager is:
201
+ - Is constructed, no graph, no fns, no tensors
202
+ - Tree manager is fetched, resulting in tree manager being allocated
203
+ - We generate a bunch of functions, calling add_strong_reference
204
+ - These functions die, calling finalize_reference
205
+ - When all the functions die, we finalize_tree_manager.
206
+
207
+ TODO: in the future, we would like to do the following once storage weak refs land
208
+ - We look for all the live storages and add references to THOSE
209
+ - We count as storages die
210
+ - All the storages are dead, we deallocate the tree manager
211
+ """
212
+
213
+ def __init__(self, device_index):
214
+ # This class keeps a strong reference to tree_manager,
215
+ # but upon all other strong references to the tree_manager will reset it to None.
216
+ # We need a strong reference so that we can still access its attributes upon cleanup.
217
+ self.tree_manager: Optional[CUDAGraphTreeManager] = None
218
+
219
+ # Number of outstanding references to the current tree manager
220
+ self.live_cudagraphify_fns = 0
221
+
222
+ self.device_index = device_index
223
+
224
+ # Following two objects are only set in the case that Tensor outputs outlive
225
+ # the cudagraphify_fns. Reference to the Graph is needed to keep the private pool from
226
+ # deallocation.
227
+ self.live_storages_count = 0
228
+ self.graph: Optional[torch.cuda.CUDAGraph] = None
229
+
230
+ self.lock = threading.Lock()
231
+
232
+ def _finalize_tensor(self):
233
+ with self.lock:
234
+ self.live_storages_count -= 1
235
+ if self.live_storages_count == 0:
236
+ self.graph = None
237
+
238
+ # manager was used again after existing cleanup,
239
+ # we shouldnt set it to None
240
+ if self.live_cudagraphify_fns == 0:
241
+ self.tree_manager = None
242
+
243
+ def finalize_cudagraphify_fn(self):
244
+ with self.lock:
245
+ self.live_cudagraphify_fns -= 1
246
+ if self.live_cudagraphify_fns == 0:
247
+ self._finalize_tree_manager()
248
+
249
+ def _finalize_tree_manager(self):
250
+ assert self.lock.locked()
251
+ self.tree_manager = None
252
+
253
+ # TODO - when issue #91395 is landed, we can set a weakref on
254
+ # storages and trigger a deallocation when all outputs of the
255
+ # cudagraph are dead.
256
+
257
+ # live_storages = list(
258
+ # tree_manager.live_cudagraph_pool_storages_in_curr_execution()
259
+ # )
260
+
261
+ # # Maintain reference to graph to keep tensors alive
262
+ # assert len(tree_manager.roots) > 0, "expected at least one use"
263
+ # root = next(tree_manager.get_roots())
264
+ # self.graph = root.graph
265
+ # seen_storages = set()
266
+ # for stor in live_storages:
267
+ # if stor in seen_storages:
268
+ # continue
269
+ # seen_storages.add(stor)
270
+ # self.live_storages_count += 1
271
+ # . weakref.finalize(stor, self._finalize_tensor)
272
+
273
+ def add_strong_reference(self, fn: Callable[..., Any]):
274
+ with self.lock:
275
+ self.live_cudagraphify_fns += 1
276
+
277
+ weakref.finalize(fn, self.finalize_cudagraphify_fn)
278
+
279
+ def get_tree_manager(self) -> CUDAGraphTreeManager:
280
+ with self.lock:
281
+ if self.tree_manager is None:
282
+ self.tree_manager = CUDAGraphTreeManager(self.device_index)
283
+ return self.tree_manager
284
+
285
+
286
+ local = threading.local()
287
+
288
+ # one tree manager per device
289
+ local.tree_manager_containers = {}
290
+ local.tree_manager_locks = defaultdict(threading.Lock)
291
+
292
+
293
+ # only incremented by user call of mark_step_begin
294
+ class MarkStepBox:
295
+ mark_step_counter = 0
296
+
297
+
298
+ # We need to register this as an object that will be copied over as TLS when new
299
+ # threads are created in autograd
300
+ torch._C._stash_obj_in_tls("tree_manager_containers", local.tree_manager_containers)
301
+ torch._C._stash_obj_in_tls("tree_manager_locks", local.tree_manager_locks)
302
+
303
+
304
+ def mark_step_begin():
305
+ "Indicates that a new iteration of inference or training is about to begin."
306
+
307
+ # iterate down to distinguish from GenerationTracking counter
308
+ MarkStepBox.mark_step_counter -= 1
309
+
310
+
311
+ def reset_cudagraph_trees():
312
+ "Clear all cudagraph trees"
313
+ # see shutdown below for why this is necessary
314
+ container_dict = get_obj(local, "tree_manager_containers")
315
+ locks_dict = get_obj(local, "tree_manager_locks")
316
+ for device, lock in locks_dict.items():
317
+ with lock:
318
+ container = container_dict.get(device)
319
+ if not container or not container.tree_manager:
320
+ continue
321
+
322
+ container.tree_manager.shutdown()
323
+
324
+ _set_cached_tensors_enabled(False)
325
+ container_dict.clear()
326
+
327
+ MarkStepBox.mark_step_counter = 0
328
+
329
+
330
+ def get_obj(local, attr_name):
331
+ if hasattr(local, attr_name):
332
+ return getattr(local, attr_name)
333
+ else:
334
+ assert torch._C._is_key_in_tls(attr_name)
335
+ return torch._C._get_obj_in_tls(attr_name)
336
+
337
+
338
+ def get_container(device_index: int):
339
+ container_dict = get_obj(local, "tree_manager_containers")
340
+ lock = get_obj(local, "tree_manager_locks")[device_index]
341
+
342
+ with lock:
343
+ if device_index not in container_dict:
344
+ container_dict[device_index] = TreeManagerContainer(device_index)
345
+
346
+ return container_dict[device_index]
347
+
348
+
349
+ def get_manager(
350
+ device_index: int, create_if_none_exists=True
351
+ ) -> Optional[CUDAGraphTreeManager]:
352
+ if create_if_none_exists:
353
+ return get_container(device_index).get_tree_manager()
354
+ return get_container(device_index).tree_manager
355
+
356
+
357
+ def cudagraphify_impl(model, inputs, static_input_idxs, *args, **kwargs):
358
+ fn_cache: Dict[Tuple[int, ...], Callable[..., Any]] = {}
359
+
360
+ # Detect int inputs: we need to index on these
361
+ int_key = [i for i, v in enumerate(inputs) if isinstance(v, int)]
362
+ get_ints: Any = operator.itemgetter(*int_key) if int_key else lambda _: None
363
+
364
+ del inputs
365
+
366
+ def deferred_cudagraphify(inputs):
367
+ int_key = get_ints(inputs)
368
+ fn = fn_cache.get(int_key)
369
+ if fn is not None:
370
+ return fn(inputs)
371
+
372
+ if int_key is None:
373
+ log.info("recording cudagraph tree for graph without symints")
374
+ else:
375
+ log.info("recording cudagraph tree for symint key %s", int_key)
376
+
377
+ # first get indices we need to check to align, then update our static inputs,
378
+ # and finally copy
379
+ check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs)
380
+ new_static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs)
381
+ copy_misaligned_inputs(inputs, check_input_idxs)
382
+
383
+ fn, out = cudagraphify(model, inputs, new_static_input_idxs, *args, **kwargs)
384
+ fn = align_inputs_from_check_idxs(fn, inputs_to_check=check_input_idxs)
385
+ fn_cache[int_key] = fn
386
+
387
+ return out
388
+
389
+ return deferred_cudagraphify
390
+
391
+
392
+ def cudagraphify(
393
+ model,
394
+ inputs,
395
+ static_input_idxs=(),
396
+ *,
397
+ device_index: int,
398
+ is_backward: bool,
399
+ is_inference: bool,
400
+ stack_traces: Optional[StackTraces] = None,
401
+ constants: Tuple[torch.Tensor, ...] = (),
402
+ ):
403
+ manager = get_container(device_index).get_tree_manager()
404
+ assert not (is_backward and is_inference)
405
+ mode = (
406
+ CompilationMode.BACKWARD
407
+ if is_backward
408
+ else (CompilationMode.INFERENCE if is_inference else CompilationMode.FORWARD)
409
+ )
410
+
411
+ return manager.add_function(
412
+ model,
413
+ inputs,
414
+ static_input_idxs,
415
+ stack_traces,
416
+ mode,
417
+ constants,
418
+ )
419
+
420
+
421
+ class StorageWeakRefWrapper:
422
+ """
423
+ Wrapper around a storage weak ref. Will deallocate it upon expiration if invoked.
424
+ """
425
+
426
+ __slots__ = ["ref", "_data_ptr", "extra_ref_check"]
427
+
428
+ storage_ref: Optional[StorageWeakRef]
429
+
430
+ def __init__(
431
+ self,
432
+ inp: Union[Tensor, UntypedStorage],
433
+ extra_ref_check: Optional[Callable[[], None]] = None,
434
+ ):
435
+ """
436
+ extra_ref_check is an additional check we need to run to check if the
437
+ weak ref has expired. in checking storage use count we assume extra_ref_check
438
+ will hold an additional reference to the storage.
439
+ """
440
+ if isinstance(inp, Tensor):
441
+ stor = inp.untyped_storage()
442
+ else:
443
+ assert isinstance(inp, UntypedStorage)
444
+ stor = inp
445
+ self.ref = StorageWeakRef(stor)
446
+ self._data_ptr = stor.data_ptr()
447
+ self.extra_ref_check = extra_ref_check
448
+
449
+ @classmethod
450
+ def from_weakref_and_data_ptr(cls, cdata, data_ptr, extra_ref_check=None):
451
+ instance = cls.__new__(cls)
452
+ instance._data_ptr = data_ptr
453
+ instance.ref = StorageWeakRef.from_weakref(cdata)
454
+ instance.extra_ref_check = extra_ref_check
455
+ return instance
456
+
457
+ def __call__(self) -> Optional[StorageWeakRefPointer]:
458
+ if self.expired():
459
+ return None
460
+
461
+ return self.ref.cdata
462
+
463
+ def swap_weakref(self, cdata):
464
+ self.ref.__del__()
465
+ self.ref.cdata = cdata
466
+
467
+ def data_ptr(self) -> int:
468
+ "NB: returns the data ptr even if the storage has expired"
469
+ return self._data_ptr
470
+
471
+ def remove_extra_reference(self):
472
+ self.extra_ref_check = None
473
+
474
+ def expired(self):
475
+ if self.extra_ref_check is not None and not self.extra_ref_check():
476
+ return False
477
+
478
+ # if extra_ref_check is not None we expect an additional reference
479
+ stor_count = torch._C._storage_Use_Count(self.ref.cdata)
480
+ return (stor_count - (self.extra_ref_check is not None)) == 0
481
+
482
+ def __repr__(self):
483
+ if self.ref is None or self.ref.expired():
484
+ return f"StorageWeakRefWrapper to {self.data_ptr()}; dead"
485
+ else:
486
+ return f"StorageWeakRefWrapper to {self.data_ptr()}; alive"
487
+
488
+
489
+ def is_live(weak_ref: Optional[StorageWeakRefWrapper]) -> bool:
490
+ return maybe_deref(weak_ref) is not None
491
+
492
+
493
+ def maybe_deref(
494
+ weak_ref: Optional[StorageWeakRefWrapper],
495
+ ) -> Optional[Tuple[StorageWeakRefPointer, int]]:
496
+ if weak_ref is None:
497
+ return None
498
+ r = weak_ref()
499
+ if r is None:
500
+ return None
501
+ # NB: r.data_ptr() does not necessarily equal weak_ref.data_ptr()
502
+ return r, weak_ref.data_ptr()
503
+
504
+
505
+ @contextlib.contextmanager
506
+ def _use_cuda_memory_pool_manager(device, mem_pool, stream):
507
+ """
508
+ Context manager to use cuda graph pool for new allocations. If you use this manager
509
+ all cudagraph tensors in use should be reflected in the allocator or they will be overwritten.
510
+ existing_graph should already have been used in a capture, and the mem_pool must already exist,
511
+ because this manager will not preserve a reference to the pool which keeps it alive.
512
+ """
513
+ torch.cuda.synchronize()
514
+ stream.wait_stream(torch.cuda.current_stream())
515
+
516
+ with torch.cuda.stream(stream), torch.device(device):
517
+ torch._C._cuda_beginAllocateCurrentStreamToPool(device, mem_pool)
518
+ try:
519
+ yield
520
+ finally:
521
+ torch._C._cuda_endAllocateCurrentStreamToPool(device, mem_pool)
522
+ torch._C._cuda_releasePool(device, mem_pool)
523
+
524
+ torch.cuda.current_stream().wait_stream(stream)
525
+
526
+
527
+ def map_to_ref(t: Optional[Tensor]) -> Optional[StorageWeakRefWrapper]:
528
+ if not isinstance(t, torch.Tensor):
529
+ assert t is None
530
+ return None
531
+ return StorageWeakRefWrapper(t)
532
+
533
+
534
+ # A path index of (depth, offset) indices into a graph that is `depth`` number of nodes from the root
535
+ # at graph output offset
536
+ PathOutputIndex = Tuple[int, int]
537
+
538
+ # For each node in the path, for each output, is the output alive
539
+ PathLiveness = List[List[bool]]
540
+
541
+ StackTraces = List[Optional[str]]
542
+
543
+
544
+ class CUDAWarmupNode:
545
+ """
546
+ Simplified Wrapper around A CUDA Model that wraps outputs in storage refs and exposes
547
+ apis to get the live storages in the current chain of warmup.
548
+
549
+ A CUDAWarmupNode may have either CUDAGraphNode or CUDAWarmupNode as a parent, but may only have
550
+ CUDAWarmupNode as children, because we cannot record or execute with tensors which do not have stable
551
+ memory addresses.
552
+
553
+ CUDAWarmupNode and CUDAGraphNode have a number of differences that make it easier to use separate classes.
554
+ - Much of the CUDAGraphNode logic & initialization is based on the tensor properties of first recording. In the
555
+ first instance of warmup, these are not finalized yet.
556
+ - All Inputs to the RecordedFunction must be copied over to the cuda graph memory pool, this is unnecessary in warmup.
557
+ - CUDAWarmup is only used once and so does not need to optimize as much bookkeeping. It is much simpler.
558
+
559
+ NB: this class and CUDAGraphNode need to expose `path_live_weakrefs`, `all_outputs_are_dead`, and
560
+ `self.outputs_weakrefs`, `stack_traces`, and `tensor_weakrefs` for compatibility.
561
+ """
562
+
563
+ def __init__(
564
+ self,
565
+ wrapped_function: WrappedFunction,
566
+ parent,
567
+ cuda_graphs_pool: Tuple[int, int],
568
+ existing_cuda_graph: Optional[torch.cuda.CUDAGraph],
569
+ device_index: int,
570
+ stack_traces: Optional[StackTraces],
571
+ stream: torch.cuda.Stream,
572
+ already_warm: bool,
573
+ ):
574
+ self.wrapped_function = wrapped_function
575
+ self.parent = parent
576
+ self.cuda_graphs_pool = cuda_graphs_pool
577
+ self.outputs_weakrefs: List[Optional[StorageWeakRefWrapper]] = []
578
+ self.tensor_weakrefs: List[Optional[TensorWeakRef]] = []
579
+ self.existing_cuda_graph = existing_cuda_graph
580
+ self.has_run = False
581
+ self.device_index = device_index
582
+ self.stack_traces = stack_traces
583
+ self.stream = stream
584
+ self.already_warm = already_warm
585
+
586
+ def run(self, new_inputs):
587
+ assert not self.has_run, "Wrapped function should never be run twice"
588
+
589
+ # See: output_is_alias_of_persistent_static_inputs below. We should only be returning freshly created
590
+ # storages in path_live_weakrefs.
591
+ existing_path_data_ptrs = {
592
+ t.data_ptr() for t in self.path_live_weakrefs() if t()
593
+ }
594
+
595
+ def get_non_cudagraph_inps():
596
+ non_cudagraph_inps = set()
597
+ for t in itertools.chain(new_inputs, self.wrapped_function.constants):
598
+ if (
599
+ isinstance(t, torch.Tensor)
600
+ and t.untyped_storage().data_ptr() not in existing_path_data_ptrs
601
+ ):
602
+ non_cudagraph_inps.add(t.untyped_storage().data_ptr())
603
+ return non_cudagraph_inps
604
+
605
+ non_cudagraph_inps = get_non_cudagraph_inps()
606
+
607
+ if config.triton.slow_path_cudagraph_asserts and not self.already_warm:
608
+ refs = list(self.path_live_weakrefs())
609
+ check_memory_pool(self.device_index, self.cuda_graphs_pool, refs)
610
+
611
+ with torch.cuda.device(
612
+ self.device_index
613
+ ), disable_conv_cache_emptying(), clear_cublas_manager(), _use_cuda_memory_pool_manager(
614
+ self.device_index, self.cuda_graphs_pool, self.stream
615
+ ), get_history_recording():
616
+ out = self.wrapped_function.model(new_inputs)
617
+
618
+ assert len(new_inputs) == 0
619
+
620
+ # sdpa returns cpu tensors when not recording cuda graph
621
+ def add_ref(o):
622
+ return (
623
+ o is not None
624
+ and isinstance(o, torch.Tensor)
625
+ and o.is_cuda
626
+ and o.untyped_storage().data_ptr() not in non_cudagraph_inps
627
+ and o.untyped_storage().data_ptr() != 0
628
+ )
629
+
630
+ self.outputs_weakrefs.extend(
631
+ [map_to_ref(o) if add_ref(o) else None for o in out]
632
+ )
633
+ self.tensor_weakrefs.extend(
634
+ [TensorWeakRef(o) if add_ref(o) else None for o in out]
635
+ )
636
+
637
+ if config.triton.slow_path_cudagraph_asserts and not self.already_warm:
638
+ out_refs = self.path_live_weakrefs()
639
+ new_storages = [
640
+ t for t in out_refs if t.data_ptr() not in non_cudagraph_inps
641
+ ]
642
+ check_memory_pool(self.device_index, self.cuda_graphs_pool, new_storages)
643
+
644
+ return out
645
+
646
+ @property
647
+ def _path_from_root(self):
648
+ nodes = []
649
+ node = self
650
+ while node:
651
+ nodes.append(node)
652
+ node = node.parent
653
+
654
+ yield from reversed(nodes)
655
+
656
+ def path_live_weakrefs(self) -> Iterator[StorageWeakRefWrapper]:
657
+ "Returns all live storages weakrefs that created by nodes in this path"
658
+ for node in self._path_from_root:
659
+ for output in node.outputs_weakrefs:
660
+ if is_live(output):
661
+ yield output
662
+
663
+ def all_outputs_are_dead(self):
664
+ return not list(self.path_live_weakrefs())
665
+
666
+
667
+ # Aliases for List that say what the indices denote
668
+ InputList = List # input indexes
669
+ OutputList = List # output indexes
670
+ LevelList = List # levels (distance from root of tree)
671
+
672
+
673
+ class OutputAliasInfo:
674
+ pass
675
+
676
+
677
+ class _UnaliasedStorage(OutputAliasInfo):
678
+ "Singleton to mark that the graph output constructs a new alias or is None"
679
+ pass
680
+
681
+
682
+ UnaliasedStorage = _UnaliasedStorage()
683
+
684
+
685
+ class AliasesPriorGraphOutput(OutputAliasInfo):
686
+ "Marks that the graph output aliases an output of a prior graph"
687
+ __slots__ = ["index"]
688
+
689
+ index: PathOutputIndex
690
+
691
+ def __init__(self, index: PathOutputIndex):
692
+ assert isinstance(index, tuple)
693
+ self.index = index
694
+
695
+
696
+ class AliasesNewOutput(OutputAliasInfo):
697
+ "Marks that the graph output aliases an index in the new, returned outputs"
698
+
699
+ __slots__ = ["index"]
700
+
701
+ index: int
702
+
703
+ def __init__(self, index):
704
+ assert isinstance(index, int)
705
+ self.index = index
706
+
707
+
708
+ class CUDAGraphNode:
709
+ """
710
+ A single recording of a function into a CUDA Graph. Recordings of CUDA Graphs share a single memory pool
711
+ and are structured into a tree, where there is a single recording that can precede it (parent) and multiple
712
+ subsequent recordings that may follow (children). A node will have no parent if it is the first recording
713
+ in a tree; i.e., when it is first recorded, there are no live tensors from a previous recording which
714
+ would force a dependency.
715
+
716
+ On first recording, all of the live tensors in the current CUDA Graph Node path will be
717
+ reflected in the corresponding private pool. On subsequent executions, the caching allocator
718
+ is unaffected when the graph is replayed.
719
+
720
+ In order to support recording a subsequent cuda graph recording after execution of this graph,
721
+ we checkpoint the state of the memory pool so that it may later be resumed.
722
+
723
+ WrappedFunction should have already been warmed up prior to invocation.
724
+
725
+ See [setCheckpointPoolState] for further explanation, as well as
726
+ https://user-images.githubusercontent.com/13564/222815509-374f3400-f83d-4f7d-8fa6-4a092b3250bb.png
727
+ """
728
+
729
+ def __init__(
730
+ self,
731
+ wrapped_function: WrappedFunction,
732
+ id: GraphID,
733
+ parent: Optional[CUDAGraphNode],
734
+ inputs: List[Tensor],
735
+ cuda_graphs_pool: Tuple[int, int],
736
+ device_index: int,
737
+ stack_traces: Optional[StackTraces],
738
+ stream: torch.cuda.Stream,
739
+ ):
740
+ assert isinstance(inputs, (list, tuple))
741
+
742
+ self.wrapped_function = wrapped_function
743
+ self.id = id
744
+ self.device = device_index
745
+ self.stack_traces = stack_traces
746
+ self.stream = stream
747
+
748
+ # if this is a root parent will be None. use weakref to prevent reference cycle
749
+ self._parent = weakref.ref(parent) if parent is not None else None
750
+ # reference to the shared memory pool for the entire cuda graphs tree
751
+ self.cuda_graphs_pool = cuda_graphs_pool
752
+
753
+ # A single wrapped function may be recorded multiple times if memory patterns or
754
+ # invariants change from one execution to the next
755
+ self.children: Dict[FunctionID, List[CUDAGraphNode]] = defaultdict(list)
756
+
757
+ # StorageWeakRef maintains whether the Storage C++ object remains allocated,
758
+ # not whether the corresponding memory has been deallocated. In order
759
+ # to use them to track memory deallocations we must maintain a single StorageWeakRef
760
+ # for all Storages that reference that memory (even if we are constructing Storages
761
+ # that do not have a deallocator function). We maintain one single storage_cache
762
+ # as we execute any tree path. When we retrieve a storage from the cache we
763
+ # check that it is still alive, and we hash based on observed recording data ptr
764
+ # and storage cdata.
765
+
766
+ # we preserve a single reference to executed outputs that is then referenced
767
+ # in children to avoid children having to chase parent pointers in the hot path
768
+ # DO NOT reassign output_weakrefs, only call `clear()`
769
+ # Path is a series of nodes from root to the current node
770
+ self.outputs_weakrefs: OutputList[Optional[StorageWeakRefWrapper]] = []
771
+ self.path_weakrefs: LevelList[OutputList[Optional[StorageWeakRefWrapper]]] = [
772
+ node.outputs_weakrefs for node in self._path_from_root
773
+ ]
774
+ self.path_stacktraces: LevelList[StackTraces] = [
775
+ node.stack_traces for node in self._path_from_root
776
+ ]
777
+ self.tensor_weakrefs: OutputList[Optional[TensorWeakRef]] = []
778
+
779
+ # tensors which are outputs of previous graphs in the tree
780
+ self.cudagraph_managed_idxs: List[int] = [
781
+ idx
782
+ for idx, t in enumerate(inputs)
783
+ if isinstance(t, torch.Tensor) and self._is_cuda_graph_recorded_tensor(t)
784
+ ]
785
+
786
+ self.static_input_idxs: List[int] = list(
787
+ set(wrapped_function.static_input_idxs) | set(self.cudagraph_managed_idxs)
788
+ )
789
+
790
+ self.static_input_data_ptrs: InputList[Optional[int]] = [
791
+ (
792
+ inputs[i].data_ptr()
793
+ if isinstance(inputs[i], torch.Tensor) and i in self.static_input_idxs
794
+ else None
795
+ )
796
+ for i in range(len(inputs))
797
+ ]
798
+
799
+ # When we checkpoint, and free generations, we will be manually freeing the outputs
800
+ # of CUDAGraphNodes. We should not be freeing parameters, not do we need to account for
801
+ # their liveness (they are static), so we need to compute which outputs are aliases of
802
+ # parameters. Some static inputs are saved tensors from the forward that die in the backward.
803
+ # Their locations are static but lifetimes are not. We only include the persistent static
804
+ # data ptrs below because the non persistent data ptrs may be outputs of this record and
805
+ # fresh allocations.
806
+
807
+ # precompute expanded dims to avoid computing in the hot path
808
+ self.expanded_dims: List[List[int]] = [
809
+ get_expanded_dims(x)
810
+ if isinstance(x, torch.Tensor) and idx not in self.static_input_idxs
811
+ else []
812
+ for idx, x in enumerate(inputs)
813
+ ]
814
+
815
+ # For each node in path, which outputs were observed to be live
816
+ # before invoking graph recording, and after graph recording
817
+ self.recorded_liveness_before_graph: LevelList[OutputList[bool]] = []
818
+ self.recorded_liveness_after_graph: LevelList[OutputList[bool]] = []
819
+
820
+ # List of Tuples of (depth, output_index) that index into node at depth
821
+ # number of nodes from root and output_index of outputs. Will index into
822
+ # path_weakrefs.
823
+ self.expected_dead_indices_before_graph: List[PathOutputIndex] = []
824
+ self.expected_dead_indices_after_graph: List[PathOutputIndex] = []
825
+
826
+ # all live indices after graph recording
827
+ self.live_indices_after_graph: List[PathOutputIndex] = []
828
+
829
+ if self.parent is not None:
830
+ previous_liveness = self.parent.recorded_liveness_after_graph
831
+ curr_liveness = self._get_liveness(self.path_weakrefs)
832
+
833
+ different_indices = self._get_different_indices(
834
+ previous_liveness, curr_liveness
835
+ )
836
+
837
+ self.recorded_liveness_before_graph = curr_liveness
838
+ self.expected_dead_indices_before_graph = different_indices
839
+
840
+ recording_inputs = self._allocate_and_copy_recording_inputs(inputs)
841
+ # recording inputs will copy over memory, so we can free non recording inputs
842
+ inputs.clear()
843
+ del inputs
844
+
845
+ # graph used for recording model invocation
846
+ self.graph: Optional[torch.cuda.CUDAGraph] = torch.cuda.CUDAGraph()
847
+
848
+ # we allocate non-static inputs within the same memory pool as the CUDAGraph
849
+ # which we will record the model with. For memory efficiency, it is important
850
+ # to reclaim the input memory when the inputs are no longer live. To accomplish this,
851
+ # we reconstruct tensors at the correct data pointers of our inputs which are
852
+ # non owning and do not prevent deallocation. On subsequent executions, input values
853
+ # will be copied over to these tensors.
854
+ self.reconstructed_inputs: InputList[Union[Tensor, int]] = [
855
+ self._reconstruct_from_tensor_metadata(self._tensor_metadata(x))
856
+ if isinstance(x, torch.Tensor)
857
+ else x
858
+ for x in recording_inputs
859
+ ]
860
+
861
+ # DO THE RECORDING!!!
862
+ # We record the CUDA graph in the constructor of CUDAGraphNode, which
863
+ # gives you what the CPU side compute of the function would do. We
864
+ # don't throw the recording outputs away: their memory is
865
+ # correctly accounted for in the CUDAGraphs caching allocator. This
866
+ # means on the very FIRST run of the CUDA graph node, we can directly
867
+ # do more recording, because we have a valid caching allocator state.
868
+ # NB: This relies on run() being called immediately after the
869
+ # constructor, otherwise this optimization would not be valid.
870
+
871
+ # initialized below in _record
872
+
873
+ self.checkpointed_caching_state: Optional[AllocatorState] = None
874
+
875
+ # Output Storage Alias information, can be:
876
+ # - A new, unaliased storage, or the output is None
877
+ # - An alias of an output of a prior graph
878
+ # - An alias of an output already created in the reconstructed outputs
879
+ # This is None if the output in question is an int
880
+ self.output_storage_alias: OutputList[Optional[OutputAliasInfo]] = []
881
+
882
+ # is the output Storage unaliased in subsequent outputs, of all subsequent paths
883
+ # if it is, we cached the output tensor and adjust storage liveness tracking to also
884
+ # check if the output tensor does not have an additional python reference.
885
+ # If a descendent node discovers it has an alias of a prior output, then the output
886
+ # will no longer be cached in the ancestor.
887
+ # The large majority of tensors are unaliased, and preserving aliased output tensors would add
888
+ # significant additional complexity with marginal gains
889
+ # The cached tensor outputs are added on the first execution, and cleared whenever we need
890
+ # to do subsequent recording
891
+ self.unaliased_in_all_paths: OutputList[bool] = []
892
+ self.cached_tensor_outputs: OutputList[Optional[Tensor]] = []
893
+
894
+ # if an output aliases a static, persistent input then the corresponding Tensor will
895
+ # be set here. These are different than cached tensors, because they are tensors that
896
+ # are aliases of parameters that are always live.
897
+ self.static_output_tensors: OutputList[Optional[Tensor]] = []
898
+
899
+ # Cleared after recording
900
+ self.recording_outputs: Optional[
901
+ OutputList[Union[torch.Tensor, int]]
902
+ ] = self._record(wrapped_function.model, recording_inputs)
903
+ self.outputs_metadata: OutputList[Union[Dict[str, Any], int, None]] = []
904
+
905
+ # As with inputs, we do not want to keep the outputs permanently alive because that would prevent
906
+ # their memory being reclaimed in subsequent cuda graph recordings. We record the tensor metadata
907
+ # needed to reconstruct instead.
908
+ assert self.recording_outputs is not None
909
+ for out in self.recording_outputs:
910
+ if isinstance(out, torch.Tensor):
911
+ self.outputs_metadata.append(
912
+ self._tensor_metadata(out, ignore_storage_offset=False)
913
+ )
914
+ else:
915
+ assert isinstance(out, (int, type(None))), type(out)
916
+ self.outputs_metadata.append(out)
917
+
918
+ self.graph.replay()
919
+
920
+ def _copy_input(self, idx, dst, src):
921
+ expanded_dims = self.expanded_dims[idx]
922
+ dst = index_expanded_dims(dst, expanded_dims)
923
+ src = index_expanded_dims(src, expanded_dims)
924
+ # TODO - one jit kernel across multiple inputs
925
+ dst.copy_(src)
926
+
927
+ def run_first_inputs(self, new_inputs):
928
+ if config.triton.fast_path_cudagraph_asserts:
929
+ self.debug_check_invariants_before_invocation()
930
+
931
+ # graph is already invoked in the __init__
932
+ # inputs are copied over in _allocate_recording_inputs and subsequently cleared
933
+ assert len(new_inputs) == 0
934
+ outputs = self.recording_outputs
935
+ self.recording_outputs = None
936
+ return outputs
937
+
938
+ def run(self, new_inputs):
939
+ if config.triton.fast_path_cudagraph_asserts:
940
+ self.debug_check_invariants_before_invocation()
941
+
942
+ assert len(self.static_input_data_ptrs) == len(new_inputs)
943
+ # NB: this ranges over non-static inputs too
944
+ for idx, data_ptr in enumerate(self.static_input_data_ptrs):
945
+ if idx in self.cudagraph_managed_idxs:
946
+ continue
947
+ if not isinstance(new_inputs[idx], torch.Tensor):
948
+ pass
949
+ elif data_ptr is not None:
950
+ # static input, e.g., parameter
951
+ assert data_ptr == new_inputs[idx].data_ptr()
952
+ else:
953
+ # non-static input, need to copy it into CUDA graph
954
+ dst = self.reconstructed_inputs[idx]
955
+ src = new_inputs[idx]
956
+ self._copy_input(idx, dst, src)
957
+
958
+ new_inputs.clear()
959
+ self.run_graph()
960
+
961
+ outputs = self.reconstruct_outputs()
962
+ self.debug_check_invariants_after_invocation()
963
+
964
+ return outputs
965
+
966
+ def reconstruct_outputs(self):
967
+ "Reconstruct output tensors according to their saved metadata and alias information"
968
+
969
+ # Cached tensors will not yet be set on the first execution
970
+ # They are also cleared in checkpointing, so if we checkpoint this node
971
+ # and then execute it again we will need to repopulate cached tensors
972
+ if not self.cached_tensor_outputs:
973
+ self._initialize_cached_tensors()
974
+
975
+ outputs: List[Optional[Union[int, torch.Tensor]]] = []
976
+
977
+ for i, (storage_info, metadata) in enumerate(
978
+ zip(self.output_storage_alias, self.outputs_metadata)
979
+ ):
980
+ if not isinstance(metadata, dict): # tensor metadata
981
+ assert isinstance(metadata, (int, type(None)))
982
+ outputs.append(metadata)
983
+ continue
984
+
985
+ cached_t = self.cached_tensor_outputs[i]
986
+ if cached_t is not None:
987
+ # No need to update weakrefs, already correctly initialized
988
+ outputs.append(cached_t)
989
+ continue
990
+
991
+ static_t = self.static_output_tensors[i]
992
+ if static_t is not None:
993
+ assert self.outputs_weakrefs[i] is None
994
+ outputs.append(static_t)
995
+ continue
996
+
997
+ storage = self.prepare_alias_info_for_tensor_construction(
998
+ storage_info, metadata
999
+ )
1000
+
1001
+ if isinstance(storage, UntypedStorage) or storage is None:
1002
+ out = self._reconstruct_from_tensor_metadata(metadata, storage)
1003
+ else:
1004
+ assert isinstance(storage, int)
1005
+ out = self._reconstruct_from_tensor_metadata(
1006
+ metadata, cast(torch.Tensor, outputs[storage]).untyped_storage()
1007
+ )
1008
+
1009
+ outputs.append(out)
1010
+ w = self.outputs_weakrefs[i]
1011
+ assert w is not None
1012
+ w.swap_weakref(out.untyped_storage()._weak_ref())
1013
+
1014
+ return outputs
1015
+
1016
+ def prepare_alias_info_for_tensor_construction(
1017
+ self,
1018
+ out_alias_info: Optional[OutputAliasInfo],
1019
+ metadata: Union[Dict[str, Any], int, None],
1020
+ ) -> Union[UntypedStorage, None, int]:
1021
+ if (
1022
+ isinstance(metadata, (int, type(None)))
1023
+ or out_alias_info is UnaliasedStorage
1024
+ ):
1025
+ return None
1026
+
1027
+ if isinstance(out_alias_info, AliasesPriorGraphOutput):
1028
+ depth, existing_output_index = out_alias_info.index
1029
+ ref = self.path_weakrefs[depth][existing_output_index]
1030
+ assert ref is not None
1031
+ return torch.UntypedStorage._new_with_weak_ptr(ref())
1032
+
1033
+ assert isinstance(out_alias_info, AliasesNewOutput)
1034
+ return out_alias_info.index
1035
+
1036
+ def prepare_storages_for_construction(
1037
+ self,
1038
+ ) -> List[Union[UntypedStorage, None, int]]:
1039
+ output_storages = []
1040
+ for output_storage_alias, metadata in zip(
1041
+ self.output_storage_alias, self.outputs_metadata
1042
+ ):
1043
+ output_storages.append(
1044
+ self.prepare_alias_info_for_tensor_construction(
1045
+ output_storage_alias, metadata
1046
+ )
1047
+ )
1048
+
1049
+ return output_storages
1050
+
1051
+ def run_graph(self):
1052
+ assert self.graph is not None
1053
+ self.graph.replay()
1054
+
1055
+ def all_outputs_are_dead(self):
1056
+ "All outputs of the path from this node to its root are dead"
1057
+ for depth, output_index in self.live_indices_after_graph:
1058
+ if is_live(self.path_weakrefs[depth][output_index]):
1059
+ return False
1060
+ return True
1061
+
1062
+ def _record(self, model, inputs):
1063
+ "Record the model"
1064
+
1065
+ def static_input_iter():
1066
+ for i in self.wrapped_function.static_input_idxs:
1067
+ if isinstance(
1068
+ inputs[i], torch.Tensor
1069
+ ) and not self._is_cuda_graph_recorded_tensor(inputs[i]):
1070
+ yield inputs[i]
1071
+
1072
+ # see: output_is_alias_of_persistent_static_inputs above
1073
+ static_input_persistent_storage_ptrs: Dict[int, StorageWeakRefWrapper] = {
1074
+ inp.untyped_storage().data_ptr(): StorageWeakRefWrapper(inp)
1075
+ for inp in itertools.chain(
1076
+ static_input_iter(), self.wrapped_function.constants
1077
+ )
1078
+ }
1079
+
1080
+ if config.triton.slow_path_cudagraph_asserts:
1081
+ # need to use parent live weakrefs because live_indices isnt set yet
1082
+ memory = (
1083
+ [] if self.parent is None else list(self.parent.path_live_weakrefs())
1084
+ )
1085
+ memory += [
1086
+ StorageWeakRefWrapper(elem)
1087
+ for i, elem in enumerate(inputs)
1088
+ if isinstance(elem, torch.Tensor)
1089
+ and i not in self.wrapped_function.static_input_idxs
1090
+ and elem.untyped_storage().data_ptr() != 0
1091
+ ]
1092
+ check_memory_pool(self.device, self.cuda_graphs_pool, memory)
1093
+
1094
+ with preserve_rng_state(), torch.cuda.device(
1095
+ self.device
1096
+ ), clear_cublas_manager(), torch.cuda.graph(
1097
+ self.graph,
1098
+ stream=self.stream,
1099
+ pool=self.cuda_graphs_pool,
1100
+ capture_error_mode="thread_local",
1101
+ ), get_history_recording():
1102
+ static_outputs = model(inputs)
1103
+
1104
+ # running model should reclaim memory
1105
+ assert len(inputs) == 0
1106
+
1107
+ if not isinstance(static_outputs, (list, tuple)):
1108
+ static_outputs = (static_outputs,)
1109
+
1110
+ self._add_first_outputs(static_outputs, static_input_persistent_storage_ptrs)
1111
+
1112
+ return static_outputs
1113
+
1114
+ def _add_first_outputs(
1115
+ self,
1116
+ outputs,
1117
+ static_input_persistent_storage_ptrs: Dict[int, StorageWeakRefWrapper],
1118
+ ):
1119
+ "Add the outputs from the first invocation of the node and set up metadata"
1120
+
1121
+ # getting liveness before we have added the outputs to path, so the length
1122
+ # of the two lists is equal
1123
+ prev_liveness = self.recorded_liveness_before_graph
1124
+ curr_liveness = self._get_liveness(self.path_weakrefs)
1125
+
1126
+ delta = self._get_different_indices(prev_liveness, curr_liveness)
1127
+ self.expected_dead_indices_after_graph = delta
1128
+
1129
+ assert len(self.outputs_weakrefs) == 0
1130
+ # index from data pointer to index in outputs
1131
+ output_new_storages_index: Dict[StorageDataPtr, int] = {}
1132
+
1133
+ self.unaliased_in_all_paths = [False for _ in range(len(outputs))]
1134
+ self.static_output_tensors = [None for _ in range(len(outputs))]
1135
+
1136
+ for i, o in enumerate(outputs):
1137
+ if o is None or not isinstance(o, torch.Tensor):
1138
+ self.output_storage_alias.append(UnaliasedStorage)
1139
+ continue
1140
+
1141
+ torch._check(
1142
+ o.is_cuda or o.untyped_storage().data_ptr() == 0,
1143
+ lambda: (
1144
+ "Expected all cuda outputs in cuda graph recording. Non cuda output "
1145
+ f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}"
1146
+ ),
1147
+ ),
1148
+
1149
+ ref = static_input_persistent_storage_ptrs.get(
1150
+ o.untyped_storage().data_ptr(), None
1151
+ )
1152
+ # also treat empty storages as static outputs because we do not need to manage their lifetime
1153
+ # and they should not participate in checkpointing
1154
+ is_empty_storage = o.untyped_storage().data_ptr() == 0
1155
+ if (ref and ref() is not None) or is_empty_storage:
1156
+ self.output_storage_alias.append(None)
1157
+ self.static_output_tensors[i] = o
1158
+ continue
1159
+
1160
+ path_ref = self._is_alias_of_live_recorded_tensor(o)
1161
+ if path_ref is not None:
1162
+ self._mark_prior_graph_output_as_aliased(path_ref)
1163
+ self.output_storage_alias.append(AliasesPriorGraphOutput(path_ref))
1164
+ continue
1165
+
1166
+ if o.untyped_storage().data_ptr() in output_new_storages_index:
1167
+ index = output_new_storages_index[o.untyped_storage().data_ptr()]
1168
+ self.unaliased_in_all_paths[index] = False
1169
+ self.output_storage_alias.append(AliasesNewOutput(index))
1170
+ continue
1171
+
1172
+ output_new_storages_index[o.untyped_storage().data_ptr()] = i
1173
+ self.output_storage_alias.append(UnaliasedStorage)
1174
+ self.unaliased_in_all_paths[i] = True
1175
+
1176
+ if self.stack_traces is None:
1177
+ self.stack_traces = [None for _ in range(len(outputs))]
1178
+ else:
1179
+ assert len(self.stack_traces) == len(
1180
+ outputs
1181
+ ), "Wrong number of stack traces passed in"
1182
+
1183
+ assert not self.outputs_weakrefs
1184
+ for out, static_output_tensor in zip(outputs, self.static_output_tensors):
1185
+ if not isinstance(out, torch.Tensor) or static_output_tensor is not None:
1186
+ self.outputs_weakrefs.append(None)
1187
+ self.tensor_weakrefs.append(None)
1188
+ else:
1189
+ self.outputs_weakrefs.append(StorageWeakRefWrapper(out))
1190
+ self.tensor_weakrefs.append(TensorWeakRef(out))
1191
+
1192
+ self.recorded_liveness_after_graph = self._get_liveness(self.path_weakrefs)
1193
+ self.checkpointed_caching_state = torch._C._cuda_getCheckpointState(
1194
+ self.device, self.cuda_graphs_pool
1195
+ )
1196
+
1197
+ # now, get liveness with outputs added
1198
+ for depth in range(len(self.path_weakrefs)):
1199
+ for output_index in range(len(self.path_weakrefs[depth])):
1200
+ if is_live(self.path_weakrefs[depth][output_index]):
1201
+ self.live_indices_after_graph.append((depth, output_index))
1202
+
1203
+ self.debug_check_invariants_after_invocation()
1204
+ if config.triton.slow_path_cudagraph_asserts:
1205
+ check_memory_pool(
1206
+ self.device, self.cuda_graphs_pool, list(self.path_live_weakrefs())
1207
+ )
1208
+
1209
+ def _mark_prior_graph_output_as_aliased(self, index: PathOutputIndex):
1210
+ "Remove a graph output from the unaliased, cached tensors in an ancestor node"
1211
+ depth, output_index = index
1212
+ node = list(self._path_from_root)[depth]
1213
+ node.unaliased_in_all_paths[output_index] = False
1214
+ x = self.path_weakrefs[depth][output_index]
1215
+ assert x is not None
1216
+ x.remove_extra_reference()
1217
+
1218
+ def _initialize_cached_tensors(self):
1219
+ # we should not be clearing output_weakrefs, and they should be set in the first
1220
+ # record run
1221
+ assert len(self.outputs_weakrefs) == len(self.outputs_metadata)
1222
+
1223
+ for i, (storage_info, metadata, make_cached) in enumerate(
1224
+ zip(
1225
+ self.output_storage_alias,
1226
+ self.outputs_metadata,
1227
+ self.unaliased_in_all_paths,
1228
+ )
1229
+ ):
1230
+ if not make_cached:
1231
+ self.cached_tensor_outputs.append(None)
1232
+ continue
1233
+
1234
+ assert storage_info is UnaliasedStorage
1235
+ assert isinstance(metadata, dict)
1236
+ s = self.create_storage(metadata)
1237
+ out = self._reconstruct_from_tensor_metadata(metadata, storage=s)
1238
+
1239
+ # XXX: let autograd know that there will be an additional reference to the tensor
1240
+ # that can be ignored when deciding whether to do gradient buffer inplacing.
1241
+ # Otherwise, inplacing could differ between tracing and subsequent execution.
1242
+ # For some models we tested this led to inputs no longer being in cudagraph pools,
1243
+ # leading to spurious re-recordings.
1244
+ # It also tells AMP cache that even though the tensor impls cannot be cached
1245
+ # in dtype conversions.
1246
+
1247
+ torch._C._add_cached_tensor(out)
1248
+
1249
+ self_ref = weakref.ref(self)
1250
+
1251
+ # one reference in our array, and calling sys.getrefcount bumps the refcount by one
1252
+ def check_refcount(i):
1253
+ self_loc = self_ref()
1254
+ if self_loc is None:
1255
+ return False
1256
+ return self_loc.get_output_refcount(i) == 2
1257
+
1258
+ check = functools.partial(check_refcount, i=i)
1259
+
1260
+ self.outputs_weakrefs[i] = StorageWeakRefWrapper(out, extra_ref_check=check)
1261
+ self.cached_tensor_outputs.append(out)
1262
+
1263
+ def get_output_refcount(self, index):
1264
+ return sys.getrefcount(self.cached_tensor_outputs[index])
1265
+
1266
+ @property
1267
+ def parent(self):
1268
+ "unwraps the weakref to _parent"
1269
+ return self._parent() if self._parent is not None else None
1270
+
1271
+ @property
1272
+ def _path_to_root(self):
1273
+ "Returns all nodes in the path starting at self and ending at root"
1274
+ node = self
1275
+ while node:
1276
+ yield node
1277
+ node = node.parent
1278
+
1279
+ @property
1280
+ def _path_from_root(self):
1281
+ "Returns all nodes in the path starting at the root and ending at self"
1282
+ nodes = reversed(list(self._path_to_root))
1283
+ yield from nodes
1284
+
1285
+ def _is_cuda_graph_recorded_tensor(self, t: torch.Tensor):
1286
+ "Is this tensor an output of a node in this path"
1287
+ for output_refs in self.path_weakrefs:
1288
+ for storage_weak_ref in output_refs:
1289
+ if storage_weak_ref is None:
1290
+ continue
1291
+ # don't need to check liveness of storage since the cuda graph managed
1292
+ # memory is never released.
1293
+ data_ptr = storage_weak_ref.data_ptr()
1294
+ if t.untyped_storage().data_ptr() == data_ptr:
1295
+ return True
1296
+
1297
+ return False
1298
+
1299
+ def _is_alias_of_live_recorded_tensor(
1300
+ self, t: torch.Tensor
1301
+ ) -> Optional[PathOutputIndex]:
1302
+ for depth, output_refs in enumerate(self.path_weakrefs):
1303
+ for output_index, storage_ref in enumerate(output_refs):
1304
+ if (storage_and_ptr := maybe_deref(storage_ref)) is not None:
1305
+ storage, ptr = storage_and_ptr
1306
+ if ptr == t.untyped_storage().data_ptr():
1307
+ return (depth, output_index)
1308
+
1309
+ return None
1310
+
1311
+ @staticmethod
1312
+ def _check_liveness(
1313
+ indices: List[PathOutputIndex],
1314
+ output_refs: List[List[Optional[StorageWeakRefWrapper]]],
1315
+ ):
1316
+ "Check that all of the indices specified are dead references"
1317
+ for depth, output_index in indices:
1318
+ w = output_refs[depth][output_index]
1319
+ assert w is not None
1320
+ if w() is not None:
1321
+ return False
1322
+ return True
1323
+
1324
+ def add_child(self, function_id: FunctionID, node: CUDAGraphNode):
1325
+ "Adds node as a a child of self"
1326
+ self.children[function_id].append(node)
1327
+
1328
+ @staticmethod
1329
+ def _get_different_indices(
1330
+ prev: List[List[bool]], curr: List[List[bool]]
1331
+ ) -> List[PathOutputIndex]:
1332
+ "Find indices where the two lists differ."
1333
+ dead_indices = []
1334
+ assert len(prev) <= len(curr)
1335
+ for i, (outputs1, outputs2) in enumerate(zip(prev, curr)):
1336
+ assert len(outputs1) == len(outputs2)
1337
+ for j, (output1, output2) in enumerate(zip(outputs1, outputs2)):
1338
+ if output1 != output2:
1339
+ dead_indices.append((i, j))
1340
+
1341
+ return dead_indices
1342
+
1343
+ @staticmethod
1344
+ def _get_liveness(
1345
+ weakrefs: List[List[Optional[StorageWeakRefWrapper]]],
1346
+ ) -> List[List[bool]]:
1347
+ "Maps weakrefs to true if the reference is alive and false otherwise"
1348
+ if len(weakrefs) == 0:
1349
+ return []
1350
+
1351
+ return [pytree.tree_map(is_live, outputs) for outputs in weakrefs]
1352
+
1353
+ def debug_assert_invariants(
1354
+ self, expected_liveness: List[List[bool]], newly_dead: List[PathOutputIndex]
1355
+ ):
1356
+ if not config.triton.fast_path_cudagraph_asserts:
1357
+ return
1358
+
1359
+ for i, node in enumerate(self._path_from_root):
1360
+ assert self.path_weakrefs[i] is node.outputs_weakrefs
1361
+
1362
+ nodes = list(self._path_from_root)
1363
+
1364
+ live_blocks = get_block_addrs(self.cuda_graphs_pool)
1365
+
1366
+ live_storage_data_ptrs = set()
1367
+ live_storage_weak_ptrs = set()
1368
+
1369
+ for depth, outputs_liveness in enumerate(expected_liveness):
1370
+ for output_idx, output_liveness in enumerate(outputs_liveness):
1371
+ # tensor can die early, but it can't be alive when it should be dead
1372
+ w = self.path_weakrefs[depth][output_idx]
1373
+ if (stor_weak_ptr_and_data_ptr := maybe_deref(w)) is not None:
1374
+ assert output_liveness
1375
+ stor_weak_ptr, stor_data_ptr = stor_weak_ptr_and_data_ptr
1376
+ assert (stor_data_ptr in live_storage_data_ptrs) == (
1377
+ stor_weak_ptr in live_storage_weak_ptrs
1378
+ )
1379
+ live_storage_data_ptrs.add(stor_data_ptr)
1380
+ live_storage_weak_ptrs.add(stor_weak_ptr)
1381
+
1382
+ is_persistent_alias = (
1383
+ nodes[depth].static_output_tensors[output_idx] is not None
1384
+ )
1385
+
1386
+ if is_persistent_alias:
1387
+ assert stor_data_ptr not in live_blocks
1388
+
1389
+ for depth, output_index in newly_dead:
1390
+ assert not is_live(self.path_weakrefs[depth][output_index])
1391
+
1392
+ def debug_check_invariants_before_invocation(self):
1393
+ self.debug_assert_invariants(
1394
+ self.recorded_liveness_before_graph, self.expected_dead_indices_before_graph
1395
+ )
1396
+
1397
+ def debug_check_invariants_after_invocation(self):
1398
+ self.debug_assert_invariants(
1399
+ self.recorded_liveness_before_graph, self.expected_dead_indices_after_graph
1400
+ )
1401
+
1402
+ def data_ptrs_dead_since_invocation(self) -> List[int]:
1403
+ """
1404
+ Since this node was invoked, return data ptrs of all tensor outputs that have died
1405
+ in the current executing tree path.
1406
+ """
1407
+ curr_liveness = self._get_liveness(self.path_weakrefs)
1408
+ _get_different_indices = self._get_different_indices(
1409
+ self.recorded_liveness_after_graph, curr_liveness
1410
+ )
1411
+
1412
+ path = list(self._path_from_root)
1413
+ ptrs_to_deallocate = []
1414
+ for depth, output_index in _get_different_indices:
1415
+ ptrs_to_deallocate.append(
1416
+ path[depth].outputs_metadata[output_index]["data_ptr"]
1417
+ )
1418
+
1419
+ return ptrs_to_deallocate
1420
+
1421
+ def path_live_weakrefs(self) -> Iterator[StorageWeakRefWrapper]:
1422
+ for i, j in self.live_indices_after_graph:
1423
+ out = self.path_weakrefs[i][j]
1424
+ if out is not None and is_live(out):
1425
+ yield out
1426
+
1427
+ def remove_node_cached_tensors(self):
1428
+ for t in self.cached_tensor_outputs:
1429
+ if t is not None:
1430
+ torch._C._remove_cached_tensor(t)
1431
+ self.cached_tensor_outputs.clear()
1432
+
1433
+ for i, unaliased in enumerate(self.unaliased_in_all_paths):
1434
+ if unaliased:
1435
+ n = self.outputs_weakrefs[i]
1436
+ assert n is not None
1437
+ n.remove_extra_reference()
1438
+
1439
+ def remove_path_cached_tensors(self):
1440
+ for node in self._path_from_root:
1441
+ node.remove_node_cached_tensors()
1442
+
1443
+ def clear_path_state(self):
1444
+ "Clear the path state in this current executing node"
1445
+ # this doesnt actually do anything right now, leaving it as placeholder
1446
+ pass
1447
+
1448
+ @staticmethod
1449
+ def _tensor_metadata(x, ignore_storage_offset=True):
1450
+ assert isinstance(x, torch.Tensor)
1451
+ # We ignore the storage offset for inputs, but not for outputs
1452
+ # TODO: - should we make the storage resizable ?
1453
+ return {
1454
+ "nbytes": x.untyped_storage().nbytes(),
1455
+ "data_ptr": x.untyped_storage().data_ptr(),
1456
+ "size": x.shape,
1457
+ "stride": x.stride(),
1458
+ "dtype": x.dtype,
1459
+ "device": x.device,
1460
+ "storage_offset": x.storage_offset() if not ignore_storage_offset else 0,
1461
+ }
1462
+
1463
+ def _reconstruct_from_tensor_metadata(
1464
+ self, metadata: Dict[str, Any], storage=None
1465
+ ) -> Tensor:
1466
+ s = self.create_storage(metadata) if storage is None else storage
1467
+ return torch._C._construct_CUDA_Tensor_From_Storage_And_Metadata(metadata, s)
1468
+
1469
+ def create_storage(self, metadata):
1470
+ return torch._C._construct_storage_from_data_pointer(
1471
+ metadata["data_ptr"], metadata["device"], metadata["nbytes"]
1472
+ )
1473
+
1474
+ def _allocate_and_copy_recording_inputs(
1475
+ self, inputs
1476
+ ) -> List[Union[torch.Tensor, int]]:
1477
+ """
1478
+ Allocate inputs for non static, non cudagraph managraphed managed tensors in the memory pool
1479
+ and copy over the tensor values.
1480
+ """
1481
+
1482
+ torch.cuda.synchronize()
1483
+ self.stream.wait_stream(torch.cuda.current_stream())
1484
+ recording_inputs: List[Union[Tensor, int]] = []
1485
+
1486
+ with warnings.catch_warnings(record=True), torch.cuda.device(
1487
+ self.device
1488
+ ), _use_cuda_memory_pool_manager(
1489
+ self.device,
1490
+ mem_pool=self.cuda_graphs_pool,
1491
+ stream=self.stream,
1492
+ ):
1493
+ for i, inp in enumerate(inputs):
1494
+ if not isinstance(inp, torch.Tensor):
1495
+ assert isinstance(inp, int)
1496
+ recording_inputs.append(inp)
1497
+ elif i not in self.static_input_idxs:
1498
+ # static_input does an allocation!
1499
+ recording_inputs.append(static_input(inp))
1500
+ # copy over and clear non recording input
1501
+ self._copy_input(i, recording_inputs[-1], inp)
1502
+ inputs[i] = None
1503
+ del inp
1504
+ else:
1505
+ recording_inputs.append(inp)
1506
+
1507
+ return recording_inputs
1508
+
1509
+ def check_invariants(self, inputs: List[Tensor]) -> bool:
1510
+ """
1511
+ Checks if this node can be run. The same pattern of tensor liveness and tensors
1512
+ managed in the cudagraph private pool must remain stable.
1513
+ """
1514
+
1515
+ # previously managed data pointers remain stable
1516
+ for idx in self.cudagraph_managed_idxs:
1517
+ if inputs[idx].data_ptr() != self.static_input_data_ptrs[idx]:
1518
+ return False
1519
+
1520
+ if not self._check_liveness(
1521
+ self.expected_dead_indices_before_graph, self.path_weakrefs
1522
+ ):
1523
+ return False
1524
+
1525
+ # the cudagraph managed tensors which died upon recording must also die upon
1526
+ # this invocation. it is too late to check after we've replayed the graph,
1527
+ # because we would have already written over their memory.
1528
+ for idx in self.cudagraph_managed_idxs:
1529
+ inputs[idx] = None # type: ignore[call-overload]
1530
+
1531
+ torch._check(
1532
+ self._check_liveness(
1533
+ self.expected_dead_indices_after_graph, self.path_weakrefs
1534
+ ),
1535
+ lambda: "TODO: graph recording observed an input tensor deallocate during graph "
1536
+ " recording that did not occur during replay. Please file an issue.",
1537
+ )
1538
+ return True
1539
+
1540
+ def num_descendants(self) -> int:
1541
+ "Total number of descendents of this node"
1542
+ num_desc = 0
1543
+ for children in self.children.values():
1544
+ for child in children:
1545
+ num_desc += 1
1546
+ num_desc += child.num_descendants()
1547
+ return num_desc
1548
+
1549
+
1550
+ def get_cudagraph_segments(pool_id):
1551
+ segments = torch.cuda.memory_snapshot()
1552
+ return [segment for segment in segments if segment["segment_pool_id"] == pool_id]
1553
+
1554
+
1555
+ def get_block_addrs(pool_id, live_only=True):
1556
+ blocks = []
1557
+
1558
+ for segment in get_cudagraph_segments(pool_id):
1559
+ addr = segment["address"]
1560
+ for block in segment["blocks"]:
1561
+ if block["state"] == "active_allocated" or not live_only:
1562
+ blocks.append(addr)
1563
+
1564
+ addr += block["size"]
1565
+
1566
+ return blocks
1567
+
1568
+
1569
+ def format_tb(frames):
1570
+ formatted_traceback = []
1571
+
1572
+ for entry in frames:
1573
+ formatted_traceback.append(
1574
+ traceback.FrameSummary(entry["filename"], entry["line"], entry["name"])
1575
+ )
1576
+
1577
+ return "".join(traceback.format_list(formatted_traceback))
1578
+
1579
+
1580
+ def check_memory_pool(device, pool_id, live_storages_ptrs: List[StorageWeakRefWrapper]):
1581
+ assert all(
1582
+ isinstance(elem, StorageWeakRefWrapper) for elem in live_storages_ptrs
1583
+ ) # noqa: C419
1584
+ unique_storages = {stor.data_ptr() for stor in live_storages_ptrs if stor()}
1585
+
1586
+ # check if there is a divergence first, then do the expensive snapshot call after
1587
+ # we know it will error
1588
+ if torch._C._cuda_checkPoolLiveAllocations(device, pool_id, unique_storages):
1589
+ return
1590
+
1591
+ # at this point we are past the fast-path. we have seen rare cases where a dead tensor is dead,
1592
+ # but hasn't been gc'd yet, and gives false positive for allocated_not_in_live_storages
1593
+ gc.collect()
1594
+
1595
+ segments = get_cudagraph_segments(pool_id)
1596
+
1597
+ allocated_not_in_live_storages = {}
1598
+
1599
+ for segment in segments:
1600
+ addr = segment["address"]
1601
+ for block in segment["blocks"]:
1602
+ if block["state"] == "active_allocated":
1603
+ if addr not in unique_storages:
1604
+ allocated_not_in_live_storages[addr] = block
1605
+ else:
1606
+ unique_storages.remove(addr)
1607
+
1608
+ addr += block["size"]
1609
+
1610
+ torch._check(
1611
+ len(unique_storages) == 0,
1612
+ lambda: f"These storage data ptrs are not allocated in pool {pool_id} but should be {unique_storages}",
1613
+ )
1614
+
1615
+ if allocated_not_in_live_storages != 0:
1616
+ formatted = []
1617
+ for dp, block in allocated_not_in_live_storages.items():
1618
+ trace = format_tb(block.get("frames", []))
1619
+ formatted.append(f"Data Pointer: {dp}, history: \n{trace}")
1620
+ formatted_s = "\n".join(formatted)
1621
+ msg = (
1622
+ f"These live storage data ptrs are in the cudagraph pool but not "
1623
+ f"accounted for as an output of cudagraph trees: \n\n{formatted_s}"
1624
+ )
1625
+ raise RuntimeError(msg)
1626
+
1627
+
1628
+ class ExecutionState(Enum):
1629
+ """
1630
+ Represents the state of the CUDAGraph Tree. Will be None if there is no live current memory allocated
1631
+ in the cuda graph pool. Otherwise will reflect the state of the most recently executed node.
1632
+ """
1633
+
1634
+ NONE = auto()
1635
+ WARMUP = auto()
1636
+ RECORDING = auto()
1637
+ EXECUTION = auto()
1638
+
1639
+
1640
+ class CompilationMode(Enum):
1641
+ FORWARD = auto()
1642
+ BACKWARD = auto()
1643
+ INFERENCE = auto()
1644
+
1645
+
1646
+ class CUDAGraphTreeManager:
1647
+ """
1648
+ Groups individual recordings or executions of cuda graphs into a tree of recordings,
1649
+ and checks required invariants, and manages warmups of graphs.
1650
+
1651
+ When graphs are recorded in the same tree, it enforces subsequent execution
1652
+ to follow the same order and have the same output tensor livespans. To remove
1653
+ unnecessary coupling of cuda graphs (and additional imposed invariants),
1654
+ the tree manager will end a currently recording tree whenever it is valid - when
1655
+ the memory pool no longer has any live allocations.
1656
+
1657
+ We ignore outputs from a previous generation that correspond to prior model outputs.
1658
+ Currently this is hardcoded `GenerationTracker.generation` tracked in torch dynamo.
1659
+ # TODO: make generation increment configurable, warn on overwrite.
1660
+
1661
+ We run graph warmups in the cudagraph memory pool and return the result on the first invocation
1662
+ of a function. For many models it is important to reclaim activations as you run the backward.
1663
+ If we were to warm up the model and keep an extra copy of the inputs around to subsequently
1664
+ use for recording, we would incur a memory penalty. Additionally, if we are part way through training
1665
+ your model and need to recompile, memory will be allocated to the cuda graph pool, so we run this
1666
+ warmup run in the cuda graph memory pool. As for recording, warm up needs the state of live tensors
1667
+ to be accurately reflected so we checkpoint the allocator state if we need to warm up following graph
1668
+ replay.
1669
+ """
1670
+
1671
+ def __init__(self, device_index: int):
1672
+ # roots are functions which have no dependencies on an other node. I.e.,
1673
+ # when they are first invoked, none of their inputs are outputs are outputs
1674
+ # of another node, nor are there any live outputs of another node whose
1675
+ # liveness would create a dependency.
1676
+ self.roots: Dict[FunctionID, List[CUDAGraphNode]] = defaultdict(list)
1677
+
1678
+ # mapping from function id to wrapped function
1679
+ self.ids_to_funcs: Dict[FunctionID, WrappedFunction] = {}
1680
+
1681
+ self.ids_to_stack_traces: Dict[FunctionID, StackTraces] = {}
1682
+
1683
+ self.warmed_up_functions: Set[FunctionID] = set()
1684
+ # if we fail to increment generation, and are stuck warming up,
1685
+ # only warn on each function once
1686
+ self.warned_functions: Set[FunctionID] = set()
1687
+ torch._C._set_cached_tensors_enabled(True)
1688
+
1689
+ # NB: cuda caching allocator will remember the stream a segment is allocated to
1690
+ # and only allocate that segment to the same stream. we need to use a single stream
1691
+ # for all allocations to the memory pool, otherwise the allocations to separate streams
1692
+ # will not be reused; separate recordings would have use the same memory pool, but not
1693
+ # the same memory.
1694
+
1695
+ with torch.cuda.device(device_index):
1696
+ torch.cuda.synchronize()
1697
+ self.stream = torch.cuda.Stream()
1698
+ self.stream.wait_stream(torch.cuda.current_stream())
1699
+
1700
+ # Keeps Memory Pool Alive
1701
+ self.graph: Optional[torch.cuda.CUDAGraph] = torch.cuda.CUDAGraph()
1702
+ self.cuda_graphs_thread_pool = torch.cuda.graph_pool_handle()
1703
+
1704
+ with warnings.catch_warnings(record=True), torch.cuda.graph(
1705
+ self.graph,
1706
+ pool=self.cuda_graphs_thread_pool,
1707
+ stream=self.stream,
1708
+ capture_error_mode="thread_local",
1709
+ ):
1710
+ pass
1711
+
1712
+ self.graph_counter = itertools.count(0)
1713
+ self.func_counter = itertools.count(0)
1714
+
1715
+ # whether we the current node is in a state of warmup, recording, execution. If
1716
+ # there is no current node the state will be ExecutionState.None.
1717
+ self.path_state = ExecutionState.NONE
1718
+ self.device_index = device_index
1719
+
1720
+ # the most recently invoked cudagraph wrapping of a function. Will be None
1721
+ # when there is no output from a previous recording or execution whose memory
1722
+ # we need to respect in the cuda caching allocation. If you incremented generation,
1723
+ # this will also be none, as ignore those allocations.
1724
+ self.current_node: Optional[CUDAGraphNode] = None
1725
+
1726
+ # current generation of cudagraph invocations. when torch.compile is run
1727
+ # we increment the current generation. are willing to ignore live outputs
1728
+ # of a previous generation in checking liveness.
1729
+ self.current_gen: int = -1
1730
+
1731
+ # number of instances we are in execution and failed to match to an
1732
+ # existing child
1733
+ self.debug_fail_counter = 0
1734
+ # number of instances we had to checkpoint the function
1735
+ self.debug_checkpointing_counter = 0
1736
+
1737
+ self.id_to_mode: Dict[FunctionID, CompilationMode] = {}
1738
+
1739
+ # Note: [Backward Generation Handling]
1740
+ # We generally perform a sequence of forward executions followed by backward executions.
1741
+ # If multiple torch.compile wrapped forwards are executed with their backwards pending,
1742
+ # we should not disregard the outputs from a prior torch.compile since the entire training
1743
+ # loop hasn't completed. Occasionally, a backward pass corresponding to a forward pass may
1744
+ # not be executed, so we cannot wait for all pending forward pass backward completions, so
1745
+ # we cannot wait for all backwards to have been invoked. Instead we wait for a single backward
1746
+ # invocation. Triggering a backward pass typically doesn't lead to another torch.compile
1747
+ # invocation, making it less likely for the generation to increase between multiple
1748
+ # backward calls. The following use case is covered by this approach:
1749
+ # mod1 = torch.compile(...)
1750
+ # mod2 = torch.compile(...)
1751
+ # mod2(mod1(x)).sum().backward()
1752
+
1753
+ self.running_forwards_with_pending_backwards = False
1754
+
1755
+ def run(self, new_inputs: List[Tensor], function_id: FunctionID):
1756
+ assert self.graph is not None, "Running CUDAGraph after shutdown"
1757
+ out = self._run(new_inputs, function_id)
1758
+
1759
+ # The forwards are only pending following invocation, not before
1760
+ mode = self.id_to_mode[function_id]
1761
+ if mode == CompilationMode.FORWARD:
1762
+ self.running_forwards_with_pending_backwards = True
1763
+ elif mode == CompilationMode.BACKWARD:
1764
+ self.running_forwards_with_pending_backwards = False
1765
+
1766
+ return out
1767
+
1768
+ def set_to_running_backward(self):
1769
+ self.running_forwards_with_pending_backwards = False
1770
+
1771
+ def _run(self, new_inputs: List[Tensor], function_id: FunctionID):
1772
+ # we will try to end the current execution lazily, since
1773
+ # we dont want to do unnecessary checking of the existing outputs
1774
+ # on the hot path, but both recording and warmup only happen once
1775
+ # so we check up front
1776
+ if self.in_recording:
1777
+ self.try_end_curr_recording(function_id)
1778
+
1779
+ if self.in_warmup:
1780
+ self.try_end_curr_warmup(function_id)
1781
+
1782
+ # warming up a function and subsequentally recording may use different memory addresses
1783
+ # because both depend on the state of the caching allocator. if we warm up graph A,
1784
+ # then warm up graph B and make more allocations, the subsequent recording of A will not
1785
+ # necessarily use the same addresses as in the warm up. Thus any warm up of a node can only
1786
+ # be followed by warm up runs.
1787
+ if (
1788
+ not (
1789
+ function_id in self.warmed_up_functions
1790
+ or config.triton.skip_cudagraph_warmup
1791
+ )
1792
+ ) or self.in_warmup:
1793
+ # If we are in the middle of executing cuda graphs, then we need to checkpoint memory state.
1794
+ # Both Recording and Warmup will be reflected in the allocator and dont need changes
1795
+ if self.path_state == ExecutionState.EXECUTION:
1796
+ self.apply_checkpoint_execution_state_in_allocator()
1797
+
1798
+ return self.run_eager(new_inputs, function_id)
1799
+
1800
+ child_nodes = (
1801
+ self.roots if self.current_node is None else self.current_node.children
1802
+ )
1803
+
1804
+ if not self.in_recording:
1805
+ for child in child_nodes[function_id]:
1806
+ # here we are checking memory consistency between recording and execution,
1807
+ # as well as things like stability of tensor locations, etc
1808
+ # and other
1809
+ if child.check_invariants(new_inputs):
1810
+ return self.execute_node(child, new_inputs)
1811
+
1812
+ # now that we know the new function can't be run as a child of the
1813
+ # current node, if it is a root, try to end the current execution.
1814
+ # as noted above, we want to do this lazily to avoid having to
1815
+ # check all existing outputs
1816
+ if self.current_node is not None and function_id in self.roots:
1817
+ self.try_end_curr_execution()
1818
+
1819
+ # run again to hit the root matching case which must succeed
1820
+ if self.current_node is None:
1821
+ return self.run(new_inputs, function_id)
1822
+
1823
+ # at this point, we necessarily will do a new recording
1824
+ self.debug_fail_counter += 1
1825
+
1826
+ self.try_end_curr_execution()
1827
+ if self.current_node is not None:
1828
+ self.apply_checkpoint_execution_state_in_allocator()
1829
+
1830
+ # now, we are in a recording state !
1831
+ return self.record_function(new_inputs, function_id)
1832
+
1833
+ def shutdown(self):
1834
+ """
1835
+ Remove all cached tensors in all nodes. Because cached tensors can hold gradients which in turn
1836
+ might reference a backward which invokes a CUDA Graph Node, we have to manually clear them on shutdown
1837
+ to avoid a reference cycle.
1838
+ """
1839
+ nodes = []
1840
+ for roots in self.roots.values():
1841
+ nodes.extend(roots)
1842
+
1843
+ while nodes:
1844
+ node = nodes.pop()
1845
+ for children in node.children.values():
1846
+ nodes.extend(children)
1847
+ node.remove_node_cached_tensors()
1848
+ node.graph = None
1849
+
1850
+ self.graph = None
1851
+ self.roots = None # type: ignore[assignment]
1852
+ self.current_node = None
1853
+
1854
+ def record_function(self, new_inputs, function_id) -> List[Optional[Tensor]]:
1855
+ graph_id = self.new_graph_id()
1856
+ log.debug(
1857
+ "Recording function %d of graph recording id %d",
1858
+ function_id.id,
1859
+ graph_id.id,
1860
+ )
1861
+ torch.cuda.synchronize()
1862
+ node = CUDAGraphNode(
1863
+ self.ids_to_funcs[function_id],
1864
+ graph_id,
1865
+ self.current_node,
1866
+ new_inputs,
1867
+ self.cuda_graphs_thread_pool,
1868
+ self.device_index,
1869
+ self.ids_to_stack_traces[function_id],
1870
+ self.stream,
1871
+ )
1872
+ if self.current_node is None:
1873
+ self.roots[function_id].append(node)
1874
+ else:
1875
+ self.current_node.add_child(function_id, node)
1876
+ self.current_node = node
1877
+ self.path_state = ExecutionState.RECORDING
1878
+ self.update_generation()
1879
+ torch.cuda.synchronize()
1880
+ return node.run_first_inputs(new_inputs)
1881
+
1882
+ def execute_node(self, node: CUDAGraphNode, new_inputs) -> List[Optional[Tensor]]:
1883
+ self.current_node = node
1884
+ self.path_state = ExecutionState.EXECUTION
1885
+ self.update_generation()
1886
+ return node.run(new_inputs)
1887
+
1888
+ def run_eager(self, new_inputs, function_id: FunctionID):
1889
+ # this is only stored on current node, because when we start a new path,
1890
+ # we will deallocate it
1891
+ already_warm = function_id in self.warmed_up_functions
1892
+ if not already_warm:
1893
+ log.debug("Running warmup of function %d", function_id.id)
1894
+ else:
1895
+ log.debug(
1896
+ "Running eager of function %d because ancestor needed to warm up",
1897
+ function_id.id,
1898
+ )
1899
+ self.warmed_up_functions.add(function_id)
1900
+ node = CUDAWarmupNode(
1901
+ self.ids_to_funcs[function_id],
1902
+ self.current_node,
1903
+ self.cuda_graphs_thread_pool,
1904
+ self.graph,
1905
+ self.device_index,
1906
+ self.ids_to_stack_traces[function_id],
1907
+ self.stream,
1908
+ already_warm,
1909
+ )
1910
+ self.current_node = node
1911
+ self.path_state = ExecutionState.WARMUP
1912
+ self.update_generation()
1913
+ return node.run(new_inputs)
1914
+
1915
+ def new_graph_id(self) -> GraphID:
1916
+ return GraphID(next(self.graph_counter))
1917
+
1918
+ def new_func_id(self) -> FunctionID:
1919
+ return FunctionID(next(self.func_counter))
1920
+
1921
+ def add_function(
1922
+ self,
1923
+ model,
1924
+ inputs,
1925
+ static_input_idxs,
1926
+ stack_traces,
1927
+ mode,
1928
+ constants,
1929
+ ) -> Tuple[Callable[..., Any], List[Optional[Tensor]]]:
1930
+ id = self.new_func_id()
1931
+ self.ids_to_stack_traces[id] = stack_traces
1932
+ self.ids_to_funcs[id] = WrappedFunction(
1933
+ model,
1934
+ static_input_idxs,
1935
+ id,
1936
+ tuple(t for t in constants if isinstance(t, torch.Tensor) and t.is_cuda),
1937
+ )
1938
+ self.id_to_mode[id] = mode
1939
+ fn = functools.partial(self.run, function_id=id)
1940
+
1941
+ # container needs to set clean up when fn dies
1942
+ get_container(self.device_index).add_strong_reference(fn)
1943
+ return fn, fn(inputs)
1944
+
1945
+ @property
1946
+ def in_recording(self):
1947
+ return self.path_state == ExecutionState.RECORDING
1948
+
1949
+ @property
1950
+ def in_warmup(self):
1951
+ return self.path_state == ExecutionState.WARMUP
1952
+
1953
+ def get_roots(self) -> Iterator[CUDAGraphNode]:
1954
+ for nodes in self.roots.values():
1955
+ yield from nodes
1956
+
1957
+ @property
1958
+ def current_node(self):
1959
+ return self._current_node
1960
+
1961
+ @current_node.setter
1962
+ def current_node(self, value):
1963
+ self._current_node = value
1964
+ if value is None:
1965
+ self.path_state = ExecutionState.NONE
1966
+
1967
+ def update_generation(self):
1968
+ self.current_gen = self.get_curr_generation()
1969
+
1970
+ @staticmethod
1971
+ def get_curr_generation() -> int:
1972
+ if MarkStepBox.mark_step_counter != 0:
1973
+ return MarkStepBox.mark_step_counter
1974
+
1975
+ return GenerationTracker.generation
1976
+
1977
+ @staticmethod
1978
+ def user_invoked_mark_step():
1979
+ return MarkStepBox.mark_step_counter != 0
1980
+
1981
+ def can_start_new_generation(self) -> bool:
1982
+ if not self.in_new_torch_compile_invocation():
1983
+ return False
1984
+
1985
+ if self.user_invoked_mark_step():
1986
+ return True
1987
+
1988
+ return not self.running_forwards_with_pending_backwards
1989
+
1990
+ def in_new_torch_compile_invocation(self):
1991
+ return self.current_gen != self.get_curr_generation()
1992
+
1993
+ def try_end_curr_recording(self, function_id: FunctionID) -> None:
1994
+ """
1995
+ Check if the current recording can be terminated, either because all outputs of the
1996
+ previously recorded node are dead or because it was executed in a different
1997
+ generation. Will set current_node to None and in_recording to False if successful.
1998
+ """
1999
+ assert self.in_recording
2000
+ assert self.current_node is not None
2001
+
2002
+ # multiple invocations, allow overwriting the previous generation
2003
+ if self.can_start_new_generation():
2004
+ self.dealloc_current_path_weakrefs()
2005
+ self.clear_current_path_state_and_set_to_none()
2006
+ return
2007
+
2008
+ if self.current_node.all_outputs_are_dead():
2009
+ self.clear_current_path_state_and_set_to_none()
2010
+ return
2011
+
2012
+ self.check_warn_on_unable_to_start_executing(function_id)
2013
+
2014
+ def try_end_curr_execution(self) -> None:
2015
+ """
2016
+ Check if the current executing node can be terminated, either because all outputs of the
2017
+ previously executed node are dead or because it was executed in a different generation.
2018
+ Will set current_node to None if successful.
2019
+ """
2020
+
2021
+ assert not self.in_recording
2022
+ if self.current_node is None:
2023
+ return
2024
+
2025
+ if self.can_start_new_generation():
2026
+ self.clear_current_path_state_and_set_to_none()
2027
+ return
2028
+
2029
+ if self.current_node.all_outputs_are_dead():
2030
+ self.clear_current_path_state_and_set_to_none()
2031
+
2032
+ def try_end_curr_warmup(self, function_id: FunctionID):
2033
+ if self.can_start_new_generation():
2034
+ self.dealloc_current_path_weakrefs()
2035
+ self.current_node = None
2036
+ return
2037
+
2038
+ if self.current_node.all_outputs_are_dead():
2039
+ self.current_node = None
2040
+ return
2041
+
2042
+ self.check_warn_on_unable_to_start_executing(function_id)
2043
+
2044
+ def check_warn_on_unable_to_start_executing(self, function_id: FunctionID):
2045
+ "Warn if we in a potential loop where we are unable to hit fast path"
2046
+ if (
2047
+ function_id in self.warned_functions
2048
+ or not self.in_new_torch_compile_invocation()
2049
+ ):
2050
+ return
2051
+
2052
+ existing_nodes = [
2053
+ node
2054
+ for node in self.current_node._path_from_root
2055
+ if node.wrapped_function.id == function_id
2056
+ ]
2057
+
2058
+ if len(existing_nodes) <= 1:
2059
+ return
2060
+
2061
+ # repeated same pattern
2062
+ parents = {
2063
+ n.parent.wrapped_function.id
2064
+ for n in itertools.chain(existing_nodes, (self.current_node,))
2065
+ if n.parent is not None
2066
+ }
2067
+ if len(parents) == len(existing_nodes):
2068
+ return
2069
+
2070
+ self.warned_functions.add(function_id)
2071
+ warnings.warn(
2072
+ "Unable to hit fast path of CUDAGraphs because of pending, uninvoked backwards. "
2073
+ "Consider running with torch.no_grad() or using torch.compiler.cudagraph_mark_step_begin() "
2074
+ "before each model invocation"
2075
+ )
2076
+
2077
+ def dealloc_current_path_weakrefs(self):
2078
+ # TODO: we could also allow the these weak refs to continue to be allocated,
2079
+ # but that adds some complications.
2080
+ for node in self.current_node._path_from_root:
2081
+ assert len(node.tensor_weakrefs) == len(node.stack_traces)
2082
+ for t, stack_trace in zip(node.tensor_weakrefs, node.stack_traces):
2083
+ ten = None if t is None else t()
2084
+ if ten is None:
2085
+ continue
2086
+
2087
+ stack_trace = (
2088
+ stack_trace.strip()
2089
+ if stack_trace
2090
+ else "[Could not find stack trace]"
2091
+ )
2092
+ msg = (
2093
+ "Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. "
2094
+ f"Stack trace: {stack_trace}. "
2095
+ "To prevent overwriting, clone the tensor outside of torch.compile() "
2096
+ "or call torch.compiler.cudagraph_mark_step_begin() before each model invocation."
2097
+ )
2098
+ torch._C._set_storage_access_error_msg(ten, msg)
2099
+
2100
+ deleted = set()
2101
+ for storage_ref in self.current_node.path_live_weakrefs():
2102
+ if storage_ref() and storage_ref.data_ptr() not in deleted:
2103
+ deleted.add(storage_ref.data_ptr())
2104
+ torch._C._free_And_Remove_DeleterFn(storage_ref())
2105
+
2106
+ def clear_current_path_state_and_set_to_none(self):
2107
+ self.current_node.clear_path_state()
2108
+ self.current_node = None
2109
+
2110
+ def apply_checkpoint_execution_state_in_allocator(self):
2111
+ """
2112
+ Checkpoint the current execution state in the caching allocator so that
2113
+ additional cudagraph recordings can be made respecting existent live storages.
2114
+ """
2115
+ self.debug_checkpointing_counter += 1
2116
+ log.debug(
2117
+ "Checkpointing cuda caching allocator state. Number of checkpoints %d",
2118
+ self.debug_checkpointing_counter,
2119
+ )
2120
+
2121
+ state = self.current_node.checkpointed_caching_state
2122
+ device = self.current_node.device
2123
+ assert state is not None and device is not None
2124
+
2125
+ # currently we deallocate on instead of allowing stale recordings
2126
+ stale_storages: List[int] = []
2127
+
2128
+ # remove cached tensors, otherwise they would prevent memory from being
2129
+ # reclaimed in subsequent recordings
2130
+ self.current_node.remove_path_cached_tensors()
2131
+ live_storages_wrappers = list(self.current_node.path_live_weakrefs())
2132
+
2133
+ live_storages_weak_refs = [t() for t in live_storages_wrappers]
2134
+ ptrs_to_deallocate = self.current_node.data_ptrs_dead_since_invocation()
2135
+ torch._C._cuda_setCheckpointPoolState(
2136
+ device, state, stale_storages, live_storages_weak_refs
2137
+ )
2138
+
2139
+ # NB: deduplicate aliased outputs
2140
+ for ptr in set(ptrs_to_deallocate):
2141
+ torch._C._cuda_cudaCachingAllocator_raw_delete(ptr)
2142
+
2143
+ # Now the live blocks should be exactly equal to the live storages in private pool
2144
+ if config.triton.slow_path_cudagraph_asserts:
2145
+ check_memory_pool(
2146
+ self.device_index, self.cuda_graphs_thread_pool, live_storages_wrappers
2147
+ )
2148
+ for wrapper in live_storages_wrappers:
2149
+ assert wrapper()
2150
+ assert torch._C._has_Standard_Deleter(wrapper())
2151
+ assert wrapper.data_ptr() not in ptrs_to_deallocate
2152
+
2153
+ def live_cudagraph_pool_storages_in_curr_execution(
2154
+ self,
2155
+ ) -> List[StorageWeakRefPointer]:
2156
+ if self.current_node is None:
2157
+ return []
2158
+ # explicitly ignoring previous recorded outputs from past path
2159
+ return [t() for t in self.current_node.path_live_weakrefs()]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/hooks.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ from typing import Callable, List, TYPE_CHECKING
3
+
4
+ if TYPE_CHECKING:
5
+ import torch
6
+
7
+ # Executed in the order they're registered
8
+ INTERMEDIATE_HOOKS: List[Callable[[str, "torch.Tensor"], None]] = []
9
+
10
+
11
+ @contextlib.contextmanager
12
+ def intermediate_hook(fn):
13
+ INTERMEDIATE_HOOKS.append(fn)
14
+ try:
15
+ yield
16
+ finally:
17
+ INTERMEDIATE_HOOKS.pop()
18
+
19
+
20
+ def run_intermediate_hooks(name, val):
21
+ global INTERMEDIATE_HOOKS
22
+ hooks = INTERMEDIATE_HOOKS
23
+ INTERMEDIATE_HOOKS = []
24
+ try:
25
+ for hook in hooks:
26
+ hook(name, val)
27
+ finally:
28
+ INTERMEDIATE_HOOKS = hooks
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/ops_handler.py ADDED
@@ -0,0 +1,655 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ from typing import Any, Callable, Generic, Literal, Optional, Tuple, TypeVar, Union
3
+ from unittest.mock import patch
4
+
5
+ import sympy
6
+ from typing_extensions import Protocol
7
+
8
+ import torch
9
+ import torch.utils._pytree as pytree
10
+ from torch.fx.graph import inplace_methods, magic_methods
11
+ from .utils import IndentedBuffer, reduction_num_outputs, sympy_index_symbol, sympy_str
12
+
13
+ T = TypeVar("T")
14
+ StoreMode = Optional[Literal["atomic_add"]]
15
+ ReductionType = Literal[
16
+ "argmax",
17
+ "argmin",
18
+ "welford_reduce",
19
+ "welford_combine",
20
+ "any",
21
+ "max",
22
+ "min",
23
+ "prod",
24
+ "sum",
25
+ "xor_sum",
26
+ ]
27
+
28
+
29
+ def _arg_str(a) -> str:
30
+ if isinstance(a, sympy.Expr):
31
+ return sympy_str(a)
32
+ return str(a)
33
+
34
+
35
+ # NB: This is not done as a parent class, because our ops handlers
36
+ # implementations make heavy use of __getattr__ magic, and pre-existing
37
+ # stubs for methods would interfere with this mechanism.
38
+ #
39
+ # TODO: A superclass that does desugaring for operations like
40
+ # reciprocal/square might be useful.
41
+ class OpsHandler(Protocol[T]):
42
+ """
43
+ Protocol describing the set of valid operations on ``torch._inductor.virtualized.ops``,
44
+ as well as the contract for op handlers. The type T signifies the domain
45
+ of the abstract analysis AKA what all of the functions return / take as arguments
46
+ anywhere compute occurs.
47
+
48
+ While these operators are typically dtype polymorphic (e.g., you can use mul
49
+ on both integers and floats), they do NOT do promotion and usually return the
50
+ same dtype as the input. You are expected to have handled type promotion
51
+ during ATen decompositions. Most operators correspond exactly to pointwise
52
+ operations as defined by torch, so when in doubt about semantics, check the
53
+ corresponding torch documentation. These are all scalar operations (so they
54
+ are defined to operate on a single element at a time.)
55
+
56
+ For convenience, many operators take a src_dtype which indicates what the dtype
57
+ of the input argument is. Although in principle this can be derived by an
58
+ analysis, providing this for ops where it is useful helps avoid having to repeatedly
59
+ recompute dtype in code generation.
60
+
61
+ Note that this often describes a class of static methods, for stateless
62
+ ops handlers.
63
+
64
+ Handlers are often defined using ``__getattr__`` metaprogramming, which means
65
+ that you cannot declare that a type implements a protocol by inheriting from
66
+ it (as the type stubs count as attribute declarations and impede the getattr
67
+ magic method from being called). Instead, define a function that casts an
68
+ argument of your type to the protocol, which is sufficient to induce mypy to
69
+ test that the protocol is implemented correctly. Search for ``_typecheck_``
70
+ in this file to see some examples. If you see an obscure error where a
71
+ class doesn't implement a Protocol, but mypy doesn't say why, check to see
72
+ that ``__getattr__`` is typed correctly (typically, it is not possible to
73
+ type ``__getattr__`` without typing it as ``Callable[..., Any]``)
74
+ """
75
+
76
+ def constant(self, value: Union[bool, float, int], dtype: torch.dtype) -> T:
77
+ """Produces a scalar constant of type dtype."""
78
+ ...
79
+
80
+ def load_seed(self, name: str, offset: T):
81
+ """Computes inductor_prims.lookup_seed."""
82
+ ...
83
+
84
+ def rand(self, seed: T, offset: T) -> T:
85
+ """Computes inductor_prims.random with mode="rand". offset has dtype int32."""
86
+ ...
87
+
88
+ def randn(self, seed: T, offset: T) -> T:
89
+ """Computes inductor_prims.random with mode="randn". offset has dtype int32."""
90
+ ...
91
+
92
+ def randint64(self, seed: T, offset: T, low: T, high: T) -> T:
93
+ """Computes inductor_prims.randint. offset has dtype int32."""
94
+ ...
95
+
96
+ def masked(self, mask: T, body: Callable[[], T], other: T) -> T:
97
+ """
98
+ Computes body, but only perform loads/stores if the boolean mask
99
+ evaluates to true. For example, you would use this if you needed to
100
+ perform an indirect load that may not be valid on some elements;
101
+ without masking, invalid accesses can cause IMAs. When mask is true,
102
+ the result is the result of body; otherwise it is other.
103
+
104
+ Contrast this with ops.where, which can multiplex between two values
105
+ that have been unconditionally computed.
106
+ """
107
+ ...
108
+
109
+ def where(self, condition: T, input: T, other: T) -> T:
110
+ """
111
+ Computes torch.where: when condition is true, return input; otherwise return other.
112
+ """
113
+ ...
114
+
115
+ def index_expr(self, expr: sympy.Expr, dtype: torch.dtype) -> T:
116
+ """
117
+ Converts a sympy expression into a scalar of type dtype. expr is typically
118
+ an indexing expression, thus the name; however, it can also be used in
119
+ non-indexing situations.
120
+ """
121
+ ...
122
+
123
+ def to_dtype(
124
+ self, x: T, dtype: torch.dtype, src_dtype: Optional[torch.dtype] = None
125
+ ) -> T:
126
+ """
127
+ Convert x to dtype. src_dtype can be optionally set to specify what the original
128
+ dtype of x was, which can improve code generation (used by torch to(dtype=dtype)).
129
+ """
130
+ ...
131
+
132
+ def to_dtype_bitcast(self, x: T, dtype: torch.dtype, src_dtype: torch.dtype) -> T:
133
+ """
134
+ Reinterpret cast x to dtype (reinterpreting the bits in memory as another dtype.)
135
+ src_dtype must be the original type of x.
136
+ """
137
+ ...
138
+
139
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
140
+ # These operations are only available in a "kernel" context. Check
141
+ # torch._inductor.codegen.common.CSEProxy for their typical implementation
142
+ # in op handler (routing to their respective implementations in the kernel
143
+ # handler)
144
+ #
145
+ # Importantly, inside a kernel, indexing and mask variables are available
146
+ # in scope, which are typically used by sympy.Expr indexing.
147
+
148
+ def indirect_indexing(
149
+ self, x: T, size: sympy.Expr, check: bool = True
150
+ ) -> sympy.Expr:
151
+ """
152
+ Convert an integral x into a sympy.Expr that can be subsequently used in
153
+ indexing computation. 'size' represents an upper bound on the what valid
154
+ indexes can be; when 'check' is True, we check that the x is in bounds.
155
+
156
+ NB: This is typically mandatory to implement for any analysis, because you
157
+ MUST return a valid sympy.Expr of some sort (even if it's a meaningless symbol).
158
+ """
159
+ ...
160
+
161
+ def load(self, name: str, index: sympy.Expr) -> T:
162
+ """
163
+ Load from the memory location 'name', offset by some indexing expression 'index'.
164
+ """
165
+ ...
166
+
167
+ def store(
168
+ self,
169
+ name: str,
170
+ index: sympy.Expr,
171
+ value: T,
172
+ mode: StoreMode = None,
173
+ ) -> None:
174
+ """
175
+ Store 'value' to the memory location 'name' offset by 'expr'. If
176
+ specified, 'mode' can require the store to be an atomic addition.
177
+ """
178
+ ...
179
+
180
+ # TODO: Better explain how the "collective" semantics of these ops;
181
+ # remember that the input value is a scalar, you can't reduce on it in the
182
+ # traditional sense!
183
+ def reduction(
184
+ self,
185
+ dtype: torch.dtype,
186
+ src_dtype: torch.dtype,
187
+ reduction_type: ReductionType,
188
+ value: T,
189
+ ) -> Union[T, Tuple[T, ...]]:
190
+ """
191
+ Perform a 'reduction_type' reduction on 'value' of dtype 'src_dtype',
192
+ using 'dtype' as the accumulation dtype for the reduction. The result
193
+ is an intermediate computation which should be stored to the final
194
+ location using 'ops.store_reduction'.
195
+
196
+ Valid reduction types are . For Welford reduction types, this
197
+ function returns multiple outputs; consult reduction_num_outputs to
198
+ determine the amount in metaprogramming applications.
199
+ """
200
+ ...
201
+
202
+ # TODO: in practice, this seems to actually return None, but not returning
203
+ # a T makes common __getattr__ idioms not type correctly. Figure out if
204
+ # this should be returning something.
205
+ def store_reduction(self, name: str, index: sympy.Expr, value: T) -> T:
206
+ """
207
+ Store the fully accumulated result of 'reduction' to the memory
208
+ location 'name' offset by 'expr'.
209
+ """
210
+ ...
211
+
212
+ def scan(
213
+ self, dtype: torch.dtype, combine_fn: Callable[[T, T], T], value: T, init: int
214
+ ) -> T:
215
+ """
216
+ Perform an associative scan on 'value'.
217
+ """
218
+ # TODO: Improve the description with some pseudocode
219
+ ...
220
+
221
+ def bucketize(
222
+ self,
223
+ values: T,
224
+ offsets_name: str,
225
+ offsets_size: sympy.Expr,
226
+ indexing_dtype: torch.dtype,
227
+ right: bool,
228
+ ) -> T:
229
+ # See [Note: Inductor bucketize op]
230
+ ...
231
+
232
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
233
+ # The following ops have semantics that correspond exactly to the torch
234
+ # operation with the same corresponding name.
235
+
236
+ def abs(self, x0: T) -> T:
237
+ ...
238
+
239
+ def exp(self, x0: T) -> T:
240
+ ...
241
+
242
+ def exp2(self, x0: T) -> T:
243
+ ...
244
+
245
+ def expm1(self, x0: T) -> T:
246
+ ...
247
+
248
+ def sqrt(self, x0: T) -> T:
249
+ ...
250
+
251
+ def relu(self, x0: T) -> T:
252
+ ...
253
+
254
+ def minimum(self, x0: T, x1: T) -> T:
255
+ ...
256
+
257
+ def maximum(self, x0: T, x1: T) -> T:
258
+ ...
259
+
260
+ def cos(self, x0: T) -> T:
261
+ ...
262
+
263
+ def sin(self, x0: T) -> T:
264
+ ...
265
+
266
+ def lgamma(self, x0: T) -> T:
267
+ ...
268
+
269
+ def erf(self, x0: T) -> T:
270
+ ...
271
+
272
+ def cosh(self, x0: T) -> T:
273
+ ...
274
+
275
+ def sinh(self, x0: T) -> T:
276
+ ...
277
+
278
+ def acos(self, x0: T) -> T:
279
+ ...
280
+
281
+ def acosh(self, x0: T) -> T:
282
+ ...
283
+
284
+ def asin(self, x0: T) -> T:
285
+ ...
286
+
287
+ def asinh(self, x0: T) -> T:
288
+ ...
289
+
290
+ def atan2(self, x0: T, x1: T) -> T:
291
+ ...
292
+
293
+ def atan(self, x0: T) -> T:
294
+ ...
295
+
296
+ def atanh(self, x0: T) -> T:
297
+ ...
298
+
299
+ def copysign(self, x0: T, x1: T) -> T:
300
+ ...
301
+
302
+ def erfc(self, x0: T) -> T:
303
+ ...
304
+
305
+ def erfinv(self, x0: T) -> T:
306
+ ...
307
+
308
+ def frexp(self, x0: T):
309
+ ...
310
+
311
+ def hypot(self, x0: T, x1: T) -> T:
312
+ ...
313
+
314
+ def log10(self, x0: T) -> T:
315
+ ...
316
+
317
+ def nextafter(self, x0: T, x1: T) -> T:
318
+ ...
319
+
320
+ def logical_and(self, x0: T, x1: T) -> T:
321
+ ...
322
+
323
+ def logical_not(self, x0: T) -> T:
324
+ ...
325
+
326
+ def logical_or(self, x0: T, x1: T) -> T:
327
+ ...
328
+
329
+ def logical_xor(self, x0: T, x1: T) -> T:
330
+ ...
331
+
332
+ def bitwise_and(self, x0: T, x1: T) -> T:
333
+ ...
334
+
335
+ def bitwise_not(self, x0: T) -> T:
336
+ ...
337
+
338
+ def bitwise_or(self, x0: T, x1: T) -> T:
339
+ ...
340
+
341
+ def bitwise_xor(self, x0: T, x1: T) -> T:
342
+ ...
343
+
344
+ def bitwise_left_shift(self, x0: T, x1: T) -> T:
345
+ ...
346
+
347
+ def bitwise_right_shift(self, x0: T, x1: T) -> T:
348
+ ...
349
+
350
+ def rsqrt(self, x0: T) -> T:
351
+ ...
352
+
353
+ def log1p(self, x0: T) -> T:
354
+ ...
355
+
356
+ def tan(self, x0: T) -> T:
357
+ ...
358
+
359
+ def tanh(self, x0: T) -> T:
360
+ ...
361
+
362
+ def sigmoid(self, x0: T) -> T:
363
+ ...
364
+
365
+ def signbit(self, x0: T) -> T:
366
+ ...
367
+
368
+ def fmod(self, x0: T, x1: T) -> T:
369
+ ...
370
+
371
+ def log(self, x0: T) -> T:
372
+ ...
373
+
374
+ def isinf(self, x0: T) -> T:
375
+ ...
376
+
377
+ def isnan(self, x0: T) -> T:
378
+ ...
379
+
380
+ def round(self, x0: T) -> T:
381
+ ...
382
+
383
+ def floor(self, x0: T) -> T:
384
+ ...
385
+
386
+ def sign(self, x0: T) -> T:
387
+ ...
388
+
389
+ def to_int(self, x0: T) -> T:
390
+ ...
391
+
392
+ def trunc(self, x0: T) -> T:
393
+ ...
394
+
395
+ def truncdiv(self, x0: T, x1: T) -> T:
396
+ ...
397
+
398
+ def ceil(self, x0: T) -> T:
399
+ ...
400
+
401
+ def neg(self, x0: T) -> T:
402
+ ...
403
+
404
+ def reciprocal(self, x0: T) -> T:
405
+ ...
406
+
407
+ def eq(self, x0: T, x1: T) -> T:
408
+ ...
409
+
410
+ def ne(self, x0: T, x1: T) -> T:
411
+ ...
412
+
413
+ def lt(self, x0: T, x1: T) -> T:
414
+ ...
415
+
416
+ def gt(self, x0: T, x1: T) -> T:
417
+ ...
418
+
419
+ def le(self, x0: T, x1: T) -> T:
420
+ ...
421
+
422
+ def ge(self, x0: T, x1: T) -> T:
423
+ ...
424
+
425
+ def add(self, x0: T, x1: T) -> T:
426
+ ...
427
+
428
+ def sub(self, x0: T, x1: T) -> T:
429
+ ...
430
+
431
+ def mul(self, x0: T, x1: T) -> T:
432
+ ...
433
+
434
+ def floordiv(self, x0: T, x1: T) -> T:
435
+ ...
436
+
437
+ def truediv(self, x0: T, x1: T) -> T:
438
+ ...
439
+
440
+ def div(self, x0: T, x1: T) -> T:
441
+ ...
442
+
443
+ def mod(self, x0: T, x1: T) -> T:
444
+ ...
445
+
446
+ def pow(self, x0: T, x1: T) -> T:
447
+ ...
448
+
449
+ def and_(self, x0: T, x1: T) -> T:
450
+ ...
451
+
452
+ def or_(self, x0: T, x1: T) -> T:
453
+ ...
454
+
455
+ def xor(self, x0: T, x1: T) -> T:
456
+ ...
457
+
458
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
459
+ # In CUDA, optimized implementations of other mathematical operations are
460
+ # offered separately via libdevice for double precision computation (in
461
+ # Triton, these go to tl.math rather than tl). We lower to these
462
+ # operators when doing FP64 on CUDA. Note that some operators
463
+ # unconditional go to tl.math.
464
+ #
465
+ # TODO(ezyang): Is this really the best way to do this? What if we have
466
+ # abs internally route to tl.math automatically when given a double
467
+ # precision input? One reason is that when doing codegen, we often don't
468
+ # know what the dtype of the inputs are! (In principle we do know, but
469
+ # for many analyses it's not conveniently available.)
470
+
471
+ def libdevice_abs(self, x0: T) -> T:
472
+ ...
473
+
474
+ def libdevice_exp(self, x0: T) -> T:
475
+ ...
476
+
477
+ def libdevice_sqrt(self, x0: T) -> T:
478
+ ...
479
+
480
+ def libdevice_cos(self, x0: T) -> T:
481
+ ...
482
+
483
+ def libdevice_sin(self, x0: T) -> T:
484
+ ...
485
+
486
+ def libdevice_sigmoid(self, x0: T) -> T:
487
+ ...
488
+
489
+ def libdevice_log(self, x0: T) -> T:
490
+ ...
491
+
492
+
493
+ class MockHandler:
494
+ def __getattr__(self, name):
495
+ if name == "name":
496
+ return "MockHandler"
497
+
498
+ def inner(*args, **kwargs):
499
+ fargs = [_arg_str(a) for a in args]
500
+ fargs.extend(f"{k}={v}" for k, v in kwargs.items())
501
+ return f"ops.{name}({', '.join(fargs)})"
502
+
503
+ return inner
504
+
505
+ @staticmethod
506
+ def masked(mask, body, other) -> str:
507
+ return f"ops.masked({mask}, {body()}, {other})"
508
+
509
+ @staticmethod
510
+ def frexp(x):
511
+ return (f"ops.frexp({x})[0]", f"ops.frexp({x})[1]")
512
+
513
+ @staticmethod
514
+ def indirect_indexing(index_var, size, check=True) -> sympy.Symbol:
515
+ return sympy_index_symbol(f"({str(index_var)})")
516
+
517
+ @classmethod
518
+ def _init_cls(cls):
519
+ def make_handler(format_string):
520
+ @staticmethod # type: ignore[misc]
521
+ def inner(*args):
522
+ return format_string.format(*args)
523
+
524
+ return inner
525
+
526
+ for name, format_string in itertools.chain(
527
+ magic_methods.items(), inplace_methods.items()
528
+ ):
529
+ setattr(cls, name, make_handler(format_string))
530
+
531
+
532
+ MockHandler._init_cls()
533
+
534
+
535
+ # Use mypy to check protocol implemented correctly
536
+ def _typecheck_MockHandler(h: MockHandler) -> OpsHandler[str]:
537
+ return h
538
+
539
+
540
+ class KernelFormatterHandler:
541
+ def __init__(self, parent_handler):
542
+ self.parent_handler = parent_handler
543
+ self.output = IndentedBuffer(1)
544
+ self.var_counter = itertools.count()
545
+
546
+ @staticmethod
547
+ def ir_to_string(ir_fn, index, rindex=None) -> str:
548
+ from .ir import FlexibleLayout
549
+ from .virtualized import V
550
+
551
+ args = [index, rindex] if rindex is not None else [index]
552
+ names = ["index", "rindex"] if rindex is not None else ["index"]
553
+ formatter = KernelFormatterHandler(MockHandler())
554
+
555
+ with formatter.output.indent(-1):
556
+ formatter.output.writeline(f"def inner_fn({', '.join(names)}):")
557
+ for name, arg in zip(names, args):
558
+ if arg:
559
+ lhs = ", ".join(
560
+ [
561
+ str("_" if isinstance(v, (int, sympy.Integer)) else v)
562
+ for v in arg
563
+ ]
564
+ )
565
+ formatter.output.writeline(f"{lhs} = {name}")
566
+
567
+ with V.set_ops_handler(formatter), patch.object(
568
+ FlexibleLayout, "allow_indexing", True
569
+ ):
570
+ result = ir_fn(*args)
571
+ return formatter.getvalue(result)
572
+
573
+ def __getattr__(self, name) -> Callable[..., Any]:
574
+ def inner(*args, **kwargs):
575
+ line = getattr(self.parent_handler, name)(*args, **kwargs)
576
+ if name == "indirect_indexing":
577
+ return line
578
+
579
+ def write(line):
580
+ # replace line with a new variable name
581
+ varname = f"tmp{next(self.var_counter)}"
582
+ self.output.writeline(f"{varname} = {line}")
583
+ return varname
584
+
585
+ return pytree.tree_map(write, line)
586
+
587
+ return inner
588
+
589
+ def reduction(
590
+ self,
591
+ dtype: torch.dtype,
592
+ src_dtype: torch.dtype,
593
+ reduction_type: ReductionType,
594
+ value: Union[str, Tuple[str, ...]],
595
+ ) -> Union[str, Tuple[str, ...]]:
596
+ line = self.parent_handler.reduction(dtype, src_dtype, reduction_type, value)
597
+ num_values = reduction_num_outputs(reduction_type)
598
+ varnames = [f"tmp{next(self.var_counter)}" for _ in range(num_values)]
599
+ self.output.writeline(f"{','.join(varnames)} = {line}")
600
+ return tuple(varnames) if num_values > 1 else varnames[0]
601
+
602
+ def getvalue(self, result):
603
+ self.output.writeline(f"return {result}")
604
+ return self.output.getvalue()
605
+
606
+
607
+ # Use mypy to check protocol implemented correctly
608
+ def _typecheck_KernelFormatterHandler(h: KernelFormatterHandler) -> OpsHandler[str]:
609
+ return h
610
+
611
+
612
+ class WrapperHandler(Generic[T]):
613
+ def __init__(self, inner: OpsHandler[T]):
614
+ self._inner = inner
615
+
616
+ def __getattr__(self, item):
617
+ return getattr(self._inner, item)
618
+
619
+
620
+ # Use mypy to check protocol implemented correctly
621
+ def _typecheck_WrapperHandler(h: WrapperHandler[T]) -> OpsHandler[T]:
622
+ return h
623
+
624
+
625
+ class OpCounterCSE:
626
+ """Shim to count how many ops are used"""
627
+
628
+ def __init__(self, inner):
629
+ super().__init__()
630
+ self.parent_handler = inner
631
+ self.op_count = 0
632
+ self.var_names = {}
633
+
634
+ def __getattr__(self, name):
635
+ def inner(*args, **kwargs):
636
+ val = getattr(self.parent_handler, name)(*args, **kwargs)
637
+ if name == "indirect_indexing":
638
+ return val
639
+
640
+ def count(val):
641
+ if val not in self.var_names:
642
+ varname = f"tmp{self.op_count}"
643
+ self.op_count += 1
644
+ self.var_names[val] = varname
645
+ return varname
646
+ else:
647
+ return self.var_names[val]
648
+
649
+ return pytree.tree_map(count, val)
650
+
651
+ return inner
652
+
653
+
654
+ def _typecheck_OpCounterCSE(h: OpCounterCSE) -> OpsHandler[str]:
655
+ return h
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/optimize_indexing.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import sympy
4
+
5
+ import torch
6
+ from torch.utils._sympy.value_ranges import ValueRanges
7
+ from .ir import LoopBody
8
+ from .utils import dominated_nodes
9
+
10
+
11
+ def val_expressable_in_32_bits(val):
12
+ if getattr(val, "is_Boolean", False):
13
+ return True
14
+
15
+ if isinstance(val, sympy.Expr):
16
+ assert val.is_number
17
+ if val.is_Integer or val.is_Boolean:
18
+ val = int(val)
19
+ else:
20
+ val = float(val)
21
+
22
+ # bound within mantissa
23
+ if isinstance(val, float):
24
+ return val <= (2**24) and val >= -(2**24)
25
+
26
+ if isinstance(val, int):
27
+ iinfo = torch.iinfo(torch.int32)
28
+ return val <= iinfo.max and val >= iinfo.min
29
+
30
+ raise Exception(f"Unexpected value {val}")
31
+
32
+
33
+ def range_expressable_in_32_bits(range):
34
+ return val_expressable_in_32_bits(range.lower) and val_expressable_in_32_bits(
35
+ range.upper
36
+ )
37
+
38
+
39
+ def try_to_reduce_precision(node, bounds, indirect_vars, indices, replacement_vals):
40
+ # if a downstream use of a node explicitly converts to int32, or float16/float32/float64,
41
+ # then it's precision is set for that chain of uses, and we don't need to consider those
42
+ # dominated values
43
+ def skip_filter(node):
44
+ return node.target == "to_dtype" and node.args[2] in (
45
+ torch.int32,
46
+ torch.float32,
47
+ torch.float64,
48
+ )
49
+
50
+ # TODO - there are dominated uses whose dtype does not depend on whether
51
+ # we reduce the precision here, e.g. add(int64, int64) one of the args can be reduced to
52
+ # int32 without changing the output precision of the node. this case hasn't shown up
53
+ for dominated in dominated_nodes([node], skip_filter):
54
+ if dominated.target in ["store", "output"]:
55
+ continue
56
+
57
+ if isinstance(dominated.target, str) and "set_indirect" in dominated.target:
58
+ idx = int(dominated.target[len("set_indirect") :])
59
+ indirect_var = indirect_vars[idx]
60
+
61
+ # We check that we can compute all the indices it's involved in with int32
62
+ for index, expr in indices.items():
63
+ if indirect_var in expr.free_symbols:
64
+ index_val = replacement_vals[index]
65
+
66
+ if math.isinf(index_val.lower) or math.isinf(index_val.upper):
67
+ return
68
+
69
+ # all indices are integers, so make sure that we
70
+ # use the bounds of integers instead of floats.
71
+ # TODO - not sure if we should be doing int/float casts while tracing,
72
+ # might interfere with sympy.
73
+
74
+ index_val_int = ValueRanges[sympy.Expr](
75
+ int(index_val.lower), int(index_val.upper)
76
+ )
77
+ if not range_expressable_in_32_bits(index_val_int):
78
+ return
79
+
80
+ if not range_expressable_in_32_bits(bounds[dominated]):
81
+ return
82
+
83
+ args = list(node.args)
84
+ args[2] = torch.int32
85
+ node.args = tuple(args)
86
+
87
+
88
+ def indexing_dtype_strength_reduction(loop_body: LoopBody):
89
+ """
90
+ Performs Value Range Analysis on LoopBody's fx graph to reduce precision of
91
+ intermediaries from int64 to int32
92
+ """
93
+ bv = loop_body.bounds()
94
+
95
+ int64_dtype_nodes = [
96
+ node
97
+ for node in loop_body.get_nodes()
98
+ if (
99
+ node.target == "to_dtype"
100
+ and node.args[2] == torch.int64
101
+ and node not in bv.unbounded_vars
102
+ )
103
+ ]
104
+ if not int64_dtype_nodes:
105
+ return
106
+
107
+ bounds = bv.get_bounds()
108
+
109
+ # TODO - if dominated node of one to_dtype is not expressible in int32,
110
+ # we should short circuit another to_dtype node if that node also dominates
111
+ for node in int64_dtype_nodes:
112
+ try_to_reduce_precision(
113
+ node,
114
+ bounds,
115
+ loop_body.indirect_vars,
116
+ loop_body.indexing_exprs,
117
+ bv.replacement_vals,
118
+ )
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/triton_heuristics.py ADDED
@@ -0,0 +1,1527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import builtins
2
+ import copy
3
+ import functools
4
+ import hashlib
5
+ import inspect
6
+ import json
7
+ import logging
8
+ import math
9
+ import operator
10
+ import os
11
+ import os.path
12
+ import re
13
+ import threading
14
+ from enum import auto, Enum
15
+ from typing import Any, Callable, Dict, List, Optional, Set, Tuple
16
+
17
+ import torch
18
+
19
+ import torch.autograd.profiler as autograd_profiler
20
+ from torch._dynamo.device_interface import get_interface_for_device
21
+ from torch._dynamo.utils import dynamo_timed, get_first_attr
22
+ from torch.utils._triton import has_triton_package
23
+
24
+ from . import config
25
+ from .codecache import cache_dir, CudaKernelParamCache
26
+ from .coordinate_descent_tuner import CoordescTuner
27
+
28
+ from .ir import ReductionHint, TileHint
29
+ from .utils import (
30
+ ceildiv,
31
+ conditional_product,
32
+ create_bandwidth_info_str,
33
+ do_bench,
34
+ get_max_y_grid,
35
+ get_num_bytes,
36
+ next_power_of_2,
37
+ triton_config_to_hashable,
38
+ )
39
+
40
+
41
+ log = logging.getLogger(__name__)
42
+
43
+ if has_triton_package():
44
+ import triton
45
+ from triton import Config
46
+ from triton.runtime.autotuner import OutOfResources
47
+ from triton.runtime.jit import KernelInterface
48
+
49
+ try:
50
+ from triton.compiler.compiler import ASTSource
51
+ except ImportError:
52
+ ASTSource = None
53
+ else:
54
+ Config = object
55
+ triton = None
56
+ KernelInterface = object
57
+ OutOfResources = object
58
+ ASTSource = None
59
+
60
+
61
+ _NUM_THREADS_PER_WARP = 32
62
+
63
+
64
+ class HeuristicType(Enum):
65
+ PERSISTENT_REDUCTION = auto()
66
+ POINTWISE = auto()
67
+ REDUCTION = auto()
68
+ SPLIT_SCAN = auto()
69
+ TEMPLATE = auto()
70
+ USER_AUTOTUNE = auto()
71
+
72
+
73
+ class AutotuneHint(Enum):
74
+ ELEMENTS_PER_WARP_32 = 0
75
+
76
+ # Triton codegen tries to codegen set of AutotuneHints.
77
+ # Enum.__repr__ looks like "<AutotuneHint.ELEMENTS_PER_WARP_32: 0>""
78
+ # which isn't valid python.
79
+ # Enum.__str__ will just return "AutotuneHint.ELEMENTS_PER_WARP_32".
80
+ __repr__ = Enum.__str__
81
+
82
+
83
+ def autotune_hints_to_configs(
84
+ hints: Set[AutotuneHint], size_hints, block_size: int
85
+ ) -> List[Config]:
86
+ """
87
+ AutotuneHints can be attached to the metadata of triton kernels for providing
88
+ suggestions about what to try for autotuning. One reason to do this is if there are
89
+ some configs that are only useful in specific scenarios, in which case we can avoid
90
+ wasting compile time on autotuning unless we know we are in one of those scenarios.
91
+
92
+ Based on those hints, this function will generate a list of additional autotuning
93
+ configs to try.
94
+ """
95
+ xyz_options: Tuple[Tuple[int, Optional[int], Optional[int]], ...]
96
+ configs = []
97
+
98
+ for hint in hints:
99
+ if hint == AutotuneHint.ELEMENTS_PER_WARP_32:
100
+ if len(size_hints) == 1:
101
+ xyz_options = ((block_size // 4, None, None),)
102
+ elif len(size_hints) == 2:
103
+ xyz_options = ((block_size // 4, 1, None), (1, block_size // 4, None))
104
+ elif len(size_hints) == 3:
105
+ xyz_options = (
106
+ (block_size // 4, 1, 1),
107
+ (1, block_size // 4, 1),
108
+ (1, 1, block_size // 4),
109
+ )
110
+ for xyz in xyz_options:
111
+ configs.append(
112
+ triton_config(
113
+ size_hints,
114
+ *xyz,
115
+ num_elements_per_warp=32,
116
+ )
117
+ )
118
+
119
+ return configs
120
+
121
+
122
+ def disable_pointwise_autotuning():
123
+ # Autotuning can give different benchmarking results from run to run, and
124
+ # therefore we disable autotuning when use_deterministic flag is on.
125
+ if torch.are_deterministic_algorithms_enabled():
126
+ return True
127
+ return not config.triton.autotune_pointwise
128
+
129
+
130
+ class CachingAutotuner(KernelInterface):
131
+ """
132
+ Simplified version of Triton autotuner that has no invalidation
133
+ key and caches the best config to disk to improve cold start times.
134
+ Unlike the main triton Autotuner, this version can precompile all
135
+ configs, and does not rely on the Triton JIT.
136
+ """
137
+
138
+ def __init__(
139
+ self,
140
+ fn,
141
+ triton_meta, # passed directly to triton
142
+ configs,
143
+ save_cache_hook,
144
+ mutated_arg_names,
145
+ heuristic_type,
146
+ size_hints=None,
147
+ inductor_meta=None, # metadata not relevant to triton
148
+ custom_kernel=False, # whether the kernel is inductor-generated or custom
149
+ ):
150
+ super().__init__()
151
+
152
+ assert len(configs) > 0, "Non-empty TritonConfig list required for compiling"
153
+ self.fn = fn
154
+ self.triton_meta = triton_meta
155
+ self.inductor_meta = {} if inductor_meta is None else inductor_meta
156
+ self.save_cache_hook = save_cache_hook
157
+ self.mutated_arg_names = mutated_arg_names
158
+ self.configs = configs
159
+ self.heuristic_type = heuristic_type
160
+ self.custom_kernel = custom_kernel
161
+ self.cuda_kernel_saved = False
162
+
163
+ # Align the default design that default as cuda
164
+ self.device_type = (
165
+ triton_meta["device_type"] if "device_type" in triton_meta else "cuda"
166
+ )
167
+ self.gpu_device = get_interface_for_device(self.device_type)
168
+
169
+ if log.isEnabledFor(logging.DEBUG):
170
+ log.debug(
171
+ "CachingAutotuner gets %d configs for %s",
172
+ len(self.configs),
173
+ self.fn.__name__,
174
+ )
175
+ for c in self.configs:
176
+ log.debug(c)
177
+
178
+ self.launchers = []
179
+ self.lock = threading.Lock()
180
+ if os.getenv("TRITON_CACHE_DIR") is None:
181
+ os.environ["TRITON_CACHE_DIR"] = os.path.join(
182
+ cache_dir(),
183
+ "triton",
184
+ str(self.triton_meta.get("device", 0)),
185
+ )
186
+
187
+ self.size_hints = size_hints
188
+ self.coordesc_tuner = CoordescTuner(
189
+ is_mm=False, name=self.fn.__name__, size_hints=size_hints
190
+ )
191
+
192
+ # pre-create the profiler context manager to reduce latency
193
+ self.record_function_ctx = torch._C._profiler._RecordFunctionFast(
194
+ self.inductor_meta.get("kernel_name", "triton kernel")
195
+ )
196
+
197
+ def precompile(self, warm_cache_only_with_cc=None):
198
+ with self.lock:
199
+ if self.launchers:
200
+ return
201
+ self.launchers = []
202
+ compiled_binaries = []
203
+ if not self.configs:
204
+ raise RuntimeError("No triton configs are available")
205
+
206
+ for c in self.configs:
207
+ try:
208
+ compiled_binary, launcher = self._precompile_config(
209
+ c, warm_cache_only_with_cc
210
+ )
211
+ except OutOfResources:
212
+ # Skip the config if we run out of resource
213
+ continue
214
+ self.launchers.append(launcher)
215
+ compiled_binaries.append(compiled_binary)
216
+
217
+ if len(self.launchers) == 0:
218
+ raise RuntimeError(
219
+ "No valid triton configs. Report a fatal compilation error"
220
+ )
221
+
222
+ seen_configs = set(self.configs)
223
+
224
+ device_prop = self.gpu_device.Worker.get_device_properties(
225
+ self.triton_meta["device"]
226
+ )
227
+ if (
228
+ config.dynamic_scale_rblock
229
+ and self.heuristic_type == HeuristicType.REDUCTION
230
+ and self.size_hints is not None
231
+ # Disable for AMDGPU as Triton is not ready to return n_regs for a compiled_binary.
232
+ and torch.version.hip is None
233
+ and device_prop.major >= 8
234
+ ):
235
+ for triton_config, compiled_binary in zip(
236
+ self.configs, compiled_binaries
237
+ ):
238
+ assert len(self.size_hints) == 2
239
+ xblock = triton_config.kwargs.get("XBLOCK", 1)
240
+ rblock = triton_config.kwargs["RBLOCK"]
241
+ total_block = (self.size_hints[0] + xblock - 1) // xblock
242
+ nreg = getattr(compiled_binary, "n_regs", None)
243
+ if nreg is None:
244
+ continue
245
+
246
+ # make sure rblock is not too small
247
+ if rblock <= 64:
248
+ continue
249
+
250
+ # each SM of A100 has 65536 32-bit registers. To maximize
251
+ # the theoretical occupancy, we need run 2048 threads on each
252
+ # SM. So each thread should use no more than 65536 / 2048
253
+ # = 32 registers. In cases where occupancy matters, and each
254
+ # thread uses too many registers, reduce RBLOCK to reduce
255
+ # the register usage.
256
+ # For kernel https://gist.github.com/shunting314/e4cccc031fe30d378b9b23c08c238cbd
257
+ # from PLBartForCausalLM, latency improve from
258
+ # 7.795ms to 4.883ms.
259
+ #
260
+ if (
261
+ nreg
262
+ <= device_prop.regs_per_multiprocessor
263
+ // device_prop.max_threads_per_multi_processor
264
+ ):
265
+ continue
266
+
267
+ nreg_per_warp = nreg * 32
268
+ nreg_per_block = nreg_per_warp * triton_config.num_warps
269
+
270
+ # Previously we set max_blocks_per_sm to 'max_threads_per_multi_processo / (32 * num_warps)'
271
+ # The formula below is a tighter upper bound since we have the assumption that
272
+ # nreg > device_prop.regs_per_multiprocessor // device_prop.max_threads_per_multi_processor
273
+ # due to the if condition above and:
274
+ # regs_per_multiprocessor / nreg_per_block
275
+ # = regs_per_multiprocessor / (nreg * 32 * num_warps)
276
+ # < regs_per_multiprocessor / ((regs_per_multiprocessor / max_threads_per_multi_processor) * 32 * num_warps)
277
+ # = max_threads_per_multi_processor / (32 * num_warps)
278
+ # Using a tigher upper bound can reveal more optimization opportunities.
279
+ max_blocks_per_sm = max(
280
+ device_prop.regs_per_multiprocessor // nreg_per_block, 1
281
+ )
282
+
283
+ if (
284
+ total_block
285
+ <= max_blocks_per_sm * device_prop.multi_processor_count
286
+ ):
287
+ # no need to improve occupancy
288
+ continue
289
+ new_config = copy.deepcopy(triton_config)
290
+ new_config.kwargs["RBLOCK"] = rblock // 2
291
+ if new_config in seen_configs:
292
+ continue
293
+ seen_configs.add(new_config)
294
+ self.launchers.append(
295
+ self._precompile_config(new_config, warm_cache_only_with_cc)[1]
296
+ )
297
+ self.configs = None
298
+
299
+ def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: Optional[int]):
300
+ """Ahead of time compile a given autotuner config."""
301
+ compile_meta = copy.deepcopy(self.triton_meta)
302
+ for k, v in cfg.kwargs.items():
303
+ compile_meta["constants"][self.fn.arg_names.index(k)] = v
304
+ compile_meta["num_warps"] = cfg.num_warps
305
+ compile_meta["num_stages"] = cfg.num_stages
306
+ compile_meta["debug"] = (
307
+ config.assert_indirect_indexing and torch.version.hip is None
308
+ )
309
+
310
+ # Setting device_type="hip" required on ROCm to pass down to triton
311
+ compile_meta["device_type"] = (
312
+ self.device_type if torch.version.hip is None else "hip"
313
+ )
314
+
315
+ if warm_cache_only_with_cc:
316
+ cc = warm_cache_only_with_cc
317
+ else:
318
+ # Use device_type 'cuda' for both cuda and hip devices to retrieve
319
+ # the compute capability.
320
+ device_type = self.device_type if torch.version.hip is None else "cuda"
321
+ device_id = compile_meta["device"]
322
+ device = torch.device(device_type, device_id)
323
+ cc = self.gpu_device.get_compute_capability(device)
324
+
325
+ compile_meta["cc"] = cc
326
+
327
+ if ASTSource:
328
+ compile_args = (
329
+ ASTSource(
330
+ self.fn,
331
+ compile_meta["signature"],
332
+ compile_meta["constants"],
333
+ compile_meta["configs"][0],
334
+ ),
335
+ )
336
+
337
+ target = (compile_meta["device_type"], cc)
338
+ options = {
339
+ "num_warps": compile_meta["num_warps"],
340
+ "num_stages": compile_meta["num_stages"],
341
+ "debug": compile_meta["debug"],
342
+ }
343
+ compile_kwargs = {
344
+ "target": target,
345
+ "options": options,
346
+ }
347
+ else:
348
+ compile_args = (self.fn,)
349
+ compile_kwargs = compile_meta
350
+
351
+ if warm_cache_only_with_cc:
352
+ return (
353
+ triton.compile(*compile_args, **compile_kwargs),
354
+ None,
355
+ )
356
+
357
+ # load binary to the correct device
358
+ with self.gpu_device.device(compile_meta["device"]): # type: ignore[attr-defined]
359
+ # need to initialize context
360
+ self.gpu_device.synchronize(self.gpu_device.current_device())
361
+
362
+ try:
363
+ binary = triton.compile(*compile_args, **compile_kwargs)
364
+ except Exception:
365
+ log.exception(
366
+ "Triton compilation failed: %s\n%s\nmetadata: %s",
367
+ self.inductor_meta.get("kernel_name", "triton_"),
368
+ self.fn.src,
369
+ compile_meta,
370
+ )
371
+ raise
372
+ binary._init_handles()
373
+
374
+ call_args = [
375
+ arg
376
+ for i, arg in enumerate(self.fn.arg_names)
377
+ if i not in self.fn.constexprs
378
+ ]
379
+ def_args = [name for name in self.fn.arg_names if name not in cfg.kwargs]
380
+
381
+ scope = {
382
+ "grid_meta": cfg.kwargs,
383
+ "bin": binary,
384
+ "launch_enter_hook": binary.launch_enter_hook,
385
+ "launch_exit_hook": binary.launch_exit_hook,
386
+ "metadata": binary.metadata,
387
+ "torch": torch,
388
+ "set_device": self.gpu_device.set_device,
389
+ "current_device": self.gpu_device.current_device,
390
+ }
391
+
392
+ scope["runner"] = get_first_attr(binary, "run", "c_wrapper")
393
+ scope["function"] = get_first_attr(binary, "function", "cu_function")
394
+ scope["cta_args"] = (
395
+ (binary.num_ctas, *get_first_attr(binary, "cluster_dims", "clusterDims"))
396
+ if hasattr(binary, "num_ctas")
397
+ else (
398
+ (binary.metadata.num_ctas, *binary.metadata.cluster_dims)
399
+ if hasattr(binary, "metadata")
400
+ else ()
401
+ )
402
+ )
403
+ scope["num_warps"] = (
404
+ binary.num_warps
405
+ if hasattr(binary, "num_warps")
406
+ else binary.metadata.num_warps
407
+ )
408
+ binary_shared = (
409
+ binary.shared if hasattr(binary, "shared") else binary.metadata.shared
410
+ )
411
+ scope["shared"] = binary_shared
412
+
413
+ exec(
414
+ f"""
415
+ def launcher({', '.join(def_args)}, grid, stream):
416
+ if callable(grid):
417
+ grid_0, grid_1, grid_2 = grid(grid_meta)
418
+ else:
419
+ grid_0, grid_1, grid_2 = grid
420
+
421
+ runner(grid_0, grid_1, grid_2, num_warps,
422
+ *cta_args, shared,
423
+ stream, function,
424
+ launch_enter_hook,
425
+ launch_exit_hook,
426
+ metadata,
427
+ {', '.join(call_args)})
428
+ return bin
429
+ """.lstrip(),
430
+ scope,
431
+ )
432
+
433
+ launcher = scope["launcher"]
434
+ launcher.config = cfg
435
+ launcher.n_regs = getattr(binary, "n_regs", None)
436
+ launcher.n_spills = getattr(binary, "n_spills", None)
437
+ launcher.shared = binary_shared
438
+ launcher.store_cubin = config.triton.store_cubin
439
+ # store this global variable to avoid the high overhead of reading it when calling run
440
+ if launcher.store_cubin:
441
+ launcher.fn = self.fn
442
+ launcher.bin = binary
443
+
444
+ return binary, launcher
445
+
446
+ def bench(self, launcher, *args, grid, **kwargs):
447
+ """Measure the performance of a given launcher"""
448
+ # we don't skip configs wiht spilled registers when auto-tuning custom
449
+ # (user-written) Triton kernels, as (i) we don't have any knowledge or
450
+ # control over the kernel code; (ii) there is empirical evidence that
451
+ # for some (complicated) custom Triton kernels, a register-spilling
452
+ # config may yield the best latency.
453
+ if not self.custom_kernel and launcher.n_spills > config.triton.spill_threshold:
454
+ log.debug(
455
+ "Skip config %s because of register spilling: %d",
456
+ launcher.config,
457
+ launcher.n_spills,
458
+ )
459
+ return float("inf")
460
+
461
+ stream = self.gpu_device.get_raw_stream( # type: ignore[call-arg]
462
+ self.gpu_device.current_device()
463
+ )
464
+
465
+ def kernel_call():
466
+ if launcher.config.pre_hook is not None:
467
+ launcher.config.pre_hook(
468
+ {**dict(zip(self.arg_names, args)), **launcher.config.kwargs}
469
+ )
470
+
471
+ cloned_args, cloned_kwargs = self.clone_args(*args, **kwargs)
472
+ launcher(
473
+ *cloned_args,
474
+ **cloned_kwargs,
475
+ grid=grid,
476
+ stream=stream,
477
+ )
478
+
479
+ return do_bench(kernel_call, rep=40, fast_flush=True)
480
+
481
+ def clone_args(self, *args, **kwargs) -> Tuple[List[Any], Dict[str, Any]]:
482
+ from .compile_fx import clone_preserve_strides
483
+
484
+ # clone inplace buffers to avoid autotune contaminating them if
485
+ # the kernel does in-place stores. avoid cloning other buffers because
486
+ # it leads to increase memory use
487
+ cloned_args = []
488
+ for i, arg in enumerate(args):
489
+ if self.fn.arg_names[i] in self.mutated_arg_names:
490
+ assert isinstance(arg, torch.Tensor)
491
+ cloned_args.append(clone_preserve_strides(arg))
492
+ else:
493
+ cloned_args.append(arg)
494
+
495
+ cloned_kwargs: Dict[str, Any] = {}
496
+ for name, arg in kwargs.items():
497
+ if name in self.mutated_arg_names:
498
+ assert isinstance(arg, torch.Tensor)
499
+ cloned_kwargs[name] = clone_preserve_strides(arg)
500
+ else:
501
+ cloned_kwargs[name] = arg
502
+
503
+ return cloned_args, cloned_kwargs
504
+
505
+ @dynamo_timed
506
+ def benchmark_all_configs(self, *args, **kwargs):
507
+ timings = {
508
+ launcher: self.bench(launcher, *args, **kwargs)
509
+ for launcher in self.launchers
510
+ }
511
+
512
+ for k, v in timings.items():
513
+ self.coordesc_tuner.cache_benchmark_result(k.config, v)
514
+
515
+ if log.isEnabledFor(logging.DEBUG):
516
+ log.debug("Benchmark all input configs for %s, get:", self.fn.__name__)
517
+ for k, v in timings.items():
518
+ log.debug(
519
+ "%s: %f, nreg %d, nspill %d, #shared-mem %s",
520
+ k.config,
521
+ v,
522
+ k.n_regs,
523
+ k.n_spills,
524
+ k.shared,
525
+ )
526
+
527
+ return timings
528
+
529
+ def autotune_to_one_config(self, *args, **kwargs):
530
+ """Do the actual autotuning"""
531
+ timings = self.benchmark_all_configs(*args, **kwargs)
532
+ self.launchers = [builtins.min(timings, key=timings.get)]
533
+ if self.save_cache_hook:
534
+ self.save_cache_hook(self.launchers[0].config)
535
+
536
+ def save_cuda_kernel(self, grid, stream, launcher):
537
+ if callable(grid):
538
+ grid_x, grid_y, grid_z = grid(launcher.config.kwargs)
539
+ else:
540
+ grid_x, grid_y, grid_z = grid
541
+
542
+ key = self.inductor_meta.get("kernel_name", None) # unique kernel name
543
+ assert key is not None, "kernel_name can not be None"
544
+ params = {
545
+ "mangled_name": launcher.bin.metadata.name
546
+ if hasattr(launcher.bin.metadata, "name")
547
+ else launcher.bin.metadata["name"],
548
+ "grid_x": grid_x,
549
+ "grid_y": grid_y,
550
+ "grid_z": grid_z,
551
+ "x_block": launcher.config.kwargs.get("XBLOCK", 1),
552
+ "y_block": launcher.config.kwargs.get("YBLOCK", None),
553
+ "z_block": launcher.config.kwargs.get("ZBLOCK", None),
554
+ "num_warps": launcher.bin.num_warps
555
+ if hasattr(launcher.bin, "num_warps")
556
+ else launcher.bin.metadata.num_warps,
557
+ "shared_mem": launcher.bin.shared
558
+ if hasattr(launcher.bin, "shared")
559
+ else launcher.bin.metadata.shared,
560
+ "stream": stream,
561
+ # User defined triton kernels will have arbitrary kwarg names
562
+ "meta": launcher.config.kwargs,
563
+ }
564
+
565
+ if torch.version.hip is None:
566
+ CudaKernelParamCache.set(key, params, launcher.bin.asm["cubin"])
567
+ else:
568
+ # There is some divergence between CUDA and ROCm here.
569
+ # On ROCm's triton we only have the the path to the binary, not the binary itself.
570
+ # For ROCm we will copy the binary to the new location instead of writing to file
571
+ import pathlib
572
+
573
+ launcher.bin.asm["hsaco"] = pathlib.Path(
574
+ launcher.bin.asm["hsaco_path"]
575
+ ).read_bytes()
576
+ CudaKernelParamCache.set(key, params, launcher.bin.asm["hsaco"])
577
+
578
+ self.cuda_kernel_saved = True
579
+
580
+ def coordinate_descent_tuning(self, launcher, *args, **kwargs):
581
+ """
582
+ Coordinate descent tuning can be run with or without max-autotune.
583
+
584
+ The only difference between these two is the starting config for coordinate_descent tuning.
585
+ E.g., assuming regular autotune only get one config C1; while max-autotune get 4 configs C1, C2, C3, C4
586
+ and max-autotune figure out C3 is the best.
587
+
588
+ Then if coordinate descnt tuning is run with max-autotune disabled, it will start from C1;
589
+ while if coordinate descent tuning is run with max-autotune enabled, it will start from C3.
590
+ """
591
+ if (
592
+ self.heuristic_type == HeuristicType.TEMPLATE
593
+ or self.heuristic_type == HeuristicType.USER_AUTOTUNE
594
+ ):
595
+ # skip triton template
596
+ return launcher
597
+
598
+ cloned_args, _ = self.clone_args(*args)
599
+ config2launcher = {launcher.config: launcher}
600
+
601
+ def benchmark_one_config(config):
602
+ with self.lock:
603
+ _, launcher = self._precompile_config(config, None)
604
+ config2launcher[config] = launcher
605
+
606
+ out = self.bench(launcher, *cloned_args, **kwargs)
607
+ log.debug(
608
+ "COORDESC: %s: %f, nreg %d, nspill %d, #shared-mem %d",
609
+ launcher.config,
610
+ out,
611
+ launcher.n_regs,
612
+ launcher.n_spills,
613
+ launcher.shared,
614
+ )
615
+ return out
616
+
617
+ assert not (
618
+ self.heuristic_type == HeuristicType.PERSISTENT_REDUCTION
619
+ and "RBLOCK" in launcher.config.kwargs
620
+ ), "Coordinate descent tuner relies on the assumption that persistent reduction's triton config does not have RBLOCK"
621
+ best_config = self.coordesc_tuner.autotune(
622
+ benchmark_one_config, launcher.config, None
623
+ )
624
+ best_config.found_by_coordesc = True
625
+
626
+ if self.save_cache_hook:
627
+ self.save_cache_hook(best_config, found_by_coordesc=True)
628
+ return config2launcher.get(best_config)
629
+
630
+ def run(self, *args, grid, stream, **kwargs):
631
+ if len(self.launchers) != 1:
632
+ if len(self.launchers) == 0:
633
+ self.precompile()
634
+ if len(self.launchers) > 1:
635
+ self.autotune_to_one_config(*args, grid=grid, **kwargs)
636
+
637
+ if (
638
+ not getattr(self.launchers[0].config, "found_by_coordesc", False)
639
+ and config.coordinate_descent_tuning
640
+ ):
641
+ self.launchers = [
642
+ self.coordinate_descent_tuning(
643
+ self.launchers[0], *args, grid=grid, **kwargs
644
+ )
645
+ ]
646
+
647
+ (launcher,) = self.launchers
648
+ if launcher.store_cubin:
649
+ self.save_cuda_kernel(grid, stream, launcher)
650
+
651
+ if launcher.config.pre_hook is not None:
652
+ launcher.config.pre_hook(
653
+ {**dict(zip(self.arg_names, args)), **launcher.config.kwargs, **kwargs}
654
+ )
655
+
656
+ # guard the record_function_ctx and only call it if profiling is currently
657
+ # in progress, to reduce latency when profiler is not turned on. Note that
658
+ # the "if" statement (instead of, say, a contextlib.nullcontext) is intentional;
659
+ # it is faster than entering and exiting a context manager, even if the context
660
+ # manager is a nullcontext.
661
+ if autograd_profiler._is_profiler_enabled:
662
+ with self.record_function_ctx:
663
+ return launcher(
664
+ *args,
665
+ **kwargs,
666
+ grid=grid,
667
+ stream=stream,
668
+ )
669
+ else:
670
+ return launcher(
671
+ *args,
672
+ **kwargs,
673
+ grid=grid,
674
+ stream=stream,
675
+ )
676
+
677
+
678
+ def _find_names(obj):
679
+ import gc
680
+ import inspect
681
+
682
+ frame = inspect.currentframe()
683
+ while frame is not None:
684
+ frame.f_locals
685
+ frame = frame.f_back
686
+ obj_names = []
687
+ for referrer in gc.get_referrers(obj):
688
+ if isinstance(referrer, dict):
689
+ for k, v in referrer.items():
690
+ if v is obj:
691
+ obj_names.append(k)
692
+ return obj_names
693
+
694
+
695
+ collected_calls: List[Any] = []
696
+
697
+
698
+ def start_graph():
699
+ collected_calls.clear()
700
+
701
+
702
+ def end_graph():
703
+ if len(collected_calls) == 0:
704
+ return
705
+ overall_time = sum(call[0] for call in collected_calls)
706
+ overall_gb = sum(call[1] for call in collected_calls)
707
+ cur_file = inspect.stack()[1].filename
708
+ summary_str = (
709
+ f"SUMMARY ({cur_file})\n"
710
+ f"{overall_time:.2f}ms \t {overall_gb:.2f} GB\t {overall_gb/(overall_time/1e3):.2f}GB/s"
711
+ )
712
+ print(summary_str)
713
+ print()
714
+ output_file = config.profile_bandwidth_output
715
+ if output_file is not None:
716
+ # sort perf numbers in descending order, i.e. placing the
717
+ # most runtime-heavy kernels at the top of the list
718
+ sorted_calls = sorted(collected_calls, key=lambda c: float(c[0]), reverse=True)
719
+ try:
720
+ with open(output_file, "a") as file:
721
+ log.debug("Save profile bandwidth results to %s", output_file)
722
+ file.write("====================\n")
723
+ file.write(f"TRITON KERNELS BANDWIDTH INFO ({cur_file})\n")
724
+ for ms, num_gb, gb_per_s, kernel_name in sorted_calls:
725
+ # also display the runtime percentage for each kernel
726
+ percentage = f"{ms/overall_time*100:.2f}%"
727
+ suffix = f" \t {percentage} \t {kernel_name}"
728
+ bw_info_str = create_bandwidth_info_str(
729
+ ms,
730
+ num_gb,
731
+ gb_per_s,
732
+ suffix=suffix,
733
+ color=False,
734
+ )
735
+ file.write(bw_info_str + "\n")
736
+ file.write(f"{summary_str}\n\n")
737
+ except Exception as e:
738
+ log.warning(
739
+ "failed to write profile bandwidth result into %s: %s",
740
+ output_file,
741
+ e,
742
+ )
743
+
744
+
745
+ class DebugAutotuner(CachingAutotuner):
746
+ def __init__(self, *args, regex_filter="", **kwargs):
747
+ self.regex_filter = regex_filter
748
+ super().__init__(*args, **kwargs)
749
+ self.cached = None
750
+
751
+ def run(self, *args, grid, stream):
752
+ possible_names = _find_names(self)
753
+ kernel_name = f"{max(possible_names, key=len)}"
754
+ if not re.match(self.regex_filter, kernel_name):
755
+ return
756
+ super().run(*args, grid=grid, stream=stream)
757
+ (launcher,) = self.launchers
758
+
759
+ if self.cached is None:
760
+ ms = self.bench(launcher, *args, grid=grid)
761
+ num_in_out_ptrs = len(
762
+ [
763
+ arg_name
764
+ for arg_name in self.fn.arg_names
765
+ if arg_name.startswith("in_out_ptr")
766
+ ]
767
+ )
768
+ num_gb = self.inductor_meta.get("kernel_num_gb", None)
769
+ if num_gb is None:
770
+ num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9
771
+ gb_per_s = num_gb / (ms / 1e3)
772
+ self.cached = (ms, num_gb, gb_per_s, kernel_name)
773
+ else:
774
+ ms, num_gb, gb_per_s, kernel_name = self.cached
775
+ collected_calls.append((ms, num_gb, gb_per_s, kernel_name))
776
+ print(
777
+ create_bandwidth_info_str(ms, num_gb, gb_per_s, suffix=f" \t {kernel_name}")
778
+ )
779
+
780
+
781
+ def hash_configs(configs: List[Config]):
782
+ """
783
+ Hash used to check for changes in configurations
784
+ """
785
+ hasher = hashlib.sha256()
786
+ for cfg in configs:
787
+ hasher.update(
788
+ f"{sorted(cfg.kwargs.items())} {cfg.num_warps} {cfg.num_stages}\n".encode()
789
+ )
790
+ return hasher.hexdigest()
791
+
792
+
793
+ def load_cached_autotuning(
794
+ best_config,
795
+ configs_hash: str,
796
+ configs: List[Config],
797
+ ):
798
+ if best_config is None:
799
+ return None
800
+ if best_config.pop("configs_hash", None) != configs_hash:
801
+ return None
802
+
803
+ if config.coordinate_descent_tuning and best_config.pop("found_by_coordesc", False):
804
+ num_warps = best_config.pop("num_warps")
805
+ num_stages = best_config.pop("num_stages")
806
+ triton_config = Config(best_config, num_warps=num_warps, num_stages=num_stages)
807
+ triton_config.found_by_coordesc = True
808
+ return triton_config
809
+
810
+ matching_configs = [
811
+ cfg
812
+ for cfg in configs
813
+ if all(val == best_config.get(key) for key, val in cfg.kwargs.items())
814
+ and cfg.num_warps == best_config.get("num_warps")
815
+ and cfg.num_stages == best_config.get("num_stages")
816
+ ]
817
+ if len(matching_configs) != 1:
818
+ return None
819
+
820
+ return matching_configs[0]
821
+
822
+
823
+ def cached_autotune(
824
+ size_hints: Optional[List[int]],
825
+ configs: List[Config],
826
+ triton_meta,
827
+ heuristic_type,
828
+ filename=None,
829
+ inductor_meta=None,
830
+ custom_kernel=False,
831
+ ):
832
+ """
833
+ A copy of triton.autotune that calls our subclass. Our subclass
834
+ has additional debugging, error handling, and on-disk caching.
835
+ """
836
+ configs = unique_configs(configs)
837
+ assert len(configs) == 1 or filename
838
+ save_cache_hook: Optional[Callable[[Any, Any], Any]]
839
+ inductor_meta = {} if inductor_meta is None else inductor_meta
840
+
841
+ # on disk caching logic and/or remote caching
842
+ if filename is not None and (len(configs) > 1 or config.coordinate_descent_tuning):
843
+ configs_hash = hash_configs(configs)
844
+
845
+ cache_filename = None
846
+ remote_cache = None
847
+ remote_cache_key = None
848
+ if config.use_autotune_local_cache:
849
+ cache_filename = os.path.splitext(filename)[0] + ".best_config"
850
+ if config.use_autotune_remote_cache or (
851
+ config.is_fbcode()
852
+ and torch._utils_internal.justknobs_check(
853
+ "pytorch/autotune_remote_cache:enable"
854
+ )
855
+ ):
856
+ backend_hash = inductor_meta.get("backend_hash", None)
857
+ if backend_hash is not None:
858
+ key = backend_hash + configs_hash + "autotune-best-config"
859
+ key = hashlib.sha256(key.encode("utf-8")).hexdigest()
860
+
861
+ try:
862
+ if config.is_fbcode():
863
+ remote_cache = (
864
+ triton.runtime.fb_memcache.FbMemcacheRemoteCacheBackend(
865
+ key, is_autotune=True
866
+ )
867
+ )
868
+ else:
869
+ remote_cache = triton.runtime.cache.RedisRemoteCacheBackend(key)
870
+ except Exception:
871
+ remote_cache = None
872
+ log.warning("Unable to create a remote cache", exc_info=True)
873
+ # we already sha256 hash the source contents
874
+ remote_cache_key = os.path.basename(filename)
875
+ else:
876
+ log.debug(
877
+ "backend_hash is not passed on the inductor_meta, unable to use autotune remote cache"
878
+ )
879
+
880
+ best_config = None
881
+ if cache_filename is not None and os.path.exists(cache_filename):
882
+ with open(cache_filename) as fd:
883
+ best_config = json.loads(fd.read())
884
+ elif remote_cache is not None and remote_cache_key is not None:
885
+ cache_outs = remote_cache.get([remote_cache_key])
886
+ cache_out = cache_outs.get(remote_cache_key, None)
887
+ best_config = json.loads(cache_out) if cache_out else None
888
+
889
+ best_config = load_cached_autotuning(best_config, configs_hash, configs)
890
+ if best_config:
891
+ configs = [best_config]
892
+
893
+ def save_cache_hook(cfg, found_by_coordesc=False):
894
+ data = json.dumps(
895
+ {
896
+ **cfg.kwargs,
897
+ "num_warps": cfg.num_warps,
898
+ "num_stages": cfg.num_stages,
899
+ "configs_hash": configs_hash,
900
+ "found_by_coordesc": found_by_coordesc,
901
+ }
902
+ )
903
+ if cache_filename is not None:
904
+ with open(cache_filename, "w") as fd:
905
+ fd.write(data)
906
+ if remote_cache is not None and remote_cache_key is not None:
907
+ remote_cache.put(remote_cache_key, data)
908
+
909
+ if log.isEnabledFor(logging.DEBUG):
910
+ type_str = "coordesc" if found_by_coordesc else "heuristic"
911
+ log.debug("Save %s tuning result to %s", type_str, cache_filename)
912
+
913
+ else:
914
+ save_cache_hook = None
915
+
916
+ mutated_arg_names = inductor_meta.pop("mutated_arg_names", ())
917
+
918
+ def decorator(fn):
919
+ # Remove XBLOCK from config if it's not a function argument.
920
+ # This way, coordinate descent tuning will not try to tune it.
921
+ #
922
+ # Context: When TritonKernel.no_x_dim is True, we hardcode XBLOCK to 1.
923
+ import inspect
924
+
925
+ if "XBLOCK" not in inspect.signature(fn.fn).parameters:
926
+ for tconfig in configs:
927
+ if "XBLOCK" in tconfig.kwargs:
928
+ assert tconfig.kwargs["XBLOCK"] == 1
929
+ tconfig.kwargs.pop("XBLOCK")
930
+
931
+ if config.profile_bandwidth:
932
+ return DebugAutotuner(
933
+ fn,
934
+ triton_meta=triton_meta,
935
+ inductor_meta=inductor_meta,
936
+ regex_filter=config.profile_bandwidth_regex,
937
+ configs=configs,
938
+ save_cache_hook=save_cache_hook,
939
+ mutated_arg_names=mutated_arg_names,
940
+ heuristic_type=heuristic_type,
941
+ size_hints=size_hints,
942
+ custom_kernel=custom_kernel,
943
+ )
944
+ return CachingAutotuner(
945
+ fn,
946
+ triton_meta=triton_meta,
947
+ inductor_meta=inductor_meta,
948
+ configs=configs,
949
+ save_cache_hook=save_cache_hook,
950
+ mutated_arg_names=mutated_arg_names,
951
+ heuristic_type=heuristic_type,
952
+ size_hints=size_hints,
953
+ custom_kernel=custom_kernel,
954
+ )
955
+
956
+ return decorator
957
+
958
+
959
+ def unique_configs(configs: List[Config]):
960
+ """Remove duplicate configurations"""
961
+ seen = set()
962
+ pruned_configs = []
963
+
964
+ for cfg in configs:
965
+ key = triton_config_to_hashable(cfg)
966
+ if key not in seen:
967
+ seen.add(key)
968
+ pruned_configs.append(cfg)
969
+ return pruned_configs
970
+
971
+
972
+ def check_config(cfg, *, xnumel=None, ynumel=None, znumel=None):
973
+ for numel, label in zip((xnumel, ynumel, znumel), "XYZ"):
974
+ if numel is None:
975
+ continue
976
+ block = cfg[f"{label}BLOCK"]
977
+ if numel == 1:
978
+ assert block == 1, (
979
+ f"TritonKernel.indexing assumes numel == 1 => BLOCK == 1"
980
+ f" but {label.lower()}numel=={numel} and {label}BLOCK={block} (cfg={cfg})."
981
+ )
982
+ max_block = config.triton.max_block[label]
983
+ max_block_str = f'config.triton.max_block["{label}"]'
984
+ assert max_block % block == 0, (
985
+ f"TritonKernel.indexing assumes {label}BLOCK divides {max_block_str}"
986
+ f" but {label}BLOCK={block} and {max_block_str}={max_block} (cfg={cfg})."
987
+ )
988
+
989
+
990
+ def triton_config(
991
+ size_hints,
992
+ x,
993
+ y=None,
994
+ z=None,
995
+ num_stages=1,
996
+ num_elements_per_warp=256,
997
+ min_elem_per_thread=0,
998
+ ) -> Config:
999
+ """
1000
+ Construct a pointwise triton config with some adjustment heuristics
1001
+ based on size_hints. Size_hints is a tuple of numels in each tile
1002
+ dimension and will be rounded up to the nearest power of 2.
1003
+
1004
+ num_elements_per_warp is a suggestion for controlling how many warps
1005
+ the triton config should contain. e.g.: if x=16, y=8, z=4 then
1006
+ num_elements = 16*8*4 = 512. Then if we set num_elements_per_warp=128,
1007
+ we'll launch 512 (elem) / 128 (elem/warp) = 4 warps. Note that it's
1008
+ just a suggestion, and sometimes other adjustment heuristics will
1009
+ override the num_elements_per_warp.
1010
+
1011
+ min_elem_per_thread controls the minimum number of elements
1012
+ processed by each thread. It's always enforced.
1013
+ """
1014
+ # Ideally we want to read this from some device config
1015
+
1016
+ # for a 2d size_hints [a, b], a should be mapped to YBLOCK rather than XBLOCK
1017
+ size_hints = list(reversed(size_hints))
1018
+
1019
+ maxGridSize = [2147483647, 65535, 65535]
1020
+
1021
+ target = conditional_product(x, y, z)
1022
+ if conditional_product(*size_hints) < target:
1023
+ target //= 8
1024
+
1025
+ # shrink sizes to size hints
1026
+ x = min(x, size_hints[0])
1027
+ if y:
1028
+ y = min(y, size_hints[1])
1029
+ if z:
1030
+ z = min(z, size_hints[2])
1031
+
1032
+ # if we are below original block size, scale up where we can;
1033
+ # or if the calculated grid size is larger than the limit, we bump up the corresponding dimension
1034
+ while x < min(size_hints[0], config.triton.max_block["X"]) and (
1035
+ x * maxGridSize[0] < size_hints[0] or conditional_product(x, y, z) < target
1036
+ ):
1037
+ x *= 2
1038
+ while (
1039
+ y
1040
+ and y < min(size_hints[1], config.triton.max_block["Y"])
1041
+ and (
1042
+ y * maxGridSize[1] < size_hints[1] or conditional_product(x, y, z) < target
1043
+ )
1044
+ ):
1045
+ y *= 2
1046
+ while (
1047
+ z
1048
+ and z < min(size_hints[2], config.triton.max_block["Z"])
1049
+ and (
1050
+ z * maxGridSize[2] < size_hints[2] or conditional_product(x, y, z) < target
1051
+ )
1052
+ ):
1053
+ z *= 2
1054
+
1055
+ num_warps = next_power_of_2(
1056
+ min(max(conditional_product(x, y, z) // num_elements_per_warp, 1), 8)
1057
+ )
1058
+ # we are going to arrive at 2 warps only if bs was too small due to
1059
+ # numel being too small. However to workaround some ptx bugs we still
1060
+ # want at least 4 warps if there's enough elements per thread
1061
+ # given that this is a rare situation, don't expect this to affect perf
1062
+ # in general
1063
+ # see https://github.com/pytorch/pytorch/pull/97950
1064
+ num_warps = max(num_warps, 4) if conditional_product(x, y, z) >= 128 else num_warps
1065
+ xnumel = size_hints[0]
1066
+ ynumel = size_hints[1] if y else None
1067
+ znumel = size_hints[2] if z else None
1068
+
1069
+ # Increase x to satisfy min_elem_per_thread requirements.
1070
+ block_size = max(
1071
+ conditional_product(x, y, z),
1072
+ min_elem_per_thread * _NUM_THREADS_PER_WARP * num_warps,
1073
+ )
1074
+ x *= math.ceil(block_size / conditional_product(x, y, z))
1075
+
1076
+ cfg = {"XBLOCK": x}
1077
+ if y:
1078
+ cfg["YBLOCK"] = y
1079
+ if z:
1080
+ cfg["ZBLOCK"] = z
1081
+ check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel)
1082
+ return Config(cfg, num_warps=num_warps, num_stages=num_stages)
1083
+
1084
+
1085
+ def triton_config_reduction(size_hints, x, r, num_stages=1, num_warps=None) -> Config:
1086
+ """
1087
+ Construct a reduction triton config with some adjustment heuristics
1088
+ based on size_hints. Size_hints is a tuple of numels in each tile
1089
+ dimension and will be rounded up to the nearest power of 2.
1090
+ """
1091
+
1092
+ target = conditional_product(x, r)
1093
+ if conditional_product(*size_hints) < target:
1094
+ target //= 8
1095
+
1096
+ # shrink sizes to size hints
1097
+ x = min(x, size_hints[0])
1098
+ r = min(r, size_hints[1])
1099
+
1100
+ # if we are below original block size, scale up where we can
1101
+ while x < size_hints[0] and conditional_product(x, r) < target:
1102
+ x *= 2
1103
+ while r < size_hints[1] and conditional_product(x, r) < target:
1104
+ r *= 2
1105
+
1106
+ cfg = {"XBLOCK": x, "RBLOCK": r}
1107
+ if num_warps is None:
1108
+ num_warps = conditional_product(x, r) // 128
1109
+ num_warps = next_power_of_2(min(max(num_warps, 2), 8))
1110
+ check_config(cfg, xnumel=size_hints[0])
1111
+ assert (
1112
+ r <= config.triton.max_block["R"]
1113
+ ), f"increase config.triton.MAX_BLOCK['r'] to {r}"
1114
+ return Config(cfg, num_warps=num_warps, num_stages=num_stages)
1115
+
1116
+
1117
+ def triton_config_tiled_reduction(size_hints, x, y, r, num_stages=1):
1118
+ """
1119
+ Construct a tile reduction triton config with some adjustment
1120
+ heuristics based on size_hints. Size_hints is a tuple of numels in
1121
+ each tile dimension and will be rounded up to the nearest power of 2.
1122
+ """
1123
+
1124
+ target = conditional_product(x, y, r)
1125
+ if conditional_product(*size_hints) < target:
1126
+ target //= 8
1127
+
1128
+ # shrink sizes to size hints
1129
+ x = min(x, size_hints[0])
1130
+ y = min(y, size_hints[1])
1131
+ r = min(r, size_hints[2])
1132
+
1133
+ # if we are below original block size, scale up where we can
1134
+ while x < size_hints[0] and conditional_product(x, y, r) < target:
1135
+ x *= 2
1136
+ while r < size_hints[2] and conditional_product(x, y, r) < target:
1137
+ r *= 2
1138
+ while y < size_hints[1] and conditional_product(x, y, r) < target:
1139
+ y *= 2
1140
+
1141
+ cfg = {"XBLOCK": x, "YBLOCK": y, "RBLOCK": r}
1142
+ num_warps = next_power_of_2(min(max(conditional_product(x, y, r) // 256, 1), 8))
1143
+ check_config(cfg, xnumel=size_hints[0], ynumel=size_hints[1])
1144
+ assert (
1145
+ r <= config.triton.max_block["R"]
1146
+ ), f"increase config.triton.MAX_BLOCK['r'] to {r}"
1147
+ return Config(cfg, num_warps=num_warps, num_stages=num_stages)
1148
+
1149
+
1150
+ def pointwise(
1151
+ size_hints,
1152
+ triton_meta,
1153
+ tile_hint=None,
1154
+ filename=None,
1155
+ min_elem_per_thread=0,
1156
+ inductor_meta=None,
1157
+ ):
1158
+ """
1159
+ Construct @triton.heuristics() based on size_hints.
1160
+ """
1161
+ inductor_meta = {} if inductor_meta is None else inductor_meta
1162
+ assert not inductor_meta.get("no_x_dim")
1163
+
1164
+ numel = functools.reduce(operator.mul, size_hints)
1165
+ bs = max(256, min(numel // 128, 1024))
1166
+
1167
+ hinted_configs = autotune_hints_to_configs(
1168
+ inductor_meta.get("autotune_hints", set()), size_hints, bs
1169
+ )
1170
+
1171
+ triton_config_with_settings = functools.partial(
1172
+ triton_config, min_elem_per_thread=min_elem_per_thread
1173
+ )
1174
+
1175
+ if len(size_hints) == 1:
1176
+ if disable_pointwise_autotuning() and not (
1177
+ config.max_autotune or config.max_autotune_pointwise
1178
+ ):
1179
+ return cached_autotune(
1180
+ size_hints,
1181
+ [triton_config_with_settings(size_hints, bs)],
1182
+ triton_meta=triton_meta,
1183
+ inductor_meta=inductor_meta,
1184
+ heuristic_type=HeuristicType.POINTWISE,
1185
+ filename=filename,
1186
+ )
1187
+ else:
1188
+ return cached_autotune(
1189
+ size_hints,
1190
+ [
1191
+ triton_config_with_settings(
1192
+ size_hints, bs, num_elements_per_warp=256
1193
+ ),
1194
+ triton_config_with_settings(
1195
+ size_hints, bs // 2, num_elements_per_warp=64
1196
+ ),
1197
+ *hinted_configs,
1198
+ ],
1199
+ triton_meta=triton_meta,
1200
+ inductor_meta=inductor_meta,
1201
+ heuristic_type=HeuristicType.POINTWISE,
1202
+ filename=filename,
1203
+ )
1204
+ if len(size_hints) == 2:
1205
+ if (disable_pointwise_autotuning() or tile_hint == TileHint.SQUARE) and not (
1206
+ config.max_autotune or config.max_autotune_pointwise
1207
+ ):
1208
+ return cached_autotune(
1209
+ size_hints,
1210
+ [triton_config_with_settings(size_hints, 32, 32)],
1211
+ triton_meta=triton_meta,
1212
+ inductor_meta=inductor_meta,
1213
+ heuristic_type=HeuristicType.POINTWISE,
1214
+ filename=filename,
1215
+ )
1216
+ return cached_autotune(
1217
+ size_hints,
1218
+ [
1219
+ triton_config_with_settings(size_hints, 32, 32),
1220
+ triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16
1221
+ triton_config_with_settings(size_hints, 256, 16),
1222
+ triton_config_with_settings(size_hints, 16, 256),
1223
+ triton_config_with_settings(size_hints, bs, 1),
1224
+ triton_config_with_settings(size_hints, 1, bs),
1225
+ *hinted_configs,
1226
+ ],
1227
+ triton_meta=triton_meta,
1228
+ inductor_meta=inductor_meta,
1229
+ filename=filename,
1230
+ heuristic_type=HeuristicType.POINTWISE,
1231
+ )
1232
+ if len(size_hints) == 3:
1233
+ if disable_pointwise_autotuning():
1234
+ return cached_autotune(
1235
+ size_hints,
1236
+ [triton_config_with_settings(size_hints, 16, 16, 16)],
1237
+ triton_meta=triton_meta,
1238
+ inductor_meta=inductor_meta,
1239
+ heuristic_type=HeuristicType.POINTWISE,
1240
+ filename=filename,
1241
+ )
1242
+ return cached_autotune(
1243
+ size_hints,
1244
+ [
1245
+ triton_config_with_settings(size_hints, 16, 16, 16),
1246
+ triton_config_with_settings(size_hints, 64, 8, 8),
1247
+ triton_config_with_settings(size_hints, 8, 64, 8),
1248
+ triton_config_with_settings(size_hints, 8, 8, 64),
1249
+ triton_config_with_settings(size_hints, bs, 1, 1),
1250
+ triton_config_with_settings(size_hints, 1, bs, 1),
1251
+ triton_config_with_settings(size_hints, 1, 1, bs),
1252
+ *hinted_configs,
1253
+ ],
1254
+ triton_meta=triton_meta,
1255
+ inductor_meta=inductor_meta,
1256
+ filename=filename,
1257
+ heuristic_type=HeuristicType.POINTWISE,
1258
+ )
1259
+ raise NotImplementedError(f"size_hints: {size_hints}")
1260
+
1261
+
1262
+ def _reduction_configs(
1263
+ *, size_hints: List[int], inductor_meta: Dict[str, Any]
1264
+ ) -> List[Config]:
1265
+ reduction_hint = inductor_meta.get("reduction_hint", None)
1266
+ assert len(size_hints) == 2
1267
+ rnumel = size_hints[-1]
1268
+
1269
+ contiguous_config = triton_config_reduction(
1270
+ size_hints, 1, (rnumel if 256 <= rnumel < 2048 else 2048)
1271
+ )
1272
+ outer_config = triton_config_reduction(size_hints, 64, 8)
1273
+ tiny_config = triton_config_reduction(
1274
+ size_hints, 2 * (256 // rnumel) if rnumel <= 256 else 1, min(rnumel, 2048)
1275
+ )
1276
+ if config.max_autotune or config.max_autotune_pointwise:
1277
+ pass # skip all these cases
1278
+ elif reduction_hint == ReductionHint.INNER:
1279
+ return [contiguous_config]
1280
+ elif reduction_hint == ReductionHint.OUTER:
1281
+ return [outer_config]
1282
+ elif reduction_hint == ReductionHint.OUTER_TINY:
1283
+ return [tiny_config]
1284
+ if disable_pointwise_autotuning():
1285
+ return [triton_config_reduction(size_hints, 32, 128)]
1286
+ return [
1287
+ contiguous_config,
1288
+ outer_config,
1289
+ tiny_config,
1290
+ triton_config_reduction(size_hints, 64, 64),
1291
+ triton_config_reduction(size_hints, 8, 512),
1292
+ # halve the XBLOCK/RBLOCK compared to outer_config
1293
+ # TODO: this may only be beneficial when each iteration of the reduction
1294
+ # is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72
1295
+ triton_config_reduction(size_hints, 64, 4, num_warps=8),
1296
+ ]
1297
+
1298
+
1299
+ def reduction(
1300
+ size_hints,
1301
+ reduction_hint=False,
1302
+ triton_meta=None,
1303
+ filename=None,
1304
+ inductor_meta=None,
1305
+ ):
1306
+ """args to @triton.heuristics()"""
1307
+ inductor_meta = {} if inductor_meta is None else inductor_meta
1308
+ inductor_meta["reduction_hint"] = reduction_hint
1309
+ if inductor_meta.get("no_x_dim"):
1310
+ size_hints = [1, *size_hints[1:]]
1311
+
1312
+ assert triton_meta is not None
1313
+ rnumel = size_hints[-1]
1314
+ if len(size_hints) != 2:
1315
+ raise NotImplementedError(f"size_hints: {size_hints}")
1316
+
1317
+ configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta)
1318
+ return cached_autotune(
1319
+ size_hints,
1320
+ configs=configs,
1321
+ triton_meta=triton_meta,
1322
+ inductor_meta=inductor_meta,
1323
+ heuristic_type=HeuristicType.REDUCTION,
1324
+ filename=filename,
1325
+ )
1326
+
1327
+
1328
+ def persistent_reduction(
1329
+ size_hints,
1330
+ reduction_hint=False,
1331
+ triton_meta=None,
1332
+ filename=None,
1333
+ inductor_meta=None,
1334
+ ):
1335
+ inductor_meta = {} if inductor_meta is None else inductor_meta
1336
+ inductor_meta["reduction_hint"] = reduction_hint
1337
+ if inductor_meta.get("no_x_dim"):
1338
+ size_hints = [1, *size_hints[1:]]
1339
+
1340
+ xnumel, rnumel = size_hints
1341
+
1342
+ configs = [
1343
+ triton_config_reduction(size_hints, xblock, rnumel)
1344
+ for xblock in (1, 8, 32, 128)
1345
+ if xblock == 1 or (rnumel * xblock <= 4096 and xblock <= xnumel)
1346
+ ]
1347
+
1348
+ # TODO(jansel): we should be able to improve these heuristics
1349
+ if reduction_hint == ReductionHint.INNER and rnumel >= 256:
1350
+ configs = configs[:1]
1351
+ elif reduction_hint == ReductionHint.OUTER:
1352
+ configs = configs[-1:]
1353
+ elif reduction_hint == ReductionHint.OUTER_TINY:
1354
+ configs = [
1355
+ triton_config_reduction(
1356
+ size_hints, 2 * (256 // rnumel) if rnumel <= 256 else 1, rnumel
1357
+ )
1358
+ ]
1359
+ for c in configs:
1360
+ # we don't need RBLOCK for persistent reduction
1361
+ c.kwargs.pop("RBLOCK")
1362
+
1363
+ if disable_pointwise_autotuning():
1364
+ configs = configs[:1]
1365
+
1366
+ return cached_autotune(
1367
+ size_hints,
1368
+ configs,
1369
+ triton_meta=triton_meta,
1370
+ inductor_meta=inductor_meta,
1371
+ filename=filename,
1372
+ heuristic_type=HeuristicType.PERSISTENT_REDUCTION,
1373
+ )
1374
+
1375
+
1376
+ def split_scan(
1377
+ size_hints,
1378
+ reduction_hint=False,
1379
+ triton_meta=None,
1380
+ filename=None,
1381
+ inductor_meta=None,
1382
+ ):
1383
+ """Heuristic for TritonSplitScanKernel"""
1384
+ inductor_meta = {} if inductor_meta is None else inductor_meta
1385
+ inductor_meta["reduction_hint"] = reduction_hint
1386
+ if inductor_meta.get("no_x_dim"):
1387
+ size_hints = [1, *size_hints[1:]]
1388
+
1389
+ assert triton_meta is not None
1390
+ rnumel = size_hints[-1]
1391
+ if len(size_hints) != 2:
1392
+ raise NotImplementedError(f"size_hints: {size_hints}")
1393
+
1394
+ configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta)
1395
+
1396
+ # Fixup configs to enforce the minimum RBLOCK size
1397
+ min_rblock = config.triton.min_split_scan_rblock
1398
+ for cfg in configs:
1399
+ if cfg.kwargs["RBLOCK"] < min_rblock:
1400
+ cfg.kwargs["RBLOCK"] = min_rblock
1401
+
1402
+ return cached_autotune(
1403
+ size_hints,
1404
+ configs=configs,
1405
+ triton_meta=triton_meta,
1406
+ inductor_meta=inductor_meta,
1407
+ heuristic_type=HeuristicType.SPLIT_SCAN,
1408
+ filename=filename,
1409
+ )
1410
+
1411
+
1412
+ def template(num_stages, num_warps, triton_meta, filename=None, inductor_meta=None):
1413
+ """
1414
+ Compile a triton template
1415
+ """
1416
+ return cached_autotune(
1417
+ None,
1418
+ [triton.Config({}, num_stages=num_stages, num_warps=num_warps)],
1419
+ triton_meta=triton_meta,
1420
+ inductor_meta=inductor_meta,
1421
+ heuristic_type=HeuristicType.TEMPLATE,
1422
+ filename=filename,
1423
+ )
1424
+
1425
+
1426
+ def user_autotune(
1427
+ configs, triton_meta, filename=None, inductor_meta=None, custom_kernel=False
1428
+ ):
1429
+ """
1430
+ Compile a user defined triton kernel
1431
+ """
1432
+ defaults = inspect.signature(triton.Config).parameters
1433
+ default_num_stages = defaults["num_stages"].default
1434
+ default_num_warps = defaults["num_warps"].default
1435
+
1436
+ if len(configs) == 0:
1437
+ configs = [
1438
+ triton.Config(
1439
+ {}, num_stages=default_num_stages, num_warps=default_num_warps
1440
+ )
1441
+ ]
1442
+ else:
1443
+ configs = [
1444
+ triton.Config(
1445
+ c.get("kwargs", {}),
1446
+ num_stages=c.get("num_stages", default_num_stages),
1447
+ num_warps=c.get("num_warps", default_num_warps),
1448
+ )
1449
+ for c in configs
1450
+ ]
1451
+
1452
+ return cached_autotune(
1453
+ None,
1454
+ configs,
1455
+ triton_meta=triton_meta,
1456
+ heuristic_type=HeuristicType.USER_AUTOTUNE,
1457
+ filename=filename,
1458
+ inductor_meta=inductor_meta,
1459
+ custom_kernel=custom_kernel,
1460
+ )
1461
+
1462
+
1463
+ def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
1464
+ """
1465
+ Compile a triton foreach kernel
1466
+ """
1467
+ return cached_autotune(
1468
+ None,
1469
+ [triton.Config({}, num_stages=1, num_warps=num_warps)],
1470
+ triton_meta=triton_meta,
1471
+ inductor_meta=inductor_meta,
1472
+ heuristic_type=HeuristicType.TEMPLATE,
1473
+ filename=filename,
1474
+ )
1475
+
1476
+
1477
+ def grid(*numels):
1478
+ """Helper function to compute triton grids"""
1479
+ if len(numels) == 1:
1480
+ xnumel, ynumel, znumel = numels[0], None, None
1481
+ elif len(numels) == 2:
1482
+ xnumel, ynumel, znumel = numels[1], numels[0], None
1483
+ elif len(numels) == 3:
1484
+ xnumel, ynumel, znumel = numels[2], numels[1], numels[0]
1485
+ else:
1486
+ raise AssertionError(f"invalid size for numels {len(numels)}")
1487
+
1488
+ def get_grid_dim(numel, block):
1489
+ if numel is None:
1490
+ return 1
1491
+ if block is None:
1492
+ return numel
1493
+ return ceildiv(numel, block)
1494
+
1495
+ max_grid_dims = config.triton.max_tiles
1496
+
1497
+ def grid_fn(meta):
1498
+ x_grid = get_grid_dim(xnumel, meta.get("XBLOCK", 1))
1499
+ y_grid = get_grid_dim(ynumel, meta.get("YBLOCK", None))
1500
+
1501
+ MAX_Y_GRID = get_max_y_grid()
1502
+ if znumel is None and max_grid_dims <= 2:
1503
+ div = ceildiv(y_grid, MAX_Y_GRID)
1504
+ y_grid = y_grid // div
1505
+ z_grid = div
1506
+ else:
1507
+ z_grid = get_grid_dim(znumel, meta.get("ZBLOCK", None))
1508
+ torch._check(
1509
+ y_grid <= MAX_Y_GRID,
1510
+ lambda: f"Generated y grid beyond 2^16 ({y_grid}) not supported with z dimension present. File issue",
1511
+ )
1512
+
1513
+ return (
1514
+ x_grid,
1515
+ y_grid,
1516
+ z_grid,
1517
+ )
1518
+
1519
+ return grid_fn
1520
+
1521
+
1522
+ def split_scan_grid(xnumel, rnumel):
1523
+ def grid_fn(meta):
1524
+ assert meta.get("XBLOCK", 1) == 1
1525
+ return (ceildiv(rnumel, meta.get("RBLOCK", 1)), xnumel, 1)
1526
+
1527
+ return grid_fn
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (3.26 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/computation.cpython-311.pyc ADDED
Binary file (1.6 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/extract_compiled_graph.cpython-311.pyc ADDED
Binary file (12 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/ts_backend.cpython-311.pyc ADDED
Binary file (550 Bytes). View file